Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions astrbot/core/agent/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,15 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
timeout = cfg.get("timeout", 10)

try:
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP 连接配置缺少 transport 或 type 字段")

async with aiohttp.ClientSession() as session:
if cfg.get("transport") == "streamable_http":
if transport_type == "streamable_http":
test_payload = {
"jsonrpc": "2.0",
"method": "initialize",
Expand Down Expand Up @@ -121,7 +128,14 @@ def logging_callback(msg: str):
if not success:
raise Exception(error_msg)

if cfg.get("transport") != "streamable_http":
if "transport" in cfg:
transport_type = cfg["transport"]
elif "type" in cfg:
transport_type = cfg["type"]
else:
raise Exception("MCP 连接配置缺少 transport 或 type 字段")

if transport_type != "streamable_http":
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
Expand All @@ -134,7 +148,7 @@ def logging_callback(msg: str):
)

# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
*streams,
Expand All @@ -159,7 +173,7 @@ def logging_callback(msg: str):
)

# Create a new client session
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(
read_stream=read_s,
Expand Down
1 change: 1 addition & 0 deletions astrbot/core/astr_agent_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ class AstrAgentContext:
first_provider_request: ProviderRequest
curr_provider_request: ProviderRequest
streaming: bool
tool_call_timeout: int = 60 # Default tool call timeout in seconds
9 changes: 9 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"show_tool_use_status": False,
"streaming_segmented": False,
"max_agent_step": 30,
"tool_call_timeout": 60,
},
"provider_stt_settings": {
"enable": False,
Expand Down Expand Up @@ -1873,6 +1874,10 @@
"description": "工具调用轮数上限",
"type": "int",
},
"tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
},
},
},
"provider_stt_settings": {
Expand Down Expand Up @@ -2145,6 +2150,10 @@
"description": "工具调用轮数上限",
"type": "int",
},
"provider_settings.tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
},
"provider_settings.streaming_response": {
"description": "流式回复",
"type": "bool",
Expand Down
46 changes: 32 additions & 14 deletions astrbot/core/pipeline/process_stage/method/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import copy
import json
import traceback
from datetime import timedelta
from typing import AsyncGenerator, Union
from astrbot.core.conversation_mgr import Conversation
from astrbot.core import logger
Expand Down Expand Up @@ -185,21 +186,33 @@ async def _execute_local(
handler=awaitable,
**tool_args,
)
async for resp in wrapper:
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
# async for resp in wrapper:
while True:
try:
resp = await asyncio.wait_for(
anext(wrapper),
timeout=run_context.context.tool_call_timeout,
)
if resp is not None:
if isinstance(resp, mcp.types.CallToolResult):
yield resp
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
yield None
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
yield None
except asyncio.TimeoutError:
raise Exception(
f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds."
)
except StopAsyncIteration:
break

@classmethod
async def _execute_mcp(
Expand All @@ -217,6 +230,9 @@ async def _execute_mcp(
res = await session.call_tool(
name=tool.name,
arguments=tool_args,
read_timeout_seconds=timedelta(
seconds=run_context.context.tool_call_timeout
),
)
if not res:
return
Expand Down Expand Up @@ -307,6 +323,7 @@ async def initialize(self, ctx: PipelineContext) -> None:
)
self.streaming_response: bool = settings["streaming_response"]
self.max_step: int = settings.get("max_agent_step", 30)
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
if isinstance(self.max_step, bool): # workaround: #2622
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
Expand Down Expand Up @@ -473,6 +490,7 @@ async def process(
first_provider_request=req,
curr_provider_request=req,
streaming=self.streaming_response,
tool_call_timeout=self.tool_call_timeout,
)
await agent_runner.reset(
provider=provider,
Expand Down
14 changes: 14 additions & 0 deletions astrbot/dashboard/routes/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,20 @@ async def test_mcp_connection(self):
server_data = await request.json
config = server_data.get("mcp_server_config", None)

if not isinstance(config, dict) or not config:
return Response().error("无效的 MCP 服务器配置").__dict__

if "mcpServers" in config:
keys = list(config["mcpServers"].keys())
if not keys:
return Response().error("MCP 服务器配置不能为空").__dict__
if len(keys) > 1:
return Response().error("一次只能配置一个 MCP 服务器配置").__dict__
config = config["mcpServers"][keys[0]]
else:
if not config:
return Response().error("MCP 服务器配置不能为空").__dict__

tools_name = await self.tool_mgr.test_mcp_server_connection(config)
return (
Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
Expand Down
3 changes: 3 additions & 0 deletions dashboard/src/i18n/locales/en-US/features/tool-use.json
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@
"save": "Save",
"testConnection": "Test Connection",
"sync": "Sync"
},
"tips": {
"timeoutConfig": "Please configure tool call timeout separately in the configuration page"
}
},
"serverDetail": {
Expand Down
3 changes: 3 additions & 0 deletions dashboard/src/i18n/locales/zh-CN/features/tool-use.json
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@
"save": "保存",
"testConnection": "测试连接",
"sync": "同步"
},
"tips": {
"timeoutConfig": "工具调用的超时时间请前往配置页面单独配置"
}
},
"serverDetail": {
Expand Down
8 changes: 6 additions & 2 deletions dashboard/src/views/ToolUsePage.vue
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@
</v-btn>
</div>

<small style="color: grey">*{{ tm('dialogs.addServer.tips.timeoutConfig') }}</small>

<div class="monaco-container" style="margin-top: 16px;">
<VueMonacoEditor v-model:value="serverConfigJson" theme="vs-dark" language="json" :options="{
minimap: {
Expand Down Expand Up @@ -524,14 +526,16 @@ export default {
transport: "streamable_http",
url: "your mcp server url",
headers: {},
timeout: 30,
timeout: 5,
sse_read_timeout: 300,
};
} else if (type === 'sse') {
template = {
transport: "sse",
url: "your mcp server url",
headers: {},
timeout: 30,
timeout: 5,
sse_read_timeout: 300,
};
} else {
template = {
Expand Down