Skip to content

Commit 5c78ab6

Browse files
author
久氢
committed
feat(memory_collection): add MemoryConversation class for dialogue history management
Change-Id: I099853b5465234a34db594c9dea1210743174932 Co-developed-by: Cursor <noreply@cursor.com> Signed-off-by: 久氢 <mapenghui.mph@alibaba-inc.com>
1 parent c4bd09e commit 5c78ab6

File tree

4 files changed

+374
-80
lines changed

4 files changed

+374
-80
lines changed

agentrun/memory_collection/memory_conversation.py

Lines changed: 109 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,19 @@
22
33
提供与 TableStore Memory 的集成能力,自动存储用户和 Agent 的对话历史。
44
5-
Example (基本使用):
6-
>>> from agentrun.server import AgentRunServer, AgentRequest
7-
>>> from agentrun.memory_collection import MemoryConversation
8-
>>>
9-
>>> # 初始化 Memory Conversation
10-
>>> memory = MemoryConversation(memory_collection_name="my-memory")
11-
>>>
12-
>>> # 包装 invoke_agent 函数
13-
>>> async def invoke_agent(req: AgentRequest):
14-
... async for event in memory.wrap_invoke_agent(req, my_agent_handler):
15-
... yield event
16-
>>>
17-
>>> server = AgentRunServer(invoke_agent=invoke_agent)
18-
>>> server.start()
195
"""
206

217
import json
228
import os
23-
from typing import Any, AsyncIterator, Callable, Dict, Optional, TYPE_CHECKING
9+
from typing import (
10+
Any,
11+
AsyncIterator,
12+
Callable,
13+
Dict,
14+
List,
15+
Optional,
16+
TYPE_CHECKING,
17+
)
2418
import uuid
2519

2620
import tablestore
@@ -88,13 +82,17 @@ def _default_user_id_extractor(req: Any) -> str:
8882
"""默认的 user_id 提取器
8983
9084
优先级:
91-
1. X-User-ID 请求头
85+
1. X-User-ID 请求头(支持多种格式)
9286
2. user_id 查询参数
9387
3. 默认值 "default_user"
9488
"""
9589
if req.raw_request:
9690
# 从请求头获取
97-
user_id = req.raw_request.headers.get("X-User-ID")
91+
user_id = (
92+
req.raw_request.headers.get("X-AgentRun-User-ID")
93+
or req.raw_request.headers.get("x-agentrun-user-id")
94+
or req.raw_request.headers.get("X-Agentrun-User-Id")
95+
)
9896
if user_id:
9997
return user_id
10098

@@ -110,26 +108,20 @@ def _default_session_id_extractor(req: Any) -> str:
110108
"""默认的 session_id 提取器
111109
112110
优先级:
113-
1. X-Session-ID 请求头
114-
2. sessionId 查询参数
115-
3. 从最后一条消息的 id 生成
116-
4. 生成新的 UUID
111+
1. X-Session-ID 请求头(支持多种格式)
112+
2. 生成新的 UUID
117113
"""
118114
if req.raw_request:
119-
# 从请求头获取
120-
session_id = req.raw_request.headers.get("X-Conversation-ID")
121-
if session_id:
122-
return session_id
123-
124-
# 从查询参数获取
125-
session_id = req.raw_request.query_params.get("sessionId")
115+
# 从请求头获取(兼容多种格式)
116+
# 支持:X-AgentRun-Session-ID, x-agentrun-session-id, X-Agentrun-Session-Id
117+
session_id = (
118+
req.raw_request.headers.get("X-AgentRun-Session-ID")
119+
or req.raw_request.headers.get("x-agentrun-session-id")
120+
or req.raw_request.headers.get("X-Agentrun-Session-Id")
121+
)
126122
if session_id:
127123
return session_id
128124

129-
# 从消息 ID 生成(如果有)
130-
if req.messages and req.messages[-1].id:
131-
return f"session_{req.messages[-1].id}"
132-
133125
# 生成新的 session_id
134126
return f"session_{uuid.uuid4().hex[:16]}"
135127

@@ -138,13 +130,18 @@ def _default_agent_id_extractor(req: Any) -> str:
138130
"""默认的 agent_id 提取器
139131
140132
优先级:
141-
1. X-Agent-ID 请求头
133+
1. X-Agent-ID 请求头(支持多种格式)
142134
2. 从 URL 路径中提取 /agent-runtimes/{agent_id}/... 格式
143135
3. 默认值 "default_agent"
144136
"""
145137
if req.raw_request:
146-
# 从请求头获取
147-
agent_id = req.raw_request.headers.get("X-Agent-ID")
138+
# 从请求头获取(兼容多种格式)
139+
# 支持:X-AgentRun-Agent-ID, x-agentrun-agent-id, X-Agentrun-Agent-Id, X-Agent-ID, x-agent-id
140+
agent_id = (
141+
req.raw_request.headers.get("X-AgentRun-Agent-ID")
142+
or req.raw_request.headers.get("x-agentrun-agent-id")
143+
or req.raw_request.headers.get("X-Agentrun-Agent-Id")
144+
)
148145
if agent_id:
149146
return agent_id
150147

@@ -407,8 +404,12 @@ async def wrap_invoke_agent(
407404
"content": self._extract_message_content(msg.content),
408405
})
409406

410-
# 收集 Agent 响应
407+
# 收集 Agent 响应(包括文本和工具调用)
411408
agent_response_content = ""
409+
tool_calls: Dict[str, Dict[str, Any]] = (
410+
{}
411+
) # tool_call_id -> tool_call_info
412+
tool_results: List[Dict[str, Any]] = [] # 工具执行结果列表
412413

413414
try:
414415
# 流式处理 Agent 响应
@@ -420,17 +421,80 @@ async def wrap_invoke_agent(
420421
if event.event == EventType.TEXT and "delta" in event.data:
421422
agent_response_content += event.data["delta"]
422423

424+
# 收集工具调用信息
425+
elif event.event == EventType.TOOL_CALL:
426+
# 完整的工具调用
427+
tool_id = event.data.get("id", "")
428+
if tool_id:
429+
tool_calls[tool_id] = {
430+
"id": tool_id,
431+
"type": "function",
432+
"function": {
433+
"name": event.data.get("name", ""),
434+
"arguments": event.data.get("args", ""),
435+
},
436+
}
437+
438+
elif event.event == EventType.TOOL_CALL_CHUNK:
439+
# 工具调用片段(流式场景)
440+
tool_id = event.data.get("id", "")
441+
if tool_id:
442+
if tool_id not in tool_calls:
443+
tool_calls[tool_id] = {
444+
"id": tool_id,
445+
"type": "function",
446+
"function": {
447+
"name": event.data.get("name", ""),
448+
"arguments": "",
449+
},
450+
}
451+
# 累积参数片段
452+
if "args_delta" in event.data:
453+
tool_calls[tool_id]["function"][
454+
"arguments"
455+
] += event.data["args_delta"]
456+
457+
# 收集工具执行结果
458+
elif event.event == EventType.TOOL_RESULT:
459+
tool_id = event.data.get("id", "")
460+
if tool_id:
461+
tool_results.append({
462+
"role": "tool",
463+
"tool_call_id": tool_id,
464+
"content": str(event.data.get("result", "")),
465+
})
466+
423467
# 透传事件
424468
yield event
425469

426470
# 保存完整的对话轮次(输入 + 输出)
427-
if agent_response_content:
471+
# 只有当有文本内容或工具调用时才保存
472+
if agent_response_content or tool_calls or tool_results:
428473
try:
429-
# 将助手响应添加到消息列表
430-
output_messages = input_messages + [{
474+
# 构建助手响应消息
475+
assistant_message: Dict[str, Any] = {
431476
"role": "assistant",
432-
"content": agent_response_content,
433-
}]
477+
}
478+
479+
# 添加文本内容(如果有)
480+
if agent_response_content:
481+
assistant_message["content"] = agent_response_content
482+
else:
483+
# OpenAI 格式要求:如果有 tool_calls,content 可以为 null
484+
assistant_message["content"] = None
485+
486+
# 添加工具调用(如果有)
487+
if tool_calls:
488+
assistant_message["tool_calls"] = list(
489+
tool_calls.values()
490+
)
491+
492+
# 构建完整的消息列表
493+
output_messages = input_messages + [assistant_message]
494+
495+
# 添加工具执行结果(如果有)
496+
if tool_results:
497+
output_messages.extend(tool_results)
434498

435499
# 将完整的对话历史存储为一条消息
436500
# content 字段存储 JSON 格式的消息列表
@@ -446,8 +510,10 @@ async def wrap_invoke_agent(
446510
await memory_store.update_session(session)
447511

448512
logger.debug(
449-
f"Saved conversation: {len(output_messages)} messages, "
450-
f"response length: {len(agent_response_content)} chars"
513+
f"Saved conversation: {len(output_messages)} messages,"
514+
f" text length: {len(agent_response_content)} chars,"
515+
f" tool_calls: {len(tool_calls)}, tool_results:"
516+
f" {len(tool_results)}"
451517
)
452518
except Exception as e:
453519
logger.error(

agentrun/server/server.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,22 @@ class AgentRunServer:
8383
... invoke_agent=invoke_agent,
8484
... config=ServerConfig(cors_origins=["http://localhost:3000"])
8585
... )
86+
87+
Example (启用会话历史记录):
88+
>>> server = AgentRunServer(
89+
... invoke_agent=invoke_agent,
90+
... memory_collection_name="my-memory-collection"
91+
... )
92+
>>> server.start(port=8000)
93+
# 会话历史将自动保存到 TableStore
8694
"""
8795

