diff --git a/src/agentlab/agents/tool_use_agent/tool_use_agent.py b/src/agentlab/agents/tool_use_agent/tool_use_agent.py index b7494693..6ac61180 100644 --- a/src/agentlab/agents/tool_use_agent/tool_use_agent.py +++ b/src/agentlab/agents/tool_use_agent/tool_use_agent.py @@ -101,8 +101,10 @@ def flatten(self) -> list[MessageBuilder]: messages.extend(group.messages) # Mark all summarized messages for caching if i == len(self.groups) - keep_last_n_obs: - if not isinstance(messages[i], ToolCalls): - messages[i].mark_all_previous_msg_for_caching() + for msg in messages: # unset previous cache breakpoints + msg._cache_breakpoint = False + # set new cache breakpoint + messages[i].mark_all_previous_msg_for_caching() return messages def set_last_summary(self, summary: MessageBuilder):