Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion astrbot/core/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
class Agent(Generic[TContext]):
name: str
instructions: str | None = None
tools: list[str, FunctionTool] | None = None
tools: list[str | FunctionTool] | None = None
run_hooks: BaseAgentRunHooks[TContext] | None = None
4 changes: 3 additions & 1 deletion astrbot/core/agent/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self):
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()

self.name = None
self.name: str | None = None
self.active: bool = True
self.tools: list[mcp.Tool] = []
self.server_errlogs: list[str] = []
Expand Down Expand Up @@ -198,6 +198,8 @@ def callback(msg: str):

async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
if not self.session:
raise Exception("MCP Client is not initialized")
response = await self.session.list_tools()
self.tools = response.tools
return response
Expand Down
45 changes: 28 additions & 17 deletions astrbot/core/agent/tool.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from dataclasses import dataclass
from deprecated import deprecated
from typing import Awaitable, Literal, Any, Optional
from typing import Awaitable, Callable, Literal, Any, Optional
from .mcp_client import MCPClient


@dataclass
class FunctionTool:
"""A class representing a function tool that can be used in function calling."""

name: str | None = None
name: str
parameters: dict | None = None
description: str | None = None
handler: Awaitable | None = None
handler: Callable[..., Awaitable[Any]] | None = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str | None = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
Expand Down Expand Up @@ -51,7 +51,7 @@ class ToolSet:
This class provides methods to add, remove, and retrieve tools, as well as
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""

def __init__(self, tools: list[FunctionTool] = None):
def __init__(self, tools: list[FunctionTool] | None = None):
self.tools: list[FunctionTool] = tools or []

def empty(self) -> bool:
Expand Down Expand Up @@ -79,7 +79,13 @@ def get_tool(self, name: str) -> Optional[FunctionTool]:
return None

@deprecated(reason="Use add_tool() instead", version="4.0.0")
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
def add_func(
self,
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
):
"""Add a function tool to the set."""
params = {
"type": "object", # hard-coded here
Expand All @@ -104,7 +110,7 @@ def remove_func(self, name: str):
self.remove_tool(name)

@deprecated(reason="Use get_tool() instead", version="4.0.0")
def get_func(self, name: str) -> list[FunctionTool]:
def get_func(self, name: str) -> FunctionTool | None:
"""Get all function tools."""
return self.get_tool(name)

Expand All @@ -125,7 +131,11 @@ def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]:
},
}

if tool.parameters.get("properties") or not omit_empty_parameter_field:
if (
tool.parameters
and tool.parameters.get("properties")
or not omit_empty_parameter_field
):
func_def["function"]["parameters"] = tool.parameters

result.append(func_def)
Expand All @@ -135,14 +145,14 @@ def anthropic_schema(self) -> list[dict]:
"""Convert tools to Anthropic API format."""
result = []
for tool in self.tools:
input_schema = {"type": "object"}
if tool.parameters:
input_schema["properties"] = tool.parameters.get("properties", {})
input_schema["required"] = tool.parameters.get("required", [])
tool_def = {
"name": tool.name,
"description": tool.description,
"input_schema": {
"type": "object",
"properties": tool.parameters.get("properties", {}),
"required": tool.parameters.get("required", []),
},
"input_schema": input_schema,
}
result.append(tool_def)
return result
Expand Down Expand Up @@ -210,14 +220,15 @@ def convert_schema(schema: dict) -> dict:

return result

tools = [
{
tools = []
for tool in self.tools:
d = {
"name": tool.name,
"description": tool.description,
"parameters": convert_schema(tool.parameters),
}
for tool in self.tools
]
if tool.parameters:
d["parameters"] = convert_schema(tool.parameters)
tools.append(d)

declarations = {}
if tools:
Expand Down
2 changes: 1 addition & 1 deletion astrbot/core/pipeline/content_safety_check/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async def initialize(self, ctx: PipelineContext):
self.strategy_selector = StrategySelector(config)

async def process(
self, event: AstrMessageEvent, check_text: str = None
self, event: AstrMessageEvent, check_text: str | None = None
) -> Union[None, AsyncGenerator[None, None]]:
"""检查内容安全"""
text = check_text if check_text else event.get_message_str()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, appid: str, ak: str, sk: str) -> None:
self.secret_key = sk
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)

def check(self, content: str):
def check(self, content: str) -> tuple[bool, str]:
res = self.client.textCensorUserDefined(content)
if "conclusionType" not in res:
return False, ""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, extra_keywords: list) -> None:
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
# )

def check(self, content: str) -> bool:
def check(self, content: str) -> tuple[bool, str]:
for keyword in self.keywords:
if re.search(keyword, content):
return False, "内容安全检查不通过,匹配到敏感词。"
Expand Down
5 changes: 4 additions & 1 deletion astrbot/core/pipeline/context_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