8896
def __init__(
8997
self,
9098
invoke_agent: InvokeAgentHandler,
9199
protocols: Optional[List[ProtocolHandler]] = None,
92100
config: Optional[ServerConfig] = None,
101+
memory_collection_name: Optional[str] = None,
93102
):
94103
"""初始化 AgentRun Server
95104
@@ -107,8 +116,20 @@ def __init__(
107116
- cors_origins: CORS 允许的源列表
108117
- openai: OpenAI 协议配置
109118
- agui: AG-UI 协议配置
119+
120+
memory_collection_name: MemoryCollection 名称(可选)
121+
- 如果提供,将自动启用会话历史记录功能
122+
- 会话历史将保存到指定的 MemoryCollection 中
110123
"""
111124
self.app = FastAPI(title="AgentRun Server")
125+
126+
# 如果启用了 memory,包装 invoke_agent
127+
if memory_collection_name:
128+
invoke_agent = self._wrap_with_memory(
129+
invoke_agent,
130+
memory_collection_name,
131+
)
132+
112133
self.agent_invoker = AgentInvoker(invoke_agent)
113134

114135
# 配置 CORS
@@ -124,6 +145,39 @@ def __init__(
124145
# 挂载所有协议的 Router
125146
self._mount_protocols(protocols)
126147

148+
def _wrap_with_memory(
149+
self,
150+
invoke_agent: InvokeAgentHandler,
151+
memory_collection_name: str,
152+
) -> InvokeAgentHandler:
153+
"""使用 MemoryConversation 包装 invoke_agent
154+
155+
Args:
156+
invoke_agent: 原始的 invoke_agent 函数
157+
memory_collection_name: MemoryCollection 名称
158+
159+
Returns:
160+
包装后的 invoke_agent 函数
161+
"""
162+
from agentrun.memory_collection import MemoryConversation
163+
164+
# 创建 MemoryConversation 实例
165+
memory = MemoryConversation(
166+
memory_collection_name=memory_collection_name,
167+
)
168+
169+
logger.info(
170+
"Memory integration enabled for collection:"
171+
f" {memory_collection_name}"
172+
)
173+
174+
# 包装 invoke_agent
175+
async def wrapped_invoke_agent(request: Any):
176+
async for event in memory.wrap_invoke_agent(request, invoke_agent):
177+
yield event
178+
179+
return wrapped_invoke_agent
180+
127181
def _setup_cors(self, cors_origins: Optional[Sequence[str]] = None):
128182
"""配置 CORS 中间件
129183

examples/server_with_memory.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,56 +32,60 @@
3232
# 配置参数
3333
MODEL_SERVICE = os.getenv("MODEL_SERVICE", "qwen3-max")
3434
MODEL_NAME = os.getenv("MODEL_NAME", "qwen3-max")
35-
SANDBOX_NAME = os.getenv("SANDBOX_NAME", "")
35+
SANDBOX_NAME = os.getenv("SANDBOX_NAME", "sandbox-browser-BmUyyD")
3636
MEMORY_COLLECTION_NAME = os.getenv("MEMORY_COLLECTION_NAME", "mem-ots0129")
3737

3838
# 创建 Agent
3939
agent = create_agent(
4040
# 使用 AgentRun 注册的模型
4141
model=model(MODEL_SERVICE, model=MODEL_NAME),
4242
system_prompt="""
43-
你是一个诗人,根据用户输入内容写一个20字以内的诗文
43+
你是 AgentRun 的 AI 助手,可以通过网络搜索帮助用户解决问题
44+
45+
46+
你的工作流程如下
47+
- 当用户向你提问概念性问题时,不要直接回答,而是先进行网络搜索
48+
- 使用 Browser 工具打开百度搜索。如果要搜索 AgentRun,对应的搜索链接为: `https://www.baidu.com/s?ie=utf-8&wd=agentrun`。为了节省 token 使用,不要使用 `snapshot` 获取完整页面内容,而是通过 `evaluate` 获取你需要的部分
49+
- 获取百度搜索的结果,根据相关性分别打开子页面获取内容
50+
- 如果子页面的相关度较低,则可以直接忽略
51+
- 如果子页面的相关度较高,则将其记录为可参考的资料,记录页面标题和实时的 url
52+
- 当你获得至少 3 条网络信息后,可以结束搜索,并根据搜索到的结果回答用户的问题。
53+
- 如果某一部分回答引用了网络的信息,需要进行标注,并在回答的最后给出跳转链接
4454
""",
4555
# 使用 AgentRun 的 Sandbox 工具
46-
# tools=[*sandbox_toolset(SANDBOX_NAME, template_type=TemplateType.BROWSER)],
56+
tools=[*sandbox_toolset(SANDBOX_NAME, template_type=TemplateType.BROWSER)],
4757
)
4858

4959
# 初始化 Memory Integration
5060
memory = MemoryConversation(memory_collection_name=MEMORY_COLLECTION_NAME)
5161

5262

5363
async def invoke_agent(req: AgentRequest):
54-
"""Agent 调用函数,集成了记忆存储功能"""
64+
"""Agent 调用函数"""
5565
try:
5666
converter = AgentRunConverter()
57-
58-
# 定义原始的 agent 处理函数
59-
async def agent_handler(request: AgentRequest):
60-
result = agent.astream_events(
61-
{
62-
"messages": [
63-
{"role": msg.role, "content": msg.content}
64-
for msg in request.messages
65-
]
66-
},
67-
config={"recursion_limit": 1000},
68-
)
69-
async for event in result:
70-
for agentrun_event in converter.convert(event):
71-
yield agentrun_event
72-
73-
# 使用 MemoryIntegration 包装,自动存储对话历史
74-
async for event in memory.wrap_invoke_agent(req, agent_handler):
75-
yield event
76-
67+
result = agent.astream_events(
68+
{
69+
"messages": [
70+
{"role": msg.role, "content": msg.content}
71+
for msg in req.messages
72+
]
73+
},
74+
config={"recursion_limit": 1000},
75+
)
76+
async for event in result:
77+
for agentrun_event in converter.convert(event):
78+
yield agentrun_event
7779
except Exception as e:
78-
print(f"Error in invoke_agent: {e}")
80+
print(e)
7981
raise Exception("Internal Error")
8082

8183

8284
# 创建并启动 Server
8385
if __name__ == "__main__":
84-
server = AgentRunServer(invoke_agent=invoke_agent)
86+
server = AgentRunServer(
87+
invoke_agent=invoke_agent, memory_collection_name=MEMORY_COLLECTION_NAME
88+
)
8589
print(f"Server starting with memory collection: {MEMORY_COLLECTION_NAME}")
8690
print("Memory will be automatically saved to TableStore")
8791
server.start(port=9000)

0 commit comments

Comments
 (0)