Skip to content

Commit 0d879a9

Browse files
author
久氢
committed
feat(memory_collection): add MemoryConversation class for dialogue history management
Change-Id: I020c9d634d3db70e1b1fedc1f5c25e4407f63b4e Co-developed-by: Cursor <noreply@cursor.com> Signed-off-by: 久氢 <mapenghui.mph@alibaba-inc.com>
1 parent 5c78ab6 commit 0d879a9

File tree

2 files changed

+58
-36
lines changed

2 files changed

+58
-36
lines changed

agentrun/memory_collection/memory_conversation.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,26 @@ def _default_session_id_extractor(req: Any) -> str:
108108
"""默认的 session_id 提取器
109109
110110
优先级:
111-
1. X-Session-ID 请求头(支持多种格式)
112-
2. 生成新的 UUID
111+
1. X-Conversation-ID 请求头(支持多种格式)
112+
2. sessionId 查询参数
113+
3. 生成新的 UUID
113114
"""
114115
if req.raw_request:
115116
# 从请求头获取(兼容多种格式)
116-
# 支持:X-AgentRun-Session-ID, x-agentrun-session-id, X-Agentrun-Session-Id
117+
# 支持:X-AgentRun-Conversation-ID, x-agentrun-conversation-id, X-Agentrun-Conversation-Id
117118
session_id = (
118-
req.raw_request.headers.get("X-AgentRun-Session-ID")
119-
or req.raw_request.headers.get("x-agentrun-session-id")
120-
or req.raw_request.headers.get("X-Agentrun-Session-Id")
119+
req.raw_request.headers.get("X-AgentRun-Conversation-ID")
120+
or req.raw_request.headers.get("x-agentrun-conversation-id")
121+
or req.raw_request.headers.get("X-Agentrun-Conversation-Id")
121122
)
122123
if session_id:
123124
return session_id
124125

126+
# 从查询参数获取
127+
session_id = req.raw_request.query_params.get("sessionId")
128+
if session_id:
129+
return session_id
130+
125131
# 生成新的 session_id
126132
return f"session_{uuid.uuid4().hex[:16]}"
127133

