Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ async def _on_turn(self, context: TurnContext):
await self._on_activity(context, turn_state)

logger.debug("Running after turn middleware")
if not await self._run_after_turn_middleware(context, turn_state):
if await self._run_after_turn_middleware(context, turn_state):
await turn_state.save(context)
return
except ApplicationError as err:
Expand Down Expand Up @@ -798,7 +798,7 @@ async def _run_before_turn_middleware(self, context: TurnContext, state: StateT)
for before_turn in self._internal_before_turn:
is_ok = await before_turn(context, state)
if not is_ok:
await state.save(context, self._options.storage)
await state.save(context)
return False
return True

Expand All @@ -824,7 +824,7 @@ async def _run_after_turn_middleware(self, context: TurnContext, state: StateT):
for after_turn in self._internal_after_turn:
is_ok = await after_turn(context, state)
if not is_ok:
await state.save(context, self._options.storage)
await state.save(context)
return False
return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ def get_storage_key(
raise ValueError("Invalid activity: missing conversation_id.")

return f"{channel_id}/conversations/{conversation_id}"

def clear(self, turn_context):
return super().clear(turn_context)
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def input_files(self) -> List[InputFile]:
def input_files(self, value: List[InputFile]) -> None:
self.set_value(self.INPUT_FILES_KEY, value)

def clear(self) -> None:
def clear(self, turn_context: TurnContext) -> None:
"""Clears all state values"""
self._state.clear()

Expand Down Expand Up @@ -96,6 +96,10 @@ async def load(self, turn_context: TurnContext, force: bool = False, **_) -> Non
"""Loads the state asynchronously"""
pass

async def save(self, turn_context, force=False):
"""Saves the state asynchronously"""
pass

async def save_changes(
self, turn_context: TurnContext, force: bool = False, **_
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,21 @@ async def load(self, turn_context: TurnContext, force: bool = False) -> None:
]
await asyncio.gather(*tasks)

def clear(self, scope: str) -> None:
def clear(self, turn_context: TurnContext, scope: str = None) -> None:
"""
Clears a state scope.

Args:
scope: The name of the scope to clear.
"""
scope_obj = self.get_scope_by_name(scope)
if hasattr(scope_obj, "clear"):
scope_obj.clear()
if scope:
scope_obj = self.get_scope_by_name(scope)
if hasattr(scope_obj, "clear"):
scope_obj.clear(turn_context)
else:
for scope in self._scopes.values():
if hasattr(scope, "clear"):
scope.clear(turn_context)

async def save(self, turn_context: TurnContext, force: bool = False) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def get_value(
else None
)

if not value and default_value_factory:
if not value and default_value_factory is not None:
# If the value is None and a factory is provided, call the factory to get a default value
return default_value_factory()

Expand Down