async def call_handler(
event: AstrMessageEvent,
handler: T.Awaitable,
handler: T.Callable[..., T.Awaitable[T.Any]],
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
Expand All @@ -36,6 +36,9 @@ async def call_handler(
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)

if not ready_to_call:
return

if inspect.isasyncgen(ready_to_call):
_has_yielded = False
try:
Expand Down
27 changes: 23 additions & 4 deletions astrbot/core/pipeline/process_stage/method/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import traceback
from typing import AsyncGenerator, Union
from astrbot.core.conversation_mgr import Conversation
from astrbot.core import logger
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
Expand Down Expand Up @@ -133,6 +134,15 @@ async def _execute_handoff(

if agent_runner.done():
llm_response = agent_runner.get_final_llm_resp()

if not llm_response:
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
yield mcp.types.CallToolResult(content=[text_content])
return

logger.debug(
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
)
Expand All @@ -148,7 +158,7 @@ async def _execute_handoff(
)
yield mcp.types.CallToolResult(content=[text_content])
else:
yield mcp.types.TextContent(
text_content = mcp.types.TextContent(
type="text",
text=f"error when deligate task to {tool.agent.name}",
)
Expand Down Expand Up @@ -200,7 +210,11 @@ async def _execute_mcp(
):
if not tool.mcp_client:
raise ValueError("MCP client is not available for MCP function tools.")
res = await tool.mcp_client.session.call_tool(

session = tool.mcp_client.session
if not session:
raise ValueError("MCP session is not available for MCP function tools.")
res = await session.call_tool(
name=tool.name,
arguments=tool_args,
)
Expand Down Expand Up @@ -325,7 +339,7 @@ def _select_provider(self, event: AstrMessageEvent) -> Provider | None:

return _ctx.get_using_provider(umo=event.unified_msg_origin)

async def _get_session_conv(self, event: AstrMessageEvent):
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
umo = event.unified_msg_origin
conv_mgr = self.conv_manager

Expand All @@ -337,6 +351,8 @@ async def _get_session_conv(self, event: AstrMessageEvent):
if not conversation:
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
conversation = await conv_mgr.get_conversation(umo, cid)
if not conversation:
raise RuntimeError("无法创建新的对话。")
return conversation

async def process(
Expand Down Expand Up @@ -444,7 +460,10 @@ async def process(
if event.plugins_name is not None and req.func_tool:
new_tool_set = ToolSet()
for tool in req.func_tool.tools:
plugin = star_map.get(tool.handler_module_path)
mp = tool.handler_module_path
if not mp:
continue
plugin = star_map.get(mp)
if not plugin:
continue
if plugin.name in event.plugins_name or plugin.reserved:
Expand Down
14 changes: 8 additions & 6 deletions astrbot/core/pipeline/process_stage/method/star_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ async def process(

for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_path not in star_map:
continue
logger.debug(
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
md = star_map.get(handler.handler_module_path)
if not md:
logger.warning(
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
)
continue
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
try:
wrapper = call_handler(event, handler.handler, **params)
async for ret in wrapper:
yield ret
Expand All @@ -49,7 +51,7 @@ async def process(
logger.error(f"Star {handler.handler_full_name} handle error: {e}")

if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
Expand Down
16 changes: 13 additions & 3 deletions astrbot/core/provider/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,16 @@ class AssistantMessageSegment:
role: str = "assistant"

def to_dict(self):
ret = {
ret: dict[str, str | list[dict]] = {
"role": self.role,
}
if self.content:
ret["content"] = self.content
if self.tool_calls:
ret["tool_calls"] = self.tool_calls
tool_calls_dict = [
tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
]
ret["tool_calls"] = tool_calls_dict
return ret


Expand Down Expand Up @@ -117,7 +120,14 @@ class ProviderRequest:
"""模型名称,为 None 时使用提供商的默认模型"""

def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
return (
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
f"image_count={len(self.image_urls or [])}, "
f"func_tool={self.func_tool}, "
f"contexts={self._print_friendly_context()}, "
f"system_prompt={self.system_prompt}, "
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
)

def __str__(self):
return self.__repr__()
Expand Down
8 changes: 4 additions & 4 deletions astrbot/core/provider/func_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import asyncio
import aiohttp

from typing import Dict, List, Awaitable
from typing import Dict, List, Awaitable, Callable, Any
from astrbot import logger
from astrbot.core import sp

Expand Down Expand Up @@ -109,7 +109,7 @@ def spec_to_func(
name: str,
func_args: list,
desc: str,
handler: Awaitable,
handler: Callable[..., Awaitable[Any]],
) -> FuncTool:
params = {
"type": "object", # hard-coded here
Expand All @@ -132,7 +132,7 @@ def add_func(
name: str,
func_args: list,
desc: str,
handler: Awaitable,
handler: Callable[..., Awaitable[Any]],
) -> None:
"""添加函数调用工具

Expand Down Expand Up @@ -220,7 +220,7 @@ async def _init_mcp_client_task_wrapper(
name: str,
cfg: dict,
event: asyncio.Event,
ready_future: asyncio.Future = None,
ready_future: asyncio.Future | None = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:
Expand Down
Loading