diff --git a/astrbot/core/agent/agent.py b/astrbot/core/agent/agent.py index 70536ca88..061ffde09 100644 --- a/astrbot/core/agent/agent.py +++ b/astrbot/core/agent/agent.py @@ -9,5 +9,5 @@ class Agent(Generic[TContext]): name: str instructions: str | None = None - tools: list[str, FunctionTool] | None = None + tools: list[str | FunctionTool] | None = None run_hooks: BaseAgentRunHooks[TContext] | None = None diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index f22a222a0..c2ed246b0 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -92,7 +92,7 @@ def __init__(self): self.session: Optional[mcp.ClientSession] = None self.exit_stack = AsyncExitStack() - self.name = None + self.name: str | None = None self.active: bool = True self.tools: list[mcp.Tool] = [] self.server_errlogs: list[str] = [] @@ -198,6 +198,8 @@ def callback(msg: str): async def list_tools_and_save(self) -> mcp.ListToolsResult: """List all tools from the server and save them to self.tools""" + if not self.session: + raise Exception("MCP Client is not initialized") response = await self.session.list_tools() self.tools = response.tools return response diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 743deae1f..ae0ab761c 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from deprecated import deprecated -from typing import Awaitable, Literal, Any, Optional +from typing import Awaitable, Callable, Literal, Any, Optional from .mcp_client import MCPClient @@ -8,10 +8,10 @@ class FunctionTool: """A class representing a function tool that can be used in function calling.""" - name: str | None = None + name: str parameters: dict | None = None description: str | None = None - handler: Awaitable | None = None + handler: Callable[..., Awaitable[Any]] | None = None """处理函数, 当 origin 为 mcp 时,这个为空""" handler_module_path: str | None = None """处理函数的模块路径,当 origin 为 mcp 时,这个为空 @@ -51,7 +51,7 @@ class ToolSet: This class provides methods to add, remove, and retrieve tools, as well as convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).""" - def __init__(self, tools: list[FunctionTool] = None): + def __init__(self, tools: list[FunctionTool] | None = None): self.tools: list[FunctionTool] = tools or [] def empty(self) -> bool: @@ -79,7 +79,13 @@ def get_tool(self, name: str) -> Optional[FunctionTool]: return None @deprecated(reason="Use add_tool() instead", version="4.0.0") - def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable): + def add_func( + self, + name: str, + func_args: list, + desc: str, + handler: Callable[..., Awaitable[Any]], + ): """Add a function tool to the set.""" params = { "type": "object", # hard-coded here @@ -104,7 +110,7 @@ def remove_func(self, name: str): self.remove_tool(name) @deprecated(reason="Use get_tool() instead", version="4.0.0") - def get_func(self, name: str) -> list[FunctionTool]: + def get_func(self, name: str) -> FunctionTool | None: """Get all function tools.""" return self.get_tool(name) @@ -125,7 +131,11 @@ def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]: }, } - if tool.parameters.get("properties") or not omit_empty_parameter_field: + if ( + tool.parameters + and tool.parameters.get("properties") + or not omit_empty_parameter_field + ): func_def["function"]["parameters"] = tool.parameters result.append(func_def) @@ -135,14 +145,14 @@ def anthropic_schema(self) -> list[dict]: """Convert tools to Anthropic API format.""" result = [] for tool in self.tools: + input_schema = {"type": "object"} + if tool.parameters: + input_schema["properties"] = tool.parameters.get("properties", {}) + input_schema["required"] = tool.parameters.get("required", []) tool_def = { "name": tool.name, "description": tool.description, - "input_schema": { - "type": "object", - "properties": tool.parameters.get("properties", {}), - "required": tool.parameters.get("required", []), - }, + "input_schema": input_schema, } result.append(tool_def) return result @@ -210,14 +220,15 @@ def convert_schema(schema: dict) -> dict: return result - tools = [ - { + tools = [] + for tool in self.tools: + d = { "name": tool.name, "description": tool.description, - "parameters": convert_schema(tool.parameters), } - for tool in self.tools - ] + if tool.parameters: + d["parameters"] = convert_schema(tool.parameters) + tools.append(d) declarations = {} if tools: diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index bafef1b05..e6ecd995c 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -19,7 +19,7 @@ async def initialize(self, ctx: PipelineContext): self.strategy_selector = StrategySelector(config) async def process( - self, event: AstrMessageEvent, check_text: str = None + self, event: AstrMessageEvent, check_text: str | None = None ) -> Union[None, AsyncGenerator[None, None]]: """检查内容安全""" text = check_text if check_text else event.get_message_str() diff --git a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py index 73296b90e..26284e1a1 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py @@ -13,7 +13,7 @@ def __init__(self, appid: str, ak: str, sk: str) -> None: self.secret_key = sk self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key) - def check(self, content: str): + def check(self, content: str) -> tuple[bool, str]: res = self.client.textCensorUserDefined(content) if "conclusionType" not in res: return False, "" diff --git a/astrbot/core/pipeline/content_safety_check/strategies/keywords.py b/astrbot/core/pipeline/content_safety_check/strategies/keywords.py index de5b9c456..c65faa000 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/keywords.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/keywords.py @@ -16,7 +16,7 @@ def __init__(self, extra_keywords: list) -> None: # json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"] # ) - def check(self, content: str) -> bool: + def check(self, content: str) -> tuple[bool, str]: for keyword in self.keywords: if re.search(keyword, content): return False, "内容安全检查不通过,匹配到敏感词。" diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 27b47cbe3..325abcf86 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -10,7 +10,7 @@ async def call_handler( event: AstrMessageEvent, - handler: T.Awaitable, + handler: T.Callable[..., T.Awaitable[T.Any]], *args, **kwargs, ) -> T.AsyncGenerator[T.Any, None]: @@ -36,6 +36,9 @@ async def call_handler( except TypeError: logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True) + if not ready_to_call: + return + if inspect.isasyncgen(ready_to_call): _has_yielded = False try: diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 6035b21e9..c4b8a7c28 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -7,6 +7,7 @@ import json import traceback from typing import AsyncGenerator, Union +from astrbot.core.conversation_mgr import Conversation from astrbot.core import logger from astrbot.core.message.components import Image from astrbot.core.message.message_event_result import ( @@ -133,6 +134,15 @@ async def _execute_handoff( if agent_runner.done(): llm_response = agent_runner.get_final_llm_resp() + + if not llm_response: + text_content = mcp.types.TextContent( + type="text", + text=f"error when deligate task to {tool.agent.name}", + ) + yield mcp.types.CallToolResult(content=[text_content]) + return + logger.debug( f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}" ) @@ -148,7 +158,7 @@ async def _execute_handoff( ) yield mcp.types.CallToolResult(content=[text_content]) else: - yield mcp.types.TextContent( + text_content = mcp.types.TextContent( type="text", text=f"error when deligate task to {tool.agent.name}", ) @@ -200,7 +210,11 @@ async def _execute_mcp( ): if not tool.mcp_client: raise ValueError("MCP client is not available for MCP function tools.") - res = await tool.mcp_client.session.call_tool( + + session = tool.mcp_client.session + if not session: + raise ValueError("MCP session is not available for MCP function tools.") + res = await session.call_tool( name=tool.name, arguments=tool_args, ) @@ -325,7 +339,7 @@ def _select_provider(self, event: AstrMessageEvent) -> Provider | None: return _ctx.get_using_provider(umo=event.unified_msg_origin) - async def _get_session_conv(self, event: AstrMessageEvent): + async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation: umo = event.unified_msg_origin conv_mgr = self.conv_manager @@ -337,6 +351,8 @@ async def _get_session_conv(self, event: AstrMessageEvent): if not conversation: cid = await conv_mgr.new_conversation(umo, event.get_platform_id()) conversation = await conv_mgr.get_conversation(umo, cid) + if not conversation: + raise RuntimeError("无法创建新的对话。") return conversation async def process( @@ -444,7 +460,10 @@ async def process( if event.plugins_name is not None and req.func_tool: new_tool_set = ToolSet() for tool in req.func_tool.tools: - plugin = star_map.get(tool.handler_module_path) + mp = tool.handler_module_path + if not mp: + continue + plugin = star_map.get(mp) if not plugin: continue if plugin.name in event.plugins_name or plugin.reserved: diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index c5c0f5738..42990aae5 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -34,12 +34,14 @@ async def process( for handler in activated_handlers: params = handlers_parsed_params.get(handler.handler_full_name, {}) - try: - if handler.handler_module_path not in star_map: - continue - logger.debug( - f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}" + md = star_map.get(handler.handler_module_path) + if not md: + logger.warning( + f"Cannot find plugin for given handler module path: {handler.handler_module_path}" ) + continue + logger.debug(f"plugin -> {md.name} - {handler.handler_name}") + try: wrapper = call_handler(event, handler.handler, **params) async for ret in wrapper: yield ret @@ -49,7 +51,7 @@ async def process( logger.error(f"Star {handler.handler_full_name} handle error: {e}") if event.is_at_or_wake_command: - ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}" + ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}" event.set_result(MessageEventResult().message(ret)) yield event.clear_result() diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 8ece29a2b..85687c417 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -65,13 +65,16 @@ class AssistantMessageSegment: role: str = "assistant" def to_dict(self): - ret = { + ret: dict[str, str | list[dict]] = { "role": self.role, } if self.content: ret["content"] = self.content if self.tool_calls: - ret["tool_calls"] = self.tool_calls + tool_calls_dict = [ + tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls + ] + ret["tool_calls"] = tool_calls_dict return ret @@ -117,7 +120,14 @@ class ProviderRequest: """模型名称,为 None 时使用提供商的默认模型""" def __repr__(self): - return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})" + return ( + f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, " + f"image_count={len(self.image_urls or [])}, " + f"func_tool={self.func_tool}, " + f"contexts={self._print_friendly_context()}, " + f"system_prompt={self.system_prompt}, " + f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, " + ) def __str__(self): return self.__repr__() diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 509975556..51cde0eb9 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -4,7 +4,7 @@ import asyncio import aiohttp -from typing import Dict, List, Awaitable +from typing import Dict, List, Awaitable, Callable, Any from astrbot import logger from astrbot.core import sp @@ -109,7 +109,7 @@ def spec_to_func( name: str, func_args: list, desc: str, - handler: Awaitable, + handler: Callable[..., Awaitable[Any]], ) -> FuncTool: params = { "type": "object", # hard-coded here @@ -132,7 +132,7 @@ def add_func( name: str, func_args: list, desc: str, - handler: Awaitable, + handler: Callable[..., Awaitable[Any]], ) -> None: """添加函数调用工具 @@ -220,7 +220,7 @@ async def _init_mcp_client_task_wrapper( name: str, cfg: dict, event: asyncio.Event, - ready_future: asyncio.Future = None, + ready_future: asyncio.Future | None = None, ) -> None: """初始化 MCP 客户端的包装函数,用于捕获异常""" try: diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 3b50e4976..4b9204a68 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -38,7 +38,7 @@ def __init__( """加载的 Text To Speech Provider 的实例""" self.embedding_provider_insts: List[EmbeddingProvider] = [] """加载的 Embedding Provider 的实例""" - self.inst_map: dict[str, Provider] = {} + self.inst_map: dict[str, Provider | STTProvider | TTSProvider] = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools @@ -87,19 +87,31 @@ async def set_provider( ) return # 不启用提供商会话隔离模式的情况 - self.curr_provider_inst = self.inst_map[provider_id] - if provider_type == ProviderType.TEXT_TO_SPEECH: + + prov = self.inst_map[provider_id] + if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance( + prov, TTSProvider + ): + self.curr_tts_provider_inst = prov sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global") - elif provider_type == ProviderType.SPEECH_TO_TEXT: + elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance( + prov, STTProvider + ): + self.curr_stt_provider_inst = prov sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global") - elif provider_type == ProviderType.CHAT_COMPLETION: + elif provider_type == ProviderType.CHAT_COMPLETION and isinstance( + prov, Provider + ): + self.curr_provider_inst = prov sp.put("curr_provider", provider_id, scope="global", scope_id="global") async def get_provider_by_id(self, provider_id: str) -> Provider | None: """根据提供商 ID 获取提供商实例""" return self.inst_map.get(provider_id) - def get_using_provider(self, provider_type: ProviderType, umo=None): + def get_using_provider( + self, provider_type: ProviderType, umo=None + ) -> Provider | STTProvider | TTSProvider | None: """获取正在使用的提供商实例。 Args: @@ -303,12 +315,14 @@ async def load_provider(self, provider_config: dict): provider_metadata = provider_cls_map[provider_config["type"]] try: # 按任务实例化提供商 + cls_type = provider_metadata.cls_type + if not cls_type: + logger.error(f"无法找到 {provider_metadata.type} 的类") + return if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: # STT 任务 - inst = provider_metadata.cls_type( - provider_config, self.provider_settings - ) + inst = cls_type(provider_config, self.provider_settings) if getattr(inst, "initialize", None): await inst.initialize() @@ -327,9 +341,7 @@ async def load_provider(self, provider_config: dict): elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH: # TTS 任务 - inst = provider_metadata.cls_type( - provider_config, self.provider_settings - ) + inst = cls_type(provider_config, self.provider_settings) if getattr(inst, "initialize", None): await inst.initialize() @@ -345,7 +357,7 @@ async def load_provider(self, provider_config: dict): elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: # 文本生成任务 - inst = provider_metadata.cls_type( + inst = cls_type( provider_config, self.provider_settings, self.selected_default_persona, @@ -370,9 +382,7 @@ async def load_provider(self, provider_config: dict): ProviderType.EMBEDDING, ProviderType.RERANK, ]: - inst = provider_metadata.cls_type( - provider_config, self.provider_settings - ) + inst = cls_type(provider_config, self.provider_settings) if getattr(inst, "initialize", None): await inst.initialize() self.embedding_provider_insts.append(inst) @@ -430,11 +440,17 @@ async def terminate_provider(self, provider_id: str): ) if self.inst_map[provider_id] in self.provider_insts: - self.provider_insts.remove(self.inst_map[provider_id]) + prov_inst = self.inst_map[provider_id] + if isinstance(prov_inst, Provider): + self.provider_insts.remove(prov_inst) if self.inst_map[provider_id] in self.stt_provider_insts: - self.stt_provider_insts.remove(self.inst_map[provider_id]) + prov_inst = self.inst_map[provider_id] + if isinstance(prov_inst, STTProvider): + self.stt_provider_insts.remove(prov_inst) if self.inst_map[provider_id] in self.tts_provider_insts: - self.tts_provider_insts.remove(self.inst_map[provider_id]) + prov_inst = self.inst_map[provider_id] + if isinstance(prov_inst, TTSProvider): + self.tts_provider_insts.remove(prov_inst) if self.inst_map[provider_id] == self.curr_provider_inst: self.curr_provider_inst = None diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 76db898aa..005266a02 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -23,7 +23,7 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType from .filter.command import CommandFilter from .filter.regex import RegexFilter -from typing import Awaitable +from typing import Awaitable, Any, Callable from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.star.filter.platform_adapter_type import ( PlatformAdapterType, @@ -105,7 +105,10 @@ def register_provider(self, provider: Provider): def get_provider_by_id(self, provider_id: str) -> Provider | None: """通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。""" - return self.provider_manager.inst_map.get(provider_id) + prov = self.provider_manager.inst_map.get(provider_id) + if prov and not isinstance(prov, Provider): + raise ValueError("返回的 Provider 不是 Provider 类型") + return prov def get_all_providers(self) -> List[Provider]: """获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。""" @@ -130,34 +133,43 @@ def get_using_provider(self, umo: str | None = None) -> Provider | None: Args: umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。 """ - return self.provider_manager.get_using_provider( + prov = self.provider_manager.get_using_provider( provider_type=ProviderType.CHAT_COMPLETION, umo=umo, ) + if prov and not isinstance(prov, Provider): + raise ValueError("返回的 Provider 不是 Provider 类型") + return prov - def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider: + def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None: """ 获取当前使用的用于 TTS 任务的 Provider。 Args: umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ - return self.provider_manager.get_using_provider( + prov = self.provider_manager.get_using_provider( provider_type=ProviderType.TEXT_TO_SPEECH, umo=umo, ) + if prov and not isinstance(prov, TTSProvider): + raise ValueError("返回的 Provider 不是 TTSProvider 类型") + return prov - def get_using_stt_provider(self, umo: str | None = None) -> STTProvider: + def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None: """ 获取当前使用的用于 STT 任务的 Provider。 Args: umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ - return self.provider_manager.get_using_provider( + prov = self.provider_manager.get_using_provider( provider_type=ProviderType.SPEECH_TO_TEXT, umo=umo, ) + if prov and not isinstance(prov, STTProvider): + raise ValueError("返回的 Provider 不是 STTProvider 类型") + return prov def get_config(self, umo: str | None = None) -> AstrBotConfig: """获取 AstrBot 的配置。""" @@ -245,7 +257,11 @@ async def send_message( """ def register_llm_tool( - self, name: str, func_args: list, desc: str, func_obj: Awaitable + self, + name: str, + func_args: list, + desc: str, + func_obj: Callable[..., Awaitable[Any]], ) -> None: """ 为函数调用(function-calling / tools-use)添加工具。 @@ -267,9 +283,7 @@ def register_llm_tool( desc=desc, ) star_handlers_registry.append(md) - self.provider_manager.llm_tools.add_func( - name, func_args, desc, func_obj, func_obj - ) + self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj) def unregister_llm_tool(self, name: str) -> None: """删除一个函数调用工具。如果再要启用,需要重新注册。""" @@ -281,7 +295,7 @@ def register_commands( command_name: str, desc: str, priority: int, - awaitable: Awaitable, + awaitable: Callable[..., Awaitable[Any]], use_regex=False, ignore_prefix=False, ): diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index 88d8ae64d..0b8cd6e86 100755 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -13,8 +13,8 @@ class CommandGroupFilter(HandlerFilter): def __init__( self, group_name: str, - alias: set = None, - parent_group: CommandGroupFilter = None, + alias: set | None = None, + parent_group: CommandGroupFilter | None = None, ): self.group_name = group_name self.alias = alias if alias else set() @@ -54,8 +54,8 @@ def print_cmd_tree( self, sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]], prefix: str = "", - event: AstrMessageEvent = None, - cfg: AstrBotConfig = None, + event: AstrMessageEvent | None = None, + cfg: AstrBotConfig | None = None, ) -> str: result = "" for sub_filter in sub_command_filters: diff --git a/astrbot/core/star/filter/platform_adapter_type.py b/astrbot/core/star/filter/platform_adapter_type.py index 1634001f3..7e9dda5ba 100644 --- a/astrbot/core/star/filter/platform_adapter_type.py +++ b/astrbot/core/star/filter/platform_adapter_type.py @@ -2,7 +2,6 @@ from . import HandlerFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.config import AstrBotConfig -from typing import Union class PlatformAdapterType(enum.Flag): @@ -54,11 +53,14 @@ class PlatformAdapterType(enum.Flag): class PlatformAdapterTypeFilter(HandlerFilter): - def __init__(self, platform_adapter_type_or_str: Union[PlatformAdapterType, str]): - self.type_or_str = platform_adapter_type_or_str + def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str): + if isinstance(platform_adapter_type_or_str, str): + self.platform_type = ADAPTER_NAME_2_TYPE.get(platform_adapter_type_or_str) + else: + self.platform_type = platform_adapter_type_or_str def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: adapter_name = event.get_platform_name() - if adapter_name in ADAPTER_NAME_2_TYPE: - return ADAPTER_NAME_2_TYPE[adapter_name] & self.type_or_str + if adapter_name in ADAPTER_NAME_2_TYPE and self.platform_type is not None: + return bool(ADAPTER_NAME_2_TYPE[adapter_name] & self.platform_type) return False diff --git a/astrbot/core/star/register/star.py b/astrbot/core/star/register/star.py index d4814d07c..a5190dd5c 100644 --- a/astrbot/core/star/register/star.py +++ b/astrbot/core/star/register/star.py @@ -5,7 +5,9 @@ _warned_register_star = False -def register_star(name: str, author: str, desc: str, version: str, repo: str = None): +def register_star( + name: str, author: str, desc: str, version: str, repo: str | None = None +): """注册一个插件(Star)。 [DEPRECATED] 该装饰器已废弃,将在未来版本中移除。 diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index b5ebd3f50..419c874bf 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -12,7 +12,7 @@ from ..filter.permission import PermissionTypeFilter, PermissionType from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr from ..filter.regex import RegexFilter -from typing import Awaitable +from typing import Awaitable, Any, Callable from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES from astrbot.core.provider.register import llm_tools from astrbot.core.agent.agent import Agent @@ -20,15 +20,19 @@ from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core import logger -def get_handler_full_name(awaitable: Awaitable) -> str: +def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str: """获取 Handler 的全名""" return f"{awaitable.__module__}_{awaitable.__name__}" def get_handler_or_create( - handler: Awaitable, event_type: EventType, dont_add=False, **kwargs + handler: Callable[..., Awaitable[Any]], + event_type: EventType, + dont_add=False, + **kwargs, ) -> StarHandlerMetadata: """获取 Handler 或者创建一个新的 Handler""" handler_full_name = get_handler_full_name(handler) @@ -59,22 +63,35 @@ def get_handler_or_create( def register_command( - command_name: str = None, sub_command: str = None, alias: set = None, **kwargs + command_name: str | None = None, + sub_command: str | None = None, + alias: set | None = None, + **kwargs, ): """注册一个 Command.""" new_command = None add_to_event_filters = False if isinstance(command_name, RegisteringCommandable): # 子指令 - parent_command_names = command_name.parent_group.get_complete_command_names() - new_command = CommandFilter( - sub_command, alias, None, parent_command_names=parent_command_names - ) - command_name.parent_group.add_sub_command_filter(new_command) + if sub_command is not None: + parent_command_names = ( + command_name.parent_group.get_complete_command_names() + ) + new_command = CommandFilter( + sub_command, alias, None, parent_command_names=parent_command_names + ) + command_name.parent_group.add_sub_command_filter(new_command) + else: + logger.warning( + f"注册指令{command_name} 的子指令时未提供 sub_command 参数。" + ) else: # 裸指令 - new_command = CommandFilter(command_name, alias, None) - add_to_event_filters = True + if command_name is None: + logger.warning("注册裸指令时未提供 command_name 参数。") + else: + new_command = CommandFilter(command_name, alias, None) + add_to_event_filters = True def decorator(awaitable): if not add_to_event_filters: @@ -84,8 +101,9 @@ def decorator(awaitable): handler_md = get_handler_or_create( awaitable, EventType.AdapterMessageEvent, **kwargs ) - new_command.init_handler_md(handler_md) - handler_md.event_filters.append(new_command) + if new_command: + new_command.init_handler_md(handler_md) + handler_md.event_filters.append(new_command) return awaitable return decorator @@ -163,26 +181,38 @@ def decorator(awaitable): def register_command_group( - command_group_name: str = None, sub_command: str = None, alias: set = None, **kwargs + command_group_name: str | None = None, + sub_command: str | None = None, + alias: set | None = None, + **kwargs, ): """注册一个 CommandGroup""" new_group = None if isinstance(command_group_name, RegisteringCommandable): # 子指令组 - new_group = CommandGroupFilter( - sub_command, alias, parent_group=command_group_name.parent_group - ) - command_group_name.parent_group.add_sub_command_filter(new_group) + if sub_command is None: + logger.warning(f"{command_group_name} 指令组的子指令组 sub_command 未指定") + else: + new_group = CommandGroupFilter( + sub_command, alias, parent_group=command_group_name.parent_group + ) + command_group_name.parent_group.add_sub_command_filter(new_group) else: # 根指令组 - new_group = CommandGroupFilter(command_group_name, alias) + if command_group_name is None: + logger.warning("根指令组的名称未指定") + else: + new_group = CommandGroupFilter(command_group_name, alias) def decorator(obj): # 根指令组 - handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs) - handler_md.event_filters.append(new_group) + if new_group: + handler_md = get_handler_or_create( + obj, EventType.AdapterMessageEvent, **kwargs + ) + handler_md.event_filters.append(new_group) - return RegisteringCommandable(new_group) + return RegisteringCommandable(new_group) return decorator @@ -323,7 +353,7 @@ def decorator(awaitable): return decorator -def register_llm_tool(name: str = None, **kwargs): +def register_llm_tool(name: str | None = None, **kwargs): """为函数调用(function-calling / tools-use)添加工具。 请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释) @@ -361,9 +391,10 @@ async def get_weather(event: AstrMessageEvent, location: str): if kwargs.get("registering_agent"): registering_agent = kwargs["registering_agent"] - def decorator(awaitable: Awaitable): + def decorator(awaitable: Callable[..., Awaitable[Any]]): llm_tool_name = name_ if name_ else awaitable.__name__ - docstring = docstring_parser.parse(awaitable.__doc__) + func_doc = awaitable.__doc__ or "" + docstring = docstring_parser.parse(func_doc) args = [] for arg in docstring.params: if arg.type_name not in SUPPORTED_TYPES: @@ -379,20 +410,18 @@ def decorator(awaitable: Awaitable): ) # print(llm_tool_name, registering_agent) if not registering_agent: + doc_desc = docstring.description.strip() if docstring.description else "" md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent) - llm_tools.add_func( - llm_tool_name, args, docstring.description.strip(), md.handler - ) + llm_tools.add_func(llm_tool_name, args, doc_desc, md.handler) else: assert isinstance(registering_agent, RegisteringAgent) # print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name) if registering_agent._agent.tools is None: registering_agent._agent.tools = [] - registering_agent._agent.tools.append( - llm_tools.spec_to_func( - llm_tool_name, args, docstring.description.strip(), awaitable - ) - ) + + desc = docstring.description.strip() if docstring.description else "" + tool = llm_tools.spec_to_func(llm_tool_name, args, desc, awaitable) + registering_agent._agent.tools.append(tool) return awaitable @@ -413,8 +442,8 @@ def __init__(self, agent: Agent[AstrAgentContext]): def register_agent( name: str, instruction: str, - tools: list[str | FunctionTool] = None, - run_hooks: BaseAgentRunHooks[AstrAgentContext] = None, + tools: list[str | FunctionTool] | None = None, + run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None, ): """注册一个 Agent @@ -426,7 +455,7 @@ def register_agent( """ tools_ = tools or [] - def decorator(awaitable: Awaitable): + def decorator(awaitable: Callable[..., Awaitable[Any]]): AstrAgent = Agent[AstrAgentContext] agent = AstrAgent( name=name, diff --git a/astrbot/core/star/session_plugin_manager.py b/astrbot/core/star/session_plugin_manager.py index d1fdf77c8..94a0c8a4d 100644 --- a/astrbot/core/star/session_plugin_manager.py +++ b/astrbot/core/star/session_plugin_manager.py @@ -140,6 +140,9 @@ def filter_handlers_by_session(event: AstrMessageEvent, handlers: List) -> List: filtered_handlers.append(handler) continue + if plugin.name is None: + continue + # 检查插件是否在当前会话中启用 if SessionPluginManager.is_plugin_enabled_for_session( session_id, plugin.name diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index e0a97e50a..80b5adb60 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -1,7 +1,7 @@ from __future__ import annotations import enum from dataclasses import dataclass, field -from typing import Awaitable, List, Dict, TypeVar, Generic +from typing import Callable, Awaitable, Any, List, Dict, TypeVar, Generic from .filter import HandlerFilter from .star import star_map @@ -60,7 +60,7 @@ def get_handlers_by_event_type( handlers.append(handler) return handlers - def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata: + def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata | None: return self.star_handlers_map.get(full_name, None) def get_handlers_by_module_name( @@ -87,7 +87,7 @@ def __len__(self): return len(self._handlers) -star_handlers_registry = StarHandlerRegistry() +star_handlers_registry = StarHandlerRegistry() # type: ignore class EventType(enum.Enum): @@ -123,7 +123,7 @@ class StarHandlerMetadata: handler_module_path: str """Handler 所在的模块路径。""" - handler: Awaitable + handler: Callable[..., Awaitable[Any]] """Handler 的函数对象,应当是一个异步函数""" event_filters: List[HandlerFilter] diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 91e7ef0d7..417006730 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -43,7 +43,7 @@ def __init__(self, context: Context, config: AstrBotConfig): self.updator = PluginUpdator() self.context = context - self.context._star_manager = self + self.context._star_manager = self # type: ignore self.config = config self.plugin_store_path = get_astrbot_plugin_path() @@ -478,9 +478,10 @@ async def load(self, specified_module_path=None, specified_dir_name=None): if isinstance(func_tool, HandoffTool): need_apply = [] sub_tools = func_tool.agent.tools - for sub_tool in sub_tools: - if isinstance(sub_tool, FunctionTool): - need_apply.append(sub_tool) + if sub_tools: + for sub_tool in sub_tools: + if isinstance(sub_tool, FunctionTool): + need_apply.append(sub_tool) else: need_apply = [func_tool] @@ -686,6 +687,9 @@ async def uninstall_plugin(self, plugin_name: str): ) # 从 star_registry 和 star_map 中删除 + if plugin.module_path is None or root_dir_name is None: + raise Exception(f"插件 {plugin_name} 数据不完整,无法卸载。") + await self._unbind_plugin(plugin_name, plugin.module_path) try: @@ -800,6 +804,8 @@ async def _terminate_plugin(star_metadata: StarMetadata): async def turn_on_plugin(self, plugin_name: str): plugin = self.context.get_registered_star(plugin_name) + if plugin is None: + raise Exception(f"插件 {plugin_name} 不存在。") inactivated_plugins: list = await sp.global_get("inactivated_plugins", []) inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", []) if plugin.module_path in inactivated_plugins: diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 14bd1ac9b..6f9dfe2fa 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -22,7 +22,7 @@ import os import uuid from pathlib import Path -from typing import Union, Awaitable, List, Optional, ClassVar +from typing import Union, Awaitable, Callable, Any, List, Optional, ClassVar from astrbot.core.message.components import BaseMessageComponent from astrbot.core.message.message_event_result import MessageChain from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType @@ -221,7 +221,11 @@ def deactivate_llm_tool(cls, name: str) -> bool: @classmethod def register_llm_tool( - cls, name: str, func_args: list, desc: str, func_obj: Awaitable + cls, + name: str, + func_args: list, + desc: str, + func_obj: Callable[..., Awaitable[Any]], ) -> None: """ 为函数调用(function-calling/tools-use)添加工具 diff --git a/astrbot/core/star/updator.py b/astrbot/core/star/updator.py index 9896115bb..a22455377 100644 --- a/astrbot/core/star/updator.py +++ b/astrbot/core/star/updator.py @@ -32,6 +32,9 @@ async def update(self, plugin: StarMetadata, proxy="") -> str: if not repo_url: raise Exception(f"插件 {plugin.name} 没有指定仓库地址。") + if not plugin.root_dir_name: + raise Exception(f"插件 {plugin.name} 的根目录名未指定。") + plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name) logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}") diff --git a/packages/web_searcher/main.py b/packages/web_searcher/main.py index c9ce6908c..4e817564c 100644 --- a/packages/web_searcher/main.py +++ b/packages/web_searcher/main.py @@ -178,7 +178,7 @@ async def _extract_tavily(self, cfg: AstrBotConfig, payload: dict) -> list[dict] return results @filter.command("websearch") - async def websearch(self, event: AstrMessageEvent, oper: str = None) -> str: + async def websearch(self, event: AstrMessageEvent, oper: str | None = None): event.set_result( MessageEventResult().message( "此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。" @@ -210,7 +210,7 @@ async def search_from_search_engine( processed_results = await asyncio.gather(*tasks, return_exceptions=True) ret = "" for processed_result in processed_results: - if isinstance(processed_result, Exception): + if isinstance(processed_result, BaseException): logger.error(f"Error processing search result: {processed_result}") continue ret += processed_result @@ -335,7 +335,7 @@ async def tavily_extract_web_page( @filter.on_llm_request(priority=-10000) async def edit_web_search_tools( self, event: AstrMessageEvent, req: ProviderRequest - ) -> str: + ): """Get the session conversation for the given event.""" cfg = self.context.get_config(umo=event.unified_msg_origin) prov_settings = cfg.get("provider_settings", {}) @@ -347,6 +347,9 @@ async def edit_web_search_tools( req.func_tool = tool_set.get_full_tool_set() tool_set = req.func_tool + if not tool_set: + return + if not websearch_enable: # pop tools for tool_name in self.TOOLS: @@ -372,3 +375,5 @@ async def edit_web_search_tools( tool_set.add_tool(tavily_extract_web_page) tool_set.remove_tool("web_search") tool_set.remove_tool("fetch_url") + + print(req.func_tool)