@@ -131,12 +137,11 @@ def _default_agent_id_extractor(req: Any) -> str:
131137
132138
优先级:
133139
1. X-Agent-ID 请求头(支持多种格式)
134-
2. 从 URL 路径中提取 /agent-runtimes/{agent_id}/... 格式
135-
3. 默认值 "default_agent"
140+
2. 默认值 "default_agent"
136141
"""
137142
if req.raw_request:
138143
# 从请求头获取(兼容多种格式)
139-
# 支持:X-AgentRun-Agent-ID, x-agentrun-agent-id, X-Agentrun-Agent-Id, X-Agent-ID, x-agent-id
144+
# 支持:X-AgentRun-Agent-ID, x-agentrun-agent-id, X-Agentrun-Agent-Id
140145
agent_id = (
141146
req.raw_request.headers.get("X-AgentRun-Agent-ID")
142147
or req.raw_request.headers.get("x-agentrun-agent-id")
@@ -145,25 +150,6 @@ def _default_agent_id_extractor(req: Any) -> str:
145150
if agent_id:
146151
return agent_id
147152

148-
# 从 URL 路径中提取
149-
# 例如:/agent-runtimes/agent-quick-xFGD/invoke -> agent-quick-xFGD
150-
try:
151-
path = (
152-
req.raw_request.url.path
153-
if hasattr(req.raw_request.url, "path")
154-
else str(req.raw_request.url)
155-
)
156-
if "/agent-runtimes/" in path:
157-
# 提取 /agent-runtimes/ 后面的部分
158-
parts = path.split("/agent-runtimes/", 1)
159-
if len(parts) > 1:
160-
# 获取下一个路径段
161-
agent_part = parts[1].split("/")[0]
162-
if agent_part:
163-
return agent_part
164-
except Exception:
165-
pass
166-
167153
return "default_agent"
168154

169155
async def _get_memory_store(self):
@@ -193,12 +179,19 @@ async def _get_memory_store(self):
193179
ots_config = await self._get_ots_config_from_memory_collection()
194180

195181
# 创建 AsyncOTSClient
196-
self._ots_client = tablestore.AsyncOTSClient(
197-
end_point=ots_config["endpoint"],
198-
access_key_id=ots_config["access_key_id"],
199-
access_key_secret=ots_config["access_key_secret"],
200-
instance_name=ots_config["instance_name"],
201-
)
182+
# 支持使用 STS 临时凭证访问 TableStore
183+
client_kwargs = {
184+
"end_point": ots_config["endpoint"],
185+
"access_key_id": ots_config["access_key_id"],
186+
"access_key_secret": ots_config["access_key_secret"],
187+
"instance_name": ots_config["instance_name"],
188+
}
189+
190+
# 如果提供了 security_token,则添加到参数中(支持 STS 临时凭证)
191+
if ots_config.get("security_token"):
192+
client_kwargs["sts_token"] = ots_config["security_token"]
193+
194+
self._ots_client = tablestore.AsyncOTSClient(**client_kwargs)
202195

203196
# 配置会话表的二级索引元数据字段
204197
# agent_id 字段用于标识会话所属的 Agent
@@ -258,6 +251,7 @@ async def _get_ots_config_from_memory_collection(self) -> Dict[str, Any]:
258251
- endpoint: OTS endpoint
259252
- access_key_id: 访问密钥 ID
260253
- access_key_secret: 访问密钥 Secret
254+
- security_token: STS 安全令牌(可选,用于临时凭证)
261255
- instance_name: OTS 实例名称
262256
"""
263257
from agentrun.memory_collection import MemoryCollection
@@ -309,6 +303,9 @@ async def _get_ots_config_from_memory_collection(self) -> Dict[str, Any]:
309303
"instance_name": vs_config.instance_name or "",
310304
"access_key_id": self.config.get_access_key_id(),
311305
"access_key_secret": self.config.get_access_key_secret(),
306+
"security_token": (
307+
self.config.get_security_token()
308+
), # 支持 STS 临时凭证
312309
}
313310

314311
return ots_config

tests/unittests/memory_collection/test_memory_conversation.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,14 @@ def test_default_user_id_extractor_fallback(self):
9292

9393
def test_default_session_id_extractor(self):
9494
"""Test default session_id extraction"""
95-
# Test with X-Session-ID header
95+
# Test with X-AgentRun-Conversation-ID header
9696
mock_req = Mock()
9797
mock_headers = Mock()
98-
mock_headers.get = Mock(return_value="session456")
98+
mock_headers.get = Mock(
99+
side_effect=lambda k: {
100+
"X-AgentRun-Conversation-ID": "session456"
101+
}.get(k)
102+
)
99103
mock_query = Mock()
100104
mock_query.get = Mock(return_value=None)
101105

@@ -110,6 +114,27 @@ def test_default_session_id_extractor(self):
110114
session_id = MemoryConversation._default_session_id_extractor(request)
111115
assert session_id == "session456"
112116

117+
def test_default_session_id_extractor_from_query(self):
118+
"""Test session_id extraction from query parameter"""
119+
mock_req = Mock()
120+
mock_headers = Mock()
121+
mock_headers.get = Mock(return_value=None)
122+
mock_query = Mock()
123+
mock_query.get = Mock(
124+
side_effect=lambda k: {"sessionId": "query_session789"}.get(k)
125+
)
126+
127+
mock_req.headers = mock_headers
128+
mock_req.query_params = mock_query
129+
130+
request = AgentRequest.model_construct(
131+
messages=[],
132+
raw_request=mock_req,
133+
)
134+
135+
session_id = MemoryConversation._default_session_id_extractor(request)
136+
assert session_id == "query_session789"
137+
113138
def test_default_session_id_extractor_generate(self):
114139
"""Test session_id generation"""
115140
request = AgentRequest(messages=[])

0 commit comments

Comments
 (0)