From c530489744b9d61cd0a7fcb853979ddaa7087a78 Mon Sep 17 00:00:00 2001 From: Dt8333 Date: Tue, 28 Oct 2025 09:33:15 +0800 Subject: [PATCH 01/44] chore(cli): ruff rewrite --- astrbot/cli/commands/cmd_conf.py | 4 +++- astrbot/cli/commands/cmd_plug.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/astrbot/cli/commands/cmd_conf.py b/astrbot/cli/commands/cmd_conf.py index fea654f20..b556536be 100644 --- a/astrbot/cli/commands/cmd_conf.py +++ b/astrbot/cli/commands/cmd_conf.py @@ -2,7 +2,9 @@ import click import hashlib import zoneinfo -from typing import Any, Callable +from typing import Any + +from collections.abc import Callable from ..utils import get_astrbot_root, check_astrbot_root diff --git a/astrbot/cli/commands/cmd_plug.py b/astrbot/cli/commands/cmd_plug.py index b250ede4b..7513a7389 100644 --- a/astrbot/cli/commands/cmd_plug.py +++ b/astrbot/cli/commands/cmd_plug.py @@ -86,7 +86,7 @@ def new(name: str): f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n") # 重写 main.py - with open(plug_path / "main.py", "r", encoding="utf-8") as f: + with open(plug_path / "main.py", encoding="utf-8") as f: content = f.read() new_content = content.replace( From 8b6d235107f1cbe783ab484d3b5831444ebddb2b Mon Sep 17 00:00:00 2001 From: Dt8333 Date: Tue, 28 Oct 2025 09:33:42 +0800 Subject: [PATCH 02/44] chore(cli): add missing type annotations --- astrbot/cli/commands/cmd_conf.py | 10 +++++----- astrbot/cli/commands/cmd_init.py | 4 +++- astrbot/cli/commands/cmd_plug.py | 20 ++++++++++++-------- astrbot/cli/commands/cmd_run.py | 2 +- astrbot/cli/utils/plugin.py | 2 +- astrbot/cli/utils/version_comparator.py | 4 ++-- 6 files changed, 24 insertions(+), 18 deletions(-) diff --git a/astrbot/cli/commands/cmd_conf.py b/astrbot/cli/commands/cmd_conf.py index b556536be..090b28ec6 100644 --- a/astrbot/cli/commands/cmd_conf.py +++ b/astrbot/cli/commands/cmd_conf.py @@ -102,7 +102,7 @@ def _save_config(config: dict[str, Any]) -> None: ) -def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None: +def _set_nested_item(obj: dict[str, Any], path: str, value: object) -> None: """设置嵌套字典中的值""" parts = path.split(".") for part in parts[:-1]: @@ -116,7 +116,7 @@ def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None: obj[parts[-1]] = value -def _get_nested_item(obj: dict[str, Any], path: str) -> Any: +def _get_nested_item(obj: dict[str, Any], path: str) -> object: """获取嵌套字典中的值""" parts = path.split(".") for part in parts: @@ -125,7 +125,7 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any: @click.group(name="conf") -def conf(): +def conf() -> None: """配置管理命令 支持的配置项: @@ -148,7 +148,7 @@ def conf(): @conf.command(name="set") @click.argument("key") @click.argument("value") -def set_config(key: str, value: str): +def set_config(key: str, value: str) -> None: """设置配置项的值""" if key not in CONFIG_VALIDATORS.keys(): raise click.ClickException(f"不支持的配置项: {key}") @@ -177,7 +177,7 @@ def set_config(key: str, value: str): @conf.command(name="get") @click.argument("key", required=False) -def get_config(key: str = None): +def get_config(key: str | None = None) -> None: """获取配置项的值,不提供key则显示所有可配置项""" config = _load_config() diff --git a/astrbot/cli/commands/cmd_init.py b/astrbot/cli/commands/cmd_init.py index d9a42f822..f9724c55a 100644 --- a/astrbot/cli/commands/cmd_init.py +++ b/astrbot/cli/commands/cmd_init.py @@ -1,3 +1,5 @@ +from pathlib import Path + import asyncio import click @@ -6,7 +8,7 @@ from ..utils import check_dashboard, get_astrbot_root -async def initialize_astrbot(astrbot_root) -> None: +async def initialize_astrbot(astrbot_root: Path) -> None: """执行 AstrBot 初始化逻辑""" dot_astrbot = astrbot_root / ".astrbot" diff --git a/astrbot/cli/commands/cmd_plug.py b/astrbot/cli/commands/cmd_plug.py index 7513a7389..d554fff59 100644 --- a/astrbot/cli/commands/cmd_plug.py +++ b/astrbot/cli/commands/cmd_plug.py @@ -16,7 +16,7 @@ @click.group() -def plug(): +def plug() -> None: """插件管理""" pass @@ -30,7 +30,11 @@ def _get_data_path() -> Path: return (base / "data").resolve() -def display_plugins(plugins, title=None, color=None): +def display_plugins( + plugins: list[dict], + title: str | None = None, + color: int | tuple[int, int, int] | str | None = None, +) -> None: if title: click.echo(click.style(title, fg=color, bold=True)) @@ -47,7 +51,7 @@ def display_plugins(plugins, title=None, color=None): @plug.command() @click.argument("name") -def new(name: str): +def new(name: str) -> None: """创建新插件""" base_path = _get_data_path() plug_path = base_path / "plugins" / name @@ -102,7 +106,7 @@ def new(name: str): @plug.command() @click.option("--all", "-a", is_flag=True, help="列出未安装的插件") -def list(all: bool): +def list(all: bool) -> None: """列出插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") @@ -143,7 +147,7 @@ def list(all: bool): @plug.command() @click.argument("name") @click.option("--proxy", help="代理服务器地址") -def install(name: str, proxy: str | None): +def install(name: str, proxy: str | None) -> None: """安装插件""" base_path = _get_data_path() plug_path = base_path / "plugins" @@ -166,7 +170,7 @@ def install(name: str, proxy: str | None): @plug.command() @click.argument("name") -def remove(name: str): +def remove(name: str) -> None: """卸载插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") @@ -189,7 +193,7 @@ def remove(name: str): @plug.command() @click.argument("name", required=False) @click.option("--proxy", help="Github代理地址") -def update(name: str, proxy: str | None): +def update(name: str, proxy: str | None) -> None: """更新插件""" base_path = _get_data_path() plug_path = base_path / "plugins" @@ -227,7 +231,7 @@ def update(name: str, proxy: str | None): @plug.command() @click.argument("query") -def search(query: str): +def search(query: str) -> None: """搜索插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") diff --git a/astrbot/cli/commands/cmd_run.py b/astrbot/cli/commands/cmd_run.py index 38113744f..d707e401c 100644 --- a/astrbot/cli/commands/cmd_run.py +++ b/astrbot/cli/commands/cmd_run.py @@ -11,7 +11,7 @@ from ..utils import check_dashboard, check_astrbot_root, get_astrbot_root -async def run_astrbot(astrbot_root: Path): +async def run_astrbot(astrbot_root: Path) -> None: """运行 AstrBot""" from astrbot.core import logger, LogManager, LogBroker, db_helper from astrbot.core.initial_loader import InitialLoader diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index cd1fcd97b..12d0af03c 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -19,7 +19,7 @@ class PluginStatus(str, Enum): NOT_PUBLISHED = "未发布" -def get_git_repo(url: str, target_path: Path, proxy: str | None = None): +def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None: """从 Git 仓库下载代码并解压到指定路径""" temp_dir = Path(tempfile.mkdtemp()) try: diff --git a/astrbot/cli/utils/version_comparator.py b/astrbot/cli/utils/version_comparator.py index fecab885e..505d01e0d 100644 --- a/astrbot/cli/utils/version_comparator.py +++ b/astrbot/cli/utils/version_comparator.py @@ -17,7 +17,7 @@ def compare_version(v1: str, v2: str) -> int: v1 = v1.lower().replace("v", "") v2 = v2.lower().replace("v", "") - def split_version(version): + def split_version(version: str) -> tuple[list[int], list[int | str] | None]: match = re.match( r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$", version, @@ -79,7 +79,7 @@ def split_version(version): return 0 # 数字部分和预发布标签都相同 @staticmethod - def _split_prerelease(prerelease): + def _split_prerelease(prerelease: str) -> list[int | str] | None: if not prerelease: return None parts = prerelease.split(".") From ffc535f3b4beef10de1b41555765a10c0e2344f3 Mon Sep 17 00:00:00 2001 From: Dt8333 Date: Tue, 28 Oct 2025 11:05:04 +0800 Subject: [PATCH 03/44] chore(core.agent): ruff rewrite --- astrbot/core/agent/mcp_client.py | 3 +-- astrbot/core/agent/tool.py | 6 ++++-- astrbot/core/agent/tool_executor.py | 4 +++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 8db9d6f26..eb9c738e7 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -1,7 +1,6 @@ import asyncio import logging from datetime import timedelta -from typing import Optional from contextlib import AsyncExitStack from astrbot import logger from astrbot.core.utils.log_pipe import LogPipe @@ -96,7 +95,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: class MCPClient: def __init__(self): # Initialize session and client objects - self.session: Optional[mcp.ClientSession] = None + self.session: mcp.ClientSession | None = None self.exit_stack = AsyncExitStack() self.name: str | None = None diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index ae0ab761c..c78922649 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from deprecated import deprecated -from typing import Awaitable, Callable, Literal, Any, Optional +from typing import Literal, Any + +from collections.abc import Awaitable, Callable from .mcp_client import MCPClient @@ -71,7 +73,7 @@ def remove_tool(self, name: str): """Remove a tool by its name.""" self.tools = [tool for tool in self.tools if tool.name != name] - def get_tool(self, name: str) -> Optional[FunctionTool]: + def get_tool(self, name: str) -> FunctionTool | None: """Get a tool by its name.""" for tool in self.tools: if tool.name == name: diff --git a/astrbot/core/agent/tool_executor.py b/astrbot/core/agent/tool_executor.py index 34a2f5e77..793e37239 100644 --- a/astrbot/core/agent/tool_executor.py +++ b/astrbot/core/agent/tool_executor.py @@ -1,5 +1,7 @@ import mcp -from typing import Any, Generic, AsyncGenerator +from typing import Any, Generic + +from collections.abc import AsyncGenerator from .run_context import TContext, ContextWrapper from .tool import FunctionTool From 1e592257efede551c8bbbd6561eeefc2c01452fa Mon Sep 17 00:00:00 2001 From: Dt8333 Date: Tue, 28 Oct 2025 11:08:02 +0800 Subject: [PATCH 04/44] chore(core.agent): add missing type annotations --- astrbot/core/agent/handoff.py | 9 ++++-- astrbot/core/agent/hooks.py | 8 ++--- astrbot/core/agent/mcp_client.py | 10 +++--- astrbot/core/agent/runners/base.py | 2 +- .../agent/runners/tool_loop_agent_runner.py | 2 +- astrbot/core/agent/tool.py | 32 +++++++++++-------- astrbot/core/agent/tool_executor.py | 5 ++- 7 files changed, 39 insertions(+), 29 deletions(-) diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index d26463147..d218bd503 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -1,4 +1,4 @@ -from typing import Generic +from typing import Any, Generic from .tool import FunctionTool from .agent import Agent from .run_context import TContext @@ -8,8 +8,11 @@ class HandoffTool(FunctionTool, Generic[TContext]): """Handoff tool for delegating tasks to another agent.""" def __init__( - self, agent: Agent[TContext], parameters: dict | None = None, **kwargs - ): + self, + agent: Agent[TContext], + parameters: dict | None = None, + **kwargs: Any, # noqa: ANN401 + ) -> None: self.agent = agent super().__init__( name=f"transfer_to_{agent.name}", diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py index 884fe6bd4..f26baa45e 100644 --- a/astrbot/core/agent/hooks.py +++ b/astrbot/core/agent/hooks.py @@ -8,20 +8,20 @@ @dataclass class BaseAgentRunHooks(Generic[TContext]): - async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ... + async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ... async def on_tool_start( self, run_context: ContextWrapper[TContext], tool: FunctionTool, tool_args: dict | None, - ): ... + ) -> None: ... async def on_tool_end( self, run_context: ContextWrapper[TContext], tool: FunctionTool, tool_args: dict | None, tool_result: mcp.types.CallToolResult | None, - ): ... + ) -> None: ... async def on_agent_done( self, run_context: ContextWrapper[TContext], llm_response: LLMResponse - ): ... + ) -> None: ... diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index eb9c738e7..7bbb67235 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -93,7 +93,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: class MCPClient: - def __init__(self): + def __init__(self) -> None: # Initialize session and client objects self.session: mcp.ClientSession | None = None self.exit_stack = AsyncExitStack() @@ -104,7 +104,7 @@ def __init__(self): self.server_errlogs: list[str] = [] self.running_event = asyncio.Event() - async def connect_to_server(self, mcp_server_config: dict, name: str): + async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: """连接到 MCP 服务器 如果 `url` 参数存在: @@ -117,7 +117,7 @@ async def connect_to_server(self, mcp_server_config: dict, name: str): """ cfg = _prepare_config(mcp_server_config.copy()) - def logging_callback(msg: str): + def logging_callback(msg: str) -> None: # 处理 MCP 服务的错误日志 print(f"MCP Server {name} Error: {msg}") self.server_errlogs.append(msg) @@ -187,7 +187,7 @@ def logging_callback(msg: str): **cfg, ) - def callback(msg: str): + def callback(msg: str) -> None: # 处理 MCP 服务的错误日志 self.server_errlogs.append(msg) @@ -217,7 +217,7 @@ async def list_tools_and_save(self) -> mcp.ListToolsResult: self.tools = response.tools return response - async def cleanup(self): + async def cleanup(self) -> None: """Clean up resources""" await self.exit_stack.aclose() self.running_event.set() # Set the running event to indicate cleanup is done diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py index 83821ae29..182e498d5 100644 --- a/astrbot/core/agent/runners/base.py +++ b/astrbot/core/agent/runners/base.py @@ -26,7 +26,7 @@ async def reset( run_context: ContextWrapper[TContext], tool_executor: BaseFunctionToolExecutor[TContext], agent_hooks: BaseAgentRunHooks[TContext], - **kwargs: T.Any, + **kwargs: T.Any, # noqa: ANN401 ) -> None: """ Reset the agent to its initial state. diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 33298e895..dc0230e29 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -69,7 +69,7 @@ async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: yield await self.provider.text_chat(**self.req.__dict__) @override - async def step(self): + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: """ Process a single step of the agent. This method should return the result of the step. diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index c78922649..f4aa1ae4d 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -2,6 +2,8 @@ from deprecated import deprecated from typing import Literal, Any +from collections.abc import Iterator + from collections.abc import Awaitable, Callable from .mcp_client import MCPClient @@ -32,7 +34,7 @@ class FunctionTool: mcp_client: MCPClient | None = None """MCP 客户端,当 origin 为 mcp 时有效""" - def __repr__(self): + def __repr__(self) -> str: return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})" def __dict__(self) -> dict[str, Any]: @@ -53,14 +55,14 @@ class ToolSet: This class provides methods to add, remove, and retrieve tools, as well as convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).""" - def __init__(self, tools: list[FunctionTool] | None = None): + def __init__(self, tools: list[FunctionTool] | None = None) -> None: self.tools: list[FunctionTool] = tools or [] def empty(self) -> bool: """Check if the tool set is empty.""" return len(self.tools) == 0 - def add_tool(self, tool: FunctionTool): + def add_tool(self, tool: FunctionTool) -> None: """Add a tool to the set.""" # 检查是否已存在同名工具 for i, existing_tool in enumerate(self.tools): @@ -69,7 +71,7 @@ def add_tool(self, tool: FunctionTool): return self.tools.append(tool) - def remove_tool(self, name: str): + def remove_tool(self, name: str) -> None: """Remove a tool by its name.""" self.tools = [tool for tool in self.tools if tool.name != name] @@ -87,7 +89,7 @@ def add_func( func_args: list, desc: str, handler: Callable[..., Awaitable[Any]], - ): + ) -> None: """Add a function tool to the set.""" params = { "type": "object", # hard-coded here @@ -107,7 +109,7 @@ def add_func( self.add_tool(_func) @deprecated(reason="Use remove_tool() instead", version="4.0.0") - def remove_func(self, name: str): + def remove_func(self, name: str) -> None: """Remove a function tool by its name.""" self.remove_tool(name) @@ -238,32 +240,34 @@ def convert_schema(schema: dict) -> dict: return declarations @deprecated(reason="Use openai_schema() instead", version="4.0.0") - def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False): + def get_func_desc_openai_style( + self, omit_empty_parameter_field: bool = False + ) -> list[dict]: return self.openai_schema(omit_empty_parameter_field) @deprecated(reason="Use anthropic_schema() instead", version="4.0.0") - def get_func_desc_anthropic_style(self): + def get_func_desc_anthropic_style(self) -> list[dict]: return self.anthropic_schema() @deprecated(reason="Use google_schema() instead", version="4.0.0") - def get_func_desc_google_genai_style(self): + def get_func_desc_google_genai_style(self) -> dict: return self.google_schema() def names(self) -> list[str]: """获取所有工具的名称列表""" return [tool.name for tool in self.tools] - def __len__(self): + def __len__(self) -> int: return len(self.tools) - def __bool__(self): + def __bool__(self) -> bool: return len(self.tools) > 0 - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.tools) - def __repr__(self): + def __repr__(self) -> str: return f"ToolSet(tools={self.tools})" - def __str__(self): + def __str__(self) -> str: return f"ToolSet(tools={self.tools})" diff --git a/astrbot/core/agent/tool_executor.py b/astrbot/core/agent/tool_executor.py index 793e37239..39578c4ed 100644 --- a/astrbot/core/agent/tool_executor.py +++ b/astrbot/core/agent/tool_executor.py @@ -9,5 +9,8 @@ class BaseFunctionToolExecutor(Generic[TContext]): @classmethod async def execute( - cls, tool: FunctionTool, run_context: ContextWrapper[TContext], **tool_args + cls, + tool: FunctionTool, + run_context: ContextWrapper[TContext], + **tool_args: Any, # noqa: ANN401 ) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ... From a60d6fe94817b479cfcdd2ebf0ebf7e7edb2b447 Mon Sep 17 00:00:00 2001 From: Dt8333 Date: Tue, 28 Oct 2025 11:11:24 +0800 Subject: [PATCH 05/44] =?UTF-8?q?fix(core):=20=E9=87=8D=E5=91=BD=E5=90=8D?= =?UTF-8?q?=5F=5Fdict=5F=5F=E6=96=B9=E6=B3=95=E9=81=BF=E5=85=8D=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E5=86=B2=E7=AA=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将FunctionTool类的__dict__方法重命名为to_dict,修复Pylance报告的类型不兼容错误 --- astrbot/core/agent/tool.py | 2 +- astrbot/dashboard/routes/tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index f4aa1ae4d..cdaeba028 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -37,7 +37,7 @@ class FunctionTool: def __repr__(self) -> str: return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})" - def __dict__(self) -> dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """将 FunctionTool 转换为字典格式""" return { "name": self.name, diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 8fd89919a..38c3defd2 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -300,7 +300,7 @@ async def get_tool_list(self): """获取所有注册的工具列表""" try: tools = self.tool_mgr.func_list - tools_dict = [tool.__dict__() for tool in tools] + tools_dict = [tool.to_dict() for tool in tools] return Response().ok(data=tools_dict).__dict__ except Exception as e: logger.error(traceback.format_exc()) From 5ea6e13f9f66827ffbc411b72cb8e8d46f9eee94 Mon Sep 17 00:00:00 2001 From: Dt8333 Date: Tue, 28 Oct 2025 11:40:12 +0800 Subject: [PATCH 06/44] chore(core.config): ruff rewrite --- astrbot/core/config/astrbot_config.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 5d1f6fbe7..48f76d6e6 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -3,7 +3,6 @@ import logging import enum from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP -from typing import Dict from astrbot.core.utils.astrbot_path import get_astrbot_data_path ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json") @@ -45,7 +44,7 @@ def __init__( json.dump(default_config, f, indent=4, ensure_ascii=False) object.__setattr__(self, "first_deploy", True) # 标记第一次部署 - with open(config_path, "r", encoding="utf-8-sig") as f: + with open(config_path, encoding="utf-8-sig") as f: conf_str = f.read() conf = json.loads(conf_str) @@ -82,7 +81,7 @@ def _parse_schema(schema: dict, conf: dict): return conf - def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""): + def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): """检查配置完整性,如果有新的配置项或顺序不一致则返回 True""" has_new = False @@ -140,7 +139,7 @@ def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""): return has_new - def save_config(self, replace_config: Dict = None): + def save_config(self, replace_config: dict = None): """将配置写入文件 如果传入 replace_config,则将配置替换为 replace_config From c414039fb85f248de847151287c96cc9cfaa545e Mon Sep 17 00:00:00 2001 From: Dt8333 Date: Tue, 28 Oct 2025 11:44:41 +0800 Subject: [PATCH 07/44] chore(core.config): add missing type annotations --- astrbot/core/astrbot_config_mgr.py | 4 ++-- astrbot/core/config/astrbot_config.py | 25 ++++++++++++++++--------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/astrbot/core/astrbot_config_mgr.py b/astrbot/core/astrbot_config_mgr.py index 0ee3f4fe6..c99a67296 100644 --- a/astrbot/core/astrbot_config_mgr.py +++ b/astrbot/core/astrbot_config_mgr.py @@ -35,7 +35,7 @@ def __init__( default_config: AstrBotConfig, ucr: UmopConfigRouter, sp: SharedPreferences, - ): + ) -> None: self.sp = sp self.ucr = ucr self.confs: dict[str, AstrBotConfig] = {} @@ -52,7 +52,7 @@ def _get_abconf_data(self) -> dict: ) return self.abconf_data - def _load_all_configs(self): + def _load_all_configs(self) -> None: """Load all configurations from the shared preferences.""" abconf_data = self._get_abconf_data() self.abconf_data = abconf_data diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 48f76d6e6..cd8a89196 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -22,14 +22,19 @@ class AstrBotConfig(dict): - 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。 """ + # Class-level 属性注解,帮助类型检查器正确推断实例属性类型 + config_path: str + default_config: dict + schema: dict | None + first_deploy: bool | None + def __init__( self, config_path: str = ASTRBOT_CONFIG_PATH, default_config: dict = DEFAULT_CONFIG, - schema: dict = None, - ): + schema: dict | None = None, + ) -> None: super().__init__() - # 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件 object.__setattr__(self, "config_path", config_path) object.__setattr__(self, "default_config", default_config) @@ -60,7 +65,7 @@ def _config_schema_to_default_config(self, schema: dict) -> dict: """将 Schema 转换成 Config""" conf = {} - def _parse_schema(schema: dict, conf: dict): + def _parse_schema(schema: dict, conf: dict) -> None: for k, v in schema.items(): if v["type"] not in DEFAULT_VALUE_MAP: raise TypeError( @@ -81,7 +86,9 @@ def _parse_schema(schema: dict, conf: dict): return conf - def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): + def check_config_integrity( + self, refer_conf: dict, conf: dict, path: str = "" + ) -> bool: """检查配置完整性,如果有新的配置项或顺序不一致则返回 True""" has_new = False @@ -139,7 +146,7 @@ def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): return has_new - def save_config(self, replace_config: dict = None): + def save_config(self, replace_config: dict | None = None) -> None: """将配置写入文件 如果传入 replace_config,则将配置替换为 replace_config @@ -149,20 +156,20 @@ def save_config(self, replace_config: dict = None): with open(self.config_path, "w", encoding="utf-8-sig") as f: json.dump(self, f, indent=2, ensure_ascii=False) - def __getattr__(self, item): + def __getattr__(self, item: str) -> object: try: return self[item] except KeyError: return None - def __delattr__(self, key): + def __delattr__(self, key: str) -> None: try: del self[key] self.save_config() except KeyError: raise AttributeError(f"没有找到 Key: '{key}'") - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: object) -> None: self[key] = value def check_exist(self) -> bool: From 27c815dcf82888f9d8cf4976af4c209af9eddcdd Mon Sep 17 00:00:00 2001 From: Dt8333 Date: Tue, 28 Oct 2025 11:48:14 +0800 Subject: [PATCH 08/44] chore(core.convmgr): ruff rewrite --- astrbot/core/conversation_mgr.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 8f8e2e0e9..beb61c663 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -7,7 +7,8 @@ import json from astrbot.core import sp -from typing import Dict, List, Callable, Awaitable + +from collections.abc import Callable, Awaitable from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Conversation, ConversationV2 @@ -16,12 +17,12 @@ class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" def __init__(self, db_helper: BaseDatabase): - self.session_conversations: Dict[str, str] = {} + self.session_conversations: dict[str, str] = {} self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 # 会话删除回调函数列表(用于级联清理,如知识库配置) - self._on_session_deleted_callbacks: List[Callable[[str], Awaitable[None]]] = [] + self._on_session_deleted_callbacks: list[Callable[[str], Awaitable[None]]] = [] def register_on_session_deleted( self, callback: Callable[[str], Awaitable[None]] @@ -182,7 +183,7 @@ async def get_conversation( async def get_conversations( self, unified_msg_origin: str | None = None, platform_id: str | None = None - ) -> List[Conversation]: + ) -> list[Conversation]: """获取对话列表 Args: From d5759697edf829ca860ecf48527e001f2f274bf5 Mon Sep 17 00:00:00 2001 From: Dt8333 Date: Tue, 28 Oct 2025 12:09:08 +0800 Subject: [PATCH 09/44] chore(core.convmgr): add missing type annotations --- astrbot/core/conversation_mgr.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index beb61c663..a4262c998 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -6,6 +6,7 @@ """ import json +from typing import Any from astrbot.core import sp from collections.abc import Callable, Awaitable @@ -16,7 +17,7 @@ class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" - def __init__(self, db_helper: BaseDatabase): + def __init__(self, db_helper: BaseDatabase) -> None: self.session_conversations: dict[str, str] = {} self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 @@ -101,7 +102,9 @@ async def new_conversation( await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id) return conv.conversation_id - async def switch_conversation(self, unified_msg_origin: str, conversation_id: str): + async def switch_conversation( + self, unified_msg_origin: str, conversation_id: str + ) -> None: """切换会话的对话 Args: @@ -113,7 +116,7 @@ async def switch_conversation(self, unified_msg_origin: str, conversation_id: st async def delete_conversation( self, unified_msg_origin: str, conversation_id: str | None = None - ): + ) -> None: """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话 Args: @@ -129,7 +132,7 @@ async def delete_conversation( self.session_conversations.pop(unified_msg_origin, None) await sp.session_remove(unified_msg_origin, "sel_conv_id") - async def delete_conversations_by_user_id(self, unified_msg_origin: str): + async def delete_conversations_by_user_id(self, unified_msg_origin: str) -> None: """删除会话的所有对话 Args: @@ -207,7 +210,7 @@ async def get_filtered_conversations( page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", - **kwargs, + **kwargs: Any, # noqa: ANN401 ) -> tuple[list[Conversation], int]: """获取过滤后的对话列表 @@ -239,7 +242,7 @@ async def update_conversation( history: list[dict] | None = None, title: str | None = None, persona_id: str | None = None, - ): + ) -> None: """更新会话的对话 Args: @@ -260,7 +263,7 @@ async def update_conversation( async def update_conversation_title( self, unified_msg_origin: str, title: str, conversation_id: str | None = None - ): + ) -> None: """更新会话的对话标题 Args: @@ -281,7 +284,7 @@ async def update_conversation_persona_id( unified_msg_origin: str, persona_id: str, conversation_id: str | None = None, - ): + ) -> None: """更新会话的对话 Persona ID Args: @@ -298,8 +301,12 @@ async def update_conversation_persona_id( ) async def get_human_readable_context( - self, unified_msg_origin, conversation_id, page=1, page_size=10 - ): + self, + unified_msg_origin: str, + conversation_id: str, + page: int = 1, + page_size: int = 10, + ) -> tuple[list[str], int]: """获取人类可读的上下文 Args: From b17f50f243dedeaf698c50eb4d7a982907316b59 Mon Sep 17 00:00:00 2001 From: Dt8333 Date: Tue, 28 Oct 2025 12:15:00 +0800 Subject: [PATCH 10/44] chore(core.db): ruff rewrite --- .../core/db/migration/shared_preferences_v3.py | 2 +- astrbot/core/db/migration/sqlite_v3.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/astrbot/core/db/migration/shared_preferences_v3.py b/astrbot/core/db/migration/shared_preferences_v3.py index 6a661bd3d..4e2dd916f 100644 --- a/astrbot/core/db/migration/shared_preferences_v3.py +++ b/astrbot/core/db/migration/shared_preferences_v3.py @@ -16,7 +16,7 @@ def __init__(self, path=None): def _load_preferences(self): if os.path.exists(self.path): try: - with open(self.path, "r") as f: + with open(self.path) as f: return json.load(f) except json.JSONDecodeError: os.remove(self.path) diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index ad86c51f3..af0cbb35d 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -1,7 +1,7 @@ import sqlite3 import time from astrbot.core.db.po import Platform, Stats -from typing import Tuple, List, Dict, Any +from typing import Any from dataclasses import dataclass @@ -126,7 +126,7 @@ def _get_conn(self, db_path: str) -> sqlite3.Connection: conn.text_factory = str return conn - def _exec_sql(self, sql: str, params: Tuple = None): + def _exec_sql(self, sql: str, params: tuple = None): conn = self.conn try: c = self.conn.cursor() @@ -257,7 +257,7 @@ def new_conversation(self, user_id: str, cid: str): (user_id, cid, history, updated_at, created_at), ) - def get_conversations(self, user_id: str) -> Tuple: + def get_conversations(self, user_id: str) -> tuple: try: c = self.conn.cursor() except sqlite3.ProgrammingError: @@ -320,7 +320,7 @@ def delete_conversation(self, user_id: str, cid: str): def get_all_conversations( self, page: int = 1, page_size: int = 20 - ) -> Tuple[List[Dict[str, Any]], int]: + ) -> tuple[list[dict[str, Any]], int]: """获取所有对话,支持分页,按更新时间降序排序""" try: c = self.conn.cursor() @@ -381,12 +381,12 @@ def get_filtered_conversations( self, page: int = 1, page_size: int = 20, - platforms: List[str] = None, - message_types: List[str] = None, + platforms: list[str] = None, + message_types: list[str] = None, search_query: str = None, - exclude_ids: List[str] = None, - exclude_platforms: List[str] = None, - ) -> Tuple[List[Dict[str, Any]], int]: + exclude_ids: list[str] = None, + exclude_platforms: list[str] = None, + ) -> tuple[list[dict[str, Any]], int]: """获取筛选后的对话列表""" try: c = self.conn.cursor() From 504c6d82806288b2aec8ce1a8016408d08eabe0a Mon Sep 17 00:00:00 2001 From: Dale Null Date: Tue, 28 Oct 2025 17:46:31 +0800 Subject: [PATCH 11/44] chore(core.db): add missing type annotations --- astrbot/core/db/__init__.py | 20 ++- astrbot/core/db/migration/helper.py | 2 +- astrbot/core/db/migration/migra_3_to_4.py | 10 +- astrbot/core/db/migration/migra_45_to_46.py | 2 +- .../db/migration/shared_preferences_v3.py | 14 +- astrbot/core/db/migration/sqlite_v3.py | 34 ++-- astrbot/core/db/sqlite.py | 166 ++++++++++-------- astrbot/core/db/vec_db/base.py | 7 +- .../db/vec_db/faiss_impl/document_storage.py | 21 +-- .../db/vec_db/faiss_impl/embedding_storage.py | 10 +- astrbot/core/db/vec_db/faiss_impl/vec_db.py | 13 +- 11 files changed, 167 insertions(+), 132 deletions(-) diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 0abd3ad49..cb2269235 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -35,7 +35,7 @@ def __init__(self) -> None: self.engine, class_=AsyncSession, expire_on_commit=False ) - async def initialize(self): + async def initialize(self) -> None: """初始化数据库连接""" pass @@ -100,7 +100,7 @@ async def get_conversations( ... @abc.abstractmethod - async def get_conversation_by_id(self, cid: str) -> ConversationV2: + async def get_conversation_by_id(self, cid: str) -> ConversationV2 | None: """Get a specific conversation by its ID.""" ... @@ -118,7 +118,7 @@ async def get_filtered_conversations( page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", - **kwargs, + **kwargs: T.Any, # noqa: ANN401 ) -> tuple[list[ConversationV2], int]: """Get conversations filtered by platform IDs and search query.""" ... @@ -145,7 +145,7 @@ async def update_conversation( title: str | None = None, persona_id: str | None = None, content: list[dict] | None = None, - ) -> None: + ) -> ConversationV2 | None: """Update a conversation's history.""" ... @@ -167,7 +167,7 @@ async def insert_platform_message_history( content: dict, sender_id: str | None = None, sender_name: str | None = None, - ) -> None: + ) -> PlatformMessageHistory | None: """Insert a new platform message history record.""" ... @@ -195,12 +195,12 @@ async def insert_attachment( path: str, type: str, mime_type: str, - ): + ) -> Attachment: """Insert a new attachment record.""" ... @abc.abstractmethod - async def get_attachment_by_id(self, attachment_id: str) -> Attachment: + async def get_attachment_by_id(self, attachment_id: str) -> Attachment | None: """Get an attachment by its ID.""" ... @@ -216,7 +216,7 @@ async def insert_persona( ... @abc.abstractmethod - async def get_persona_by_id(self, persona_id: str) -> Persona: + async def get_persona_by_id(self, persona_id: str) -> Persona | None: """Get a persona by its ID.""" ... @@ -249,7 +249,9 @@ async def insert_preference_or_update( ... @abc.abstractmethod - async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference: + async def get_preference( + self, scope: str, scope_id: str, key: str + ) -> Preference | None: """Get a preference by scope ID and key.""" ... diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py index 901cdc4ed..257d69ab6 100644 --- a/astrbot/core/db/migration/helper.py +++ b/astrbot/core/db/migration/helper.py @@ -32,7 +32,7 @@ async def do_migration_v4( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], astrbot_config: AstrBotConfig, -): +) -> None: """ 执行数据库迁移 迁移旧的 webchat_conversation 表到新的 conversation 表。 diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index 4aa5082db..21874aa12 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -37,7 +37,7 @@ def get_platform_type( async def migration_conversation_table( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] -): +) -> None: db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db") ) @@ -91,7 +91,7 @@ async def migration_conversation_table( async def migration_platform_table( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] -): +) -> None: db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db") ) @@ -166,7 +166,7 @@ async def migration_platform_table( async def migration_webchat_data( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] -): +) -> None: """迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中""" db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db") @@ -219,7 +219,7 @@ async def migration_webchat_data( async def migration_persona_data( db_helper: BaseDatabase, astrbot_config: AstrBotConfig -): +) -> None: """ 迁移 Persona 数据到新的表中。 旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。 @@ -261,7 +261,7 @@ async def migration_persona_data( async def migration_preferences( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] -): +) -> None: # 1. global scope migration keys = [ "inactivated_llm_tools", diff --git a/astrbot/core/db/migration/migra_45_to_46.py b/astrbot/core/db/migration/migra_45_to_46.py index 8a1dc5de7..c01bdfc1d 100644 --- a/astrbot/core/db/migration/migra_45_to_46.py +++ b/astrbot/core/db/migration/migra_45_to_46.py @@ -3,7 +3,7 @@ from astrbot.core.umop_config_router import UmopConfigRouter -async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter): +async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter) -> None: abconf_data = acm.abconf_data if not isinstance(abconf_data, dict): diff --git a/astrbot/core/db/migration/shared_preferences_v3.py b/astrbot/core/db/migration/shared_preferences_v3.py index 4e2dd916f..e137078ee 100644 --- a/astrbot/core/db/migration/shared_preferences_v3.py +++ b/astrbot/core/db/migration/shared_preferences_v3.py @@ -7,13 +7,13 @@ class SharedPreferences: - def __init__(self, path=None): + def __init__(self, path: str | None = None) -> None: if path is None: path = os.path.join(get_astrbot_data_path(), "shared_preferences.json") self.path = path self._data = self._load_preferences() - def _load_preferences(self): + def _load_preferences(self) -> dict: if os.path.exists(self.path): try: with open(self.path) as f: @@ -22,24 +22,24 @@ def _load_preferences(self): os.remove(self.path) return {} - def _save_preferences(self): + def _save_preferences(self) -> None: with open(self.path, "w") as f: json.dump(self._data, f, indent=4, ensure_ascii=False) f.flush() - def get(self, key, default: _VT = None) -> _VT: + def get(self, key: str, default: _VT = None) -> _VT: return self._data.get(key, default) - def put(self, key, value): + def put(self, key: str, value: object) -> None: self._data[key] = value self._save_preferences() - def remove(self, key): + def remove(self, key: str) -> None: if key in self._data: del self._data[key] self._save_preferences() - def clear(self): + def clear(self) -> None: self._data.clear() self._save_preferences() diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index af0cbb35d..4a258f4c9 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -126,7 +126,7 @@ def _get_conn(self, db_path: str) -> sqlite3.Connection: conn.text_factory = str return conn - def _exec_sql(self, sql: str, params: tuple = None): + def _exec_sql(self, sql: str, params: tuple | None = None) -> None: conn = self.conn try: c = self.conn.cursor() @@ -143,7 +143,7 @@ def _exec_sql(self, sql: str, params: tuple = None): conn.commit() - def insert_platform_metrics(self, metrics: dict): + def insert_platform_metrics(self, metrics: dict) -> None: for k, v in metrics.items(): self._exec_sql( """ @@ -152,7 +152,7 @@ def insert_platform_metrics(self, metrics: dict): (k, v, int(time.time())), ) - def insert_llm_metrics(self, metrics: dict): + def insert_llm_metrics(self, metrics: dict) -> None: for k, v in metrics.items(): self._exec_sql( """ @@ -225,7 +225,9 @@ def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: return Stats(platform, [], []) - def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: + def get_conversation_by_user_id( + self, user_id: str, cid: str + ) -> Conversation | None: try: c = self.conn.cursor() except sqlite3.ProgrammingError: @@ -246,7 +248,7 @@ def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: return Conversation(*res) - def new_conversation(self, user_id: str, cid: str): + def new_conversation(self, user_id: str, cid: str) -> None: history = "[]" updated_at = int(time.time()) created_at = updated_at @@ -257,7 +259,7 @@ def new_conversation(self, user_id: str, cid: str): (user_id, cid, history, updated_at, created_at), ) - def get_conversations(self, user_id: str) -> tuple: + def get_conversations(self, user_id: str) -> list[Conversation]: try: c = self.conn.cursor() except sqlite3.ProgrammingError: @@ -284,7 +286,7 @@ def get_conversations(self, user_id: str) -> tuple: ) return conversations - def update_conversation(self, user_id: str, cid: str, history: str): + def update_conversation(self, user_id: str, cid: str, history: str) -> None: """更新对话,并且同时更新时间""" updated_at = int(time.time()) self._exec_sql( @@ -294,7 +296,7 @@ def update_conversation(self, user_id: str, cid: str, history: str): (history, updated_at, user_id, cid), ) - def update_conversation_title(self, user_id: str, cid: str, title: str): + def update_conversation_title(self, user_id: str, cid: str, title: str) -> None: self._exec_sql( """ UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? @@ -302,7 +304,9 @@ def update_conversation_title(self, user_id: str, cid: str, title: str): (title, user_id, cid), ) - def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): + def update_conversation_persona_id( + self, user_id: str, cid: str, persona_id: str + ) -> None: self._exec_sql( """ UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? @@ -310,7 +314,7 @@ def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str (persona_id, user_id, cid), ) - def delete_conversation(self, user_id: str, cid: str): + def delete_conversation(self, user_id: str, cid: str) -> None: self._exec_sql( """ DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? @@ -381,11 +385,11 @@ def get_filtered_conversations( self, page: int = 1, page_size: int = 20, - platforms: list[str] = None, - message_types: list[str] = None, - search_query: str = None, - exclude_ids: list[str] = None, - exclude_platforms: list[str] = None, + platforms: list[str] | None = None, + message_types: list[str] | None = None, + search_query: str | None = None, + exclude_ids: list[str] | None = None, + exclude_platforms: list[str] | None = None, ) -> tuple[list[dict[str, Any]], int]: """获取筛选后的对话列表""" try: diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index f9faede19..2c469c92b 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -46,10 +46,10 @@ async def initialize(self) -> None: async def insert_platform_stats( self, - platform_id, - platform_type, - count=1, - timestamp=None, + platform_id: str, + platform_type: str, + count: int = 1, + timestamp: datetime | None = None, ) -> None: """Insert a new platform statistic record.""" async with self.get_db() as session: @@ -108,7 +108,9 @@ async def get_platform_stats(self, offset_sec: int = 86400) -> T.List[PlatformSt # Conversation Management # ==== - async def get_conversations(self, user_id=None, platform_id=None): + async def get_conversations( + self, user_id: str | None = None, platform_id: str | None = None + ) -> list[ConversationV2]: async with self.get_db() as session: session: AsyncSession query = select(ConversationV2) @@ -121,16 +123,18 @@ async def get_conversations(self, user_id=None, platform_id=None): query = query.order_by(desc(ConversationV2.created_at)) result = await session.execute(query) - return result.scalars().all() + return result.scalars().all() # type: ignore - async def get_conversation_by_id(self, cid): + async def get_conversation_by_id(self, cid: str) -> ConversationV2 | None: async with self.get_db() as session: session: AsyncSession query = select(ConversationV2).where(ConversationV2.conversation_id == cid) result = await session.execute(query) return result.scalar_one_or_none() - async def get_all_conversations(self, page=1, page_size=20): + async def get_all_conversations( + self, page: int = 1, page_size: int = 20 + ) -> list[ConversationV2]: async with self.get_db() as session: session: AsyncSession offset = (page - 1) * page_size @@ -140,16 +144,16 @@ async def get_all_conversations(self, page=1, page_size=20): .offset(offset) .limit(page_size) ) - return result.scalars().all() + return result.scalars().all() # type:ignore async def get_filtered_conversations( self, - page=1, - page_size=20, - platform_ids=None, - search_query="", - **kwargs, - ): + page: int = 1, + page_size: int = 20, + platform_ids: list[str] | None = None, + search_query: str = "", + **kwargs: T.Any, # noqa:ANN401 + ) -> tuple[list[ConversationV2], int]: async with self.get_db() as session: session: AsyncSession # Build the base query with filters @@ -194,19 +198,19 @@ async def get_filtered_conversations( result = await session.execute(result_query) conversations = result.scalars().all() - return conversations, total + return conversations, total # type:ignore async def create_conversation( self, - user_id, - platform_id, - content=None, - title=None, - persona_id=None, - cid=None, - created_at=None, - updated_at=None, - ): + user_id: str, + platform_id: str, + content: list[dict] | None = None, + title: str | None = None, + persona_id: str | None = None, + cid: str | None = None, + created_at: datetime | None = None, + updated_at: datetime | None = None, + ) -> ConversationV2: kwargs = {} if cid: kwargs["conversation_id"] = cid @@ -228,7 +232,13 @@ async def create_conversation( session.add(new_conversation) return new_conversation - async def update_conversation(self, cid, title=None, persona_id=None, content=None): + async def update_conversation( + self, + cid: str, + title: str | None = None, + persona_id: str | None = None, + content: list[dict] | None = None, + ) -> ConversationV2 | None: async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -248,7 +258,7 @@ async def update_conversation(self, cid, title=None, persona_id=None, content=No await session.execute(query) return await self.get_conversation_by_id(cid) - async def delete_conversation(self, cid): + async def delete_conversation(self, cid: str) -> None: async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -268,10 +278,10 @@ async def delete_conversations_by_user_id(self, user_id: str) -> None: async def get_session_conversations( self, - page=1, - page_size=20, - search_query=None, - platform=None, + page: int = 1, + page_size: int = 20, + search_query: str | None = None, + platform: str | None = None, ) -> tuple[list[dict], int]: """Get paginated session conversations with joined conversation and persona details.""" async with self.get_db() as session: @@ -375,12 +385,12 @@ async def get_session_conversations( async def insert_platform_message_history( self, - platform_id, - user_id, - content, - sender_id=None, - sender_name=None, - ): + platform_id: str, + user_id: str, + content: dict, + sender_id: str | None = None, + sender_name: str | None = None, + ) -> PlatformMessageHistory: """Insert a new platform message history record.""" async with self.get_db() as session: session: AsyncSession @@ -396,8 +406,8 @@ async def insert_platform_message_history( return new_history async def delete_platform_message_offset( - self, platform_id, user_id, offset_sec=86400 - ): + self, platform_id: str, user_id: str, offset_sec: int = 86400 + ) -> None: """Delete platform message history records older than the specified offset.""" async with self.get_db() as session: session: AsyncSession @@ -413,8 +423,8 @@ async def delete_platform_message_offset( ) async def get_platform_message_history( - self, platform_id, user_id, page=1, page_size=20 - ): + self, platform_id: str, user_id: str, page: int = 1, page_size: int = 20 + ) -> list[PlatformMessageHistory]: """Get platform message history records.""" async with self.get_db() as session: session: AsyncSession @@ -428,9 +438,11 @@ async def get_platform_message_history( .order_by(desc(PlatformMessageHistory.created_at)) ) result = await session.execute(query.offset(offset).limit(page_size)) - return result.scalars().all() + return result.scalars().all() # type:ignore - async def insert_attachment(self, path, type, mime_type): + async def insert_attachment( + self, path: str, type: str, mime_type: str + ) -> Attachment: """Insert a new attachment record.""" async with self.get_db() as session: session: AsyncSession @@ -443,7 +455,7 @@ async def insert_attachment(self, path, type, mime_type): session.add(new_attachment) return new_attachment - async def get_attachment_by_id(self, attachment_id): + async def get_attachment_by_id(self, attachment_id: str) -> Attachment | None: """Get an attachment by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -452,8 +464,12 @@ async def get_attachment_by_id(self, attachment_id): return result.scalar_one_or_none() async def insert_persona( - self, persona_id, system_prompt, begin_dialogs=None, tools=None - ): + self, + persona_id: str, + system_prompt: str, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, + ) -> Persona: """Insert a new persona record.""" async with self.get_db() as session: session: AsyncSession @@ -467,7 +483,7 @@ async def insert_persona( session.add(new_persona) return new_persona - async def get_persona_by_id(self, persona_id): + async def get_persona_by_id(self, persona_id: str) -> Persona | None: """Get a persona by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -475,17 +491,21 @@ async def get_persona_by_id(self, persona_id): result = await session.execute(query) return result.scalar_one_or_none() - async def get_personas(self): + async def get_personas(self) -> list[Persona]: """Get all personas for a specific bot.""" async with self.get_db() as session: session: AsyncSession query = select(Persona) result = await session.execute(query) - return result.scalars().all() + return result.scalars().all() # type:ignore async def update_persona( - self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN - ): + self, + persona_id: str, + system_prompt: str | None = None, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, + ) -> Persona | None: """Update a persona's system prompt or begin dialogs.""" async with self.get_db() as session: session: AsyncSession @@ -504,7 +524,7 @@ async def update_persona( await session.execute(query) return await self.get_persona_by_id(persona_id) - async def delete_persona(self, persona_id): + async def delete_persona(self, persona_id: str) -> None: """Delete a persona by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -513,7 +533,9 @@ async def delete_persona(self, persona_id): delete(Persona).where(col(Persona.persona_id) == persona_id) ) - async def insert_preference_or_update(self, scope, scope_id, key, value): + async def insert_preference_or_update( + self, scope: str, scope_id: str, key: str, value: dict + ) -> Preference: """Insert a new preference record or update if it exists.""" async with self.get_db() as session: session: AsyncSession @@ -534,7 +556,9 @@ async def insert_preference_or_update(self, scope, scope_id, key, value): session.add(new_preference) return existing_preference or new_preference - async def get_preference(self, scope, scope_id, key): + async def get_preference( + self, scope: str, scope_id: str, key: str + ) -> Preference | None: """Get a preference by key.""" async with self.get_db() as session: session: AsyncSession @@ -546,7 +570,9 @@ async def get_preference(self, scope, scope_id, key): result = await session.execute(query) return result.scalar_one_or_none() - async def get_preferences(self, scope, scope_id=None, key=None): + async def get_preferences( + self, scope: str, scope_id: str | None = None, key: str | None = None + ) -> list[Preference]: """Get all preferences for a specific scope ID or key.""" async with self.get_db() as session: session: AsyncSession @@ -556,9 +582,9 @@ async def get_preferences(self, scope, scope_id=None, key=None): if key is not None: query = query.where(Preference.key == key) result = await session.execute(query) - return result.scalars().all() + return result.scalars().all() # type:ignore - async def remove_preference(self, scope, scope_id, key): + async def remove_preference(self, scope: str, scope_id: str, key: str) -> None: """Remove a preference by scope ID and key.""" async with self.get_db() as session: session: AsyncSession @@ -572,7 +598,7 @@ async def remove_preference(self, scope, scope_id, key): ) await session.commit() - async def clear_preferences(self, scope, scope_id): + async def clear_preferences(self, scope: str, scope_id: str) -> None: """Clear all preferences for a specific scope ID.""" async with self.get_db() as session: session: AsyncSession @@ -589,10 +615,10 @@ async def clear_preferences(self, scope, scope_id): # Deprecated Methods # ==== - def get_base_stats(self, offset_sec=86400): + def get_base_stats(self, offset_sec: int = 86400) -> DeprecatedStats: """Get base statistics within the specified offset in seconds.""" - async def _inner(): + async def _inner() -> DeprecatedStats: async with self.get_db() as session: session: AsyncSession now = datetime.now() @@ -614,19 +640,19 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() - return result + return result # type:ignore - def get_total_message_count(self): + def get_total_message_count(self) -> int: """Get the total message count from platform statistics.""" - async def _inner(): + async def _inner() -> int: async with self.get_db() as session: session: AsyncSession result = await session.execute( @@ -637,18 +663,18 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() - return result + return result # type:ignore - def get_grouped_base_stats(self, offset_sec=86400): + def get_grouped_base_stats(self, offset_sec: int = 86400) -> DeprecatedStats: # group by platform_id - async def _inner(): + async def _inner() -> DeprecatedStats: async with self.get_db() as session: session: AsyncSession now = datetime.now() @@ -672,11 +698,11 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() - return result + return result # type:ignore diff --git a/astrbot/core/db/vec_db/base.py b/astrbot/core/db/vec_db/base.py index 27fc9f3fb..46242881d 100644 --- a/astrbot/core/db/vec_db/base.py +++ b/astrbot/core/db/vec_db/base.py @@ -1,5 +1,6 @@ import abc from dataclasses import dataclass +from types import FunctionType @dataclass @@ -9,7 +10,7 @@ class Result: class BaseVecDB: - async def initialize(self): + async def initialize(self) -> None: """ 初始化向量数据库 """ @@ -33,7 +34,7 @@ async def insert_batch( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: FunctionType | None = None, ) -> int: """ 批量插入文本和其对应向量,自动生成 ID 并保持一致性。 @@ -74,4 +75,4 @@ async def delete(self, doc_id: str) -> bool: ... @abc.abstractmethod - async def close(self): ... + async def close(self) -> None: ... diff --git a/astrbot/core/db/vec_db/faiss_impl/document_storage.py b/astrbot/core/db/vec_db/faiss_impl/document_storage.py index 265c0cc43..629c895c9 100644 --- a/astrbot/core/db/vec_db/faiss_impl/document_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/document_storage.py @@ -2,6 +2,7 @@ import json from datetime import datetime from contextlib import asynccontextmanager +from collections.abc import AsyncGenerator from sqlalchemy import Text, Column from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine @@ -30,7 +31,7 @@ class Document(BaseDocModel, table=True): class DocumentStorage: - def __init__(self, db_path: str): + def __init__(self, db_path: str) -> None: self.db_path = db_path self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" self.engine: AsyncEngine | None = None @@ -39,7 +40,7 @@ def __init__(self, db_path: str): os.path.dirname(__file__), "sqlite_init.sql" ) - async def initialize(self): + async def initialize(self) -> None: """Initialize the SQLite database and create the documents table if it doesn't exist.""" await self.connect() async with self.engine.begin() as conn: # type: ignore @@ -76,7 +77,7 @@ async def initialize(self): await conn.commit() - async def connect(self): + async def connect(self) -> None: """Connect to the SQLite database.""" if self.engine is None: self.engine = create_async_engine( @@ -91,7 +92,7 @@ async def connect(self): ) # type: ignore @asynccontextmanager - async def get_session(self): + async def get_session(self) -> AsyncGenerator[AsyncSession, None]: """Context manager for database sessions.""" async with self.async_session_maker() as session: # type: ignore yield session @@ -203,7 +204,7 @@ async def insert_documents_batch( await session.flush() # Flush to get all IDs return [doc.id for doc in documents] # type: ignore - async def delete_document_by_doc_id(self, doc_id: str): + async def delete_document_by_doc_id(self, doc_id: str) -> None: """Delete a document by its doc_id. Args: @@ -220,7 +221,7 @@ async def delete_document_by_doc_id(self, doc_id: str): if document: await session.delete(document) - async def get_document_by_doc_id(self, doc_id: str): + async def get_document_by_doc_id(self, doc_id: str) -> dict | None: """Retrieve a document by its doc_id. Args: @@ -240,7 +241,7 @@ async def get_document_by_doc_id(self, doc_id: str): return self._document_to_dict(document) return None - async def update_document_by_doc_id(self, doc_id: str, new_text: str): + async def update_document_by_doc_id(self, doc_id: str, new_text: str) -> None: """Update a document by its doc_id. Args: @@ -260,7 +261,7 @@ async def update_document_by_doc_id(self, doc_id: str, new_text: str): document.updated_at = datetime.now() session.add(document) - async def delete_documents(self, metadata_filters: dict): + async def delete_documents(self, metadata_filters: dict) -> None: """Delete documents by their metadata filters. Args: @@ -351,7 +352,7 @@ def _document_to_dict(self, document: Document) -> dict: else document.updated_at, } - async def tuple_to_dict(self, row): + async def tuple_to_dict(self, row: tuple) -> dict: """Convert a tuple to a dictionary. Args: @@ -371,7 +372,7 @@ async def tuple_to_dict(self, row): "updated_at": row[5], } - async def close(self): + async def close(self) -> None: """Close the connection to the SQLite database.""" if self.engine: await self.engine.dispose() diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index 2c0cc8dfe..d2e7a0f06 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -9,7 +9,7 @@ class EmbeddingStorage: - def __init__(self, dimension: int, path: str | None = None): + def __init__(self, dimension: int, path: str | None = None) -> None: self.dimension = dimension self.path = path self.index = None @@ -19,7 +19,7 @@ def __init__(self, dimension: int, path: str | None = None): base_index = faiss.IndexFlatL2(dimension) self.index = faiss.IndexIDMap(base_index) - async def insert(self, vector: np.ndarray, id: int): + async def insert(self, vector: np.ndarray, id: int) -> None: """插入向量 Args: @@ -36,7 +36,7 @@ async def insert(self, vector: np.ndarray, id: int): self.index.add_with_ids(vector.reshape(1, -1), np.array([id])) await self.save_index() - async def insert_batch(self, vectors: np.ndarray, ids: list[int]): + async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None: """批量插入向量 Args: @@ -67,7 +67,7 @@ async def search(self, vector: np.ndarray, k: int) -> tuple: distances, indices = self.index.search(vector, k) return distances, indices - async def delete(self, ids: list[int]): + async def delete(self, ids: list[int]) -> None: """删除向量 Args: @@ -78,7 +78,7 @@ async def delete(self, ids: list[int]): self.index.remove_ids(id_array) await self.save_index() - async def save_index(self): + async def save_index(self) -> None: """保存索引 Args: diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index 8a21538ec..3a08ed3f7 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -1,3 +1,4 @@ +from types import FunctionType import uuid import time import numpy as np @@ -20,7 +21,7 @@ def __init__( index_store_path: str, embedding_provider: EmbeddingProvider, rerank_provider: RerankProvider | None = None, - ): + ) -> None: self.doc_store_path = doc_store_path self.index_store_path = index_store_path self.embedding_provider = embedding_provider @@ -31,7 +32,7 @@ def __init__( self.embedding_provider = embedding_provider self.rerank_provider = rerank_provider - async def initialize(self): + async def initialize(self) -> None: await self.document_storage.initialize() async def insert( @@ -61,7 +62,7 @@ async def insert_batch( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: FunctionType | None = None, ) -> list[int]: """ 批量插入文本和其对应向量,自动生成 ID 并保持一致性。 @@ -158,7 +159,7 @@ async def retrieve( return top_k_results - async def delete(self, doc_id: str): + async def delete(self, doc_id: str) -> None: """ 删除一条文档块(chunk) """ @@ -172,7 +173,7 @@ async def delete(self, doc_id: str): await self.document_storage.delete_document_by_doc_id(doc_id) await self.embedding_storage.delete([int_id]) - async def close(self): + async def close(self) -> None: await self.document_storage.close() async def count_documents(self, metadata_filter: dict | None = None) -> int: @@ -187,7 +188,7 @@ async def count_documents(self, metadata_filter: dict | None = None) -> int: ) return count - async def delete_documents(self, metadata_filters: dict): + async def delete_documents(self, metadata_filters: dict) -> None: """ 根据元数据过滤器删除文档 """ From 4a54442dc02d3b165cadf81fd03adecae823d1a2 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Tue, 28 Oct 2025 17:32:35 +0800 Subject: [PATCH 12/44] chore(core.kb): ruff rewrite --- astrbot/core/knowledge_base/retrieval/manager.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 278e4da20..05a7719c6 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -6,7 +6,6 @@ import time from dataclasses import dataclass -from typing import List from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion @@ -61,11 +60,11 @@ def __init__( async def retrieve( self, query: str, - kb_ids: List[str], + kb_ids: list[str], kb_id_helper_map: dict[str, KBHelper], top_k_fusion: int = 20, top_m_final: int = 5, - ) -> List[RetrievalResult]: + ) -> list[RetrievalResult]: """混合检索 流程: @@ -188,7 +187,7 @@ async def retrieve( async def _dense_retrieve( self, query: str, - kb_ids: List[str], + kb_ids: list[str], kb_options: dict, ): """稠密检索 (向量相似度) @@ -233,10 +232,10 @@ async def _dense_retrieve( async def _rerank( self, query: str, - results: List[RetrievalResult], + results: list[RetrievalResult], top_k: int, rerank_provider: RerankProvider, - ) -> List[RetrievalResult]: + ) -> list[RetrievalResult]: """Rerank 重排序 Args: From 556cf555ba66f92fb4f0f707f23da9f41633e240 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Tue, 28 Oct 2025 18:10:55 +0800 Subject: [PATCH 13/44] chore(core.kb): add missing type annotations --- astrbot/core/knowledge_base/chunking/base.py | 3 ++- .../knowledge_base/chunking/fixed_size.py | 5 +++-- .../core/knowledge_base/chunking/recursive.py | 5 +++-- astrbot/core/knowledge_base/kb_db_sqlite.py | 5 +++-- astrbot/core/knowledge_base/kb_helper.py | 19 ++++++++++--------- astrbot/core/knowledge_base/kb_mgr.py | 10 +++++----- .../core/knowledge_base/retrieval/manager.py | 4 ++-- .../knowledge_base/retrieval/rank_fusion.py | 2 +- .../retrieval/sparse_retriever.py | 2 +- 9 files changed, 30 insertions(+), 25 deletions(-) diff --git a/astrbot/core/knowledge_base/chunking/base.py b/astrbot/core/knowledge_base/chunking/base.py index 5aaf84ba1..3cf1796dc 100644 --- a/astrbot/core/knowledge_base/chunking/base.py +++ b/astrbot/core/knowledge_base/chunking/base.py @@ -4,6 +4,7 @@ """ from abc import ABC, abstractmethod +from typing import Any class BaseChunker(ABC): @@ -13,7 +14,7 @@ class BaseChunker(ABC): """ @abstractmethod - async def chunk(self, text: str, **kwargs) -> list[str]: + async def chunk(self, text: str, **kwargs: Any) -> list[str]: # noqa:ANN401 """将文本分块 Args: diff --git a/astrbot/core/knowledge_base/chunking/fixed_size.py b/astrbot/core/knowledge_base/chunking/fixed_size.py index c9b35d7d8..f36fb86e4 100644 --- a/astrbot/core/knowledge_base/chunking/fixed_size.py +++ b/astrbot/core/knowledge_base/chunking/fixed_size.py @@ -3,6 +3,7 @@ 按照固定的字符数将文本分块,支持重叠区域。 """ +from typing import Any from .base import BaseChunker @@ -12,7 +13,7 @@ class FixedSizeChunker(BaseChunker): 按照固定的字符数分块,并支持块之间的重叠。 """ - def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50): + def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50) -> None: """初始化分块器 Args: @@ -22,7 +23,7 @@ def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50): self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap - async def chunk(self, text: str, **kwargs) -> list[str]: + async def chunk(self, text: str, **kwargs: Any) -> list[str]: # noqa:ANN401 """固定大小分块 Args: diff --git a/astrbot/core/knowledge_base/chunking/recursive.py b/astrbot/core/knowledge_base/chunking/recursive.py index 21b76cba5..11b31b165 100644 --- a/astrbot/core/knowledge_base/chunking/recursive.py +++ b/astrbot/core/knowledge_base/chunking/recursive.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from typing import Any from .base import BaseChunker @@ -10,7 +11,7 @@ def __init__( length_function: Callable[[str], int] = len, is_separator_regex: bool = False, separators: list[str] | None = None, - ): + ) -> None: """ 初始化递归字符文本分割器 @@ -38,7 +39,7 @@ def __init__( "", # 字符 ] - async def chunk(self, text: str, **kwargs) -> list[str]: + async def chunk(self, text: str, **kwargs: Any) -> list[str]: # noqa:ANN401 """ 递归地将文本分割成块 diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 827d621d3..cc8dd6d0c 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -1,5 +1,6 @@ from contextlib import asynccontextmanager from pathlib import Path +from collections.abc import AsyncGenerator from sqlmodel import col, desc from sqlalchemy import text, func, select, update, delete @@ -45,7 +46,7 @@ def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None: ) @asynccontextmanager - async def get_db(self): + async def get_db(self) -> AsyncGenerator[AsyncSession, None]: """获取数据库会话 用法: @@ -249,7 +250,7 @@ async def get_document_with_metadata(self, doc_id: str) -> dict | None: "knowledge_base": row[1], } - async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB): + async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None: """删除单个文档及其相关数据""" # 在知识库表中删除 async with self.get_db() as session: diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 09b9c9fc8..793e41a8d 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -1,3 +1,4 @@ +from types import FunctionType import uuid import aiofiles import json @@ -24,7 +25,7 @@ def __init__( provider_manager: ProviderManager, kb_root_dir: str, chunker: BaseChunker, - ): + ) -> None: self.kb_db = kb_db self.kb = kb self.prov_mgr = provider_manager @@ -38,7 +39,7 @@ def __init__( self.kb_medias_dir.mkdir(parents=True, exist_ok=True) self.kb_files_dir.mkdir(parents=True, exist_ok=True) - async def initialize(self): + async def initialize(self) -> None: await self._ensure_vec_db() async def get_ep(self) -> EmbeddingProvider: @@ -82,7 +83,7 @@ async def _ensure_vec_db(self) -> FaissVecDB: self.vec_db = vec_db return vec_db - async def delete_vec_db(self): + async def delete_vec_db(self) -> None: """删除知识库的向量数据库和所有相关文件""" import shutil @@ -90,7 +91,7 @@ async def delete_vec_db(self): if self.kb_dir.exists(): shutil.rmtree(self.kb_dir) - async def terminate(self): + async def terminate(self) -> None: if self.vec_db: await self.vec_db.close() @@ -104,7 +105,7 @@ async def upload_document( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: FunctionType | None = None, ) -> KBDocument: """上传并处理文档(带原子性保证和失败清理) @@ -180,7 +181,7 @@ async def upload_document( await progress_callback("chunking", 100, 100) # 阶段3: 生成向量(带进度回调) - async def embedding_progress_callback(current, total): + async def embedding_progress_callback(current: int, total: int) -> None: if progress_callback: await progress_callback("embedding", current, total) @@ -245,7 +246,7 @@ async def get_document(self, doc_id: str) -> KBDocument | None: doc = await self.kb_db.get_document_by_id(doc_id) return doc - async def delete_document(self, doc_id: str): + async def delete_document(self, doc_id: str) -> None: """删除单个文档及其相关数据""" await self.kb_db.delete_document_by_id( doc_id=doc_id, @@ -257,7 +258,7 @@ async def delete_document(self, doc_id: str): ) await self.refresh_kb() - async def delete_chunk(self, chunk_id: str, doc_id: str): + async def delete_chunk(self, chunk_id: str, doc_id: str) -> None: """删除单个文本块及其相关数据""" vec_db: FaissVecDB = self.vec_db # type: ignore await vec_db.delete(chunk_id) @@ -268,7 +269,7 @@ async def delete_chunk(self, chunk_id: str, doc_id: str): await self.refresh_kb() await self.refresh_document(doc_id) - async def refresh_kb(self): + async def refresh_kb(self) -> None: if self.kb: kb = await self.kb_db.get_kb_by_id(self.kb.kb_id) if kb: diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index c1c63d08a..1f0a8b68d 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -28,14 +28,14 @@ class KnowledgeBaseManager: def __init__( self, provider_manager: ProviderManager, - ): + ) -> None: Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True) self.provider_manager = provider_manager self._session_deleted_callback_registered = False self.kb_insts: dict[str, KBHelper] = {} - async def initialize(self): + async def initialize(self) -> None: """初始化知识库模块""" try: logger.info("正在初始化知识库模块...") @@ -60,13 +60,13 @@ async def initialize(self): logger.error(f"知识库模块初始化失败: {e}") logger.error(traceback.format_exc()) - async def _init_kb_database(self): + async def _init_kb_database(self) -> None: self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix()) await self.kb_db.initialize() await self.kb_db.migrate_to_v1() logger.info(f"KnowledgeBase database initialized: {DB_PATH}") - async def load_kbs(self): + async def load_kbs(self) -> None: """加载所有知识库实例""" kb_records = await self.kb_db.list_kbs() for record in kb_records: @@ -269,7 +269,7 @@ def _format_context(self, results: list[RetrievalResult]) -> str: return "\n".join(lines) - async def terminate(self): + async def terminate(self) -> None: """终止所有知识库实例,关闭数据库连接""" for kb_id, kb_helper in self.kb_insts.items(): try: diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 05a7719c6..8d9447f42 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -44,7 +44,7 @@ def __init__( sparse_retriever: SparseRetriever, rank_fusion: RankFusion, kb_db: KBSQLiteDatabase, - ): + ) -> None: """初始化检索管理器 Args: @@ -189,7 +189,7 @@ async def _dense_retrieve( query: str, kb_ids: list[str], kb_options: dict, - ): + ) -> list[Result]: """稠密检索 (向量相似度) 为每个知识库使用独立的向量数据库进行检索,然后合并结果。 diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py index 3ceba4ff8..2e1b90d62 100644 --- a/astrbot/core/knowledge_base/retrieval/rank_fusion.py +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -31,7 +31,7 @@ class RankFusion: - 使用 Reciprocal Rank Fusion (RRF) 算法 """ - def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60): + def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60) -> None: """初始化结果融合器 Args: diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index 315930b3e..03735348a 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -32,7 +32,7 @@ class SparseRetriever: - 使用 BM25 算法计算相关度 """ - def __init__(self, kb_db: KBSQLiteDatabase): + def __init__(self, kb_db: KBSQLiteDatabase) -> None: """初始化稀疏检索器 Args: From a6bd81433824b5604ed874a13345fa4927e05444 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 29 Oct 2025 08:28:41 +0800 Subject: [PATCH 14/44] chore(core.db): ruff rewrite missing file --- astrbot/core/db/po.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 24a05f947..fa6fe8fee 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -9,7 +9,7 @@ UniqueConstraint, Field, ) -from typing import Optional, TypedDict +from typing import TypedDict class PlatformStat(SQLModel, table=True): @@ -50,14 +50,14 @@ class ConversationV2(SQLModel, table=True): ) platform_id: str = Field(nullable=False) user_id: str = Field(nullable=False) - content: Optional[list] = Field(default=None, sa_type=JSON) + content: list | None = Field(default=None, sa_type=JSON) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), sa_column_kwargs={"onupdate": datetime.now(timezone.utc)}, ) - title: Optional[str] = Field(default=None, max_length=255) - persona_id: Optional[str] = Field(default=None) + title: str | None = Field(default=None, max_length=255) + persona_id: str | None = Field(default=None) __table_args__ = ( UniqueConstraint( @@ -80,9 +80,9 @@ class Persona(SQLModel, table=True): ) persona_id: str = Field(max_length=255, nullable=False) system_prompt: str = Field(sa_type=Text, nullable=False) - begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON) + begin_dialogs: list | None = Field(default=None, sa_type=JSON) """a list of strings, each representing a dialog to start with""" - tools: Optional[list] = Field(default=None, sa_type=JSON) + tools: list | None = Field(default=None, sa_type=JSON) """None means use ALL tools for default, empty list means no tools, otherwise a list of tool names.""" created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( @@ -142,10 +142,8 @@ class PlatformMessageHistory(SQLModel, table=True): ) platform_id: str = Field(nullable=False) user_id: str = Field(nullable=False) # An id of group, user in platform - sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform - sender_name: Optional[str] = Field( - default=None - ) # Name of the sender in the platform + sender_id: str | None = Field(default=None) # ID of the sender in the platform + sender_name: str | None = Field(default=None) # Name of the sender in the platform content: dict = Field(sa_type=JSON, nullable=False) # a message chain list created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( From f0b80204c1c414baa9ef9d062347a7ee31cd5b4c Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 29 Oct 2025 10:11:01 +0800 Subject: [PATCH 15/44] chore(core.message): ruff rewrite --- astrbot/core/message/components.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index d9ec4b41b..db1e8fa8f 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -81,7 +81,7 @@ def toString(self): k = "type" if isinstance(v, bool): v = 1 if v else 0 - output += ",%s=%s" % ( + output += ",{}={}".format( k, str(v) .replace("&", "&") @@ -757,11 +757,9 @@ def file(self) -> str: loop = asyncio.get_event_loop() if loop.is_running(): logger.warning( - ( - "不可以在异步上下文中同步等待下载! " - "这个警告通常发生于某些逻辑试图通过 .file 获取文件消息段的文件内容。" - "请使用 await get_file() 代替直接获取 .file 字段" - ) + "不可以在异步上下文中同步等待下载! " + "这个警告通常发生于某些逻辑试图通过 .file 获取文件消息段的文件内容。" + "请使用 await get_file() 代替直接获取 .file 字段" ) return "" else: From ef631eea83d94b06a2e82a4a9987f534ab36567d Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 29 Oct 2025 10:13:17 +0800 Subject: [PATCH 16/44] chore(core.message): add missing type annotations --- astrbot/core/message/components.py | 110 +++++++++---------- astrbot/core/message/message_event_result.py | 36 +++--- 2 files changed, 74 insertions(+), 72 deletions(-) diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index db1e8fa8f..386e944a7 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -72,7 +72,7 @@ class ComponentType(str, Enum): class BaseMessageComponent(BaseModel): type: ComponentType - def toString(self): + def toString(self) -> str: output = f"[CQ:{self.type.lower()}" for k, v in self.__dict__.items(): if k == "type" or v is None: @@ -92,7 +92,7 @@ def toString(self): output += "]" return output - def toDict(self): + def toDict(self) -> dict: data = {} for k, v in self.__dict__.items(): if k == "type" or v is None: @@ -112,20 +112,20 @@ class Plain(BaseMessageComponent): text: str convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息 - def __init__(self, text: str, convert: bool = True, **_): + def __init__(self, text: str, convert: bool = True, **_: T.Any) -> None: # noqa: ANN401 super().__init__(text=text, convert=convert, **_) - def toString(self): # 没有 [CQ:plain] 这种东西,所以直接导出纯文本 + def toString(self) -> str: # 没有 [CQ:plain] 这种东西,所以直接导出纯文本 if not self.convert: return self.text return ( self.text.replace("&", "&").replace("[", "[").replace("]", "]") ) - def toDict(self): + def toDict(self) -> dict: return {"type": "text", "data": {"text": self.text.strip()}} - async def to_dict(self): + async def to_dict(self) -> dict: return {"type": "text", "data": {"text": self.text}} @@ -133,7 +133,7 @@ class Face(BaseMessageComponent): type = ComponentType.Face id: int - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -148,7 +148,7 @@ class Record(BaseMessageComponent): # 额外 path: T.Optional[str] - def __init__(self, file: T.Optional[str], **_): + def __init__(self, file: T.Optional[str], **_: T.Any) -> None: # noqa: ANN401 for k in _.keys(): if k == "url": pass @@ -156,17 +156,17 @@ def __init__(self, file: T.Optional[str], **_): super().__init__(file=file, **_) @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: T.Any) -> "Record": # noqa: ANN401 return Record(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromURL(url: str, **_): + def fromURL(url: str, **_: T.Any) -> "Record": # noqa: ANN401 if url.startswith("http://") or url.startswith("https://"): return Record(file=url, **_) raise Exception("not a valid url") @staticmethod - def fromBase64(bs64_data: str, **_): + def fromBase64(bs64_data: str, **_: T.Any) -> "Record": # noqa: ANN401 return Record(file=f"base64://{bs64_data}", **_) async def convert_to_file_path(self) -> str: @@ -250,15 +250,15 @@ class Video(BaseMessageComponent): # 额外 path: T.Optional[str] = "" - def __init__(self, file: str, **_): + def __init__(self, file: str, **_: T.Any) -> None: # noqa: ANN401 super().__init__(file=file, **_) @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: T.Any) -> "Video": # noqa: ANN401 return Video(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromURL(url: str, **_): + def fromURL(url: str, **_: T.Any) -> "Video": # noqa: ANN401 if url.startswith("http://") or url.startswith("https://"): return Video(file=url, **_) raise Exception("not a valid url") @@ -285,7 +285,7 @@ async def convert_to_file_path(self) -> str: else: raise Exception(f"not a valid file: {url}") - async def register_to_file_service(self): + async def register_to_file_service(self) -> str: """ 将视频注册到文件服务。 @@ -308,7 +308,7 @@ async def register_to_file_service(self): return f"{callback_host}/api/file/{token}" - async def to_dict(self): + async def to_dict(self) -> dict: """需要和 toDict 区分开,toDict 是同步方法""" url_or_path = self.file if url_or_path.startswith("http"): @@ -333,10 +333,10 @@ class At(BaseMessageComponent): qq: T.Union[int, str] # 此处str为all时代表所有人 name: T.Optional[str] = "" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) - def toDict(self): + def toDict(self) -> dict: return { "type": "at", "data": {"qq": str(self.qq)}, @@ -346,28 +346,28 @@ def toDict(self): class AtAll(At): qq: str = "all" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) class RPS(BaseMessageComponent): # TODO type = ComponentType.RPS - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) class Dice(BaseMessageComponent): # TODO type = ComponentType.Dice - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) class Shake(BaseMessageComponent): # TODO type = ComponentType.Shake - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -375,7 +375,7 @@ class Anonymous(BaseMessageComponent): # TODO type = ComponentType.Anonymous ignore: T.Optional[bool] = False - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -386,7 +386,7 @@ class Share(BaseMessageComponent): content: T.Optional[str] = "" image: T.Optional[str] = "" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -395,7 +395,7 @@ class Contact(BaseMessageComponent): # TODO _type: str # type 字段冲突 id: T.Optional[int] = 0 - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -406,7 +406,7 @@ class Location(BaseMessageComponent): # TODO title: T.Optional[str] = "" content: T.Optional[str] = "" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -420,7 +420,7 @@ class Music(BaseMessageComponent): content: T.Optional[str] = "" image: T.Optional[str] = "" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 # for k in _.keys(): # if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]: # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") @@ -440,29 +440,29 @@ class Image(BaseMessageComponent): path: T.Optional[str] = "" file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识 - def __init__(self, file: T.Optional[str], **_): + def __init__(self, file: T.Optional[str], **_: T.Any) -> None: # noqa: ANN401 super().__init__(file=file, **_) @staticmethod - def fromURL(url: str, **_): + def fromURL(url: str, **_: T.Any) -> "Image": # noqa: ANN401 if url.startswith("http://") or url.startswith("https://"): return Image(file=url, **_) raise Exception("not a valid url") @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: T.Any) -> "Image": # noqa: ANN401 return Image(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromBase64(base64: str, **_): + def fromBase64(base64: str, **_: T.Any) -> "Image": # noqa: ANN401 return Image(f"base64://{base64}", **_) @staticmethod - def fromBytes(byte: bytes): + def fromBytes(byte: bytes) -> "Image": return Image.fromBase64(base64.b64encode(byte).decode()) @staticmethod - def fromIO(IO): + def fromIO(IO: T.BinaryIO) -> "Image": return Image.fromBytes(IO.read()) async def convert_to_file_path(self) -> str: @@ -562,7 +562,7 @@ class Reply(BaseMessageComponent): seq: T.Optional[int] = 0 """deprecated""" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -570,16 +570,16 @@ class RedBag(BaseMessageComponent): type = ComponentType.RedBag title: str - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) class Poke(BaseMessageComponent): - type: str = ComponentType.Poke + type = ComponentType.Poke id: T.Optional[int] = 0 qq: T.Optional[int] = 0 - def __init__(self, type: str, **_): + def __init__(self, type: str, **_: T.Any) -> None: # noqa: ANN401 type = f"Poke:{type}" super().__init__(type=type, **_) @@ -588,7 +588,7 @@ class Forward(BaseMessageComponent): type = ComponentType.Forward id: str - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -603,13 +603,13 @@ class Node(BaseMessageComponent): seq: T.Optional[T.Union[str, list]] = "" # 忽略 time: T.Optional[int] = 0 # 忽略 - def __init__(self, content: list[BaseMessageComponent], **_): + def __init__(self, content: list[BaseMessageComponent], **_: T.Any) -> None: # noqa: ANN401 if isinstance(content, Node): # back content = [content] super().__init__(content=content, **_) - async def to_dict(self): + async def to_dict(self) -> dict: data_content = [] for comp in self.content: if isinstance(comp, (Image, Record)): @@ -650,10 +650,10 @@ class Nodes(BaseMessageComponent): type = ComponentType.Nodes nodes: T.List[Node] - def __init__(self, nodes: T.List[Node], **_): + def __init__(self, nodes: T.List[Node], **_: T.Any) -> None: # noqa: ANN401 super().__init__(nodes=nodes, **_) - def toDict(self): + def toDict(self) -> dict: """Deprecated. Use to_dict instead""" ret = { "messages": [], @@ -663,7 +663,7 @@ def toDict(self): ret["messages"].append(d) return ret - async def to_dict(self): + async def to_dict(self) -> dict: """将 Nodes 转换为字典格式,适用于 OneBot JSON 格式""" ret = {"messages": []} for node in self.nodes: @@ -677,7 +677,7 @@ class Xml(BaseMessageComponent): data: str resid: T.Optional[int] = 0 - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -686,7 +686,7 @@ class Json(BaseMessageComponent): data: T.Union[str, dict] resid: T.Optional[int] = 0 - def __init__(self, data, **_): + def __init__(self, data: T.Union[str, dict], **_: T.Any) -> None: # noqa: ANN401 if isinstance(data, dict): data = json.dumps(data) super().__init__(data=data, **_) @@ -703,11 +703,11 @@ class CardImage(BaseMessageComponent): source: T.Optional[str] = "" icon: T.Optional[str] = "" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: T.Any) -> "CardImage": # noqa: ANN401 return CardImage(file=f"file:///{os.path.abspath(path)}", **_) @@ -715,7 +715,7 @@ class TTS(BaseMessageComponent): type = ComponentType.TTS text: str - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) @@ -723,7 +723,7 @@ class Unknown(BaseMessageComponent): type = ComponentType.Unknown text: str - def toString(self): + def toString(self) -> str: return "" @@ -737,7 +737,7 @@ class File(BaseMessageComponent): file_: T.Optional[str] = "" # 本地路径 url: T.Optional[str] = "" # url - def __init__(self, name: str, file: str = "", url: str = ""): + def __init__(self, name: str, file: str = "", url: str = "") -> None: # noqa: ANN401 """文件消息段。""" super().__init__(name=name, file_=file, url=url) @@ -774,7 +774,7 @@ def file(self) -> str: return "" @file.setter - def file(self, value: str): + def file(self, value: str) -> None: """ 向前兼容, 设置file属性, 传入的参数可能是文件路径或URL @@ -807,7 +807,7 @@ async def get_file(self, allow_return_url: bool = False) -> str: return "" - async def _download_file(self): + async def _download_file(self) -> None: """下载文件""" download_dir = os.path.join(get_astrbot_data_path(), "temp") os.makedirs(download_dir, exist_ok=True) @@ -815,7 +815,7 @@ async def _download_file(self): await download_file(self.url, file_path) self.file_ = os.path.abspath(file_path) - async def register_to_file_service(self): + async def register_to_file_service(self) -> str: """ 将文件注册到文件服务。 @@ -838,7 +838,7 @@ async def register_to_file_service(self): return f"{callback_host}/api/file/{token}" - async def to_dict(self): + async def to_dict(self) -> dict: """需要和 toDict 区分开,toDict 是同步方法""" url_or_path = await self.get_file(allow_return_url=True) if url_or_path.startswith("http"): @@ -865,7 +865,7 @@ class WechatEmoji(BaseMessageComponent): md5_len: T.Optional[int] = 0 cdnurl: T.Optional[str] = "" - def __init__(self, **_): + def __init__(self, **_: T.Any) -> None: # noqa: ANN401 super().__init__(**_) diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 7bfdd34c8..d058fedd4 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -1,6 +1,8 @@ import enum -from typing import List, Optional, Union, AsyncGenerator +from typing import Optional + +from collections.abc import AsyncGenerator from dataclasses import dataclass, field from astrbot.core.message.components import ( BaseMessageComponent, @@ -22,12 +24,12 @@ class MessageChain: `use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。 """ - chain: List[BaseMessageComponent] = field(default_factory=list) - use_t2i_: Optional[bool] = None # None 为跟随用户设置 - type: Optional[str] = None + chain: list[BaseMessageComponent] = field(default_factory=list) + use_t2i_: bool | None = None # None 为跟随用户设置 + type: str | None = None """消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。""" - def message(self, message: str): + def message(self, message: str) -> "MessageChain": """添加一条文本消息到消息链 `chain` 中。 Example: @@ -39,7 +41,7 @@ def message(self, message: str): self.chain.append(Plain(message)) return self - def at(self, name: str, qq: Union[str, int]): + def at(self, name: str, qq: str | int) -> "MessageChain": """添加一条 At 消息到消息链 `chain` 中。 Example: @@ -51,7 +53,7 @@ def at(self, name: str, qq: Union[str, int]): self.chain.append(At(name=name, qq=qq)) return self - def at_all(self): + def at_all(self) -> "MessageChain": """添加一条 AtAll 消息到消息链 `chain` 中。 Example: @@ -64,7 +66,7 @@ def at_all(self): return self @deprecated("请使用 message 方法代替。") - def error(self, message: str): + def error(self, message: str) -> "MessageChain": """添加一条错误消息到消息链 `chain` 中 Example: @@ -75,7 +77,7 @@ def error(self, message: str): self.chain.append(Plain(message)) return self - def url_image(self, url: str): + def url_image(self, url: str) -> "MessageChain": """添加一条图片消息(https 链接)到消息链 `chain` 中。 Note: @@ -89,7 +91,7 @@ def url_image(self, url: str): self.chain.append(Image.fromURL(url)) return self - def file_image(self, path: str): + def file_image(self, path: str) -> "MessageChain": """添加一条图片消息(本地文件路径)到消息链 `chain` 中。 Note: @@ -100,7 +102,7 @@ def file_image(self, path: str): self.chain.append(Image.fromFileSystem(path)) return self - def base64_image(self, base64_str: str): + def base64_image(self, base64_str: str) -> "MessageChain": """添加一条图片消息(base64 编码字符串)到消息链 `chain` 中。 Example: @@ -109,7 +111,7 @@ def base64_image(self, base64_str: str): self.chain.append(Image.fromBase64(base64_str)) return self - def use_t2i(self, use_t2i: bool): + def use_t2i(self, use_t2i: bool) -> "MessageChain": """设置是否使用文本转图片服务。 Args: @@ -122,10 +124,10 @@ def get_plain_text(self) -> str: """获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) - def squash_plain(self): + def squash_plain(self) -> Optional["MessageChain"]: """将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。""" if not self.chain: - return + return None new_chain = [] first_plain = None @@ -183,15 +185,15 @@ class MessageEventResult(MessageChain): `result_type` (EventResultType): 事件处理的结果类型。 """ - result_type: Optional[EventResultType] = field( + result_type: EventResultType | None = field( default_factory=lambda: EventResultType.CONTINUE ) - result_content_type: Optional[ResultContentType] = field( + result_content_type: ResultContentType | None = field( default_factory=lambda: ResultContentType.GENERAL_RESULT ) - async_stream: Optional[AsyncGenerator] = None + async_stream: AsyncGenerator | None = None """异步流""" def stop_event(self) -> "MessageEventResult": From 9a6dd403da7f076356c9d3a069ef4fdbcac6b936 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 29 Oct 2025 13:14:54 +0800 Subject: [PATCH 17/44] chore(core.pipeline): ruff rewrite --- astrbot/core/pipeline/content_safety_check/stage.py | 4 ++-- .../content_safety_check/strategies/__init__.py | 3 +-- .../content_safety_check/strategies/strategy.py | 5 ++--- astrbot/core/pipeline/preprocess_stage/stage.py | 5 +++-- .../core/pipeline/process_stage/method/star_request.py | 10 ++++++---- astrbot/core/pipeline/process_stage/stage.py | 6 +++--- astrbot/core/pipeline/rate_limit_check/stage.py | 6 ++++-- astrbot/core/pipeline/respond/stage.py | 5 +++-- astrbot/core/pipeline/result_decorate/stage.py | 5 +++-- astrbot/core/pipeline/scheduler.py | 2 +- astrbot/core/pipeline/session_status_check/stage.py | 5 +++-- astrbot/core/pipeline/stage.py | 9 ++++----- astrbot/core/pipeline/waking_check/stage.py | 4 ++-- astrbot/core/pipeline/whitelist_check/stage.py | 5 +++-- 14 files changed, 40 insertions(+), 34 deletions(-) diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index e6ecd995c..93416eca2 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -1,4 +1,4 @@ -from typing import Union, AsyncGenerator +from collections.abc import AsyncGenerator from ..stage import Stage, register_stage from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -20,7 +20,7 @@ async def initialize(self, ctx: PipelineContext): async def process( self, event: AstrMessageEvent, check_text: str | None = None - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: """检查内容安全""" text = check_text if check_text else event.get_message_str() ok, info = self.strategy_selector.check(text) diff --git a/astrbot/core/pipeline/content_safety_check/strategies/__init__.py b/astrbot/core/pipeline/content_safety_check/strategies/__init__.py index 5701f0634..f0a34e73f 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/__init__.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/__init__.py @@ -1,8 +1,7 @@ import abc -from typing import Tuple class ContentSafetyStrategy(abc.ABC): @abc.abstractmethod - def check(self, content: str) -> Tuple[bool, str]: + def check(self, content: str) -> tuple[bool, str]: raise NotImplementedError diff --git a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py index af960328f..27494d455 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/strategy.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/strategy.py @@ -1,11 +1,10 @@ from . import ContentSafetyStrategy -from typing import List, Tuple from astrbot import logger class StrategySelector: def __init__(self, config: dict) -> None: - self.enabled_strategies: List[ContentSafetyStrategy] = [] + self.enabled_strategies: list[ContentSafetyStrategy] = [] if config["internal_keywords"]["enable"]: from .keywords import KeywordsStrategy @@ -26,7 +25,7 @@ def __init__(self, config: dict) -> None: ) ) - def check(self, content: str) -> Tuple[bool, str]: + def check(self, content: str) -> tuple[bool, str]: for strategy in self.enabled_strategies: ok, info = strategy.check(content) if not ok: diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py index 5c075687f..936e9f3e7 100644 --- a/astrbot/core/pipeline/preprocess_stage/stage.py +++ b/astrbot/core/pipeline/preprocess_stage/stage.py @@ -1,7 +1,8 @@ import traceback import asyncio import random -from typing import Union, AsyncGenerator + +from collections.abc import AsyncGenerator from ..stage import Stage, register_stage from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -21,7 +22,7 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: """在处理事件之前的预处理""" # 平台特异配置:platform_specific..pre_ack_emoji supported = {"telegram", "lark"} diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 42990aae5..7a9d2045f 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -4,7 +4,9 @@ from ...context import PipelineContext, call_handler from ..stage import Stage -from typing import Dict, Any, List, AsyncGenerator, Union +from typing import Any + +from collections.abc import AsyncGenerator from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.message.message_event_result import MessageEventResult from astrbot.core import logger @@ -22,11 +24,11 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: - activated_handlers: List[StarHandlerMetadata] = event.get_extra( + ) -> None | AsyncGenerator[None, None]: + activated_handlers: list[StarHandlerMetadata] = event.get_extra( "activated_handlers" ) - handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra( + handlers_parsed_params: dict[str, dict[str, Any]] = event.get_extra( "handlers_parsed_params" ) if not handlers_parsed_params: diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index f653a9fb9..87906d653 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -1,4 +1,4 @@ -from typing import List, Union, AsyncGenerator +from collections.abc import AsyncGenerator from ..stage import Stage, register_stage from ..context import PipelineContext from .method.llm_request import LLMRequestSubStage @@ -23,9 +23,9 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: """处理事件""" - activated_handlers: List[StarHandlerMetadata] = event.get_extra( + activated_handlers: list[StarHandlerMetadata] = event.get_extra( "activated_handlers" ) # 有插件 Handler 被激活 diff --git a/astrbot/core/pipeline/rate_limit_check/stage.py b/astrbot/core/pipeline/rate_limit_check/stage.py index b36a2fbd0..f3f2b28de 100644 --- a/astrbot/core/pipeline/rate_limit_check/stage.py +++ b/astrbot/core/pipeline/rate_limit_check/stage.py @@ -1,7 +1,9 @@ import asyncio from datetime import datetime, timedelta from collections import defaultdict, deque -from typing import DefaultDict, Deque, Union, AsyncGenerator +from typing import DefaultDict, Deque + +from collections.abc import AsyncGenerator from ..stage import Stage, register_stage from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -43,7 +45,7 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: """ 检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。 diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index dc6a67e2f..d5b0a5b60 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -2,7 +2,8 @@ import asyncio import math import astrbot.core.message.components as Comp -from typing import Union, AsyncGenerator + +from collections.abc import AsyncGenerator from ..stage import register_stage, Stage from ..context import PipelineContext, call_event_hook from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -151,7 +152,7 @@ def _extract_comp( async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: result = event.get_result() if result is None: return diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index c1f893baf..55e0a88f3 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -1,7 +1,8 @@ import re import time import traceback -from typing import AsyncGenerator, Union + +from collections.abc import AsyncGenerator from astrbot.core import file_token_service, html_renderer, logger from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply @@ -72,7 +73,7 @@ async def initialize(self, ctx: PipelineContext): async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: result = event.get_result() if result is None or not result.chain: return diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 7a38ec03f..f09aa7f15 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -1,7 +1,7 @@ from . import STAGES_ORDER from .stage import registered_stages from .context import PipelineContext -from typing import AsyncGenerator +from collections.abc import AsyncGenerator from astrbot.core.platform import AstrMessageEvent from astrbot.core import logger diff --git a/astrbot/core/pipeline/session_status_check/stage.py b/astrbot/core/pipeline/session_status_check/stage.py index 3c451e26a..731441f3a 100644 --- a/astrbot/core/pipeline/session_status_check/stage.py +++ b/astrbot/core/pipeline/session_status_check/stage.py @@ -1,6 +1,7 @@ from ..stage import Stage, register_stage from ..context import PipelineContext -from typing import AsyncGenerator, Union + +from collections.abc import AsyncGenerator from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core import logger @@ -16,7 +17,7 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: # 检查会话是否整体启用 if not SessionServiceManager.is_session_enabled(event.unified_msg_origin): logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。") diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index c4550495a..93122edb8 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -1,10 +1,11 @@ from __future__ import annotations import abc -from typing import List, AsyncGenerator, Union, Type + +from collections.abc import AsyncGenerator from astrbot.core.platform.astr_message_event import AstrMessageEvent from .context import PipelineContext -registered_stages: List[Type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型 +registered_stages: list[type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型 def register_stage(cls): @@ -26,9 +27,7 @@ async def initialize(self, ctx: PipelineContext) -> None: raise NotImplementedError @abc.abstractmethod - async def process( - self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + async def process(self, event: AstrMessageEvent) -> None | AsyncGenerator[None]: """处理事件 Args: diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index de6ad5e35..161b206a9 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -1,4 +1,4 @@ -from typing import AsyncGenerator, Union +from collections.abc import AsyncGenerator from astrbot import logger from astrbot.core.message.components import At, AtAll, Reply @@ -49,7 +49,7 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: if ( self.ignore_bot_self_message and event.get_self_id() == event.get_sender_id() diff --git a/astrbot/core/pipeline/whitelist_check/stage.py b/astrbot/core/pipeline/whitelist_check/stage.py index b140d23ba..ce3616b27 100644 --- a/astrbot/core/pipeline/whitelist_check/stage.py +++ b/astrbot/core/pipeline/whitelist_check/stage.py @@ -1,6 +1,7 @@ from ..stage import Stage, register_stage from ..context import PipelineContext -from typing import AsyncGenerator, Union + +from collections.abc import AsyncGenerator from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.message_type import MessageType from astrbot.core import logger @@ -28,7 +29,7 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent - ) -> Union[None, AsyncGenerator[None, None]]: + ) -> None | AsyncGenerator[None, None]: if not self.enable_whitelist_check: # 白名单检查未启用 return From d85a3953b7b643b27d6f2a5083c5a4b2843066c2 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 29 Oct 2025 16:29:15 +0800 Subject: [PATCH 18/44] chore(core.pipeline): add missing type annotations Co-author: GPT-5-mini --- .../pipeline/content_safety_check/stage.py | 2 +- astrbot/core/pipeline/context_utils.py | 8 ++--- .../process_stage/method/llm_request.py | 30 ++++++++++++------- .../core/pipeline/rate_limit_check/stage.py | 2 +- astrbot/core/pipeline/respond/stage.py | 6 ++-- .../core/pipeline/result_decorate/stage.py | 2 +- astrbot/core/pipeline/scheduler.py | 10 ++++--- astrbot/core/pipeline/stage.py | 2 +- 8 files changed, 36 insertions(+), 26 deletions(-) diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index 93416eca2..d94e97ea9 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -14,7 +14,7 @@ class ContentSafetyCheckStage(Stage): 当前只会检查文本的。 """ - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: config = ctx.astrbot_config["content_safety"] self.strategy_selector = StrategySelector(config) diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index e7ac120b7..10582c69a 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -11,8 +11,8 @@ async def call_handler( event: AstrMessageEvent, handler: T.Callable[..., T.Awaitable[T.Any]], - *args, - **kwargs, + *args: T.Any, # noqa: ANN401 + **kwargs: T.Any, # noqa: ANN401 ) -> T.AsyncGenerator[T.Any, None]: """执行事件处理函数并处理其返回结果 @@ -73,8 +73,8 @@ async def call_handler( async def call_event_hook( event: AstrMessageEvent, hook_type: EventType, - *args, - **kwargs, + *args: T.Any, # noqa: ANN401 + **kwargs: T.Any, # noqa: ANN401 ) -> bool: """调用事件钩子函数 diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 703b3681c..747b5562b 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -6,6 +6,7 @@ import copy import json import traceback +import typing as T from datetime import timedelta from collections.abc import AsyncGenerator from astrbot.core.conversation_mgr import Conversation @@ -50,7 +51,12 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): @classmethod - async def execute(cls, tool, run_context, **tool_args): + async def execute( + cls, + tool: FunctionTool, + run_context: ContextWrapper[AstrAgentContext], + **tool_args: T.Any, # noqa:ANN401 + ) -> AsyncGenerator[T.Union[None, "mcp.types.CallToolResult"], None]: """执行函数调用。 Args: @@ -82,8 +88,8 @@ async def _execute_handoff( cls, tool: HandoffTool, run_context: ContextWrapper[AstrAgentContext], - **tool_args, - ): + **tool_args: T.Any, # noqa: ANN401 + ) -> AsyncGenerator[T.Union[None, "mcp.types.CallToolResult"], None]: input_ = tool_args.get("input", "agent") agent_runner = AgentRunner() @@ -172,8 +178,8 @@ async def _execute_local( cls, tool: FunctionTool, run_context: ContextWrapper[AstrAgentContext], - **tool_args, - ): + **tool_args: T.Any, # noqa: ANN401 + ) -> AsyncGenerator[T.Union[None, "mcp.types.CallToolResult"], None]: if not run_context.event: raise ValueError("Event must be provided for local function tools.") @@ -220,8 +226,8 @@ async def _execute_mcp( cls, tool: FunctionTool, run_context: ContextWrapper[AstrAgentContext], - **tool_args, - ): + **tool_args: T.Any, # noqa: ANN401 + ) -> AsyncGenerator[T.Union[None, "mcp.types.CallToolResult"], None]: if not tool.mcp_client: raise ValueError("MCP client is not available for MCP function tools.") @@ -241,7 +247,9 @@ async def _execute_mcp( class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): - async def on_agent_done(self, run_context, llm_response): + async def on_agent_done( + self, run_context: ContextWrapper[AstrAgentContext], llm_response: LLMResponse + ) -> None: # 执行事件钩子 await call_event_hook( run_context.event, EventType.OnLLMResponseEvent, llm_response @@ -338,7 +346,7 @@ async def initialize(self, ctx: PipelineContext) -> None: self.conv_manager = ctx.plugin_manager.context.conversation_manager - def _select_provider(self, event: AstrMessageEvent): + def _select_provider(self, event: AstrMessageEvent) -> Provider | None: """选择使用的 LLM 提供商""" sel_provider = event.get_extra("selected_provider") _ctx = self.ctx.plugin_manager.context @@ -565,7 +573,7 @@ async def process( async def _handle_webchat( self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider - ): + ) -> None: """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title""" if not req.conversation: return @@ -623,7 +631,7 @@ async def _save_to_history( event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse | None, - ): + ) -> None: if ( not req or not req.conversation diff --git a/astrbot/core/pipeline/rate_limit_check/stage.py b/astrbot/core/pipeline/rate_limit_check/stage.py index f3f2b28de..e16a1ca48 100644 --- a/astrbot/core/pipeline/rate_limit_check/stage.py +++ b/astrbot/core/pipeline/rate_limit_check/stage.py @@ -20,7 +20,7 @@ class RateLimitStage(Stage): 如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。 """ - def __init__(self): + def __init__(self) -> None: # 存储每个会话的请求时间队列 self.event_timestamps: DefaultDict[str, Deque[datetime]] = defaultdict(deque) # 为每个会话设置一个锁,避免并发冲突 diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index d5b0a5b60..98afecefe 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -35,7 +35,7 @@ class RespondStage(Stage): Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情 } - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config self.platform_settings: dict = self.config.get("platform_settings", {}) @@ -93,7 +93,7 @@ async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float: # random return random.uniform(self.interval[0], self.interval[1]) - async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]): + async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bool: """检查消息链是否为空 Args: @@ -135,7 +135,7 @@ def _extract_comp( raw_chain: list[BaseMessageComponent], extract_types: set[ComponentType], modify_raw_chain: bool = True, - ): + ) -> list[BaseMessageComponent]: extracted = [] if modify_raw_chain: remaining = [] diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 55e0a88f3..9dddd6d1a 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -19,7 +19,7 @@ @register_stage class ResultDecorateStage(Stage): - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"] self.reply_with_mention = ctx.astrbot_config["platform_settings"][ diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index f09aa7f15..1b2f7dccc 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -9,21 +9,23 @@ class PipelineScheduler: """管道调度器,负责调度各个阶段的执行""" - def __init__(self, context: PipelineContext): + def __init__(self, context: PipelineContext) -> None: registered_stages.sort( key=lambda x: STAGES_ORDER.index(x.__name__) ) # 按照顺序排序 self.ctx = context # 上下文对象 self.stages = [] # 存储阶段实例 - async def initialize(self): + async def initialize(self) -> None: """初始化管道调度器时, 初始化所有阶段""" for stage_cls in registered_stages: stage_instance = stage_cls() # 创建实例 await stage_instance.initialize(self.ctx) self.stages.append(stage_instance) - async def _process_stages(self, event: AstrMessageEvent, from_stage=0): + async def _process_stages( + self, event: AstrMessageEvent, from_stage: int = 0 + ) -> None: """依次执行各个阶段 Args: @@ -65,7 +67,7 @@ async def _process_stages(self, event: AstrMessageEvent, from_stage=0): logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") break - async def execute(self, event: AstrMessageEvent): + async def execute(self, event: AstrMessageEvent) -> None: """执行 pipeline Args: diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index 93122edb8..fb7dec9e9 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -8,7 +8,7 @@ registered_stages: list[type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型 -def register_stage(cls): +def register_stage(cls: type[Stage]) -> type[Stage]: """一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类""" registered_stages.append(cls) return cls From 9da5d1d1d9eb99c2118fc4676fbf3c209167b633 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Thu, 30 Oct 2025 07:05:08 +0800 Subject: [PATCH 19/44] chore(core.platform): ruff rewrite --- astrbot/core/platform/astr_message_event.py | 18 +-- astrbot/core/platform/astrbot_message.py | 7 +- astrbot/core/platform/manager.py | 3 +- astrbot/core/platform/platform.py | 4 +- astrbot/core/platform/register.py | 5 +- .../aiocqhttp/aiocqhttp_message_event.py | 5 +- .../aiocqhttp/aiocqhttp_platform_adapter.py | 4 +- .../platform/sources/discord/components.py | 5 +- .../discord/discord_platform_adapter.py | 4 +- .../sources/discord/discord_platform_event.py | 7 +- .../core/platform/sources/lark/lark_event.py | 3 +- .../sources/misskey/misskey_adapter.py | 36 ++--- .../platform/sources/misskey/misskey_api.py | 130 +++++++++--------- .../platform/sources/misskey/misskey_event.py | 2 +- .../platform/sources/misskey/misskey_utils.py | 76 +++++----- .../qqofficial/qqofficial_message_event.py | 3 +- .../qqofficial/qqofficial_platform_adapter.py | 7 +- .../platform/sources/satori/satori_adapter.py | 19 ++- astrbot/core/platform/sources/slack/client.py | 7 +- .../platform/sources/slack/slack_adapter.py | 4 +- .../platform/sources/slack/slack_event.py | 2 +- .../sources/webchat/webchat_adapter.py | 4 +- .../wechatpadpro/wechatpadpro_adapter.py | 7 +- .../core/platform/sources/wecom/wecom_kf.py | 2 - .../sources/wecom_ai_bot/WXBizJsonMsgCrypt.py | 5 +- .../platform/sources/wecom_ai_bot/ierror.py | 1 - .../sources/wecom_ai_bot/wecomai_adapter.py | 14 +- .../sources/wecom_ai_bot/wecomai_api.py | 20 +-- .../sources/wecom_ai_bot/wecomai_queue_mgr.py | 14 +- .../sources/wecom_ai_bot/wecomai_server.py | 9 +- .../sources/wecom_ai_bot/wecomai_utils.py | 6 +- 31 files changed, 219 insertions(+), 214 deletions(-) diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 3a4b8c128..448a0733f 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -4,7 +4,9 @@ import hashlib import uuid -from typing import List, Union, Optional, AsyncGenerator, Any +from typing import Any + +from collections.abc import AsyncGenerator from astrbot import logger from astrbot.core.db.po import Conversation @@ -90,7 +92,7 @@ def get_message_str(self) -> str: """ return self.message_str - def _outline_chain(self, chain: Optional[List[BaseMessageComponent]]) -> str: + def _outline_chain(self, chain: list[BaseMessageComponent] | None) -> str: outline = "" if not chain: return outline @@ -127,7 +129,7 @@ def get_message_outline(self) -> str: """ return self._outline_chain(self.message_obj.message) - def get_messages(self) -> List[BaseMessageComponent]: + def get_messages(self) -> list[BaseMessageComponent]: """ 获取消息链。 """ @@ -240,7 +242,7 @@ async def _pre_send(self): async def _post_send(self): """调度器会在执行 send() 后调用该方法 deprecated in v3.5.18""" - def set_result(self, result: Union[MessageEventResult, str]): + def set_result(self, result: MessageEventResult | str): """设置消息事件的结果。 Note: @@ -344,7 +346,7 @@ def image_result(self, url_or_path: str) -> MessageEventResult: return MessageEventResult().url_image(url_or_path) return MessageEventResult().file_image(url_or_path) - def chain_result(self, chain: List[BaseMessageComponent]) -> MessageEventResult: + def chain_result(self, chain: list[BaseMessageComponent]) -> MessageEventResult: """ 创建一个空的消息事件结果,包含指定的消息链。 """ @@ -359,8 +361,8 @@ def request_llm( prompt: str, func_tool_manager=None, session_id: str = None, - image_urls: List[str] = [], - contexts: List = [], + image_urls: list[str] = [], + contexts: list = [], system_prompt: str = "", conversation: Conversation = None, ) -> ProviderRequest: @@ -427,7 +429,7 @@ async def react(self, emoji: str): """ await self.send(MessageChain([Plain(emoji)])) - async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]: + async def get_group(self, group_id: str = None, **kwargs) -> Group | None: """获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。 适配情况: diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 1808c2911..33ae120e0 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -1,5 +1,4 @@ import time -from typing import List from dataclasses import dataclass from astrbot.core.message.components import BaseMessageComponent from .message_type import MessageType @@ -28,9 +27,9 @@ class Group: """群头像""" group_owner: str = None """群主 id""" - group_admins: List[str] = None + group_admins: list[str] = None """群管理员 id""" - members: List[MessageMember] = None + members: list[MessageMember] = None """所有群成员""" def __str__(self): @@ -57,7 +56,7 @@ class AstrBotMessage: message_id: str # 消息id group: Group # 群组 sender: MessageMember # 发送者 - message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 + message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 message_str: str # 最直观的纯文本消息字符串 raw_message: object timestamp: int # 消息时间戳 diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 7090c669c..436185296 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -2,7 +2,6 @@ import asyncio from astrbot.core.config.astrbot_config import AstrBotConfig from .platform import Platform -from typing import List from asyncio import Queue from .register import platform_cls_map from astrbot.core import logger @@ -12,7 +11,7 @@ class PlatformManager: def __init__(self, config: AstrBotConfig, event_queue: Queue): - self.platform_insts: List[Platform] = [] + self.platform_insts: list[Platform] = [] """加载的 Platform 的实例""" self._inst_map = {} diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index c109f29b4..31afba005 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -1,6 +1,8 @@ import abc import uuid -from typing import Awaitable, Any +from typing import Any + +from collections.abc import Awaitable from asyncio import Queue from .platform_metadata import PlatformMetadata from .astr_message_event import AstrMessageEvent diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py index 97c33a43e..b2a43cf15 100644 --- a/astrbot/core/platform/register.py +++ b/astrbot/core/platform/register.py @@ -1,10 +1,9 @@ -from typing import List, Dict, Type from .platform_metadata import PlatformMetadata from astrbot.core import logger -platform_registry: List[PlatformMetadata] = [] +platform_registry: list[PlatformMetadata] = [] """维护了通过装饰器注册的平台适配器""" -platform_cls_map: Dict[str, Type] = {} +platform_cls_map: dict[str, type] = {} """维护了平台适配器名称和适配器类的映射""" diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index b8bb723d5..50dc4e070 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -1,6 +1,7 @@ import asyncio import re -from typing import AsyncGenerator, Dict, List + +from collections.abc import AsyncGenerator from aiocqhttp import CQHttp, Event from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import ( @@ -198,7 +199,7 @@ async def get_group(self, group_id=None, **kwargs): group_id=group_id, ) - members: List[Dict] = await self.bot.call_action( + members: list[dict] = await self.bot.call_action( "get_group_member_list", group_id=group_id, ) diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index d1992b6c3..2bbd885e9 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -3,7 +3,9 @@ import logging import uuid import itertools -from typing import Awaitable, Any +from typing import Any + +from collections.abc import Awaitable from aiocqhttp import CQHttp, Event from astrbot.api.platform import ( Platform, diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index 07e712161..79b89704a 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -1,5 +1,4 @@ import discord -from typing import List from astrbot.api.message_components import BaseMessageComponent @@ -18,7 +17,7 @@ def __init__( thumbnail: str = None, image: str = None, footer: str = None, - fields: List[dict] = None, + fields: list[dict] = None, ): self.title = title self.description = description @@ -96,7 +95,7 @@ class DiscordView(BaseMessageComponent): type: str = "discord_view" def __init__( - self, components: List[BaseMessageComponent] = None, timeout: float = None + self, components: list[BaseMessageComponent] = None, timeout: float = None ): self.components = components or [] self.timeout = timeout diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 6764eda61..fa82d0da7 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -19,7 +19,7 @@ from .client import DiscordBotClient from .discord_platform_event import DiscordPlatformEvent -from typing import Any, Tuple +from typing import Any from astrbot.core.star.filter.command import CommandFilter from astrbot.core.star.filter.command_group import CommandGroupFilter from astrbot.core.star.star import star_map @@ -420,7 +420,7 @@ async def dynamic_callback(ctx: discord.ApplicationContext, params: str = None): @staticmethod def _extract_command_info( event_filter: Any, handler_metadata: StarHandlerMetadata - ) -> Tuple[str, str, CommandFilter] | None: + ) -> tuple[str, str, CommandFilter] | None: """从事件过滤器中提取指令信息""" cmd_name = None # is_group = False diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 2c8d055fc..46a5df51b 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -3,7 +3,6 @@ import base64 from io import BytesIO from pathlib import Path -from typing import Optional import sys from astrbot.api.event import AstrMessageEvent, MessageChain @@ -41,7 +40,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: DiscordBotClient, - interaction_followup_webhook: Optional[discord.Webhook] = None, + interaction_followup_webhook: discord.Webhook | None = None, ): super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -98,7 +97,7 @@ async def send(self, message: MessageChain): await super().send(message) - async def _get_channel(self) -> Optional[discord.abc.Messageable]: + async def _get_channel(self) -> discord.abc.Messageable | None: """获取当前事件对应的频道对象""" try: channel_id = int(self.session_id) @@ -112,7 +111,7 @@ async def _get_channel(self) -> Optional[discord.abc.Messageable]: async def _parse_to_discord( self, message: MessageChain, - ) -> tuple[str, list[discord.File], Optional[discord.ui.View], list[discord.Embed]]: + ) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]: """将 MessageChain 解析为 Discord 发送所需的内容""" content = "" files = [] diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index 2174c497c..d9e96566e 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -4,7 +4,6 @@ import base64 import lark_oapi as lark from io import BytesIO -from typing import List from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import Plain, Image as AstrBotImage, At from astrbot.core.utils.io import download_image_by_url @@ -21,7 +20,7 @@ def __init__( self.bot = bot @staticmethod - async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> List: + async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> list: ret = [] _stage = [] for comp in message.chain: diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index 981d05c82..8d8cc18b8 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -1,6 +1,8 @@ import asyncio import random -from typing import Dict, Any, Optional, Awaitable, List +from typing import Any + +from collections.abc import Awaitable from astrbot.api import logger from astrbot.api.event import MessageChain @@ -87,7 +89,7 @@ def __init__( self.unique_session = platform_settings["unique_session"] - self.api: Optional[MisskeyAPI] = None + self.api: MisskeyAPI | None = None self._running = False self.client_self_id = "" self._bot_username = "" @@ -168,7 +170,7 @@ async def _send_text_only_message( from .misskey_utils import extract_user_id_from_session_id user_id = extract_user_id_from_session_id(session_id) - payload: Dict[str, Any] = {"toUserId": user_id, "text": text} + payload: dict[str, Any] = {"toUserId": user_id, "text": text} await self.api.send_message(payload) elif session_id and is_valid_room_session_id(session_id): from .misskey_utils import extract_room_id_from_session_id @@ -180,7 +182,7 @@ async def _send_text_only_message( return await super().send_by_session(session, message_chain) def _process_poll_data( - self, message: AstrBotMessage, poll: Dict[str, Any], message_parts: List[str] + self, message: AstrBotMessage, poll: dict[str, Any], message_parts: list[str] ): """处理投票数据,将其添加到消息中""" try: @@ -196,7 +198,7 @@ def _process_poll_data( message.message.append(Comp.Plain(poll_text)) message_parts.append(poll_text) - def _extract_additional_fields(self, session, message_chain) -> Dict[str, Any]: + def _extract_additional_fields(self, session, message_chain) -> dict[str, Any]: """从会话和消息链中提取额外字段""" fields = {"cw": None, "poll": None, "renote_id": None, "channel_id": None} @@ -267,7 +269,7 @@ async def _start_websocket_connection(self): await asyncio.sleep(sleep_time) backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff) - async def _handle_notification(self, data: Dict[str, Any]): + async def _handle_notification(self, data: dict[str, Any]): try: notification_type = data.get("type") logger.debug( @@ -291,7 +293,7 @@ async def _handle_notification(self, data: Dict[str, Any]): except Exception as e: logger.error(f"[Misskey] 处理通知失败: {e}") - async def _handle_chat_message(self, data: Dict[str, Any]): + async def _handle_chat_message(self, data: dict[str, Any]): try: sender_id = str( data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "") @@ -326,13 +328,13 @@ async def _handle_chat_message(self, data: Dict[str, Any]): except Exception as e: logger.error(f"[Misskey] 处理聊天消息失败: {e}") - async def _debug_handler(self, data: Dict[str, Any]): + async def _debug_handler(self, data: dict[str, Any]): event_type = data.get("type", "unknown") logger.debug( f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}" ) - def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool: + def _is_bot_mentioned(self, note: dict[str, Any]) -> bool: text = note.get("text", "") if not text: return False @@ -400,8 +402,8 @@ async def send_by_session( if len(text) > self.max_message_length: text = text[: self.max_message_length] + "..." - file_ids: List[str] = [] - fallback_urls: List[str] = [] + file_ids: list[str] = [] + fallback_urls: list[str] = [] if not self.enable_file_upload: return await self._send_text_only_message( @@ -417,7 +419,7 @@ async def send_by_session( upload_concurrency = min(upload_concurrency, MAX_UPLOAD_CONCURRENCY) sem = asyncio.Semaphore(upload_concurrency) - async def _upload_comp(comp) -> Optional[object]: + async def _upload_comp(comp) -> object | None: """组件上传函数:处理 URL(下载后上传)或本地文件(直接上传)""" from .misskey_utils import ( resolve_component_url_or_path, @@ -540,7 +542,7 @@ async def _upload_comp(comp) -> Optional[object]: if fallback_urls: appended = "\n" + "\n".join(fallback_urls) text = (text or "") + appended - payload: Dict[str, Any] = {"toRoomId": room_id, "text": text} + payload: dict[str, Any] = {"toRoomId": room_id, "text": text} if file_ids: payload["fileIds"] = file_ids await self.api.send_room_message(payload) @@ -555,7 +557,7 @@ async def _upload_comp(comp) -> Optional[object]: if fallback_urls: appended = "\n" + "\n".join(fallback_urls) text = (text or "") + appended - payload: Dict[str, Any] = {"toUserId": user_id, "text": text} + payload: dict[str, Any] = {"toUserId": user_id, "text": text} if file_ids: # 聊天消息只支持单个文件,使用 fileId 而不是 fileIds payload["fileId"] = file_ids[0] @@ -610,7 +612,7 @@ async def _upload_comp(comp) -> Optional[object]: return await super().send_by_session(session, message_chain) - async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage: + async def convert_message(self, raw_data: dict[str, Any]) -> AstrBotMessage: """将 Misskey 贴文数据转换为 AstrBotMessage 对象""" sender_info = extract_sender_info(raw_data, is_chat=False) message = create_base_message( @@ -652,7 +654,7 @@ async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage: ) return message - async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage: + async def convert_chat_message(self, raw_data: dict[str, Any]) -> AstrBotMessage: """将 Misskey 聊天消息数据转换为 AstrBotMessage 对象""" sender_info = extract_sender_info(raw_data, is_chat=True) message = create_base_message( @@ -676,7 +678,7 @@ async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage message.message_str = raw_text if raw_text else "" return message - async def convert_room_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage: + async def convert_room_message(self, raw_data: dict[str, Any]) -> AstrBotMessage: """将 Misskey 群聊消息数据转换为 AstrBotMessage 对象""" sender_info = extract_sender_info(raw_data, is_chat=True) room_id = raw_data.get("toRoomId", "") diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index 4b920508f..dae02b183 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -1,7 +1,9 @@ import json import random import asyncio -from typing import Any, Optional, Dict, List, Callable, Awaitable +from typing import Any + +from collections.abc import Callable, Awaitable import uuid try: @@ -54,11 +56,11 @@ class StreamingClient: def __init__(self, instance_url: str, access_token: str): self.instance_url = instance_url.rstrip("/") self.access_token = access_token - self.websocket: Optional[Any] = None + self.websocket: Any | None = None self.is_connected = False - self.message_handlers: Dict[str, Callable] = {} - self.channels: Dict[str, str] = {} - self.desired_channels: Dict[str, Optional[Dict]] = {} + self.message_handlers: dict[str, Callable] = {} + self.channels: dict[str, str] = {} + self.desired_channels: dict[str, dict | None] = {} self._running = False self._last_pong = None @@ -104,7 +106,7 @@ async def disconnect(self): logger.info("[Misskey WebSocket] 连接已断开") async def subscribe_channel( - self, channel_type: str, params: Optional[Dict] = None + self, channel_type: str, params: dict | None = None ) -> str: if not self.is_connected or not self.websocket: raise WebSocketError("WebSocket 未连接") @@ -136,7 +138,7 @@ async def unsubscribe_channel(self, channel_id: str): self.desired_channels.pop(channel_type, None) def add_message_handler( - self, event_type: str, handler: Callable[[Dict], Awaitable[None]] + self, event_type: str, handler: Callable[[dict], Awaitable[None]] ): self.message_handlers[event_type] = handler @@ -188,11 +190,11 @@ async def listen(self): except Exception: pass - async def _handle_message(self, data: Dict[str, Any]): + async def _handle_message(self, data: dict[str, Any]): message_type = data.get("type") body = data.get("body", {}) - def _build_channel_summary(message_type: Optional[str], body: Any) -> str: + def _build_channel_summary(message_type: str | None, body: Any) -> str: try: if not isinstance(body, dict): return f"[Misskey WebSocket] 收到消息类型: {message_type}" @@ -334,12 +336,12 @@ def __init__( allow_insecure_downloads: bool = False, download_timeout: int = 15, chunk_size: int = 64 * 1024, - max_download_bytes: Optional[int] = None, + max_download_bytes: int | None = None, ): self.instance_url = instance_url.rstrip("/") self.access_token = access_token - self._session: Optional[aiohttp.ClientSession] = None - self.streaming: Optional[StreamingClient] = None + self._session: aiohttp.ClientSession | None = None + self.streaming: StreamingClient | None = None # download options self.allow_insecure_downloads = allow_insecure_downloads self.download_timeout = download_timeout @@ -456,7 +458,7 @@ async def _process_response( retryable_exceptions=(APIConnectionError, APIRateLimitError), ) async def _make_request( - self, endpoint: str, data: Optional[Dict[str, Any]] = None + self, endpoint: str, data: dict[str, Any] | None = None ) -> Any: url = f"{self.instance_url}/api/{endpoint}" payload = {"i": self.access_token} @@ -472,24 +474,24 @@ async def _make_request( async def create_note( self, - text: Optional[str] = None, + text: str | None = None, visibility: str = "public", - reply_id: Optional[str] = None, - visible_user_ids: Optional[List[str]] = None, - file_ids: Optional[List[str]] = None, + reply_id: str | None = None, + visible_user_ids: list[str] | None = None, + file_ids: list[str] | None = None, local_only: bool = False, - cw: Optional[str] = None, - poll: Optional[Dict[str, Any]] = None, - renote_id: Optional[str] = None, - channel_id: Optional[str] = None, - reaction_acceptance: Optional[str] = None, - no_extract_mentions: Optional[bool] = None, - no_extract_hashtags: Optional[bool] = None, - no_extract_emojis: Optional[bool] = None, - media_ids: Optional[List[str]] = None, - ) -> Dict[str, Any]: + cw: str | None = None, + poll: dict[str, Any] | None = None, + renote_id: str | None = None, + channel_id: str | None = None, + reaction_acceptance: str | None = None, + no_extract_mentions: bool | None = None, + no_extract_hashtags: bool | None = None, + no_extract_emojis: bool | None = None, + media_ids: list[str] | None = None, + ) -> dict[str, Any]: """Create a note (wrapper for notes/create). All additional fields are optional and passed through to the API.""" - data: Dict[str, Any] = {} + data: dict[str, Any] = {} if text is not None: data["text"] = text @@ -537,9 +539,9 @@ async def create_note( async def upload_file( self, file_path: str, - name: Optional[str] = None, - folder_id: Optional[str] = None, - ) -> Dict[str, Any]: + name: str | None = None, + folder_id: str | None = None, + ) -> dict[str, Any]: """Upload a file to Misskey drive/files/create and return a dict containing id and raw result.""" if not file_path: raise APIError("No file path provided for upload") @@ -574,7 +576,7 @@ async def upload_file( logger.error(f"[Misskey API] 文件上传网络错误: {e}") raise APIConnectionError(f"Upload failed: {e}") from e - async def find_files_by_hash(self, md5_hash: str) -> List[Dict[str, Any]]: + async def find_files_by_hash(self, md5_hash: str) -> list[dict[str, Any]]: """Find files by MD5 hash""" if not md5_hash: raise APIError("No MD5 hash provided for find-by-hash") @@ -593,13 +595,13 @@ async def find_files_by_hash(self, md5_hash: str) -> List[Dict[str, Any]]: raise async def find_files_by_name( - self, name: str, folder_id: Optional[str] = None - ) -> List[Dict[str, Any]]: + self, name: str, folder_id: str | None = None + ) -> list[dict[str, Any]]: """Find files by name""" if not name: raise APIError("No name provided for find") - data: Dict[str, Any] = {"name": name} + data: dict[str, Any] = {"name": name} if folder_id: data["folderId"] = folder_id @@ -617,11 +619,11 @@ async def find_files_by_name( async def find_files( self, limit: int = 10, - folder_id: Optional[str] = None, - type: Optional[str] = None, - ) -> List[Dict[str, Any]]: + folder_id: str | None = None, + type: str | None = None, + ) -> list[dict[str, Any]]: """List files with optional filters""" - data: Dict[str, Any] = {"limit": limit} + data: dict[str, Any] = {"limit": limit} if folder_id is not None: data["folderId"] = folder_id if type is not None: @@ -642,7 +644,7 @@ async def find_files( async def _download_with_existing_session( self, url: str, ssl_verify: bool = True - ) -> Optional[bytes]: + ) -> bytes | None: """使用现有会话下载文件""" if not (hasattr(self, "session") and self.session): raise APIConnectionError("No existing session available") @@ -656,7 +658,7 @@ async def _download_with_existing_session( async def _download_with_temp_session( self, url: str, ssl_verify: bool = True - ) -> Optional[bytes]: + ) -> bytes | None: """使用临时会话下载文件""" connector = aiohttp.TCPConnector(ssl=ssl_verify) async with aiohttp.ClientSession(connector=connector) as temp_session: @@ -670,11 +672,11 @@ async def _download_with_temp_session( async def upload_and_find_file( self, url: str, - name: Optional[str] = None, - folder_id: Optional[str] = None, + name: str | None = None, + folder_id: str | None = None, max_wait_time: float = 30.0, check_interval: float = 2.0, - ) -> Optional[Dict[str, Any]]: + ) -> dict[str, Any] | None: """ 简化的文件上传:尝试 URL 上传,失败则下载后本地上传 @@ -732,13 +734,13 @@ async def upload_and_find_file( return None - async def get_current_user(self) -> Dict[str, Any]: + async def get_current_user(self) -> dict[str, Any]: """获取当前用户信息""" return await self._make_request("i", {}) async def send_message( - self, user_id_or_payload: Any, text: Optional[str] = None - ) -> Dict[str, Any]: + self, user_id_or_payload: Any, text: str | None = None + ) -> dict[str, Any]: """发送聊天消息。 Accepts either (user_id: str, text: str) or a single dict payload prepared by caller. @@ -754,8 +756,8 @@ async def send_message( return result async def send_room_message( - self, room_id_or_payload: Any, text: Optional[str] = None - ) -> Dict[str, Any]: + self, room_id_or_payload: Any, text: str | None = None + ) -> dict[str, Any]: """发送房间消息。 Accepts either (room_id: str, text: str) or a single dict payload. @@ -771,10 +773,10 @@ async def send_room_message( return result async def get_messages( - self, user_id: str, limit: int = 10, since_id: Optional[str] = None - ) -> List[Dict[str, Any]]: + self, user_id: str, limit: int = 10, since_id: str | None = None + ) -> list[dict[str, Any]]: """获取聊天消息历史""" - data: Dict[str, Any] = {"userId": user_id, "limit": limit} + data: dict[str, Any] = {"userId": user_id, "limit": limit} if since_id: data["sinceId"] = since_id @@ -785,10 +787,10 @@ async def get_messages( return [] async def get_mentions( - self, limit: int = 10, since_id: Optional[str] = None - ) -> List[Dict[str, Any]]: + self, limit: int = 10, since_id: str | None = None + ) -> list[dict[str, Any]]: """获取提及通知""" - data: Dict[str, Any] = {"limit": limit} + data: dict[str, Any] = {"limit": limit} if since_id: data["sinceId"] = since_id data["includeTypes"] = ["mention", "reply", "quote"] @@ -806,11 +808,11 @@ async def send_message_with_media( self, message_type: str, target_id: str, - text: Optional[str] = None, - media_urls: Optional[List[str]] = None, - local_files: Optional[List[str]] = None, + text: str | None = None, + media_urls: list[str] | None = None, + local_files: list[str] | None = None, **kwargs, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ 通用消息发送函数:统一处理文本+媒体发送 @@ -846,7 +848,7 @@ async def send_message_with_media( message_type, target_id, text, file_ids, **kwargs ) - async def _process_media_urls(self, urls: List[str]) -> List[str]: + async def _process_media_urls(self, urls: list[str]) -> list[str]: """处理远程媒体文件URL列表,返回文件ID列表""" file_ids = [] for url in urls: @@ -863,7 +865,7 @@ async def _process_media_urls(self, urls: List[str]) -> List[str]: continue return file_ids - async def _process_local_files(self, file_paths: List[str]) -> List[str]: + async def _process_local_files(self, file_paths: list[str]) -> list[str]: """处理本地文件路径列表,返回文件ID列表""" file_ids = [] for file_path in file_paths: @@ -883,10 +885,10 @@ async def _dispatch_message( self, message_type: str, target_id: str, - text: Optional[str], - file_ids: List[str], + text: str | None, + file_ids: list[str], **kwargs, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """根据消息类型分发到对应的发送方法""" if message_type == "chat": # 聊天消息使用 fileId (单数) diff --git a/astrbot/core/platform/sources/misskey/misskey_event.py b/astrbot/core/platform/sources/misskey/misskey_event.py index cd737f78e..0b4a7b876 100644 --- a/astrbot/core/platform/sources/misskey/misskey_event.py +++ b/astrbot/core/platform/sources/misskey/misskey_event.py @@ -1,6 +1,6 @@ import asyncio import re -from typing import AsyncGenerator +from collections.abc import AsyncGenerator from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import PlatformMetadata, AstrBotMessage diff --git a/astrbot/core/platform/sources/misskey/misskey_utils.py b/astrbot/core/platform/sources/misskey/misskey_utils.py index ebc95d8d7..135744167 100644 --- a/astrbot/core/platform/sources/misskey/misskey_utils.py +++ b/astrbot/core/platform/sources/misskey/misskey_utils.py @@ -1,6 +1,6 @@ """Misskey 平台适配器通用工具函数""" -from typing import Dict, Any, List, Tuple, Optional, Union +from typing import Any import astrbot.api.message_components as Comp from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType @@ -9,7 +9,7 @@ class FileIDExtractor: """从 API 响应中提取文件 ID 的帮助类(无状态)。""" @staticmethod - def extract_file_id(result: Any) -> Optional[str]: + def extract_file_id(result: Any) -> str | None: if not isinstance(result, dict): return None @@ -34,8 +34,8 @@ class MessagePayloadBuilder: @staticmethod def build_chat_payload( - user_id: str, text: Optional[str], file_id: Optional[str] = None - ) -> Dict[str, Any]: + user_id: str, text: str | None, file_id: str | None = None + ) -> dict[str, Any]: payload = {"toUserId": user_id} if text: payload["text"] = text @@ -45,8 +45,8 @@ def build_chat_payload( @staticmethod def build_room_payload( - room_id: str, text: Optional[str], file_id: Optional[str] = None - ) -> Dict[str, Any]: + room_id: str, text: str | None, file_id: str | None = None + ) -> dict[str, Any]: payload = {"toRoomId": room_id} if text: payload["text"] = text @@ -56,9 +56,9 @@ def build_room_payload( @staticmethod def build_note_payload( - text: Optional[str], file_ids: Optional[List[str]] = None, **kwargs - ) -> Dict[str, Any]: - payload: Dict[str, Any] = {} + text: str | None, file_ids: list[str] | None = None, **kwargs + ) -> dict[str, Any]: + payload: dict[str, Any] = {} if text: payload["text"] = text if file_ids: @@ -67,7 +67,7 @@ def build_note_payload( return payload -def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]: +def serialize_message_chain(chain: list[Any]) -> tuple[str, bool]: """将消息链序列化为文本字符串""" text_parts = [] has_at = False @@ -113,12 +113,12 @@ def process_component(component): def resolve_message_visibility( - user_id: Optional[str] = None, - user_cache: Optional[Dict[str, Any]] = None, - self_id: Optional[str] = None, - raw_message: Optional[Dict[str, Any]] = None, + user_id: str | None = None, + user_cache: dict[str, Any] | None = None, + self_id: str | None = None, + raw_message: dict[str, Any] | None = None, default_visibility: str = "public", -) -> Tuple[str, Optional[List[str]]]: +) -> tuple[str, list[str] | None]: """解析 Misskey 消息的可见性设置 可以从 user_cache 或 raw_message 中解析,支持两种调用方式: @@ -169,13 +169,13 @@ def resolve_message_visibility( # 保留旧函数名作为向后兼容的别名 def resolve_visibility_from_raw_message( - raw_message: Dict[str, Any], self_id: Optional[str] = None -) -> Tuple[str, Optional[List[str]]]: + raw_message: dict[str, Any], self_id: str | None = None +) -> tuple[str, list[str] | None]: """从原始消息数据中解析可见性设置(已弃用,使用 resolve_message_visibility 替代)""" return resolve_message_visibility(raw_message=raw_message, self_id=self_id) -def is_valid_user_session_id(session_id: Union[str, Any]) -> bool: +def is_valid_user_session_id(session_id: str | Any) -> bool: """检查 session_id 是否是有效的聊天用户 session_id (仅限chat%前缀)""" if not isinstance(session_id, str) or "%" not in session_id: return False @@ -189,7 +189,7 @@ def is_valid_user_session_id(session_id: Union[str, Any]) -> bool: ) -def is_valid_room_session_id(session_id: Union[str, Any]) -> bool: +def is_valid_room_session_id(session_id: str | Any) -> bool: """检查 session_id 是否是有效的房间 session_id (仅限room%前缀)""" if not isinstance(session_id, str) or "%" not in session_id: return False @@ -203,7 +203,7 @@ def is_valid_room_session_id(session_id: Union[str, Any]) -> bool: ) -def is_valid_chat_session_id(session_id: Union[str, Any]) -> bool: +def is_valid_chat_session_id(session_id: str | Any) -> bool: """检查 session_id 是否是有效的聊天 session_id (仅限chat%前缀)""" if not isinstance(session_id, str) or "%" not in session_id: return False @@ -236,7 +236,7 @@ def extract_room_id_from_session_id(session_id: str) -> str: def add_at_mention_if_needed( - text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False + text: str, user_info: dict[str, Any] | None, has_at: bool = False ) -> str: """如果需要且没有@用户,则添加@用户 @@ -258,7 +258,7 @@ def add_at_mention_if_needed( return text -def create_file_component(file_info: Dict[str, Any]) -> Tuple[Any, str]: +def create_file_component(file_info: dict[str, Any]) -> tuple[Any, str]: """创建文件组件和描述文本""" file_url = file_info.get("url", "") file_name = file_info.get("name", "未知文件") @@ -287,7 +287,7 @@ def process_files( return file_parts -def format_poll(poll: Dict[str, Any]) -> str: +def format_poll(poll: dict[str, Any]) -> str: """将 Misskey 的 poll 对象格式化为可读字符串。""" if not poll or not isinstance(poll, dict): return "" @@ -304,8 +304,8 @@ def format_poll(poll: Dict[str, Any]) -> str: def extract_sender_info( - raw_data: Dict[str, Any], is_chat: bool = False -) -> Dict[str, Any]: + raw_data: dict[str, Any], is_chat: bool = False +) -> dict[str, Any]: """提取发送者信息""" if is_chat: sender = raw_data.get("fromUser", {}) @@ -323,11 +323,11 @@ def extract_sender_info( def create_base_message( - raw_data: Dict[str, Any], - sender_info: Dict[str, Any], + raw_data: dict[str, Any], + sender_info: dict[str, Any], client_self_id: str, is_chat: bool = False, - room_id: Optional[str] = None, + room_id: str | None = None, unique_session: bool = False, ) -> AstrBotMessage: """创建基础消息对象""" @@ -367,7 +367,7 @@ def create_base_message( def process_at_mention( message: AstrBotMessage, raw_text: str, bot_username: str, client_self_id: str -) -> Tuple[List[str], str]: +) -> tuple[list[str], str]: """处理@提及逻辑,返回消息部分列表和处理后的文本""" message_parts = [] @@ -389,9 +389,9 @@ def process_at_mention( def cache_user_info( - user_cache: Dict[str, Any], - sender_info: Dict[str, Any], - raw_data: Dict[str, Any], + user_cache: dict[str, Any], + sender_info: dict[str, Any], + raw_data: dict[str, Any], client_self_id: str, is_chat: bool = False, ): @@ -417,7 +417,7 @@ def cache_user_info( def cache_room_info( - user_cache: Dict[str, Any], raw_data: Dict[str, Any], client_self_id: str + user_cache: dict[str, Any], raw_data: dict[str, Any], client_self_id: str ): """缓存房间信息""" room_data = raw_data.get("toRoom") @@ -437,7 +437,7 @@ def cache_room_info( async def resolve_component_url_or_path( comp: Any, -) -> Tuple[Optional[str], Optional[str]]: +) -> tuple[str | None, str | None]: """尝试从组件解析可上传的远程 URL 或本地路径。 返回 (url_candidate, local_path)。两者可能都为 None。 @@ -503,7 +503,7 @@ async def _get_str_value(coro_or_val): return url_candidate, local_path -def summarize_component_for_log(comp: Any) -> Dict[str, Any]: +def summarize_component_for_log(comp: Any) -> dict[str, Any]: """生成适合日志的组件属性字典(尽量不抛异常)。""" attrs = {} for a in ("file", "url", "path", "src", "source", "name"): @@ -519,9 +519,9 @@ def summarize_component_for_log(comp: Any) -> Dict[str, Any]: async def upload_local_with_retries( api: Any, local_path: str, - preferred_name: Optional[str], - folder_id: Optional[str], -) -> Optional[str]: + preferred_name: str | None, + folder_id: str | None, +) -> str | None: """尝试本地上传,返回 file id 或 None。如果文件类型不允许则直接失败。""" try: res = await api.upload_file(local_path, preferred_name, folder_id) diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 2096237ce..8dc061812 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -16,7 +16,6 @@ from astrbot.api import logger from botpy.types.message import Media from botpy.types import message -from typing import Optional import random import uuid import os @@ -196,7 +195,7 @@ async def upload_group_and_c2c_image( async def upload_group_and_c2c_record( self, file_source: str, file_type: int, srv_send_msg: bool = False, **kwargs - ) -> Optional[Media]: + ) -> Media | None: """ 上传媒体文件 """ diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index d5285f759..9acd822b3 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -19,7 +19,6 @@ ) from astrbot import logger from astrbot.api.event import MessageChain -from typing import Union, List from astrbot.api.message_components import Image, Plain, At from astrbot.core.platform.astr_message_event import MessageSesion from .qqofficial_message_event import QQOfficialMessageEvent @@ -33,7 +32,7 @@ # QQ 机器人官方框架 class botClient(Client): - def set_platform(self, platform: "QQOfficialPlatformAdapter"): + def set_platform(self, platform: QQOfficialPlatformAdapter): self.platform = platform # 收到群消息 @@ -133,7 +132,7 @@ def meta(self) -> PlatformMetadata: @staticmethod def _parse_from_qqofficial( - message: Union[botpy.message.Message, botpy.message.GroupMessage], + message: botpy.message.Message | botpy.message.GroupMessage, message_type: MessageType, ): abm = AstrBotMessage() @@ -142,7 +141,7 @@ def _parse_from_qqofficial( abm.raw_message = message abm.message_id = message.id abm.tag = "qq_official" - msg: List[BaseMessageComponent] = [] + msg: list[BaseMessageComponent] = [] if isinstance(message, botpy.message.GroupMessage) or isinstance( message, botpy.message.C2CMessage diff --git a/astrbot/core/platform/sources/satori/satori_adapter.py b/astrbot/core/platform/sources/satori/satori_adapter.py index a3f4f53ec..06de8bd7f 100644 --- a/astrbot/core/platform/sources/satori/satori_adapter.py +++ b/astrbot/core/platform/sources/satori/satori_adapter.py @@ -3,7 +3,6 @@ import time import websockets from websockets.asyncio.client import connect -from typing import Optional from aiohttp import ClientSession, ClientTimeout from websockets.asyncio.client import ClientConnection from astrbot.api import logger @@ -57,12 +56,12 @@ def __init__( id=self.config["id"], ) - self.ws: Optional[ClientConnection] = None - self.session: Optional[ClientSession] = None + self.ws: ClientConnection | None = None + self.session: ClientSession | None = None self.sequence = 0 self.logins = [] self.running = False - self.heartbeat_task: Optional[asyncio.Task] = None + self.heartbeat_task: asyncio.Task | None = None self.ready_received = False async def send_by_session( @@ -295,10 +294,10 @@ async def convert_satori_message( message: dict, user: dict, channel: dict, - guild: Optional[dict], + guild: dict | None, login: dict, - timestamp: Optional[int] = None, - ) -> Optional[AstrBotMessage]: + timestamp: int | None = None, + ) -> AstrBotMessage | None: try: abm = AstrBotMessage() abm.message_id = message.get("id", "") @@ -438,7 +437,7 @@ def _extract_namespace_prefixes(self, content: str) -> set: return prefixes - async def _extract_quote_element(self, content: str) -> Optional[dict]: + async def _extract_quote_element(self, content: str) -> dict | None: """提取标签信息""" try: # 处理命名空间前缀问题 @@ -506,7 +505,7 @@ async def _extract_quote_element(self, content: str) -> Optional[dict]: logger.error(f"提取标签时发生错误: {e}") return None - async def _extract_quote_with_regex(self, content: str) -> Optional[dict]: + async def _extract_quote_with_regex(self, content: str) -> dict | None: """使用正则表达式提取quote标签信息""" import re @@ -529,7 +528,7 @@ async def _extract_quote_with_regex(self, content: str) -> Optional[dict]: "content_without_quote": content_without_quote, } - async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]: + async def _convert_quote_message(self, quote: dict) -> AstrBotMessage | None: """转换引用消息""" try: quote_abm = AstrBotMessage() diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py index 7877e4f52..74b4570ff 100644 --- a/astrbot/core/platform/sources/slack/client.py +++ b/astrbot/core/platform/sources/slack/client.py @@ -3,7 +3,8 @@ import hashlib import asyncio import logging -from typing import Callable, Optional + +from collections.abc import Callable from quart import Quart, request, Response from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.socket_mode.aiohttp import SocketModeClient @@ -22,7 +23,7 @@ def __init__( host: str = "0.0.0.0", port: int = 3000, path: str = "/slack/events", - event_handler: Optional[Callable] = None, + event_handler: Callable | None = None, ): self.web_client = web_client self.signing_secret = signing_secret @@ -119,7 +120,7 @@ def __init__( self, web_client: AsyncWebClient, app_token: str, - event_handler: Optional[Callable] = None, + event_handler: Callable | None = None, ): self.web_client = web_client self.app_token = app_token diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index 7e75f3c20..4ad8433cf 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -4,7 +4,9 @@ import aiohttp import re import base64 -from typing import Awaitable, Any +from typing import Any + +from collections.abc import Awaitable from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.socket_mode.request import SocketModeRequest from astrbot.api.platform import ( diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py index 86f9f9764..692c03be2 100644 --- a/astrbot/core/platform/sources/slack/slack_event.py +++ b/astrbot/core/platform/sources/slack/slack_event.py @@ -1,6 +1,6 @@ import asyncio import re -from typing import AsyncGenerator +from collections.abc import AsyncGenerator from slack_sdk.web.async_client import AsyncWebClient from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import ( diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index faec122ac..16de7edfa 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -2,7 +2,9 @@ import asyncio import uuid import os -from typing import Awaitable, Any, Callable +from typing import Any + +from collections.abc import Awaitable, Callable from astrbot.core.platform import ( Platform, AstrBotMessage, diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py index 6b835ecb5..be979a78e 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py @@ -4,7 +4,6 @@ import os import traceback import time -from typing import Optional import aiohttp import anyio @@ -137,7 +136,7 @@ def load_credentials(self): """ if os.path.exists(self.credentials_file): try: - with open(self.credentials_file, "r") as f: + with open(self.credentials_file) as f: credentials = json.load(f) logger.info("成功加载 WeChatPadPro 凭据。") return credentials @@ -540,7 +539,7 @@ async def _process_chat_type( async def _get_group_member_nickname( self, group_id: str, member_wxid: str - ) -> Optional[str]: + ) -> str | None: """ 通过接口获取群成员的昵称。 """ @@ -896,7 +895,7 @@ async def get_contact_list(self): async def get_contact_details_list( self, room_wx_id_list: list[str] = None, user_names: list[str] = None - ) -> Optional[dict]: + ) -> dict | None: """ 获取联系人详情列表。 """ diff --git a/astrbot/core/platform/sources/wecom/wecom_kf.py b/astrbot/core/platform/sources/wecom/wecom_kf.py index 118667975..8d838da97 100644 --- a/astrbot/core/platform/sources/wecom/wecom_kf.py +++ b/astrbot/core/platform/sources/wecom/wecom_kf.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ The MIT License (MIT) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py index 5332942b9..4a0508051 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- encoding:utf-8 -*- """对企业微信发送给企业后台的消息加解密示例代码. @copyright: Copyright (c) 1998-2020 Tencent Inc. @@ -136,7 +135,7 @@ def decode(self, decrypted): return decrypted[:-pad] -class Prpcrypt(object): +class Prpcrypt: """提供接收和推送给企业微信消息的加解密接口""" def __init__(self, key): @@ -210,7 +209,7 @@ def get_random_str(self): return str(random.randint(1000000000000000, 9999999999999999)).encode() -class WXBizJsonMsgCrypt(object): +class WXBizJsonMsgCrypt: # 构造函数 def __init__(self, sToken, sEncodingAESKey, sReceiveId): try: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/ierror.py b/astrbot/core/platform/sources/wecom_ai_bot/ierror.py index cc1bf221e..0df14a505 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/ierror.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/ierror.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- ######################################################################### # Author: jonyqin # Created Time: Thu 11 Sep 2014 01:53:58 PM CST diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 830d8de58..3198dc28f 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -9,7 +9,9 @@ import uuid import hashlib import base64 -from typing import Awaitable, Any, Dict, Optional, Callable +from typing import Any + +from collections.abc import Awaitable, Callable from astrbot.api.platform import ( @@ -151,8 +153,8 @@ async def _handle_queued_message(self, data: dict): logger.error(f"处理队列消息时发生异常: {e}") async def _process_message( - self, message_data: Dict[str, Any], callback_params: Dict[str, str] - ) -> Optional[str]: + self, message_data: dict[str, Any], callback_params: dict[str, str] + ) -> str | None: """处理接收到的消息 Args: @@ -278,15 +280,15 @@ async def _process_message( return None pass - def _extract_session_id(self, message_data: Dict[str, Any]) -> str: + def _extract_session_id(self, message_data: dict[str, Any]) -> str: """从消息数据中提取会话ID""" user_id = message_data.get("from", {}).get("userid", "default_user") return format_session_id("wecomai", user_id) async def _enqueue_message( self, - message_data: Dict[str, Any], - callback_params: Dict[str, str], + message_data: dict[str, Any], + callback_params: dict[str, str], stream_id: str, session_id: str, ): diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py index 540bf06b6..ea34dee32 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py @@ -6,7 +6,7 @@ import json import base64 import hashlib -from typing import Dict, Any, Optional, Tuple, Union +from typing import Any from Crypto.Cipher import AES import aiohttp @@ -31,7 +31,7 @@ def __init__(self, token: str, encoding_aes_key: str): async def decrypt_message( self, encrypted_data: bytes, msg_signature: str, timestamp: str, nonce: str - ) -> Tuple[int, Optional[Dict[str, Any]]]: + ) -> tuple[int, dict[str, Any] | None]: """解密企业微信消息 Args: @@ -71,7 +71,7 @@ async def decrypt_message( async def encrypt_message( self, plain_message: str, nonce: str, timestamp: str - ) -> Optional[str]: + ) -> str | None: """加密消息 Args: @@ -127,8 +127,8 @@ def verify_url( return "verify fail" async def process_encrypted_image( - self, image_url: str, aes_key_base64: Optional[str] = None - ) -> Tuple[bool, Union[bytes, str]]: + self, image_url: str, aes_key_base64: str | None = None + ) -> tuple[bool, bytes | str]: """下载并解密加密图片 Args: @@ -292,7 +292,7 @@ class WecomAIBotMessageParser: """企业微信智能机器人消息解析器""" @staticmethod - def parse_text_message(data: Dict[str, Any]) -> Optional[str]: + def parse_text_message(data: dict[str, Any]) -> str | None: """解析文本消息 Args: @@ -308,7 +308,7 @@ def parse_text_message(data: Dict[str, Any]) -> Optional[str]: return None @staticmethod - def parse_image_message(data: Dict[str, Any]) -> Optional[str]: + def parse_image_message(data: dict[str, Any]) -> str | None: """解析图片消息 Args: @@ -324,7 +324,7 @@ def parse_image_message(data: Dict[str, Any]) -> Optional[str]: return None @staticmethod - def parse_stream_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def parse_stream_message(data: dict[str, Any]) -> dict[str, Any] | None: """解析流消息 Args: @@ -346,7 +346,7 @@ def parse_stream_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: return None @staticmethod - def parse_mixed_message(data: Dict[str, Any]) -> Optional[list]: + def parse_mixed_message(data: dict[str, Any]) -> list | None: """解析混合消息 Args: @@ -362,7 +362,7 @@ def parse_mixed_message(data: Dict[str, Any]) -> Optional[list]: return None @staticmethod - def parse_event_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def parse_event_message(data: dict[str, Any]) -> dict[str, Any] | None: """解析事件消息 Args: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py index 1367301c9..3086e9513 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py @@ -5,7 +5,7 @@ """ import asyncio -from typing import Dict, Any, Optional +from typing import Any from astrbot.api import logger @@ -13,13 +13,13 @@ class WecomAIQueueMgr: """企业微信智能机器人队列管理器""" def __init__(self) -> None: - self.queues: Dict[str, asyncio.Queue] = {} + self.queues: dict[str, asyncio.Queue] = {} """StreamID 到输入队列的映射 - 用于接收用户消息""" - self.back_queues: Dict[str, asyncio.Queue] = {} + self.back_queues: dict[str, asyncio.Queue] = {} """StreamID 到输出队列的映射 - 用于发送机器人响应""" - self.pending_responses: Dict[str, Dict[str, Any]] = {} + self.pending_responses: dict[str, dict[str, Any]] = {} """待处理的响应缓存,用于流式响应""" def get_or_create_queue(self, session_id: str) -> asyncio.Queue: @@ -90,7 +90,7 @@ def has_back_queue(self, session_id: str) -> bool: """ return session_id in self.back_queues - def set_pending_response(self, session_id: str, callback_params: Dict[str, str]): + def set_pending_response(self, session_id: str, callback_params: dict[str, str]): """设置待处理的响应参数 Args: @@ -103,7 +103,7 @@ def set_pending_response(self, session_id: str, callback_params: Dict[str, str]) } logger.debug(f"[WecomAI] 设置待处理响应: {session_id}") - def get_pending_response(self, session_id: str) -> Optional[Dict[str, Any]]: + def get_pending_response(self, session_id: str) -> dict[str, Any] | None: """获取待处理的响应参数 Args: @@ -131,7 +131,7 @@ def cleanup_expired_responses(self, max_age_seconds: int = 300): del self.pending_responses[session_id] logger.debug(f"[WecomAI] 清理过期响应: {session_id}") - def get_stats(self) -> Dict[str, int]: + def get_stats(self) -> dict[str, int]: """获取队列统计信息 Returns: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py index bbb69d041..27b78dd78 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -4,7 +4,9 @@ """ import asyncio -from typing import Dict, Any, Optional, Callable +from typing import Any + +from collections.abc import Callable import quart from astrbot.api import logger @@ -21,9 +23,8 @@ def __init__( host: str, port: int, api_client: WecomAIBotAPIClient, - message_handler: Optional[ - Callable[[Dict[str, Any], Dict[str, str]], Any] - ] = None, + message_handler: None + | (Callable[[dict[str, Any], dict[str, str]], Any]) = None, ): """初始化服务器 diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py index dccb2e260..cc4361581 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py @@ -10,7 +10,7 @@ import aiohttp import asyncio from Crypto.Cipher import AES -from typing import Any, Tuple +from typing import Any from astrbot.api import logger @@ -91,7 +91,7 @@ def format_session_id(session_type: str, session_id: str) -> str: return f"wecom_ai_bot_{session_type}_{session_id}" -def parse_session_id(formatted_session_id: str) -> Tuple[str, str]: +def parse_session_id(formatted_session_id: str) -> tuple[str, str]: """解析格式化的会话 ID Args: @@ -145,7 +145,7 @@ def format_error_response(error_code: int, error_msg: str) -> str: async def process_encrypted_image( image_url: str, aes_key_base64: str -) -> Tuple[bool, str]: +) -> tuple[bool, str]: """下载并解密加密图片 Args: From b71ba19be3cf739b3eb5ba92af296c0fb940d059 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Thu, 30 Oct 2025 18:05:46 +0800 Subject: [PATCH 20/44] chore(core.provider): ruff rewrite --- astrbot/core/provider/entities.py | 30 +++++++++---------- astrbot/core/provider/func_tool_manager.py | 16 +++++----- astrbot/core/provider/provider.py | 9 +++--- astrbot/core/provider/register.py | 5 ++-- .../core/provider/sources/anthropic_source.py | 9 +++--- .../core/provider/sources/azure_tts_source.py | 5 ++-- .../core/provider/sources/coze_api_client.py | 8 +++-- astrbot/core/provider/sources/coze_source.py | 7 +++-- .../core/provider/sources/dashscope_tts.py | 9 +++--- .../core/provider/sources/gemini_source.py | 13 ++++---- .../sources/minimax_tts_api_source.py | 7 +++-- .../core/provider/sources/openai_source.py | 13 ++++---- 12 files changed, 66 insertions(+), 65 deletions(-) diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 85687c417..f59365283 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -4,7 +4,7 @@ from astrbot.core.utils.io import download_image_by_url from astrbot import logger from dataclasses import dataclass, field -from typing import List, Dict, Type, Any +from typing import Any from astrbot.core.agent.tool import ToolSet from openai.types.chat.chat_completion import ChatCompletion from google.genai.types import GenerateContentResponse @@ -32,7 +32,7 @@ class ProviderMetaData: desc: str = "" """提供商适配器描述.""" provider_type: ProviderType = ProviderType.CHAT_COMPLETION - cls_type: Type | None = None + cls_type: type | None = None default_config_tmpl: dict | None = None """平台的默认配置模板""" @@ -61,7 +61,7 @@ class AssistantMessageSegment: """OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling""" content: str | None = None - tool_calls: List[ChatCompletionMessageToolCall | Dict] = field(default_factory=list) + tool_calls: list[ChatCompletionMessageToolCall | dict] = field(default_factory=list) role: str = "assistant" def to_dict(self): @@ -84,10 +84,10 @@ class ToolCallsResult: tool_calls_info: AssistantMessageSegment """函数调用的信息""" - tool_calls_result: List[ToolCallMessageSegment] + tool_calls_result: list[ToolCallMessageSegment] """函数调用的结果""" - def to_openai_messages(self) -> List[Dict]: + def to_openai_messages(self) -> list[dict]: ret = [ self.tool_calls_info.to_dict(), *[item.to_dict() for item in self.tool_calls_result], @@ -175,7 +175,7 @@ def _print_friendly_context(self): return result_parts - async def assemble_context(self) -> Dict: + async def assemble_context(self) -> dict: """将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。""" if self.image_urls: user_content = { @@ -219,15 +219,15 @@ class LLMResponse: """角色, assistant, tool, err""" result_chain: MessageChain | None = None """返回的消息链""" - tools_call_args: List[Dict[str, Any]] = field(default_factory=list) + tools_call_args: list[dict[str, Any]] = field(default_factory=list) """工具调用参数""" - tools_call_name: List[str] = field(default_factory=list) + tools_call_name: list[str] = field(default_factory=list) """工具调用名称""" - tools_call_ids: List[str] = field(default_factory=list) + tools_call_ids: list[str] = field(default_factory=list) """工具调用 ID""" raw_completion: ChatCompletion | GenerateContentResponse | Message | None = None - _new_record: Dict[str, Any] | None = None + _new_record: dict[str, Any] | None = None _completion_text: str = "" @@ -239,11 +239,11 @@ def __init__( role: str, completion_text: str = "", result_chain: MessageChain | None = None, - tools_call_args: List[Dict[str, Any]] | None = None, - tools_call_name: List[str] | None = None, - tools_call_ids: List[str] | None = None, + tools_call_args: list[dict[str, Any]] | None = None, + tools_call_name: list[str] | None = None, + tools_call_ids: list[str] | None = None, raw_completion: ChatCompletion | None = None, - _new_record: Dict[str, Any] | None = None, + _new_record: dict[str, Any] | None = None, is_chunk: bool = False, ): """初始化 LLMResponse @@ -291,7 +291,7 @@ def completion_text(self, value): else: self._completion_text = value - def to_openai_tool_calls(self) -> List[Dict]: + def to_openai_tool_calls(self) -> list[dict]: """将工具调用信息转换为 OpenAI 格式""" ret = [] for idx, tool_call_arg in enumerate(self.tools_call_args): diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 51cde0eb9..ceba07076 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -4,7 +4,9 @@ import asyncio import aiohttp -from typing import Dict, List, Awaitable, Callable, Any +from typing import Any + +from collections.abc import Awaitable, Callable from astrbot import logger from astrbot.core import sp @@ -96,10 +98,10 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: class FunctionToolManager: def __init__(self) -> None: - self.func_list: List[FuncTool] = [] - self.mcp_client_dict: Dict[str, MCPClient] = {} + self.func_list: list[FuncTool] = [] + self.mcp_client_dict: dict[str, MCPClient] = {} """MCP 服务列表""" - self.mcp_client_event: Dict[str, asyncio.Event] = {} + self.mcp_client_event: dict[str, asyncio.Event] = {} def empty(self) -> bool: return len(self.func_list) == 0 @@ -202,8 +204,8 @@ async def init_mcp_clients(self) -> None: logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}") return - mcp_server_json_obj: Dict[str, Dict] = json.load( - open(mcp_json_file, "r", encoding="utf-8") + mcp_server_json_obj: dict[str, dict] = json.load( + open(mcp_json_file, encoding="utf-8") )["mcpServers"] for name in mcp_server_json_obj.keys(): @@ -479,7 +481,7 @@ def load_mcp_config(self): return DEFAULT_MCP_CONFIG try: - with open(self.mcp_config_path, "r", encoding="utf-8") as f: + with open(self.mcp_config_path, encoding="utf-8") as f: return json.load(f) except Exception as e: logger.error(f"加载 MCP 配置失败: {e}") diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 9953e9f17..8e48a11a3 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,7 +1,6 @@ import abc import asyncio -from typing import List -from typing import AsyncGenerator +from collections.abc import AsyncGenerator from astrbot.core.agent.tool import ToolSet from astrbot.core.provider.entities import ( LLMResponse, @@ -67,7 +66,7 @@ def __init__( def get_current_key(self) -> str: raise NotImplementedError() - def get_keys(self) -> List[str]: + def get_keys(self) -> list[str]: """获得提供商 Key""" keys = self.provider_config.get("key", [""]) return keys or [""] @@ -77,7 +76,7 @@ def set_key(self, key: str): raise NotImplementedError() @abc.abstractmethod - async def get_models(self) -> List[str]: + async def get_models(self) -> list[str]: """获得支持的模型列表""" raise NotImplementedError() @@ -140,7 +139,7 @@ async def text_chat_stream( """ ... - async def pop_record(self, context: List): + async def pop_record(self, context: list): """ 弹出 context 第一条非系统提示词对话记录 """ diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index 02d7934d1..31b82bb8a 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -1,11 +1,10 @@ -from typing import List, Dict from .entities import ProviderMetaData, ProviderType from astrbot.core import logger from .func_tool_manager import FuncCall -provider_registry: List[ProviderMetaData] = [] +provider_registry: list[ProviderMetaData] = [] """维护了通过装饰器注册的 Provider""" -provider_cls_map: Dict[str, ProviderMetaData] = {} +provider_cls_map: dict[str, ProviderMetaData] = {} """维护了 Provider 类型名称和 ProviderMetadata 的映射""" llm_tools = FuncCall() diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index cd4206ce7..8b3a8772a 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -1,7 +1,6 @@ import json import anthropic import base64 -from typing import List from mimetypes import guess_type from anthropic import AsyncAnthropic @@ -13,7 +12,7 @@ from astrbot.core.provider.func_tool_manager import ToolSet from ..register import register_provider_adapter from astrbot.core.provider.entities import LLMResponse -from typing import AsyncGenerator +from collections.abc import AsyncGenerator @register_provider_adapter( @@ -33,7 +32,7 @@ def __init__( ) self.chosen_api_key: str = "" - self.api_keys: List = super().get_keys() + self.api_keys: list = super().get_keys() self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else "" self.base_url = provider_config.get("api_base", "https://api.anthropic.com") self.timeout = provider_config.get("timeout", 120) @@ -326,7 +325,7 @@ async def text_chat_stream( async for llm_response in self._query_stream(payloads, func_tool): yield llm_response - async def assemble_context(self, text: str, image_urls: List[str] | None = None): + async def assemble_context(self, text: str, image_urls: list[str] | None = None): """组装上下文,支持文本和图片""" if not image_urls: return {"role": "user", "content": text} @@ -384,7 +383,7 @@ async def encode_image_bs64(self, image_url: str) -> str: def get_current_key(self) -> str: return self.chosen_api_key - async def get_models(self) -> List[str]: + async def get_models(self) -> list[str]: models_str = [] models = await self.client.models.list() models = sorted(models.data, key=lambda x: x.id) diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py index 6ddf452d4..5ef101b20 100644 --- a/astrbot/core/provider/sources/azure_tts_source.py +++ b/astrbot/core/provider/sources/azure_tts_source.py @@ -6,7 +6,6 @@ import random import asyncio from pathlib import Path -from typing import Dict from xml.sax.saxutils import escape from httpx import AsyncClient, Timeout @@ -21,7 +20,7 @@ class OTTSProvider: - def __init__(self, config: Dict): + def __init__(self, config: dict): self.skey = config["OTTS_SKEY"] self.api_url = config["OTTS_URL"] self.auth_time_url = config["OTTS_AUTH_TIME"] @@ -58,7 +57,7 @@ async def _generate_signature(self) -> str: path = re.sub(r"^https?://[^/]+", "", self.api_url) or "/" return f"{timestamp}-{nonce}-0-{hashlib.md5(f'{path}-{timestamp}-{nonce}-0-{self.skey}'.encode()).hexdigest()}" - async def get_audio(self, text: str, voice_params: Dict) -> str: + async def get_audio(self, text: str, voice_params: dict) -> str: file_path = TEMP_DIR / f"otts-{uuid.uuid4()}.wav" signature = await self._generate_signature() for attempt in range(self.retry_count): diff --git a/astrbot/core/provider/sources/coze_api_client.py b/astrbot/core/provider/sources/coze_api_client.py index a768979c6..3fe02d8e5 100644 --- a/astrbot/core/provider/sources/coze_api_client.py +++ b/astrbot/core/provider/sources/coze_api_client.py @@ -2,7 +2,9 @@ import asyncio import aiohttp import io -from typing import Dict, List, Any, AsyncGenerator +from typing import Any + +from collections.abc import AsyncGenerator from astrbot.core import logger @@ -117,12 +119,12 @@ async def chat_messages( self, bot_id: str, user_id: str, - additional_messages: List[Dict] | None = None, + additional_messages: list[dict] | None = None, conversation_id: str | None = None, auto_save_history: bool = True, stream: bool = True, timeout: float = 120, - ) -> AsyncGenerator[Dict[str, Any], None]: + ) -> AsyncGenerator[dict[str, Any], None]: """发送聊天消息并返回流式响应 Args: diff --git a/astrbot/core/provider/sources/coze_source.py b/astrbot/core/provider/sources/coze_source.py index 639af0814..93887900c 100644 --- a/astrbot/core/provider/sources/coze_source.py +++ b/astrbot/core/provider/sources/coze_source.py @@ -2,7 +2,8 @@ import os import base64 import hashlib -from typing import AsyncGenerator, Dict + +from collections.abc import AsyncGenerator from astrbot.core.message.message_event_result import MessageChain import astrbot.core.message.components as Comp from astrbot.api.provider import Provider @@ -44,8 +45,8 @@ def __init__( if isinstance(self.timeout, str): self.timeout = int(self.timeout) self.auto_save_history = provider_config.get("auto_save_history", True) - self.conversation_ids: Dict[str, str] = {} - self.file_id_cache: Dict[str, Dict[str, str]] = {} + self.conversation_ids: dict[str, str] = {} + self.file_id_cache: dict[str, dict[str, str]] = {} # 创建 API 客户端 self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base) diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py index efda31ca9..54e043f6a 100644 --- a/astrbot/core/provider/sources/dashscope_tts.py +++ b/astrbot/core/provider/sources/dashscope_tts.py @@ -3,7 +3,6 @@ import logging import os import uuid -from typing import Optional, Tuple import aiohttp import dashscope from dashscope.audio.tts_v2 import AudioFormat, SpeechSynthesizer @@ -80,7 +79,7 @@ def _call_qwen_tts(self, model: str, text: str): async def _synthesize_with_qwen_tts( self, model: str, text: str - ) -> Tuple[Optional[bytes], str]: + ) -> tuple[bytes | None, str]: loop = asyncio.get_event_loop() response = await loop.run_in_executor(None, self._call_qwen_tts, model, text) audio_bytes = await self._extract_audio_from_response(response) @@ -91,7 +90,7 @@ async def _synthesize_with_qwen_tts( ext = ".wav" return audio_bytes, ext - async def _extract_audio_from_response(self, response) -> Optional[bytes]: + async def _extract_audio_from_response(self, response) -> bytes | None: output = getattr(response, "output", None) audio_obj = getattr(output, "audio", None) if output is not None else None if not audio_obj: @@ -110,7 +109,7 @@ async def _extract_audio_from_response(self, response) -> Optional[bytes]: return await self._download_audio_from_url(url) return None - async def _download_audio_from_url(self, url: str) -> Optional[bytes]: + async def _download_audio_from_url(self, url: str) -> bytes | None: if not url: return None timeout = max(self.timeout_ms / 1000, 1) if self.timeout_ms else 20 @@ -126,7 +125,7 @@ async def _download_audio_from_url(self, url: str) -> Optional[bytes]: async def _synthesize_with_cosyvoice( self, model: str, text: str - ) -> Tuple[Optional[bytes], str]: + ) -> tuple[bytes | None, str]: synthesizer = SpeechSynthesizer( model=model, voice=self.voice, diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index b14a9bdcb..1865741e0 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -3,7 +3,6 @@ import json import logging import random -from typing import Optional, List from collections.abc import AsyncGenerator from google import genai @@ -60,11 +59,11 @@ def __init__( provider_settings, default_persona, ) - self.api_keys: List = super().get_keys() + self.api_keys: list = super().get_keys() self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else "" self.timeout: int = int(provider_config.get("timeout", 180)) - self.api_base: Optional[str] = provider_config.get("api_base", None) + self.api_base: str | None = provider_config.get("api_base", None) if self.api_base and self.api_base.endswith("/"): self.api_base = self.api_base[:-1] @@ -122,9 +121,9 @@ async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool: async def _prepare_query_config( self, payloads: dict, - tools: Optional[ToolSet] = None, - system_instruction: Optional[str] = None, - modalities: Optional[list[str]] = None, + tools: ToolSet | None = None, + system_instruction: str | None = None, + modalities: list[str] | None = None, temperature: float = 0.7, ) -> types.GenerateContentConfig: """准备查询配置""" @@ -406,7 +405,7 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: conversation = self._prepare_conversation(payloads) temperature = payloads.get("temperature", 0.7) - result: Optional[types.GenerateContentResponse] = None + result: types.GenerateContentResponse | None = None while True: try: config = await self._prepare_query_config( diff --git a/astrbot/core/provider/sources/minimax_tts_api_source.py b/astrbot/core/provider/sources/minimax_tts_api_source.py index 5b210835b..66901497f 100644 --- a/astrbot/core/provider/sources/minimax_tts_api_source.py +++ b/astrbot/core/provider/sources/minimax_tts_api_source.py @@ -2,7 +2,8 @@ import os import uuid import aiohttp -from typing import Dict, List, Union, AsyncIterator + +from collections.abc import AsyncIterator from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.api import logger from ..entities import ProviderType @@ -30,7 +31,7 @@ def __init__( self.is_timber_weight: bool = provider_config.get( "minimax-is-timber-weight", False ) - self.timber_weight: List[Dict[str, Union[str, int]]] = json.loads( + self.timber_weight: list[dict[str, str | int]] = json.loads( provider_config.get( "minimax-timber-weight", '[{"voice_id": "Chinese (Mandarin)_Warm_Girl", "weight": 1}]', @@ -66,7 +67,7 @@ def __init__( def _build_tts_stream_body(self, text: str): """构建流式请求体""" - dict_body: Dict[str, object] = { + dict_body: dict[str, object] = { "model": self.model_name, "text": text, "stream": True, diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 09c284acb..6174ca9d9 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -17,7 +17,8 @@ from astrbot.api.provider import Provider from astrbot import logger from astrbot.core.provider.func_tool_manager import ToolSet -from typing import List, AsyncGenerator + +from collections.abc import AsyncGenerator from ..register import register_provider_adapter from astrbot.core.provider.entities import LLMResponse, ToolCallsResult @@ -38,7 +39,7 @@ def __init__( default_persona, ) self.chosen_api_key = None - self.api_keys: List = super().get_keys() + self.api_keys: list = super().get_keys() self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): @@ -280,7 +281,7 @@ async def _handle_api_error( context_query: list, func_tool: ToolSet, chosen_key: str, - available_api_keys: List[str], + available_api_keys: list[str], retry_cnt: int, max_retries: int, ) -> tuple: @@ -497,7 +498,7 @@ async def text_chat_stream( raise Exception("未知错误") raise last_exception - async def _remove_image_from_context(self, contexts: List): + async def _remove_image_from_context(self, contexts: list): """ 从上下文中删除所有带有 image 的记录 """ @@ -521,14 +522,14 @@ async def _remove_image_from_context(self, contexts: List): def get_current_key(self) -> str: return self.client.api_key - def get_keys(self) -> List[str]: + def get_keys(self) -> list[str]: return self.api_keys def set_key(self, key): self.client.api_key = key async def assemble_context( - self, text: str, image_urls: List[str] | None = None + self, text: str, image_urls: list[str] | None = None ) -> dict: """组装成符合 OpenAI 格式的 role 为 user 的消息段""" if image_urls: From 28d70db85289185f3e419edf2373bffdc8bf34d2 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Thu, 30 Oct 2025 18:06:55 +0800 Subject: [PATCH 21/44] chore(core.star): ruff rewrite --- astrbot/core/star/config.py | 9 ++++---- astrbot/core/star/context.py | 23 ++++++++++----------- astrbot/core/star/filter/command.py | 12 +++++------ astrbot/core/star/filter/command_group.py | 11 +++++----- astrbot/core/star/register/star_handler.py | 8 +++---- astrbot/core/star/session_plugin_manager.py | 5 ++--- astrbot/core/star/star_handler.py | 14 +++++++------ astrbot/core/star/star_tools.py | 12 ++++++----- 8 files changed, 47 insertions(+), 47 deletions(-) diff --git a/astrbot/core/star/config.py b/astrbot/core/star/config.py index 23a522dc1..8d9acbcb2 100644 --- a/astrbot/core/star/config.py +++ b/astrbot/core/star/config.py @@ -2,13 +2,12 @@ 此功能已过时,参考 https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta """ -from typing import Union import os import json from astrbot.core.utils.astrbot_path import get_astrbot_data_path -def load_config(namespace: str) -> Union[dict, bool]: +def load_config(namespace: str) -> dict | bool: """ 从配置文件中加载配置。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 @@ -17,7 +16,7 @@ def load_config(namespace: str) -> Union[dict, bool]: path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json") if not os.path.exists(path): return False - with open(path, "r", encoding="utf-8-sig") as f: + with open(path, encoding="utf-8-sig") as f: ret = {} data = json.load(f) for k in data: @@ -51,7 +50,7 @@ def put_config(namespace: str, name: str, key: str, value, description: str): if not os.path.exists(path): with open(path, "w", encoding="utf-8-sig") as f: f.write("{}") - with open(path, "r", encoding="utf-8-sig") as f: + with open(path, encoding="utf-8-sig") as f: d = json.load(f) assert isinstance(d, dict) if key not in d: @@ -78,7 +77,7 @@ def update_config(namespace: str, key: str, value): path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json") if not os.path.exists(path): raise FileNotFoundError(f"配置文件 {namespace}.json 不存在。") - with open(path, "r", encoding="utf-8-sig") as f: + with open(path, encoding="utf-8-sig") as f: d = json.load(f) assert isinstance(d, dict) if key not in d: diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 0229f4dbb..ce795d767 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -1,5 +1,4 @@ from asyncio import Queue -from typing import List, Union from astrbot.core.provider.provider import ( Provider, @@ -25,7 +24,9 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType from .filter.command import CommandFilter from .filter.regex import RegexFilter -from typing import Awaitable, Any, Callable +from typing import Any + +from collections.abc import Awaitable, Callable from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.star.filter.platform_adapter_type import ( PlatformAdapterType, @@ -42,7 +43,7 @@ class Context: registered_web_apis: list = [] # back compatibility - _register_tasks: List[Awaitable] = [] + _register_tasks: list[Awaitable] = [] _star_manager = None def __init__( @@ -78,7 +79,7 @@ def get_registered_star(self, star_name: str) -> StarMetadata | None: if star.name == star_name: return star - def get_all_stars(self) -> List[StarMetadata]: + def get_all_stars(self) -> list[StarMetadata]: """获取当前载入的所有插件 Metadata 的列表""" return star_registry @@ -116,19 +117,19 @@ def get_provider_by_id( prov = self.provider_manager.inst_map.get(provider_id) return prov - def get_all_providers(self) -> List[Provider]: + def get_all_providers(self) -> list[Provider]: """获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。""" return self.provider_manager.provider_insts - def get_all_tts_providers(self) -> List[TTSProvider]: + def get_all_tts_providers(self) -> list[TTSProvider]: """获取所有用于 TTS 任务的 Provider。""" return self.provider_manager.tts_provider_insts - def get_all_stt_providers(self) -> List[STTProvider]: + def get_all_stt_providers(self) -> list[STTProvider]: """获取所有用于 STT 任务的 Provider。""" return self.provider_manager.stt_provider_insts - def get_all_embedding_providers(self) -> List[EmbeddingProvider]: + def get_all_embedding_providers(self) -> list[EmbeddingProvider]: """获取所有用于 Embedding 任务的 Provider。""" return self.provider_manager.embedding_provider_insts @@ -196,9 +197,7 @@ def get_event_queue(self) -> Queue: return self._event_queue @deprecated(version="4.0.0", reason="Use get_platform_inst instead") - def get_platform( - self, platform_type: Union[PlatformAdapterType, str] - ) -> Platform | None: + def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None: """ 获取指定类型的平台适配器。 @@ -231,7 +230,7 @@ def get_platform_inst(self, platform_id: str) -> Platform | None: return platform async def send_message( - self, session: Union[str, MessageSesion], message_chain: MessageChain + self, session: str | MessageSesion, message_chain: MessageChain ) -> bool: """ 根据 session(unified_msg_origin) 主动发送消息。 diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index 3d67cb750..9adc528ee 100755 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -2,7 +2,7 @@ import inspect import types import typing -from typing import List, Any, Type, Dict +from typing import Any from . import HandlerFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.config import AstrBotConfig @@ -37,14 +37,14 @@ def __init__( command_name: str, alias: set | None = None, handler_md: StarHandlerMetadata | None = None, - parent_command_names: List[str] = [""], + parent_command_names: list[str] = [""], ): self.command_name = command_name self.alias = alias if alias else set() self.parent_command_names = parent_command_names if handler_md: self.init_handler_md(handler_md) - self.custom_filter_list: List[CustomFilter] = [] + self.custom_filter_list: list[CustomFilter] = [] # Cache for complete command names list self._cmpl_cmd_names: list | None = None @@ -89,8 +89,8 @@ def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: return True def validate_and_convert_params( - self, params: List[Any], param_type: Dict[str, Type] - ) -> Dict[str, Any]: + self, params: list[Any], param_type: dict[str, type] + ) -> dict[str, Any]: """将参数列表 params 根据 param_type 转换为参数字典。""" result = {} param_items = list(param_type.items()) @@ -111,7 +111,7 @@ def validate_and_convert_params( # 没有 GreedyStr 的情况 if i >= len(params): if ( - isinstance(param_type_or_default_val, (Type, types.UnionType)) + isinstance(param_type_or_default_val, (type, types.UnionType)) or typing.get_origin(param_type_or_default_val) is typing.Union or param_type_or_default_val is inspect.Parameter.empty ): diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index e01fa2c58..2f11f193b 100755 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -1,6 +1,5 @@ from __future__ import annotations -from typing import List, Union from . import HandlerFilter from .command import CommandFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -18,22 +17,22 @@ def __init__( ): self.group_name = group_name self.alias = alias if alias else set() - self.sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]] = [] - self.custom_filter_list: List[CustomFilter] = [] + self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = [] + self.custom_filter_list: list[CustomFilter] = [] self.parent_group = parent_group # Cache for complete command names list self._cmpl_cmd_names: list | None = None def add_sub_command_filter( - self, sub_command_filter: Union[CommandFilter, CommandGroupFilter] + self, sub_command_filter: CommandFilter | CommandGroupFilter ): self.sub_command_filters.append(sub_command_filter) def add_custom_filter(self, custom_filter: CustomFilter): self.custom_filter_list.append(custom_filter) - def get_complete_command_names(self) -> List[str]: + def get_complete_command_names(self) -> list[str]: """遍历父节点获取完整的指令名。 新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。""" @@ -59,7 +58,7 @@ def get_complete_command_names(self) -> List[str]: # 以树的形式打印出来 def print_cmd_tree( self, - sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]], + sub_command_filters: list[CommandFilter | CommandGroupFilter], prefix: str = "", event: AstrMessageEvent | None = None, cfg: AstrBotConfig | None = None, diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index d1c5a6dce..c61e09a86 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -12,7 +12,9 @@ from ..filter.permission import PermissionTypeFilter, PermissionType from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr from ..filter.regex import RegexFilter -from typing import Awaitable, Any, Callable +from typing import Any + +from collections.abc import Awaitable, Callable from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES from astrbot.core.provider.register import llm_tools from astrbot.core.agent.agent import Agent @@ -220,9 +222,7 @@ def decorator(obj): class RegisteringCommandable: """用于指令组级联注册""" - group: Callable[..., Callable[..., "RegisteringCommandable"]] = ( - register_command_group - ) + group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group command: Callable[..., Callable[..., None]] = register_command custom_filter: Callable[..., Callable[..., None]] = register_custom_filter diff --git a/astrbot/core/star/session_plugin_manager.py b/astrbot/core/star/session_plugin_manager.py index 94a0c8a4d..313ab7c8c 100644 --- a/astrbot/core/star/session_plugin_manager.py +++ b/astrbot/core/star/session_plugin_manager.py @@ -3,7 +3,6 @@ """ from astrbot.core import sp, logger -from typing import Dict, List from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -95,7 +94,7 @@ def set_plugin_status_for_session( ) @staticmethod - def get_session_plugin_config(session_id: str) -> Dict[str, List[str]]: + def get_session_plugin_config(session_id: str) -> dict[str, list[str]]: """获取指定会话的插件配置 Args: @@ -112,7 +111,7 @@ def get_session_plugin_config(session_id: str) -> Dict[str, List[str]]: ) @staticmethod - def filter_handlers_by_session(event: AstrMessageEvent, handlers: List) -> List: + def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list: """根据会话配置过滤处理器列表 Args: diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index 80b5adb60..5dc076bca 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -1,7 +1,9 @@ from __future__ import annotations import enum from dataclasses import dataclass, field -from typing import Callable, Awaitable, Any, List, Dict, TypeVar, Generic +from typing import Any, TypeVar, Generic + +from collections.abc import Callable, Awaitable from .filter import HandlerFilter from .star import star_map @@ -10,8 +12,8 @@ class StarHandlerRegistry(Generic[T]): def __init__(self): - self.star_handlers_map: Dict[str, StarHandlerMetadata] = {} - self._handlers: List[StarHandlerMetadata] = [] + self.star_handlers_map: dict[str, StarHandlerMetadata] = {} + self._handlers: list[StarHandlerMetadata] = [] def append(self, handler: StarHandlerMetadata): """添加一个 Handler,并保持按优先级有序""" @@ -31,7 +33,7 @@ def get_handlers_by_event_type( event_type: EventType, only_activated=True, plugins_name: list[str] | None = None, - ) -> List[StarHandlerMetadata]: + ) -> list[StarHandlerMetadata]: handlers = [] for handler in self._handlers: # 过滤事件类型 @@ -65,7 +67,7 @@ def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata | None def get_handlers_by_module_name( self, module_name: str - ) -> List[StarHandlerMetadata]: + ) -> list[StarHandlerMetadata]: return [ handler for handler in self._handlers @@ -126,7 +128,7 @@ class StarHandlerMetadata: handler: Callable[..., Awaitable[Any]] """Handler 的函数对象,应当是一个异步函数""" - event_filters: List[HandlerFilter] + event_filters: list[HandlerFilter] """一个适配器消息事件过滤器,用于描述这个 Handler 能够处理、应该处理的适配器消息事件""" desc: str = "" diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 6f9dfe2fa..d37e79189 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -22,7 +22,9 @@ import os import uuid from pathlib import Path -from typing import Union, Awaitable, Callable, Any, List, Optional, ClassVar +from typing import Any, ClassVar + +from collections.abc import Awaitable, Callable from astrbot.core.message.components import BaseMessageComponent from astrbot.core.message.message_event_result import MessageChain from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType @@ -44,7 +46,7 @@ class StarTools: 这些方法封装了一些常用操作,使插件开发更加简单便捷! """ - _context: ClassVar[Optional[Context]] = None + _context: ClassVar[Context | None] = None @classmethod def initialize(cls, context: Context) -> None: @@ -58,7 +60,7 @@ def initialize(cls, context: Context) -> None: @classmethod async def send_message( - cls, session: Union[str, MessageSesion], message_chain: MessageChain + cls, session: str | MessageSesion, message_chain: MessageChain ) -> bool: """ 根据session(unified_msg_origin)主动发送消息 @@ -122,7 +124,7 @@ async def create_message( self_id: str, session_id: str, sender: MessageMember, - message: List[BaseMessageComponent], + message: list[BaseMessageComponent], message_str: str, message_id: str = "", raw_message: object = None, @@ -254,7 +256,7 @@ def unregister_llm_tool(cls, name: str) -> None: cls._context.unregister_llm_tool(name) @classmethod - def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path: + def get_data_dir(cls, plugin_name: str | None = None) -> Path: """ 返回插件数据目录的绝对路径。 From 595f766eae717b3bc3b63c04c0d1b0770542751a Mon Sep 17 00:00:00 2001 From: Dale Null Date: Thu, 30 Oct 2025 18:08:06 +0800 Subject: [PATCH 22/44] chore(core.message): ruff rewrite --- astrbot/core/message/message_event_result.py | 28 ++++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index d058fedd4..030118e90 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -1,6 +1,6 @@ import enum -from typing import Optional +from typing_extensions import Self from collections.abc import AsyncGenerator from dataclasses import dataclass, field @@ -29,7 +29,7 @@ class MessageChain: type: str | None = None """消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。""" - def message(self, message: str) -> "MessageChain": + def message(self, message: str) -> Self: """添加一条文本消息到消息链 `chain` 中。 Example: @@ -41,7 +41,7 @@ def message(self, message: str) -> "MessageChain": self.chain.append(Plain(message)) return self - def at(self, name: str, qq: str | int) -> "MessageChain": + def at(self, name: str, qq: str | int) -> Self: """添加一条 At 消息到消息链 `chain` 中。 Example: @@ -53,7 +53,7 @@ def at(self, name: str, qq: str | int) -> "MessageChain": self.chain.append(At(name=name, qq=qq)) return self - def at_all(self) -> "MessageChain": + def at_all(self) -> Self: """添加一条 AtAll 消息到消息链 `chain` 中。 Example: @@ -66,7 +66,7 @@ def at_all(self) -> "MessageChain": return self @deprecated("请使用 message 方法代替。") - def error(self, message: str) -> "MessageChain": + def error(self, message: str) -> Self: """添加一条错误消息到消息链 `chain` 中 Example: @@ -77,7 +77,7 @@ def error(self, message: str) -> "MessageChain": self.chain.append(Plain(message)) return self - def url_image(self, url: str) -> "MessageChain": + def url_image(self, url: str) -> Self: """添加一条图片消息(https 链接)到消息链 `chain` 中。 Note: @@ -91,7 +91,7 @@ def url_image(self, url: str) -> "MessageChain": self.chain.append(Image.fromURL(url)) return self - def file_image(self, path: str) -> "MessageChain": + def file_image(self, path: str) -> Self: """添加一条图片消息(本地文件路径)到消息链 `chain` 中。 Note: @@ -102,7 +102,7 @@ def file_image(self, path: str) -> "MessageChain": self.chain.append(Image.fromFileSystem(path)) return self - def base64_image(self, base64_str: str) -> "MessageChain": + def base64_image(self, base64_str: str) -> Self: """添加一条图片消息(base64 编码字符串)到消息链 `chain` 中。 Example: @@ -111,7 +111,7 @@ def base64_image(self, base64_str: str) -> "MessageChain": self.chain.append(Image.fromBase64(base64_str)) return self - def use_t2i(self, use_t2i: bool) -> "MessageChain": + def use_t2i(self, use_t2i: bool) -> Self: """设置是否使用文本转图片服务。 Args: @@ -124,7 +124,7 @@ def get_plain_text(self) -> str: """获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) - def squash_plain(self) -> Optional["MessageChain"]: + def squash_plain(self) -> Self | None: """将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。""" if not self.chain: return None @@ -196,12 +196,12 @@ class MessageEventResult(MessageChain): async_stream: AsyncGenerator | None = None """异步流""" - def stop_event(self) -> "MessageEventResult": + def stop_event(self) -> Self: """终止事件传播。""" self.result_type = EventResultType.STOP return self - def continue_event(self) -> "MessageEventResult": + def continue_event(self) -> Self: """继续事件传播。""" self.result_type = EventResultType.CONTINUE return self @@ -212,12 +212,12 @@ def is_stopped(self) -> bool: """ return self.result_type == EventResultType.STOP - def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult": + def set_async_stream(self, stream: AsyncGenerator) -> Self: """设置异步流。""" self.async_stream = stream return self - def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult": + def set_result_content_type(self, typ: ResultContentType) -> Self: """设置事件处理的结果类型。 Args: From 932d76d030e7290fc72957dbfe6abd687cea9833 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Thu, 30 Oct 2025 18:09:10 +0800 Subject: [PATCH 23/44] chore(core.utils): ruff rewrite --- astrbot/core/log.py | 3 +-- astrbot/core/utils/dify_api_client.py | 16 +++++++++------- astrbot/core/utils/metrics.py | 2 +- astrbot/core/utils/session_waiter.py | 12 +++++++----- astrbot/core/utils/t2i/local_strategy.py | 10 +++++----- astrbot/core/utils/t2i/template_manager.py | 2 +- 6 files changed, 24 insertions(+), 21 deletions(-) diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 3a1c50371..30699dc0b 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -28,7 +28,6 @@ import sys from collections import deque from asyncio import Queue -from typing import List # 日志缓存大小 CACHED_SIZE = 200 @@ -87,7 +86,7 @@ class LogBroker: def __init__(self): self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志 - self.subscribers: List[Queue] = [] # 订阅者列表 + self.subscribers: list[Queue] = [] # 订阅者列表 def register(self) -> Queue: """注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列 diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/utils/dify_api_client.py index 15a6b71fb..18be474c8 100644 --- a/astrbot/core/utils/dify_api_client.py +++ b/astrbot/core/utils/dify_api_client.py @@ -2,7 +2,9 @@ import json from astrbot.core import logger from aiohttp import ClientSession, ClientResponse -from typing import Dict, List, Any, AsyncGenerator +from typing import Any + +from collections.abc import AsyncGenerator async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]: @@ -39,14 +41,14 @@ def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"): async def chat_messages( self, - inputs: Dict, + inputs: dict, query: str, user: str, response_mode: str = "streaming", conversation_id: str = "", - files: List[Dict[str, Any]] = [], + files: list[dict[str, Any]] = [], timeout: float = 60, - ) -> AsyncGenerator[Dict[str, Any], None]: + ) -> AsyncGenerator[dict[str, Any], None]: url = f"{self.api_base}/chat-messages" payload = locals() payload.pop("self") @@ -65,10 +67,10 @@ async def chat_messages( async def workflow_run( self, - inputs: Dict, + inputs: dict, user: str, response_mode: str = "streaming", - files: List[Dict[str, Any]] = [], + files: list[dict[str, Any]] = [], timeout: float = 60, ): url = f"{self.api_base}/workflows/run" @@ -91,7 +93,7 @@ async def file_upload( self, file_path: str, user: str, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: url = f"{self.api_base}/files/upload" with open(file_path, "rb") as f: payload = { diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index 7fe9bde05..19a50976e 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -21,7 +21,7 @@ def get_installation_id(): if os.path.exists(id_file): try: - with open(id_file, "r") as f: + with open(id_file) as f: Metric._iid_cache = f.read().strip() return Metric._iid_cache except Exception: diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index c27a54113..ae7737a7a 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -8,11 +8,13 @@ import functools import copy import astrbot.core.message.components as Comp -from typing import Dict, Any, Callable, Awaitable, List +from typing import Any + +from collections.abc import Callable, Awaitable from astrbot.core.platform import AstrMessageEvent -USER_SESSIONS: Dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例 -FILTERS: List["SessionFilter"] = [] # 存储 SessionFilter 实例 +USER_SESSIONS: dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例 +FILTERS: list["SessionFilter"] = [] # 存储 SessionFilter 实例 class SessionController: @@ -29,7 +31,7 @@ def __init__(self): self.timeout: float | int = None """上次保持(keep)开始时的超时时间""" - self.history_chains: List[List[Comp.BaseMessageComponent]] = [] + self.history_chains: list[list[Comp.BaseMessageComponent]] = [] def stop(self, error: Exception = None): """立即结束这个会话""" @@ -81,7 +83,7 @@ async def _holding(self, event: asyncio.Event, timeout: int): pass # 避免报错 # finally: - def get_history_chains(self) -> List[List[Comp.BaseMessageComponent]]: + def get_history_chains(self) -> list[list[Comp.BaseMessageComponent]]: """获取历史消息链""" return self.history_chains diff --git a/astrbot/core/utils/t2i/local_strategy.py b/astrbot/core/utils/t2i/local_strategy.py index 19eab2efe..bc9f7dad9 100644 --- a/astrbot/core/utils/t2i/local_strategy.py +++ b/astrbot/core/utils/t2i/local_strategy.py @@ -66,7 +66,7 @@ class TextMeasurer: """测量文本尺寸的工具类""" @staticmethod - def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]: + def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> tuple[int, int]: """获取文本的尺寸""" try: # PIL 9.0.0 以上版本 @@ -82,7 +82,7 @@ def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]: @staticmethod def split_text_to_fit_width( text: str, font: ImageFont.FreeTypeFont, max_width: int - ) -> List[str]: + ) -> list[str]: """将文本拆分为多行,确保每行不超过指定宽度""" lines = [] if not text: @@ -532,7 +532,7 @@ def render( class CodeBlockElement(MarkdownElement): """代码块元素""" - def __init__(self, content: List[str]): + def __init__(self, content: list[str]): super().__init__("\n".join(content)) def calculate_height(self, image_width: int, font_size: int) -> int: @@ -705,7 +705,7 @@ class MarkdownParser: """Markdown解析器,将文本解析为元素""" @staticmethod - async def parse(text: str) -> List[MarkdownElement]: + async def parse(text: str) -> list[MarkdownElement]: elements = [] lines = text.split("\n") @@ -847,7 +847,7 @@ def __init__( self, font_size: int = 26, width: int = 800, - bg_color: Tuple[int, int, int] = (255, 255, 255), + bg_color: tuple[int, int, int] = (255, 255, 255), ): self.font_size = font_size self.width = width diff --git a/astrbot/core/utils/t2i/template_manager.py b/astrbot/core/utils/t2i/template_manager.py index b441a908e..9ae422947 100644 --- a/astrbot/core/utils/t2i/template_manager.py +++ b/astrbot/core/utils/t2i/template_manager.py @@ -43,7 +43,7 @@ def _get_user_template_path(self, name: str) -> str: def _read_file(self, path: str) -> str: """读取文件内容。""" - with open(path, "r", encoding="utf-8") as f: + with open(path, encoding="utf-8") as f: return f.read() def list_templates(self) -> list[dict]: From 70a9186f12a6bf72bad141c88248b6cef120c257 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Thu, 11 Dec 2025 15:27:20 +0800 Subject: [PATCH 24/44] =?UTF-8?q?refactor:=20=E4=B8=BA=E5=A4=A7=E9=87=8F?= =?UTF-8?q?=E6=9E=84=E9=80=A0=E6=96=B9=E6=B3=95=E5=8F=8A=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E6=96=B9=E6=B3=95=E6=B7=BB=E5=8A=A0=E8=BF=94=E5=9B=9E=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=20None=20=E6=B3=A8=E8=A7=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/agent/handoff.py | 2 +- astrbot/core/agent/mcp_client.py | 4 +- .../agent/runners/coze/coze_api_client.py | 2 +- .../agent/runners/dify/dify_api_client.py | 2 +- astrbot/core/agent/tool.py | 10 ++-- astrbot/core/astrbot_config_mgr.py | 2 +- astrbot/core/config/astrbot_config.py | 6 +-- astrbot/core/conversation_mgr.py | 2 +- .../db/migration/shared_preferences_v3.py | 2 +- .../db/vec_db/faiss_impl/document_storage.py | 2 +- .../db/vec_db/faiss_impl/embedding_storage.py | 2 +- astrbot/core/db/vec_db/faiss_impl/vec_db.py | 2 +- astrbot/core/event_bus.py | 2 +- astrbot/core/file_token_service.py | 2 +- astrbot/core/initial_loader.py | 2 +- .../knowledge_base/chunking/fixed_size.py | 2 +- .../core/knowledge_base/chunking/recursive.py | 2 +- astrbot/core/knowledge_base/kb_helper.py | 4 +- astrbot/core/knowledge_base/kb_mgr.py | 2 +- .../core/knowledge_base/parsers/url_parser.py | 2 +- .../core/knowledge_base/retrieval/manager.py | 2 +- .../knowledge_base/retrieval/rank_fusion.py | 2 +- .../retrieval/sparse_retriever.py | 2 +- astrbot/core/log.py | 4 +- astrbot/core/message/components.py | 46 +++++++++---------- astrbot/core/persona_mgr.py | 2 +- .../core/pipeline/rate_limit_check/stage.py | 2 +- astrbot/core/pipeline/scheduler.py | 2 +- astrbot/core/platform/astr_message_event.py | 2 +- astrbot/core/platform/astrbot_message.py | 4 +- astrbot/core/platform/manager.py | 2 +- astrbot/core/platform/message_session.py | 2 +- astrbot/core/platform/platform.py | 2 +- .../aiocqhttp/aiocqhttp_message_event.py | 2 +- .../sources/dingtalk/dingtalk_event.py | 2 +- .../core/platform/sources/discord/client.py | 2 +- .../platform/sources/discord/components.py | 8 ++-- .../sources/discord/discord_platform_event.py | 4 +- .../core/platform/sources/lark/lark_event.py | 2 +- .../platform/sources/misskey/misskey_api.py | 4 +- .../platform/sources/misskey/misskey_event.py | 2 +- .../qqofficial/qqofficial_message_event.py | 2 +- .../qqofficial_webhook/qo_webhook_event.py | 2 +- .../qqofficial_webhook/qo_webhook_server.py | 2 +- .../platform/sources/satori/satori_event.py | 2 +- astrbot/core/platform/sources/slack/client.py | 4 +- .../platform/sources/slack/slack_event.py | 2 +- .../platform/sources/telegram/tg_event.py | 2 +- .../platform/sources/webchat/webchat_event.py | 2 +- .../wechatpadpro_message_event.py | 2 +- .../sources/wechatpadpro/xml_data_parser.py | 2 +- .../platform/sources/wecom/wecom_adapter.py | 2 +- .../platform/sources/wecom/wecom_event.py | 2 +- .../sources/wecom_ai_bot/WXBizJsonMsgCrypt.py | 4 +- .../sources/wecom_ai_bot/wecomai_api.py | 2 +- .../sources/wecom_ai_bot/wecomai_event.py | 2 +- .../sources/wecom_ai_bot/wecomai_server.py | 2 +- .../weixin_offacc_adapter.py | 2 +- .../weixin_offacc_event.py | 2 +- astrbot/core/platform_message_history_mgr.py | 2 +- astrbot/core/provider/entities.py | 6 +-- astrbot/core/provider/func_tool_manager.py | 4 +- astrbot/core/provider/manager.py | 2 +- .../core/provider/sources/azure_tts_source.py | 6 +-- astrbot/core/star/__init__.py | 4 +- astrbot/core/star/context.py | 2 +- astrbot/core/star/filter/command.py | 2 +- astrbot/core/star/filter/command_group.py | 2 +- astrbot/core/star/filter/custom_filter.py | 6 +-- .../core/star/filter/event_message_type.py | 2 +- astrbot/core/star/filter/permission.py | 2 +- .../core/star/filter/platform_adapter_type.py | 2 +- astrbot/core/star/filter/regex.py | 2 +- astrbot/core/star/register/star_handler.py | 4 +- astrbot/core/star/star_handler.py | 4 +- astrbot/core/star/star_manager.py | 2 +- astrbot/core/umop_config_router.py | 2 +- astrbot/core/utils/log_pipe.py | 2 +- astrbot/core/utils/pip_installer.py | 2 +- astrbot/core/utils/session_lock.py | 2 +- astrbot/core/utils/session_waiter.py | 4 +- astrbot/core/utils/shared_preferences.py | 2 +- astrbot/core/utils/t2i/renderer.py | 2 +- astrbot/core/utils/t2i/template_manager.py | 2 +- astrbot/dashboard/routes/route.py | 2 +- astrbot/dashboard/routes/t2i.py | 2 +- packages/astrbot/long_term_memory.py | 2 +- packages/astrbot/process_llm_request.py | 2 +- packages/builtin_commands/commands/admin.py | 2 +- .../builtin_commands/commands/alter_cmd.py | 2 +- .../builtin_commands/commands/conversation.py | 2 +- packages/builtin_commands/commands/help.py | 2 +- packages/builtin_commands/commands/llm.py | 2 +- packages/builtin_commands/commands/persona.py | 2 +- packages/builtin_commands/commands/plugin.py | 2 +- .../builtin_commands/commands/provider.py | 2 +- .../builtin_commands/commands/setunset.py | 2 +- packages/builtin_commands/commands/sid.py | 2 +- packages/builtin_commands/commands/t2i.py | 2 +- packages/builtin_commands/commands/tool.py | 2 +- packages/builtin_commands/commands/tts.py | 2 +- packages/session_controller/main.py | 2 +- tests/test_main.py | 2 +- 103 files changed, 153 insertions(+), 153 deletions(-) diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index 85276540b..7b1292cf7 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -13,7 +13,7 @@ def __init__( agent: Agent[TContext], parameters: dict | None = None, **kwargs, - ): + ) -> None: self.agent = agent super().__init__( name=f"transfer_to_{agent.name}", diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index c5ff123b2..e02e4b973 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -108,7 +108,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: class MCPClient: - def __init__(self): + def __init__(self) -> None: # Initialize session and client objects self.session: mcp.ClientSession | None = None self.exit_stack = AsyncExitStack() @@ -365,7 +365,7 @@ class MCPTool(FunctionTool, Generic[TContext]): def __init__( self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs - ): + ) -> None: super().__init__( name=mcp_tool.name, description=mcp_tool.description or "", diff --git a/astrbot/core/agent/runners/coze/coze_api_client.py b/astrbot/core/agent/runners/coze/coze_api_client.py index e8f3a1e24..5124dc311 100644 --- a/astrbot/core/agent/runners/coze/coze_api_client.py +++ b/astrbot/core/agent/runners/coze/coze_api_client.py @@ -10,7 +10,7 @@ class CozeAPIClient: - def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"): + def __init__(self, api_key: str, api_base: str = "https://api.coze.cn") -> None: self.api_key = api_key self.api_base = api_base self.session = None diff --git a/astrbot/core/agent/runners/dify/dify_api_client.py b/astrbot/core/agent/runners/dify/dify_api_client.py index d9c6556cf..d07569779 100644 --- a/astrbot/core/agent/runners/dify/dify_api_client.py +++ b/astrbot/core/agent/runners/dify/dify_api_client.py @@ -31,7 +31,7 @@ async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]: class DifyAPIClient: - def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"): + def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1") -> None: self.api_key = api_key self.api_base = api_base self.session = ClientSession(trust_env=True) diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 7f30f44ef..306e5c081 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -58,7 +58,7 @@ class FunctionTool(ToolSchema, Generic[TContext]): You can ignore it when integrating with other frameworks. """ - def __repr__(self): + def __repr__(self) -> str: return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})" async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult: @@ -274,17 +274,17 @@ def names(self) -> list[str]: """获取所有工具的名称列表""" return [tool.name for tool in self.tools] - def __len__(self): + def __len__(self) -> int: return len(self.tools) - def __bool__(self): + def __bool__(self) -> bool: return len(self.tools) > 0 def __iter__(self): return iter(self.tools) - def __repr__(self): + def __repr__(self) -> str: return f"ToolSet(tools={self.tools})" - def __str__(self): + def __str__(self) -> str: return f"ToolSet(tools={self.tools})" diff --git a/astrbot/core/astrbot_config_mgr.py b/astrbot/core/astrbot_config_mgr.py index 3a1353ce5..5cc4adb69 100644 --- a/astrbot/core/astrbot_config_mgr.py +++ b/astrbot/core/astrbot_config_mgr.py @@ -36,7 +36,7 @@ def __init__( default_config: AstrBotConfig, ucr: UmopConfigRouter, sp: SharedPreferences, - ): + ) -> None: self.sp = sp self.ucr = ucr self.confs: dict[str, AstrBotConfig] = {} diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 9477eabaa..8bb82b621 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -33,7 +33,7 @@ def __init__( config_path: str = ASTRBOT_CONFIG_PATH, default_config: dict = DEFAULT_CONFIG, schema: dict | None = None, - ): + ) -> None: super().__init__() # 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件 @@ -162,14 +162,14 @@ def __getattr__(self, item): except KeyError: return None - def __delattr__(self, key): + def __delattr__(self, key) -> None: try: del self[key] self.save_config() except KeyError: raise AttributeError(f"没有找到 Key: '{key}'") - def __setattr__(self, key, value): + def __setattr__(self, key, value) -> None: self[key] = value def check_exist(self) -> bool: diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 287fe03c4..98f10cc1f 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -16,7 +16,7 @@ class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" - def __init__(self, db_helper: BaseDatabase): + def __init__(self, db_helper: BaseDatabase) -> None: self.session_conversations: dict[str, str] = {} self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 diff --git a/astrbot/core/db/migration/shared_preferences_v3.py b/astrbot/core/db/migration/shared_preferences_v3.py index 3abcb1a66..b74fe3e90 100644 --- a/astrbot/core/db/migration/shared_preferences_v3.py +++ b/astrbot/core/db/migration/shared_preferences_v3.py @@ -8,7 +8,7 @@ class SharedPreferences: - def __init__(self, path=None): + def __init__(self, path=None) -> None: if path is None: path = os.path.join(get_astrbot_data_path(), "shared_preferences.json") self.path = path diff --git a/astrbot/core/db/vec_db/faiss_impl/document_storage.py b/astrbot/core/db/vec_db/faiss_impl/document_storage.py index e27eb6fe8..1d6b8fa85 100644 --- a/astrbot/core/db/vec_db/faiss_impl/document_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/document_storage.py @@ -33,7 +33,7 @@ class Document(BaseDocModel, table=True): class DocumentStorage: - def __init__(self, db_path: str): + def __init__(self, db_path: str) -> None: self.db_path = db_path self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" self.engine: AsyncEngine | None = None diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index 564454cb1..6e5197b9e 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -10,7 +10,7 @@ class EmbeddingStorage: - def __init__(self, dimension: int, path: str | None = None): + def __init__(self, dimension: int, path: str | None = None) -> None: self.dimension = dimension self.path = path self.index = None diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index 14221f1e8..4b7f036ee 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -20,7 +20,7 @@ def __init__( index_store_path: str, embedding_provider: EmbeddingProvider, rerank_provider: RerankProvider | None = None, - ): + ) -> None: self.doc_store_path = doc_store_path self.index_store_path = index_store_path self.embedding_provider = embedding_provider diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 0017e65fa..20bd47692 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -28,7 +28,7 @@ def __init__( event_queue: Queue, pipeline_scheduler_mapping: dict[str, PipelineScheduler], astrbot_config_mgr: AstrBotConfigManager, - ): + ) -> None: self.event_queue = event_queue # 事件队列 # abconf uuid -> scheduler self.pipeline_scheduler_mapping = pipeline_scheduler_mapping diff --git a/astrbot/core/file_token_service.py b/astrbot/core/file_token_service.py index ea97759c1..c807f44e9 100644 --- a/astrbot/core/file_token_service.py +++ b/astrbot/core/file_token_service.py @@ -9,7 +9,7 @@ class FileTokenService: """维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。""" - def __init__(self, default_timeout: float = 300): + def __init__(self, default_timeout: float = 300) -> None: self.lock = asyncio.Lock() self.staged_files = {} # token: (file_path, expire_time) self.default_timeout = default_timeout diff --git a/astrbot/core/initial_loader.py b/astrbot/core/initial_loader.py index f54d18641..c3ee98ad3 100644 --- a/astrbot/core/initial_loader.py +++ b/astrbot/core/initial_loader.py @@ -17,7 +17,7 @@ class InitialLoader: """AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。""" - def __init__(self, db: BaseDatabase, log_broker: LogBroker): + def __init__(self, db: BaseDatabase, log_broker: LogBroker) -> None: self.db = db self.logger = logger self.log_broker = log_broker diff --git a/astrbot/core/knowledge_base/chunking/fixed_size.py b/astrbot/core/knowledge_base/chunking/fixed_size.py index 5439f070f..c0eb17865 100644 --- a/astrbot/core/knowledge_base/chunking/fixed_size.py +++ b/astrbot/core/knowledge_base/chunking/fixed_size.py @@ -12,7 +12,7 @@ class FixedSizeChunker(BaseChunker): 按照固定的字符数分块,并支持块之间的重叠。 """ - def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50): + def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50) -> None: """初始化分块器 Args: diff --git a/astrbot/core/knowledge_base/chunking/recursive.py b/astrbot/core/knowledge_base/chunking/recursive.py index 3f4aabb57..6c14ade63 100644 --- a/astrbot/core/knowledge_base/chunking/recursive.py +++ b/astrbot/core/knowledge_base/chunking/recursive.py @@ -11,7 +11,7 @@ def __init__( length_function: Callable[[str], int] = len, is_separator_regex: bool = False, separators: list[str] | None = None, - ): + ) -> None: """初始化递归字符文本分割器 Args: diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 4adfb60b8..4bfca90b3 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -31,7 +31,7 @@ class RateLimiter: """一个简单的速率限制器""" - def __init__(self, max_rpm: int): + def __init__(self, max_rpm: int) -> None: self.max_per_minute = max_rpm self.interval = 60.0 / max_rpm if max_rpm > 0 else 0 self.last_call_time = 0 @@ -116,7 +116,7 @@ def __init__( provider_manager: ProviderManager, kb_root_dir: str, chunker: BaseChunker, - ): + ) -> None: self.kb_db = kb_db self.kb = kb self.prov_mgr = provider_manager diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index 2219cc00b..14b97ac4d 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -26,7 +26,7 @@ class KnowledgeBaseManager: def __init__( self, provider_manager: ProviderManager, - ): + ) -> None: Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True) self.provider_manager = provider_manager self._session_deleted_callback_registered = False diff --git a/astrbot/core/knowledge_base/parsers/url_parser.py b/astrbot/core/knowledge_base/parsers/url_parser.py index f68e2e0c4..2867164a9 100644 --- a/astrbot/core/knowledge_base/parsers/url_parser.py +++ b/astrbot/core/knowledge_base/parsers/url_parser.py @@ -6,7 +6,7 @@ class URLExtractor: """URL 内容提取器,封装了 Tavily API 调用和密钥管理""" - def __init__(self, tavily_keys: list[str]): + def __init__(self, tavily_keys: list[str]) -> None: """ 初始化 URL 提取器 diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 746406e90..d406ceabc 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -44,7 +44,7 @@ def __init__( sparse_retriever: SparseRetriever, rank_fusion: RankFusion, kb_db: KBSQLiteDatabase, - ): + ) -> None: """初始化检索管理器 Args: diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py index 26203f94b..40afd9748 100644 --- a/astrbot/core/knowledge_base/retrieval/rank_fusion.py +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -31,7 +31,7 @@ class RankFusion: - 使用 Reciprocal Rank Fusion (RRF) 算法 """ - def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60): + def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60) -> None: """初始化结果融合器 Args: diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index ea5da1c9e..d453251d1 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -34,7 +34,7 @@ class SparseRetriever: - 使用 BM25 算法计算相关度 """ - def __init__(self, kb_db: KBSQLiteDatabase): + def __init__(self, kb_db: KBSQLiteDatabase) -> None: """初始化稀疏检索器 Args: diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 376f5ffd6..b338cb512 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -86,7 +86,7 @@ class LogBroker: 发布-订阅模式 """ - def __init__(self): + def __init__(self) -> None: self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志 self.subscribers: list[Queue] = [] # 订阅者列表 @@ -132,7 +132,7 @@ class LogQueueHandler(logging.Handler): 继承自 logging.Handler """ - def __init__(self, log_broker: LogBroker): + def __init__(self, log_broker: LogBroker) -> None: super().__init__() self.log_broker = log_broker diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 0e7b3bab6..778a7f79f 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -66,7 +66,7 @@ class ComponentType(str, Enum): class BaseMessageComponent(BaseModel): type: ComponentType - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) def toDict(self): @@ -89,7 +89,7 @@ class Plain(BaseMessageComponent): text: str convert: bool | None = True - def __init__(self, text: str, convert: bool = True, **_): + def __init__(self, text: str, convert: bool = True, **_) -> None: super().__init__(text=text, convert=convert, **_) def toDict(self): @@ -103,7 +103,7 @@ class Face(BaseMessageComponent): type = ComponentType.Face id: int - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -118,7 +118,7 @@ class Record(BaseMessageComponent): # 额外 path: str | None - def __init__(self, file: str | None, **_): + def __init__(self, file: str | None, **_) -> None: for k in _: if k == "url": pass @@ -221,7 +221,7 @@ class Video(BaseMessageComponent): # 额外 path: str | None = "" - def __init__(self, file: str, **_): + def __init__(self, file: str, **_) -> None: super().__init__(file=file, **_) @staticmethod @@ -303,7 +303,7 @@ class At(BaseMessageComponent): qq: int | str # 此处str为all时代表所有人 name: str | None = "" - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) def toDict(self): @@ -316,28 +316,28 @@ def toDict(self): class AtAll(At): qq: str = "all" - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) class RPS(BaseMessageComponent): # TODO type = ComponentType.RPS - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) class Dice(BaseMessageComponent): # TODO type = ComponentType.Dice - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) class Shake(BaseMessageComponent): # TODO type = ComponentType.Shake - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -348,7 +348,7 @@ class Share(BaseMessageComponent): content: str | None = "" image: str | None = "" - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -357,7 +357,7 @@ class Contact(BaseMessageComponent): # TODO _type: str # type 字段冲突 id: int | None = 0 - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -368,7 +368,7 @@ class Location(BaseMessageComponent): # TODO title: str | None = "" content: str | None = "" - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -382,7 +382,7 @@ class Music(BaseMessageComponent): content: str | None = "" image: str | None = "" - def __init__(self, **_): + def __init__(self, **_) -> None: # for k in _.keys(): # if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]: # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") @@ -402,7 +402,7 @@ class Image(BaseMessageComponent): path: str | None = "" file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识 - def __init__(self, file: str | None, **_): + def __init__(self, file: str | None, **_) -> None: super().__init__(file=file, **_) @staticmethod @@ -525,7 +525,7 @@ class Reply(BaseMessageComponent): seq: int | None = 0 """deprecated""" - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -534,7 +534,7 @@ class Poke(BaseMessageComponent): id: int | None = 0 qq: int | None = 0 - def __init__(self, type: str, **_): + def __init__(self, type: str, **_) -> None: type = f"Poke:{type}" super().__init__(type=type, **_) @@ -543,7 +543,7 @@ class Forward(BaseMessageComponent): type = ComponentType.Forward id: str - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) @@ -558,7 +558,7 @@ class Node(BaseMessageComponent): seq: str | list | None = "" # 忽略 time: int | None = 0 # 忽略 - def __init__(self, content: list[BaseMessageComponent], **_): + def __init__(self, content: list[BaseMessageComponent], **_) -> None: if isinstance(content, Node): # back content = [content] @@ -605,7 +605,7 @@ class Nodes(BaseMessageComponent): type = ComponentType.Nodes nodes: list[Node] - def __init__(self, nodes: list[Node], **_): + def __init__(self, nodes: list[Node], **_) -> None: super().__init__(nodes=nodes, **_) def toDict(self): @@ -632,7 +632,7 @@ class Json(BaseMessageComponent): data: str | dict resid: int | None = 0 - def __init__(self, data, **_): + def __init__(self, data, **_) -> None: if isinstance(data, dict): data = json.dumps(data) super().__init__(data=data, **_) @@ -651,7 +651,7 @@ class File(BaseMessageComponent): file_: str | None = "" # 本地路径 url: str | None = "" # url - def __init__(self, name: str, file: str = "", url: str = ""): + def __init__(self, name: str, file: str = "", url: str = "") -> None: """文件消息段。""" super().__init__(name=name, file_=file, url=url) @@ -787,7 +787,7 @@ class WechatEmoji(BaseMessageComponent): md5_len: int | None = 0 cdnurl: str | None = "" - def __init__(self, **_): + def __init__(self, **_) -> None: super().__init__(**_) diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py index b2d2c6be1..52c9b1332 100644 --- a/astrbot/core/persona_mgr.py +++ b/astrbot/core/persona_mgr.py @@ -16,7 +16,7 @@ class PersonaManager: - def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager): + def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager) -> None: self.db = db_helper self.acm = acm default_ps = acm.default_conf.get("provider_settings", {}) diff --git a/astrbot/core/pipeline/rate_limit_check/stage.py b/astrbot/core/pipeline/rate_limit_check/stage.py index 64e21dd7e..392bceff3 100644 --- a/astrbot/core/pipeline/rate_limit_check/stage.py +++ b/astrbot/core/pipeline/rate_limit_check/stage.py @@ -19,7 +19,7 @@ class RateLimitStage(Stage): 如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。 """ - def __init__(self): + def __init__(self) -> None: # 存储每个会话的请求时间队列 self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque) # 为每个会话设置一个锁,避免并发冲突 diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 5fb3034f5..b30c9ddb1 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -15,7 +15,7 @@ class PipelineScheduler: """管道调度器,负责调度各个阶段的执行""" - def __init__(self, context: PipelineContext): + def __init__(self, context: PipelineContext) -> None: registered_stages.sort( key=lambda x: STAGES_ORDER.index(x.__name__), ) # 按照顺序排序 diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index f6eda07a9..b215fcd4c 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -35,7 +35,7 @@ def __init__( message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, - ): + ) -> None: self.message_str = message_str """纯文本的消息""" self.message_obj = message_obj diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 253963322..7e8127649 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -11,7 +11,7 @@ class MessageMember: user_id: str # 发送者id nickname: str | None = None - def __str__(self): + def __str__(self) -> str: # 使用 f-string 来构建返回的字符串表示形式 return ( f"User ID: {self.user_id}," @@ -34,7 +34,7 @@ class Group: members: list[MessageMember] | None = None """所有群成员""" - def __str__(self): + def __str__(self) -> str: # 使用 f-string 来构建返回的字符串表示形式 return ( f"Group ID: {self.group_id}\n" diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index b941c8cbc..92440c10f 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -12,7 +12,7 @@ class PlatformManager: - def __init__(self, config: AstrBotConfig, event_queue: Queue): + def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None: self.platform_insts: list[Platform] = [] """加载的 Platform 的实例""" diff --git a/astrbot/core/platform/message_session.py b/astrbot/core/platform/message_session.py index bca5300b8..87c397af4 100644 --- a/astrbot/core/platform/message_session.py +++ b/astrbot/core/platform/message_session.py @@ -15,7 +15,7 @@ class MessageSession: session_id: str platform_id: str | None = None - def __str__(self): + def __str__(self) -> str: return f"{self.platform_id}:{self.message_type.value}:{self.session_id}" def __post_init__(self): diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index c139b8bd7..8a698aeed 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -34,7 +34,7 @@ class PlatformError: class Platform(abc.ABC): - def __init__(self, config: dict, event_queue: Queue): + def __init__(self, config: dict, event_queue: Queue) -> None: super().__init__() # 平台配置 self.config = config diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 293b462d3..ff7716b2c 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -26,7 +26,7 @@ def __init__( platform_meta, session_id, bot: CQHttp, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index d520189d8..eb07dffe9 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -16,7 +16,7 @@ def __init__( platform_meta, session_id, client: dingtalk_stream.ChatbotHandler, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index ac0610f2a..b9b33e17b 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -15,7 +15,7 @@ class DiscordBotClient(discord.Bot): """Discord客户端封装""" - def __init__(self, token: str, proxy: str | None = None): + def __init__(self, token: str, proxy: str | None = None) -> None: self.token = token self.proxy = proxy diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index f875652a0..433509f5e 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -19,7 +19,7 @@ def __init__( image: str | None = None, footer: str | None = None, fields: list[dict] | None = None, - ): + ) -> None: self.title = title self.description = description self.color = color @@ -71,7 +71,7 @@ def __init__( emoji: str | None = None, url: str | None = None, disabled: bool = False, - ): + ) -> None: self.label = label self.custom_id = custom_id self.style = style @@ -85,7 +85,7 @@ class DiscordReference(BaseMessageComponent): type: str = "discord_reference" - def __init__(self, message_id: str, channel_id: str): + def __init__(self, message_id: str, channel_id: str) -> None: self.message_id = message_id self.channel_id = channel_id @@ -99,7 +99,7 @@ def __init__( self, components: list[BaseMessageComponent] | None = None, timeout: float | None = None, - ): + ) -> None: self.components = components or [] self.timeout = timeout diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 053018225..42459c1b2 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -28,7 +28,7 @@ class DiscordViewComponent(BaseMessageComponent): type: str = "discord_view" - def __init__(self, view: discord.ui.View): + def __init__(self, view: discord.ui.View) -> None: self.view = view @@ -41,7 +41,7 @@ def __init__( session_id: str, client: DiscordBotClient, interaction_followup_webhook: discord.Webhook | None = None, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client self.interaction_followup_webhook = interaction_followup_webhook diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index 7b7d20b38..dc952da45 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -31,7 +31,7 @@ def __init__( platform_meta, session_id, bot: lark.Client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index 06dc6304d..356e9bc3e 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -43,7 +43,7 @@ class WebSocketError(APIError): class StreamingClient: - def __init__(self, instance_url: str, access_token: str): + def __init__(self, instance_url: str, access_token: str) -> None: self.instance_url = instance_url.rstrip("/") self.access_token = access_token self.websocket: Any | None = None @@ -334,7 +334,7 @@ def __init__( download_timeout: int = 15, chunk_size: int = 64 * 1024, max_download_bytes: int | None = None, - ): + ) -> None: self.instance_url = instance_url.rstrip("/") self.access_token = access_token self._session: aiohttp.ClientSession | None = None diff --git a/astrbot/core/platform/sources/misskey/misskey_event.py b/astrbot/core/platform/sources/misskey/misskey_event.py index 7975f0ec7..77ed58376 100644 --- a/astrbot/core/platform/sources/misskey/misskey_event.py +++ b/astrbot/core/platform/sources/misskey/misskey_event.py @@ -26,7 +26,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index d693c4206..2d1f9f46e 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -32,7 +32,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, bot: Client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot self.send_buffer = None diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py index 306db5e56..5ceeb2c70 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py @@ -13,5 +13,5 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, bot: Client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id, bot) diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 2eda11a6c..f8ee930ff 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -14,7 +14,7 @@ class QQOfficialWebhook: - def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Client): + def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Client) -> None: self.appid = config["appid"] self.secret = config["secret"] self.port = config.get("port", 6196) diff --git a/astrbot/core/platform/sources/satori/satori_event.py b/astrbot/core/platform/sources/satori/satori_event.py index 81a0d222c..722fe939d 100644 --- a/astrbot/core/platform/sources/satori/satori_event.py +++ b/astrbot/core/platform/sources/satori/satori_event.py @@ -28,7 +28,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, adapter: "SatoriPlatformAdapter", - ): + ) -> None: # 更新平台元数据 if adapter and hasattr(adapter, "logins") and adapter.logins: current_login = adapter.logins[0] diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py index fbdc71759..c0dfe969c 100644 --- a/astrbot/core/platform/sources/slack/client.py +++ b/astrbot/core/platform/sources/slack/client.py @@ -27,7 +27,7 @@ def __init__( port: int = 3000, path: str = "/slack/events", event_handler: Callable | None = None, - ): + ) -> None: self.web_client = web_client self.signing_secret = signing_secret self.host = host @@ -135,7 +135,7 @@ def __init__( web_client: AsyncWebClient, app_token: str, event_handler: Callable | None = None, - ): + ) -> None: self.web_client = web_client self.app_token = app_token self.event_handler = event_handler diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py index 822e6fdeb..b08e05509 100644 --- a/astrbot/core/platform/sources/slack/slack_event.py +++ b/astrbot/core/platform/sources/slack/slack_event.py @@ -24,7 +24,7 @@ def __init__( platform_meta, session_id, web_client: AsyncWebClient, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.web_client = web_client diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 37f60e65a..bd518e9e9 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -38,7 +38,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: ExtBot, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 9f1a6d059..b220b002b 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -14,7 +14,7 @@ class WebChatMessageEvent(AstrMessageEvent): - def __init__(self, message_str, message_obj, platform_meta, session_id): + def __init__(self, message_str, message_obj, platform_meta, session_id) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) os.makedirs(imgs_dir, exist_ok=True) diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py index 08ab27013..5a208a04c 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py @@ -32,7 +32,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, adapter: "WeChatPadProAdapter", # 传递适配器实例 - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.message_obj = message_obj # Save the full message object self.adapter = adapter # Save the adapter instance diff --git a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py index 09924edb6..cf23c6b54 100644 --- a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +++ b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py @@ -20,7 +20,7 @@ def __init__( cached_images=None, raw_message: dict | None = None, downloader=None, - ): + ) -> None: self._xml = None self.content = content self.is_private_chat = is_private_chat diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 8f3d091a4..36a8713c5 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -39,7 +39,7 @@ class WecomServer: - def __init__(self, event_queue: asyncio.Queue, config: dict): + def __init__(self, event_queue: asyncio.Queue, config: dict) -> None: self.server = quart.Quart(__name__) self.port = int(cast(str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 0b5dae272..2e2cdb751 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -28,7 +28,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: WeChatClient, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client diff --git a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py index 2df09a763..4e7714b0e 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py @@ -145,7 +145,7 @@ class Prpcrypt: MIN_RANDOM_VALUE = 1000000000000000 # 最小值: 1000000000000000 (16位) RANDOM_RANGE = 9000000000000000 # 范围大小: 确保最大值为 9999999999999999 (16位) - def __init__(self, key): + def __init__(self, key) -> None: # self.key = base64.b64decode(key+"=") self.key = key # 设置加解密模式为AES的CBC模式 @@ -220,7 +220,7 @@ def get_random_str(self): class WXBizJsonMsgCrypt: # 构造函数 - def __init__(self, sToken, sEncodingAESKey, sReceiveId): + def __init__(self, sToken, sEncodingAESKey, sReceiveId) -> None: try: self.key = base64.b64decode(sEncodingAESKey + "=") assert len(self.key) == 32 diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py index 6c448a97e..97831fbb2 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py @@ -19,7 +19,7 @@ class WecomAIBotAPIClient: """企业微信智能机器人 API 客户端""" - def __init__(self, token: str, encoding_aes_key: str): + def __init__(self, token: str, encoding_aes_key: str) -> None: """初始化 API 客户端 Args: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py index fd11d7ceb..14dadcace 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -22,7 +22,7 @@ def __init__( session_id: str, api_client: WecomAIBotAPIClient, queue_mgr: WecomAIQueueMgr, - ): + ) -> None: """初始化消息事件 Args: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py index 5cbdd1130..3d7325c9d 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -23,7 +23,7 @@ def __init__( port: int, api_client: WecomAIBotAPIClient, message_handler: Callable[[dict[str, Any], dict[str, str]], Any] | None = None, - ): + ) -> None: """初始化服务器 Args: diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index d0304a48e..65fad22a0 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -35,7 +35,7 @@ class WeixinOfficialAccountServer: - def __init__(self, event_queue: asyncio.Queue, config: dict): + def __init__(self, event_queue: asyncio.Queue, config: dict) -> None: self.server = quart.Quart(__name__) self.port = int(cast(int | str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py index c1f137a41..66d003039 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -26,7 +26,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: WeChatClient, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py index d6d524698..3c22b641c 100644 --- a/astrbot/core/platform_message_history_mgr.py +++ b/astrbot/core/platform_message_history_mgr.py @@ -3,7 +3,7 @@ class PlatformMessageHistoryManager: - def __init__(self, db_helper: BaseDatabase): + def __init__(self, db_helper: BaseDatabase) -> None: self.db = db_helper async def insert( diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index dc188f141..a141bf73e 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -106,7 +106,7 @@ class ProviderRequest: model: str | None = None """模型名称,为 None 时使用提供商的默认模型""" - def __repr__(self): + def __repr__(self) -> str: return ( f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, " f"image_count={len(self.image_urls or [])}, " @@ -116,7 +116,7 @@ def __repr__(self): f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, " ) - def __str__(self): + def __str__(self) -> str: return self.__repr__() def append_tool_calls_result(self, tool_calls_result: ToolCallsResult): @@ -241,7 +241,7 @@ def __init__( | AnthropicMessage | None = None, is_chunk: bool = False, - ): + ) -> None: """初始化 LLMResponse Args: diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 7aad86bdd..0f5f26f35 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -575,10 +575,10 @@ async def sync_modelscope_mcp_servers(self, access_token: str) -> None: except Exception as e: raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {e!s}") - def __str__(self): + def __str__(self) -> str: return str(self.func_list) - def __repr__(self): + def __repr__(self) -> str: return str(self.func_list) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index be8edc282..7054276d6 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -30,7 +30,7 @@ def __init__( acm: AstrBotConfigManager, db_helper: BaseDatabase, persona_mgr: PersonaManager, - ): + ) -> None: self.reload_lock = asyncio.Lock() self.persona_mgr = persona_mgr self.acm = acm diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py index 2ccf146ca..3a0a79c93 100644 --- a/astrbot/core/provider/sources/azure_tts_source.py +++ b/astrbot/core/provider/sources/azure_tts_source.py @@ -21,7 +21,7 @@ class OTTSProvider: - def __init__(self, config: dict): + def __init__(self, config: dict) -> None: self.skey = config["OTTS_SKEY"] self.api_url = config["OTTS_URL"] self.auth_time_url = config["OTTS_AUTH_TIME"] @@ -103,7 +103,7 @@ async def get_audio(self, text: str, voice_params: dict) -> str: class AzureNativeProvider(TTSProvider): - def __init__(self, provider_config: dict, provider_settings: dict): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) self.subscription_key = provider_config.get( "azure_tts_subscription_key", @@ -195,7 +195,7 @@ async def get_audio(self, text: str) -> str: @register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH) class AzureTTSProvider(TTSProvider): - def __init__(self, provider_config: dict, provider_settings: dict): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) key_value = provider_config.get("azure_tts_subscription_key", "") self.provider = self._parse_provider(key_value, provider_config) diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index e27db7405..bdcc8785e 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -11,7 +11,7 @@ class Star(CommandParserMixin): """所有插件(Star)的父类,所有插件都应该继承于这个类""" - def __init__(self, context: Context, config: dict | None = None): + def __init__(self, context: Context, config: dict | None = None) -> None: StarTools.initialize(context) self.context = context @@ -57,7 +57,7 @@ async def initialize(self): async def terminate(self): """当插件被禁用、重载插件时会调用这个方法""" - def __del__(self): + def __del__(self) -> None: """[Deprecated] 当插件被禁用、重载插件时会调用这个方法""" diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 9a52ec8bc..b6ff13b62 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -65,7 +65,7 @@ def __init__( persona_manager: PersonaManager, astrbot_config_mgr: AstrBotConfigManager, knowledge_base_manager: KnowledgeBaseManager, - ): + ) -> None: self._event_queue = event_queue """事件队列。消息平台通过事件队列传递消息事件。""" self._config = config diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index 2a9868fdc..e42c07076 100755 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -37,7 +37,7 @@ def __init__( alias: set | None = None, handler_md: StarHandlerMetadata | None = None, parent_command_names: list[str] | None = None, - ): + ) -> None: self.command_name = command_name self.alias = alias if alias else set() self.parent_command_names = ( diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index e1c2efb22..087f35445 100755 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -15,7 +15,7 @@ def __init__( group_name: str, alias: set | None = None, parent_group: CommandGroupFilter | None = None, - ): + ) -> None: self.group_name = group_name self.alias = alias if alias else set() self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = [] diff --git a/astrbot/core/star/filter/custom_filter.py b/astrbot/core/star/filter/custom_filter.py index d57b5cac0..2e7303d06 100644 --- a/astrbot/core/star/filter/custom_filter.py +++ b/astrbot/core/star/filter/custom_filter.py @@ -19,7 +19,7 @@ def __or__(cls, other): class CustomFilter(HandlerFilter, metaclass=CustomFilterMeta): - def __init__(self, raise_error: bool = True, **kwargs): + def __init__(self, raise_error: bool = True, **kwargs) -> None: self.raise_error = raise_error @abstractmethod @@ -35,7 +35,7 @@ def __and__(self, other): class CustomFilterOr(CustomFilter): - def __init__(self, filter1: CustomFilter, filter2: CustomFilter): + def __init__(self, filter1: CustomFilter, filter2: CustomFilter) -> None: super().__init__() if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): raise ValueError( @@ -49,7 +49,7 @@ def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: class CustomFilterAnd(CustomFilter): - def __init__(self, filter1: CustomFilter, filter2: CustomFilter): + def __init__(self, filter1: CustomFilter, filter2: CustomFilter) -> None: super().__init__() if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): raise ValueError( diff --git a/astrbot/core/star/filter/event_message_type.py b/astrbot/core/star/filter/event_message_type.py index 7f350bd38..604fc3ed3 100644 --- a/astrbot/core/star/filter/event_message_type.py +++ b/astrbot/core/star/filter/event_message_type.py @@ -22,7 +22,7 @@ class EventMessageType(enum.Flag): class EventMessageTypeFilter(HandlerFilter): - def __init__(self, event_message_type: EventMessageType): + def __init__(self, event_message_type: EventMessageType) -> None: self.event_message_type = event_message_type def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: diff --git a/astrbot/core/star/filter/permission.py b/astrbot/core/star/filter/permission.py index 3374544c2..0017489cc 100644 --- a/astrbot/core/star/filter/permission.py +++ b/astrbot/core/star/filter/permission.py @@ -14,7 +14,7 @@ class PermissionType(enum.Flag): class PermissionTypeFilter(HandlerFilter): - def __init__(self, permission_type: PermissionType, raise_error: bool = True): + def __init__(self, permission_type: PermissionType, raise_error: bool = True) -> None: self.permission_type = permission_type self.raise_error = raise_error diff --git a/astrbot/core/star/filter/platform_adapter_type.py b/astrbot/core/star/filter/platform_adapter_type.py index 1182ff9b0..d5ff6146a 100644 --- a/astrbot/core/star/filter/platform_adapter_type.py +++ b/astrbot/core/star/filter/platform_adapter_type.py @@ -58,7 +58,7 @@ class PlatformAdapterType(enum.Flag): class PlatformAdapterTypeFilter(HandlerFilter): - def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str): + def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str) -> None: if isinstance(platform_adapter_type_or_str, str): self.platform_type = ADAPTER_NAME_2_TYPE.get(platform_adapter_type_or_str) else: diff --git a/astrbot/core/star/filter/regex.py b/astrbot/core/star/filter/regex.py index cd5bebdb4..abec5a488 100644 --- a/astrbot/core/star/filter/regex.py +++ b/astrbot/core/star/filter/regex.py @@ -10,7 +10,7 @@ class RegexFilter(HandlerFilter): """正则表达式过滤器""" - def __init__(self, regex: str): + def __init__(self, regex: str) -> None: self.regex_str = regex self.regex = re.compile(regex) diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index daf36a8f6..01b0889e2 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -250,7 +250,7 @@ class RegisteringCommandable: command: Callable[..., Callable[..., None]] = register_command custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter - def __init__(self, parent_group: CommandGroupFilter): + def __init__(self, parent_group: CommandGroupFilter) -> None: self.parent_group = parent_group @@ -492,7 +492,7 @@ def llm_tool(self, *args, **kwargs): kwargs["registering_agent"] = self return register_llm_tool(*args, **kwargs) - def __init__(self, agent: Agent[AstrAgentContext]): + def __init__(self, agent: Agent[AstrAgentContext]) -> None: self._agent = agent diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index da59cd291..663616f33 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -12,7 +12,7 @@ class StarHandlerRegistry(Generic[T]): - def __init__(self): + def __init__(self) -> None: self.star_handlers_map: dict[str, StarHandlerMetadata] = {} self._handlers: list[StarHandlerMetadata] = [] @@ -165,7 +165,7 @@ def remove(self, handler: StarHandlerMetadata): def __iter__(self): return iter(self._handlers) - def __len__(self): + def __len__(self) -> int: return len(self._handlers) diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index abdedc249..d73590722 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -37,7 +37,7 @@ class PluginManager: - def __init__(self, context: Context, config: AstrBotConfig): + def __init__(self, context: Context, config: AstrBotConfig) -> None: self.updator = PluginUpdator() self.context = context diff --git a/astrbot/core/umop_config_router.py b/astrbot/core/umop_config_router.py index 27f6232aa..e5d0d2071 100644 --- a/astrbot/core/umop_config_router.py +++ b/astrbot/core/umop_config_router.py @@ -4,7 +4,7 @@ class UmopConfigRouter: """UMOP 配置路由器""" - def __init__(self, sp: SharedPreferences): + def __init__(self, sp: SharedPreferences) -> None: self.umop_to_conf_id: dict[str, str] = {} """UMOP 到配置文件 ID 的映射""" self.sp = sp diff --git a/astrbot/core/utils/log_pipe.py b/astrbot/core/utils/log_pipe.py index 2e931dd81..077275c6d 100644 --- a/astrbot/core/utils/log_pipe.py +++ b/astrbot/core/utils/log_pipe.py @@ -10,7 +10,7 @@ def __init__( logger: Logger, identifier=None, callback=None, - ): + ) -> None: threading.Thread.__init__(self) self.daemon = True self.level = level diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py index 6076a114a..0f1eb6583 100644 --- a/astrbot/core/utils/pip_installer.py +++ b/astrbot/core/utils/pip_installer.py @@ -6,7 +6,7 @@ class PipInstaller: - def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None): + def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None) -> None: self.pip_install_arg = pip_install_arg self.pypi_index_url = pypi_index_url diff --git a/astrbot/core/utils/session_lock.py b/astrbot/core/utils/session_lock.py index 912d91e53..7810d6ce4 100644 --- a/astrbot/core/utils/session_lock.py +++ b/astrbot/core/utils/session_lock.py @@ -4,7 +4,7 @@ class SessionLockManager: - def __init__(self): + def __init__(self) -> None: self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._lock_count: dict[str, int] = defaultdict(int) self._access_lock = asyncio.Lock() diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index e1f2fbef7..3e7f339f9 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -18,7 +18,7 @@ class SessionController: """控制一个 Session 是否已经结束""" - def __init__(self): + def __init__(self) -> None: self.future = asyncio.Future() self.current_event: asyncio.Event | None = None """当前正在等待的所用的异步事件""" @@ -107,7 +107,7 @@ def __init__( session_filter: SessionFilter, session_id: str, record_history_chains: bool, - ): + ) -> None: self.session_id = session_id self.session_filter = session_filter self.handler: ( diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index ccd394ee4..cb3c9fcf2 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -12,7 +12,7 @@ class SharedPreferences: - def __init__(self, db_helper: BaseDatabase, json_storage_path=None): + def __init__(self, db_helper: BaseDatabase, json_storage_path=None) -> None: if json_storage_path is None: json_storage_path = os.path.join( get_astrbot_data_path(), diff --git a/astrbot/core/utils/t2i/renderer.py b/astrbot/core/utils/t2i/renderer.py index 2ce7a5ebf..b1c3ec384 100644 --- a/astrbot/core/utils/t2i/renderer.py +++ b/astrbot/core/utils/t2i/renderer.py @@ -7,7 +7,7 @@ class HtmlRenderer: - def __init__(self, endpoint_url: str | None = None): + def __init__(self, endpoint_url: str | None = None) -> None: self.network_strategy = NetworkRenderStrategy(endpoint_url) self.local_strategy = LocalRenderStrategy() diff --git a/astrbot/core/utils/t2i/template_manager.py b/astrbot/core/utils/t2i/template_manager.py index 6d44f735b..8bbdb7e9e 100644 --- a/astrbot/core/utils/t2i/template_manager.py +++ b/astrbot/core/utils/t2i/template_manager.py @@ -14,7 +14,7 @@ class TemplateManager: CORE_TEMPLATES = ["base.html", "astrbot_powershell.html"] - def __init__(self): + def __init__(self) -> None: self.builtin_template_dir = os.path.join( get_astrbot_path(), "astrbot", diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index 01ab292d4..72e359737 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -14,7 +14,7 @@ class RouteContext: class Route: routes: list | dict - def __init__(self, context: RouteContext): + def __init__(self, context: RouteContext) -> None: self.app = context.app self.config = context.config diff --git a/astrbot/dashboard/routes/t2i.py b/astrbot/dashboard/routes/t2i.py index db70a8820..df8fbbaa7 100644 --- a/astrbot/dashboard/routes/t2i.py +++ b/astrbot/dashboard/routes/t2i.py @@ -12,7 +12,7 @@ class T2iRoute(Route): - def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle): + def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle) -> None: super().__init__(context) self.core_lifecycle = core_lifecycle self.config = core_lifecycle.astrbot_config diff --git a/packages/astrbot/long_term_memory.py b/packages/astrbot/long_term_memory.py index 610995db2..f1cb0ec02 100644 --- a/packages/astrbot/long_term_memory.py +++ b/packages/astrbot/long_term_memory.py @@ -17,7 +17,7 @@ class LongTermMemory: - def __init__(self, acm: AstrBotConfigManager, context: star.Context): + def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None: self.acm = acm self.context = context self.session_chats = defaultdict(list) diff --git a/packages/astrbot/process_llm_request.py b/packages/astrbot/process_llm_request.py index 28c41df9f..f11987748 100644 --- a/packages/astrbot/process_llm_request.py +++ b/packages/astrbot/process_llm_request.py @@ -11,7 +11,7 @@ class ProcessLLMRequest: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.ctx = context cfg = context.get_config() self.timezone = cfg.get("timezone") diff --git a/packages/builtin_commands/commands/admin.py b/packages/builtin_commands/commands/admin.py index 2073f45a2..5725bd39a 100644 --- a/packages/builtin_commands/commands/admin.py +++ b/packages/builtin_commands/commands/admin.py @@ -5,7 +5,7 @@ class AdminCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def op(self, event: AstrMessageEvent, admin_id: str = ""): diff --git a/packages/builtin_commands/commands/alter_cmd.py b/packages/builtin_commands/commands/alter_cmd.py index 50007f6c0..6f62e8c36 100644 --- a/packages/builtin_commands/commands/alter_cmd.py +++ b/packages/builtin_commands/commands/alter_cmd.py @@ -11,7 +11,7 @@ class AlterCmdCommands(CommandParserMixin): - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def update_reset_permission(self, scene_key: str, perm_type: str): diff --git a/packages/builtin_commands/commands/conversation.py b/packages/builtin_commands/commands/conversation.py index de3d11ac8..6f0bda4a8 100644 --- a/packages/builtin_commands/commands/conversation.py +++ b/packages/builtin_commands/commands/conversation.py @@ -16,7 +16,7 @@ class ConversationCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def _get_current_persona_id(self, session_id): diff --git a/packages/builtin_commands/commands/help.py b/packages/builtin_commands/commands/help.py index 7f5b6c170..8874328b1 100644 --- a/packages/builtin_commands/commands/help.py +++ b/packages/builtin_commands/commands/help.py @@ -7,7 +7,7 @@ class HelpCommand: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def _query_astrbot_notice(self): diff --git a/packages/builtin_commands/commands/llm.py b/packages/builtin_commands/commands/llm.py index 85977df40..7d599d6fa 100644 --- a/packages/builtin_commands/commands/llm.py +++ b/packages/builtin_commands/commands/llm.py @@ -3,7 +3,7 @@ class LLMCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def llm(self, event: AstrMessageEvent): diff --git a/packages/builtin_commands/commands/persona.py b/packages/builtin_commands/commands/persona.py index 13a57f07f..155ac333a 100644 --- a/packages/builtin_commands/commands/persona.py +++ b/packages/builtin_commands/commands/persona.py @@ -5,7 +5,7 @@ class PersonaCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def persona(self, message: AstrMessageEvent): diff --git a/packages/builtin_commands/commands/plugin.py b/packages/builtin_commands/commands/plugin.py index ab45efc11..327606ced 100644 --- a/packages/builtin_commands/commands/plugin.py +++ b/packages/builtin_commands/commands/plugin.py @@ -8,7 +8,7 @@ class PluginCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def plugin_ls(self, event: AstrMessageEvent): diff --git a/packages/builtin_commands/commands/provider.py b/packages/builtin_commands/commands/provider.py index ce8f31831..09f685dca 100644 --- a/packages/builtin_commands/commands/provider.py +++ b/packages/builtin_commands/commands/provider.py @@ -8,7 +8,7 @@ class ProviderCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context def _log_reachability_failure( diff --git a/packages/builtin_commands/commands/setunset.py b/packages/builtin_commands/commands/setunset.py index 79e5d5d1c..06a45ff6c 100644 --- a/packages/builtin_commands/commands/setunset.py +++ b/packages/builtin_commands/commands/setunset.py @@ -3,7 +3,7 @@ class SetUnsetCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def set_variable(self, event: AstrMessageEvent, key: str, value: str): diff --git a/packages/builtin_commands/commands/sid.py b/packages/builtin_commands/commands/sid.py index 4d95c5a60..35525ccdd 100644 --- a/packages/builtin_commands/commands/sid.py +++ b/packages/builtin_commands/commands/sid.py @@ -7,7 +7,7 @@ class SIDCommand: """会话ID命令类""" - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def sid(self, event: AstrMessageEvent): diff --git a/packages/builtin_commands/commands/t2i.py b/packages/builtin_commands/commands/t2i.py index 7766b342f..76e5acf3f 100644 --- a/packages/builtin_commands/commands/t2i.py +++ b/packages/builtin_commands/commands/t2i.py @@ -7,7 +7,7 @@ class T2ICommand: """文本转图片命令类""" - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def t2i(self, event: AstrMessageEvent): diff --git a/packages/builtin_commands/commands/tool.py b/packages/builtin_commands/commands/tool.py index 9a6c507e6..bec9b53ac 100644 --- a/packages/builtin_commands/commands/tool.py +++ b/packages/builtin_commands/commands/tool.py @@ -3,7 +3,7 @@ class ToolCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def tool_ls(self, event: AstrMessageEvent): diff --git a/packages/builtin_commands/commands/tts.py b/packages/builtin_commands/commands/tts.py index d733ba1ea..22af09238 100644 --- a/packages/builtin_commands/commands/tts.py +++ b/packages/builtin_commands/commands/tts.py @@ -8,7 +8,7 @@ class TTSCommand: """文本转语音命令类""" - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def tts(self, event: AstrMessageEvent): diff --git a/packages/session_controller/main.py b/packages/session_controller/main.py index 9ea62ea30..3e2106a9e 100644 --- a/packages/session_controller/main.py +++ b/packages/session_controller/main.py @@ -17,7 +17,7 @@ class Main(Star): """会话控制""" - def __init__(self, context: Context): + def __init__(self, context: Context) -> None: super().__init__(context) @filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize) diff --git a/tests/test_main.py b/tests/test_main.py index 0453a51ee..c70fe6865 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -12,7 +12,7 @@ class _version_info: - def __init__(self, major, minor): + def __init__(self, major, minor) -> None: self.major = major self.minor = minor From 39463e97797399eb28e485eec2de8858e457f44a Mon Sep 17 00:00:00 2001 From: Dale Null Date: Thu, 11 Dec 2025 15:53:03 +0800 Subject: [PATCH 25/44] =?UTF-8?q?refactor:=20=E4=B8=BA=E5=A4=A7=E9=87=8F?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E6=B7=BB=E5=8A=A0=E8=BF=94=E5=9B=9E=20None?= =?UTF-8?q?=20=E7=9A=84=E7=B1=BB=E5=9E=8B=E6=B3=A8=E8=A7=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/cli/commands/cmd_conf.py | 6 +- astrbot/cli/commands/cmd_plug.py | 16 ++--- astrbot/cli/commands/cmd_run.py | 2 +- astrbot/cli/utils/plugin.py | 2 +- astrbot/core/agent/hooks.py | 8 +-- astrbot/core/agent/mcp_client.py | 4 +- .../agent/runners/coze/coze_api_client.py | 4 +- .../dashscope/dashscope_agent_runner.py | 2 +- .../agent/runners/dify/dify_api_client.py | 2 +- astrbot/core/agent/tool.py | 8 +-- astrbot/core/astr_agent_hooks.py | 4 +- astrbot/core/config/astrbot_config.py | 2 +- astrbot/core/conversation_mgr.py | 6 +- astrbot/core/db/__init__.py | 2 +- astrbot/core/db/migration/migra_3_to_4.py | 10 +-- astrbot/core/db/migration/migra_45_to_46.py | 2 +- .../db/migration/migra_webchat_session.py | 2 +- .../db/migration/shared_preferences_v3.py | 6 +- astrbot/core/db/migration/sqlite_v3.py | 14 ++-- astrbot/core/db/sqlite.py | 10 +-- astrbot/core/db/vec_db/base.py | 2 +- .../db/vec_db/faiss_impl/document_storage.py | 12 ++-- .../db/vec_db/faiss_impl/embedding_storage.py | 8 +-- astrbot/core/db/vec_db/faiss_impl/vec_db.py | 8 +-- astrbot/core/event_bus.py | 2 +- astrbot/core/initial_loader.py | 2 +- astrbot/core/knowledge_base/kb_db_sqlite.py | 2 +- astrbot/core/knowledge_base/kb_helper.py | 12 ++-- astrbot/core/knowledge_base/kb_mgr.py | 6 +- astrbot/core/log.py | 6 +- astrbot/core/message/components.py | 4 +- astrbot/core/persona_mgr.py | 4 +- .../pipeline/content_safety_check/stage.py | 2 +- astrbot/core/pipeline/respond/stage.py | 2 +- .../core/pipeline/result_decorate/stage.py | 2 +- astrbot/core/pipeline/scheduler.py | 4 +- astrbot/core/platform/astr_message_event.py | 20 +++--- astrbot/core/platform/manager.py | 10 +-- astrbot/core/platform/platform.py | 10 +-- .../aiocqhttp/aiocqhttp_message_event.py | 2 +- .../aiocqhttp/aiocqhttp_platform_adapter.py | 8 +-- .../sources/dingtalk/dingtalk_adapter.py | 10 +-- .../sources/dingtalk/dingtalk_event.py | 4 +- .../core/platform/sources/discord/client.py | 8 +-- .../discord/discord_platform_adapter.py | 10 +-- .../sources/discord/discord_platform_event.py | 4 +- .../platform/sources/lark/lark_adapter.py | 10 +-- .../core/platform/sources/lark/lark_event.py | 4 +- .../sources/misskey/misskey_adapter.py | 4 +- .../platform/sources/misskey/misskey_api.py | 8 +-- .../platform/sources/misskey/misskey_event.py | 2 +- .../platform/sources/misskey/misskey_utils.py | 4 +- .../qqofficial/qqofficial_message_event.py | 2 +- .../qqofficial/qqofficial_platform_adapter.py | 16 ++--- .../qqofficial_webhook/qo_webhook_adapter.py | 18 ++--- .../qqofficial_webhook/qo_webhook_server.py | 6 +- .../platform/sources/satori/satori_adapter.py | 18 ++--- .../platform/sources/satori/satori_event.py | 2 +- astrbot/core/platform/sources/slack/client.py | 10 +-- .../platform/sources/slack/slack_adapter.py | 6 +- .../platform/sources/slack/slack_event.py | 2 +- .../platform/sources/telegram/tg_adapter.py | 14 ++-- .../platform/sources/telegram/tg_event.py | 4 +- .../sources/webchat/webchat_adapter.py | 10 +-- .../platform/sources/webchat/webchat_event.py | 4 +- .../sources/webchat/webchat_queue_mgr.py | 2 +- .../wechatpadpro/wechatpadpro_adapter.py | 16 ++--- .../wechatpadpro_message_event.py | 2 +- .../platform/sources/wecom/wecom_adapter.py | 12 ++-- .../platform/sources/wecom/wecom_event.py | 2 +- .../sources/wecom_ai_bot/WXBizJsonMsgCrypt.py | 3 +- .../sources/wecom_ai_bot/wecomai_adapter.py | 10 +-- .../sources/wecom_ai_bot/wecomai_event.py | 4 +- .../sources/wecom_ai_bot/wecomai_queue_mgr.py | 6 +- .../sources/wecom_ai_bot/wecomai_server.py | 6 +- .../weixin_offacc_adapter.py | 12 ++-- .../weixin_offacc_event.py | 2 +- astrbot/core/platform_message_history_mgr.py | 2 +- astrbot/core/provider/entities.py | 2 +- astrbot/core/provider/manager.py | 12 ++-- astrbot/core/provider/provider.py | 20 +++--- .../core/provider/sources/anthropic_source.py | 2 +- .../core/provider/sources/gemini_source.py | 4 +- .../provider/sources/gsv_selfhosted_source.py | 4 +- .../core/provider/sources/openai_source.py | 2 +- .../sources/sensevoice_selfhosted_source.py | 2 +- .../sources/whisper_selfhosted_source.py | 2 +- .../sources/xinference_rerank_source.py | 2 +- .../sources/xinference_stt_provider.py | 2 +- astrbot/core/star/__init__.py | 4 +- astrbot/core/star/config.py | 4 +- astrbot/core/star/context.py | 8 +-- astrbot/core/star/filter/command.py | 4 +- astrbot/core/star/filter/command_group.py | 4 +- astrbot/core/star/star_handler.py | 6 +- astrbot/core/star/star_manager.py | 8 +-- astrbot/core/star/updator.py | 2 +- astrbot/core/umop_config_router.py | 6 +- astrbot/core/updator.py | 4 +- astrbot/core/utils/io.py | 6 +- astrbot/core/utils/log_pipe.py | 4 +- astrbot/core/utils/pip_installer.py | 2 +- astrbot/core/utils/session_waiter.py | 4 +- astrbot/core/utils/shared_preferences.py | 20 +++--- astrbot/core/utils/t2i/network_strategy.py | 4 +- astrbot/core/utils/t2i/renderer.py | 2 +- astrbot/core/utils/t2i/template_manager.py | 8 +-- astrbot/core/utils/webhook_utils.py | 2 +- astrbot/core/zip_updator.py | 9 +-- astrbot/dashboard/routes/config.py | 2 +- astrbot/dashboard/routes/route.py | 2 +- astrbot/dashboard/server.py | 2 +- main.py | 2 +- packages/astrbot/long_term_memory.py | 6 +- packages/astrbot/main.py | 6 +- packages/astrbot/process_llm_request.py | 2 +- packages/builtin_commands/commands/admin.py | 10 +-- .../builtin_commands/commands/alter_cmd.py | 4 +- .../builtin_commands/commands/conversation.py | 16 ++--- packages/builtin_commands/commands/help.py | 2 +- packages/builtin_commands/commands/llm.py | 2 +- packages/builtin_commands/commands/persona.py | 2 +- packages/builtin_commands/commands/plugin.py | 10 +-- .../builtin_commands/commands/provider.py | 6 +- .../builtin_commands/commands/setunset.py | 4 +- packages/builtin_commands/commands/sid.py | 2 +- packages/builtin_commands/commands/t2i.py | 2 +- packages/builtin_commands/commands/tool.py | 8 +-- packages/builtin_commands/commands/tts.py | 2 +- packages/builtin_commands/main.py | 72 +++++++++---------- packages/python_interpreter/main.py | 6 +- packages/python_interpreter/shared/api.py | 6 +- packages/reminder/main.py | 4 +- packages/session_controller/main.py | 2 +- packages/web_searcher/main.py | 6 +- tests/test_dashboard.py | 10 +-- tests/test_main.py | 10 +-- tests/test_plugin_manager.py | 16 ++--- tests/test_security_fixes.py | 12 ++-- 139 files changed, 451 insertions(+), 449 deletions(-) diff --git a/astrbot/cli/commands/cmd_conf.py b/astrbot/cli/commands/cmd_conf.py index a9bd40f00..703c9b899 100644 --- a/astrbot/cli/commands/cmd_conf.py +++ b/astrbot/cli/commands/cmd_conf.py @@ -127,7 +127,7 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any: @click.group(name="conf") -def conf(): +def conf() -> None: """配置管理命令 支持的配置项: @@ -149,7 +149,7 @@ def conf(): @conf.command(name="set") @click.argument("key") @click.argument("value") -def set_config(key: str, value: str): +def set_config(key: str, value: str) -> None: """设置配置项的值""" if key not in CONFIG_VALIDATORS: raise click.ClickException(f"不支持的配置项: {key}") @@ -178,7 +178,7 @@ def set_config(key: str, value: str): @conf.command(name="get") @click.argument("key", required=False) -def get_config(key: str | None = None): +def get_config(key: str | None = None) -> None: """获取配置项的值,不提供key则显示所有可配置项""" config = _load_config() diff --git a/astrbot/cli/commands/cmd_plug.py b/astrbot/cli/commands/cmd_plug.py index a1099de1d..9cf94365a 100644 --- a/astrbot/cli/commands/cmd_plug.py +++ b/astrbot/cli/commands/cmd_plug.py @@ -15,7 +15,7 @@ @click.group() -def plug(): +def plug() -> None: """插件管理""" @@ -28,7 +28,7 @@ def _get_data_path() -> Path: return (base / "data").resolve() -def display_plugins(plugins, title=None, color=None): +def display_plugins(plugins, title=None, color=None) -> None: if title: click.echo(click.style(title, fg=color, bold=True)) @@ -45,7 +45,7 @@ def display_plugins(plugins, title=None, color=None): @plug.command() @click.argument("name") -def new(name: str): +def new(name: str) -> None: """创建新插件""" base_path = _get_data_path() plug_path = base_path / "plugins" / name @@ -100,7 +100,7 @@ def new(name: str): @plug.command() @click.option("--all", "-a", is_flag=True, help="列出未安装的插件") -def list(all: bool): +def list(all: bool) -> None: """列出插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") @@ -141,7 +141,7 @@ def list(all: bool): @plug.command() @click.argument("name") @click.option("--proxy", help="代理服务器地址") -def install(name: str, proxy: str | None): +def install(name: str, proxy: str | None) -> None: """安装插件""" base_path = _get_data_path() plug_path = base_path / "plugins" @@ -164,7 +164,7 @@ def install(name: str, proxy: str | None): @plug.command() @click.argument("name") -def remove(name: str): +def remove(name: str) -> None: """卸载插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") @@ -187,7 +187,7 @@ def remove(name: str): @plug.command() @click.argument("name", required=False) @click.option("--proxy", help="Github代理地址") -def update(name: str, proxy: str | None): +def update(name: str, proxy: str | None) -> None: """更新插件""" base_path = _get_data_path() plug_path = base_path / "plugins" @@ -225,7 +225,7 @@ def update(name: str, proxy: str | None): @plug.command() @click.argument("query") -def search(query: str): +def search(query: str) -> None: """搜索插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") diff --git a/astrbot/cli/commands/cmd_run.py b/astrbot/cli/commands/cmd_run.py index 9333f1b87..23665dff3 100644 --- a/astrbot/cli/commands/cmd_run.py +++ b/astrbot/cli/commands/cmd_run.py @@ -10,7 +10,7 @@ from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root -async def run_astrbot(astrbot_root: Path): +async def run_astrbot(astrbot_root: Path) -> None: """运行 AstrBot""" from astrbot.core import LogBroker, LogManager, db_helper, logger from astrbot.core.initial_loader import InitialLoader diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index cd76a07c8..81f59e0bf 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -19,7 +19,7 @@ class PluginStatus(str, Enum): NOT_PUBLISHED = "未发布" -def get_git_repo(url: str, target_path: Path, proxy: str | None = None): +def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None: """从 Git 仓库下载代码并解压到指定路径""" temp_dir = Path(tempfile.mkdtemp()) try: diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py index d834240b7..74ca6335b 100644 --- a/astrbot/core/agent/hooks.py +++ b/astrbot/core/agent/hooks.py @@ -9,22 +9,22 @@ class BaseAgentRunHooks(Generic[TContext]): - async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ... + async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ... async def on_tool_start( self, run_context: ContextWrapper[TContext], tool: FunctionTool, tool_args: dict | None, - ): ... + ) -> None: ... async def on_tool_end( self, run_context: ContextWrapper[TContext], tool: FunctionTool, tool_args: dict | None, tool_result: mcp.types.CallToolResult | None, - ): ... + ) -> None: ... async def on_agent_done( self, run_context: ContextWrapper[TContext], llm_response: LLMResponse, - ): ... + ) -> None: ... diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index e02e4b973..84b6e83c9 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -126,7 +126,7 @@ def __init__(self) -> None: self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection self._reconnecting: bool = False # For logging and debugging - async def connect_to_server(self, mcp_server_config: dict, name: str): + async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: """Connect to MCP server If `url` parameter exists: @@ -343,7 +343,7 @@ async def _call_with_retry(): return await _call_with_retry() - async def cleanup(self): + async def cleanup(self) -> None: """Clean up resources including old exit stacks from reconnections""" # Close current exit stack try: diff --git a/astrbot/core/agent/runners/coze/coze_api_client.py b/astrbot/core/agent/runners/coze/coze_api_client.py index 5124dc311..f5799dfbb 100644 --- a/astrbot/core/agent/runners/coze/coze_api_client.py +++ b/astrbot/core/agent/runners/coze/coze_api_client.py @@ -277,7 +277,7 @@ async def get_message_list( logger.error(f"获取Coze消息列表失败: {e!s}") raise Exception(f"获取Coze消息列表失败: {e!s}") - async def close(self): + async def close(self) -> None: """关闭会话""" if self.session: await self.session.close() @@ -288,7 +288,7 @@ async def close(self): import asyncio import os - async def test_coze_api_client(): + async def test_coze_api_client() -> None: api_key = os.getenv("COZE_API_KEY", "") bot_id = os.getenv("COZE_BOT_ID", "") client = CozeAPIClient(api_key=api_key) diff --git a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py index 7a095a60b..1aaf6e3b9 100644 --- a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +++ b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py @@ -67,7 +67,7 @@ async def reset( if isinstance(self.timeout, str): self.timeout = int(self.timeout) - def has_rag_options(self): + def has_rag_options(self) -> bool: """判断是否有 RAG 选项 Returns: diff --git a/astrbot/core/agent/runners/dify/dify_api_client.py b/astrbot/core/agent/runners/dify/dify_api_client.py index d07569779..26da6dfe9 100644 --- a/astrbot/core/agent/runners/dify/dify_api_client.py +++ b/astrbot/core/agent/runners/dify/dify_api_client.py @@ -155,7 +155,7 @@ async def file_upload( raise Exception(f"Dify 文件上传失败:{resp.status}. {text}") return await resp.json() # {"id": "xxx", ...} - async def close(self): + async def close(self) -> None: await self.session.close() async def get_chat_convs(self, user: str, limit: int = 20): diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 306e5c081..98f302fad 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -82,7 +82,7 @@ def empty(self) -> bool: """Check if the tool set is empty.""" return len(self.tools) == 0 - def add_tool(self, tool: FunctionTool): + def add_tool(self, tool: FunctionTool) -> None: """Add a tool to the set.""" # 检查是否已存在同名工具 for i, existing_tool in enumerate(self.tools): @@ -91,7 +91,7 @@ def add_tool(self, tool: FunctionTool): return self.tools.append(tool) - def remove_tool(self, name: str): + def remove_tool(self, name: str) -> None: """Remove a tool by its name.""" self.tools = [tool for tool in self.tools if tool.name != name] @@ -109,7 +109,7 @@ def add_func( func_args: list, desc: str, handler: Callable[..., Awaitable[Any]], - ): + ) -> None: """Add a function tool to the set.""" params = { "type": "object", # hard-coded here @@ -129,7 +129,7 @@ def add_func( self.add_tool(_func) @deprecated(reason="Use remove_tool() instead", version="4.0.0") - def remove_func(self, name: str): + def remove_func(self, name: str) -> None: """Remove a function tool by its name.""" self.remove_tool(name) diff --git a/astrbot/core/astr_agent_hooks.py b/astrbot/core/astr_agent_hooks.py index f394fc947..39e247c6d 100644 --- a/astrbot/core/astr_agent_hooks.py +++ b/astrbot/core/astr_agent_hooks.py @@ -11,7 +11,7 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): - async def on_agent_done(self, run_context, llm_response): + async def on_agent_done(self, run_context, llm_response) -> None: # 执行事件钩子 await call_event_hook( run_context.context.event, @@ -25,7 +25,7 @@ async def on_tool_end( tool: FunctionTool[Any], tool_args: dict | None, tool_result: CallToolResult | None, - ): + ) -> None: run_context.context.event.clear_result() diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 8bb82b621..aa2ad4fee 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -146,7 +146,7 @@ def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): return has_new - def save_config(self, replace_config: dict | None = None): + def save_config(self, replace_config: dict | None = None) -> None: """将配置写入文件 如果传入 replace_config,则将配置替换为 replace_config diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 98f10cc1f..54043f339 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -105,7 +105,7 @@ async def new_conversation( await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id) return conv.conversation_id - async def switch_conversation(self, unified_msg_origin: str, conversation_id: str): + async def switch_conversation(self, unified_msg_origin: str, conversation_id: str) -> None: """切换会话的对话 Args: @@ -120,7 +120,7 @@ async def delete_conversation( self, unified_msg_origin: str, conversation_id: str | None = None, - ): + ) -> None: """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话 Args: @@ -137,7 +137,7 @@ async def delete_conversation( self.session_conversations.pop(unified_msg_origin, None) await sp.session_remove(unified_msg_origin, "sel_conv_id") - async def delete_conversations_by_user_id(self, unified_msg_origin: str): + async def delete_conversations_by_user_id(self, unified_msg_origin: str) -> None: """删除会话的所有对话 Args: diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 44c69b209..b9fa3d77a 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -37,7 +37,7 @@ def __init__(self) -> None: expire_on_commit=False, ) - async def initialize(self): + async def initialize(self) -> None: """初始化数据库连接""" @asynccontextmanager diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index 66b72d5cb..727d97b29 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -43,7 +43,7 @@ def get_platform_type( async def migration_conversation_table( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) @@ -101,7 +101,7 @@ async def migration_conversation_table( async def migration_platform_table( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) @@ -180,7 +180,7 @@ async def migration_platform_table( async def migration_webchat_data( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: """迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中""" db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), @@ -236,7 +236,7 @@ async def migration_webchat_data( async def migration_persona_data( db_helper: BaseDatabase, astrbot_config: AstrBotConfig, -): +) -> None: """迁移 Persona 数据到新的表中。 旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。 """ @@ -279,7 +279,7 @@ async def migration_persona_data( async def migration_preferences( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: # 1. global scope migration keys = [ "inactivated_llm_tools", diff --git a/astrbot/core/db/migration/migra_45_to_46.py b/astrbot/core/db/migration/migra_45_to_46.py index dc70026f9..58736ab51 100644 --- a/astrbot/core/db/migration/migra_45_to_46.py +++ b/astrbot/core/db/migration/migra_45_to_46.py @@ -3,7 +3,7 @@ from astrbot.core.umop_config_router import UmopConfigRouter -async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter): +async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter) -> None: abconf_data = acm.abconf_data if not isinstance(abconf_data, dict): diff --git a/astrbot/core/db/migration/migra_webchat_session.py b/astrbot/core/db/migration/migra_webchat_session.py index ff0b5ca6f..46025fc64 100644 --- a/astrbot/core/db/migration/migra_webchat_session.py +++ b/astrbot/core/db/migration/migra_webchat_session.py @@ -17,7 +17,7 @@ from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession -async def migrate_webchat_session(db_helper: BaseDatabase): +async def migrate_webchat_session(db_helper: BaseDatabase) -> None: """Create PlatformSession records from platform_message_history. This migration extracts all unique user_ids from platform_message_history diff --git a/astrbot/core/db/migration/shared_preferences_v3.py b/astrbot/core/db/migration/shared_preferences_v3.py index b74fe3e90..3d1771b8b 100644 --- a/astrbot/core/db/migration/shared_preferences_v3.py +++ b/astrbot/core/db/migration/shared_preferences_v3.py @@ -31,16 +31,16 @@ def _save_preferences(self): def get(self, key, default: _VT = None) -> _VT: return self._data.get(key, default) - def put(self, key, value): + def put(self, key, value) -> None: self._data[key] = value self._save_preferences() - def remove(self, key): + def remove(self, key) -> None: if key in self._data: del self._data[key] self._save_preferences() - def clear(self): + def clear(self) -> None: self._data.clear() self._save_preferences() diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index b1a780d48..faa7088f2 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -144,7 +144,7 @@ def _exec_sql(self, sql: str, params: tuple | None = None): conn.commit() - def insert_platform_metrics(self, metrics: dict): + def insert_platform_metrics(self, metrics: dict) -> None: for k, v in metrics.items(): self._exec_sql( """ @@ -153,7 +153,7 @@ def insert_platform_metrics(self, metrics: dict): (k, v, int(time.time())), ) - def insert_llm_metrics(self, metrics: dict): + def insert_llm_metrics(self, metrics: dict) -> None: for k, v in metrics.items(): self._exec_sql( """ @@ -249,7 +249,7 @@ def get_conversation_by_user_id( return Conversation(*res) - def new_conversation(self, user_id: str, cid: str): + def new_conversation(self, user_id: str, cid: str) -> None: history = "[]" updated_at = int(time.time()) created_at = updated_at @@ -287,7 +287,7 @@ def get_conversations(self, user_id: str) -> list[Conversation]: ) return conversations - def update_conversation(self, user_id: str, cid: str, history: str): + def update_conversation(self, user_id: str, cid: str, history: str) -> None: """更新对话,并且同时更新时间""" updated_at = int(time.time()) self._exec_sql( @@ -297,7 +297,7 @@ def update_conversation(self, user_id: str, cid: str, history: str): (history, updated_at, user_id, cid), ) - def update_conversation_title(self, user_id: str, cid: str, title: str): + def update_conversation_title(self, user_id: str, cid: str, title: str) -> None: self._exec_sql( """ UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? @@ -305,7 +305,7 @@ def update_conversation_title(self, user_id: str, cid: str, title: str): (title, user_id, cid), ) - def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): + def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str) -> None: self._exec_sql( """ UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? @@ -313,7 +313,7 @@ def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str (persona_id, user_id, cid), ) - def delete_conversation(self, user_id: str, cid: str): + def delete_conversation(self, user_id: str, cid: str) -> None: self._exec_sql( """ DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 033d076c8..02ef5ea6d 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -257,7 +257,7 @@ async def update_conversation(self, cid, title=None, persona_id=None, content=No await session.execute(query) return await self.get_conversation_by_id(cid) - async def delete_conversation(self, cid): + async def delete_conversation(self, cid) -> None: async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -413,7 +413,7 @@ async def delete_platform_message_offset( platform_id, user_id, offset_sec=86400, - ): + ) -> None: """Delete platform message history records newer than the specified offset.""" async with self.get_db() as session: session: AsyncSession @@ -586,7 +586,7 @@ async def update_persona( await session.execute(query) return await self.get_persona_by_id(persona_id) - async def delete_persona(self, persona_id): + async def delete_persona(self, persona_id) -> None: """Delete a persona by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -643,7 +643,7 @@ async def get_preferences(self, scope, scope_id=None, key=None): result = await session.execute(query) return result.scalars().all() - async def remove_preference(self, scope, scope_id, key): + async def remove_preference(self, scope, scope_id, key) -> None: """Remove a preference by scope ID and key.""" async with self.get_db() as session: session: AsyncSession @@ -657,7 +657,7 @@ async def remove_preference(self, scope, scope_id, key): ) await session.commit() - async def clear_preferences(self, scope, scope_id): + async def clear_preferences(self, scope, scope_id) -> None: """Clear all preferences for a specific scope ID.""" async with self.get_db() as session: session: AsyncSession diff --git a/astrbot/core/db/vec_db/base.py b/astrbot/core/db/vec_db/base.py index 7440b6f2a..04f8903b1 100644 --- a/astrbot/core/db/vec_db/base.py +++ b/astrbot/core/db/vec_db/base.py @@ -9,7 +9,7 @@ class Result: class BaseVecDB: - async def initialize(self): + async def initialize(self) -> None: """初始化向量数据库""" @abc.abstractmethod diff --git a/astrbot/core/db/vec_db/faiss_impl/document_storage.py b/astrbot/core/db/vec_db/faiss_impl/document_storage.py index 1d6b8fa85..2adae69cc 100644 --- a/astrbot/core/db/vec_db/faiss_impl/document_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/document_storage.py @@ -43,7 +43,7 @@ def __init__(self, db_path: str) -> None: "sqlite_init.sql", ) - async def initialize(self): + async def initialize(self) -> None: """Initialize the SQLite database and create the documents table if it doesn't exist.""" await self.connect() async with self.engine.begin() as conn: # type: ignore @@ -80,7 +80,7 @@ async def initialize(self): await conn.commit() - async def connect(self): + async def connect(self) -> None: """Connect to the SQLite database.""" if self.engine is None: self.engine = create_async_engine( @@ -211,7 +211,7 @@ async def insert_documents_batch( await session.flush() # Flush to get all IDs return [doc.id for doc in documents] # type: ignore - async def delete_document_by_doc_id(self, doc_id: str): + async def delete_document_by_doc_id(self, doc_id: str) -> None: """Delete a document by its doc_id. Args: @@ -249,7 +249,7 @@ async def get_document_by_doc_id(self, doc_id: str): return self._document_to_dict(document) return None - async def update_document_by_doc_id(self, doc_id: str, new_text: str): + async def update_document_by_doc_id(self, doc_id: str, new_text: str) -> None: """Update a document by its doc_id. Args: @@ -269,7 +269,7 @@ async def update_document_by_doc_id(self, doc_id: str, new_text: str): document.updated_at = datetime.now() session.add(document) - async def delete_documents(self, metadata_filters: dict): + async def delete_documents(self, metadata_filters: dict) -> None: """Delete documents by their metadata filters. Args: @@ -384,7 +384,7 @@ async def tuple_to_dict(self, row): "updated_at": row[5], } - async def close(self): + async def close(self) -> None: """Close the connection to the SQLite database.""" if self.engine: await self.engine.dispose() diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index 6e5197b9e..dc6977cf8 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -20,7 +20,7 @@ def __init__(self, dimension: int, path: str | None = None) -> None: base_index = faiss.IndexFlatL2(dimension) self.index = faiss.IndexIDMap(base_index) - async def insert(self, vector: np.ndarray, id: int): + async def insert(self, vector: np.ndarray, id: int) -> None: """插入向量 Args: @@ -38,7 +38,7 @@ async def insert(self, vector: np.ndarray, id: int): self.index.add_with_ids(vector.reshape(1, -1), np.array([id])) await self.save_index() - async def insert_batch(self, vectors: np.ndarray, ids: list[int]): + async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None: """批量插入向量 Args: @@ -71,7 +71,7 @@ async def search(self, vector: np.ndarray, k: int) -> tuple: distances, indices = self.index.search(vector, k) return distances, indices - async def delete(self, ids: list[int]): + async def delete(self, ids: list[int]) -> None: """删除向量 Args: @@ -83,7 +83,7 @@ async def delete(self, ids: list[int]): self.index.remove_ids(id_array) await self.save_index() - async def save_index(self): + async def save_index(self) -> None: """保存索引 Args: diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index 4b7f036ee..3fca246ef 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -32,7 +32,7 @@ def __init__( self.embedding_provider = embedding_provider self.rerank_provider = rerank_provider - async def initialize(self): + async def initialize(self) -> None: await self.document_storage.initialize() async def insert( @@ -165,7 +165,7 @@ async def retrieve( return top_k_results - async def delete(self, doc_id: str): + async def delete(self, doc_id: str) -> None: """删除一条文档块(chunk)""" # 获得对应的 int id result = await self.document_storage.get_document_by_doc_id(doc_id) @@ -177,7 +177,7 @@ async def delete(self, doc_id: str): await self.document_storage.delete_document_by_doc_id(doc_id) await self.embedding_storage.delete([int_id]) - async def close(self): + async def close(self) -> None: await self.document_storage.close() async def count_documents(self, metadata_filter: dict | None = None) -> int: @@ -192,7 +192,7 @@ async def count_documents(self, metadata_filter: dict | None = None) -> int: ) return count - async def delete_documents(self, metadata_filters: dict): + async def delete_documents(self, metadata_filters: dict) -> None: """根据元数据过滤器删除文档""" docs = await self.document_storage.get_documents( metadata_filters=metadata_filters, diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 20bd47692..773940c59 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -34,7 +34,7 @@ def __init__( self.pipeline_scheduler_mapping = pipeline_scheduler_mapping self.astrbot_config_mgr = astrbot_config_mgr - async def dispatch(self): + async def dispatch(self) -> None: while True: event: AstrMessageEvent = await self.event_queue.get() conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) diff --git a/astrbot/core/initial_loader.py b/astrbot/core/initial_loader.py index c3ee98ad3..3f836a4c4 100644 --- a/astrbot/core/initial_loader.py +++ b/astrbot/core/initial_loader.py @@ -23,7 +23,7 @@ def __init__(self, db: BaseDatabase, log_broker: LogBroker) -> None: self.log_broker = log_broker self.webui_dir: str | None = None - async def start(self): + async def start(self) -> None: core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db) try: diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 5e1db842f..ba25ed7e5 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -253,7 +253,7 @@ async def get_document_with_metadata(self, doc_id: str) -> dict | None: "knowledge_base": row[1], } - async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB): + async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None: """删除单个文档及其相关数据""" # 在知识库表中删除 async with self.get_db() as session, session.begin(): diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 4bfca90b3..e8ead1a5e 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -130,7 +130,7 @@ def __init__( self.kb_medias_dir.mkdir(parents=True, exist_ok=True) self.kb_files_dir.mkdir(parents=True, exist_ok=True) - async def initialize(self): + async def initialize(self) -> None: await self._ensure_vec_db() async def get_ep(self) -> EmbeddingProvider: @@ -174,7 +174,7 @@ async def _ensure_vec_db(self) -> FaissVecDB: self.vec_db = vec_db return vec_db - async def delete_vec_db(self): + async def delete_vec_db(self) -> None: """删除知识库的向量数据库和所有相关文件""" import shutil @@ -182,7 +182,7 @@ async def delete_vec_db(self): if self.kb_dir.exists(): shutil.rmtree(self.kb_dir) - async def terminate(self): + async def terminate(self) -> None: if self.vec_db: await self.vec_db.close() @@ -360,7 +360,7 @@ async def get_document(self, doc_id: str) -> KBDocument | None: doc = await self.kb_db.get_document_by_id(doc_id) return doc - async def delete_document(self, doc_id: str): + async def delete_document(self, doc_id: str) -> None: """删除单个文档及其相关数据""" await self.kb_db.delete_document_by_id( doc_id=doc_id, @@ -372,7 +372,7 @@ async def delete_document(self, doc_id: str): ) await self.refresh_kb() - async def delete_chunk(self, chunk_id: str, doc_id: str): + async def delete_chunk(self, chunk_id: str, doc_id: str) -> None: """删除单个文本块及其相关数据""" vec_db: FaissVecDB = self.vec_db # type: ignore await vec_db.delete(chunk_id) @@ -383,7 +383,7 @@ async def delete_chunk(self, chunk_id: str, doc_id: str): await self.refresh_kb() await self.refresh_document(doc_id) - async def refresh_kb(self): + async def refresh_kb(self) -> None: if self.kb: kb = await self.kb_db.get_kb_by_id(self.kb.kb_id) if kb: diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index 14b97ac4d..9777fea0c 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -33,7 +33,7 @@ def __init__( self.kb_insts: dict[str, KBHelper] = {} - async def initialize(self): + async def initialize(self) -> None: """初始化知识库模块""" try: logger.info("正在初始化知识库模块...") @@ -64,7 +64,7 @@ async def _init_kb_database(self): await self.kb_db.migrate_to_v1() logger.info(f"KnowledgeBase database initialized: {DB_PATH}") - async def load_kbs(self): + async def load_kbs(self) -> None: """加载所有知识库实例""" kb_records = await self.kb_db.list_kbs() for record in kb_records: @@ -268,7 +268,7 @@ def _format_context(self, results: list[RetrievalResult]) -> str: return "\n".join(lines) - async def terminate(self): + async def terminate(self) -> None: """终止所有知识库实例,关闭数据库连接""" for kb_id, kb_helper in self.kb_insts.items(): try: diff --git a/astrbot/core/log.py b/astrbot/core/log.py index b338cb512..db093f02d 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -101,7 +101,7 @@ def register(self) -> Queue: self.subscribers.append(q) return q - def unregister(self, q: Queue): + def unregister(self, q: Queue) -> None: """取消订阅 Args: @@ -110,7 +110,7 @@ def unregister(self, q: Queue): """ self.subscribers.remove(q) - def publish(self, log_entry: dict): + def publish(self, log_entry: dict) -> None: """发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统 Args: @@ -136,7 +136,7 @@ def __init__(self, log_broker: LogBroker) -> None: super().__init__() self.log_broker = log_broker - def emit(self, record): + def emit(self, record) -> None: """日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布 这个方法会在每次日志记录时被调用 diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 778a7f79f..e8b7f38a1 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -255,7 +255,7 @@ async def convert_to_file_path(self) -> str: return os.path.abspath(url) raise Exception(f"not a valid file: {url}") - async def register_to_file_service(self): + async def register_to_file_service(self) -> str: """将视频注册到文件服务。 Returns: @@ -737,7 +737,7 @@ async def _download_file(self): await download_file(self.url, file_path) self.file_ = os.path.abspath(file_path) - async def register_to_file_service(self): + async def register_to_file_service(self) -> str: """将文件注册到文件服务。 Returns: diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py index 52c9b1332..8b32a1d0e 100644 --- a/astrbot/core/persona_mgr.py +++ b/astrbot/core/persona_mgr.py @@ -28,7 +28,7 @@ def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager) -> None: self.selected_default_persona_v3: Personality | None = None self.persona_v3_config: list[dict] = [] - async def initialize(self): + async def initialize(self) -> None: self.personas = await self.get_all_personas() self.get_v3_persona_data() logger.info(f"已加载 {len(self.personas)} 个人格。") @@ -57,7 +57,7 @@ async def get_default_persona_v3( except Exception: return DEFAULT_PERSONALITY - async def delete_persona(self, persona_id: str): + async def delete_persona(self, persona_id: str) -> None: """删除指定 persona""" if not await self.db.get_persona_by_id(persona_id): raise ValueError(f"Persona with ID {persona_id} does not exist.") diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index b089c48e0..19037eb08 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -16,7 +16,7 @@ class ContentSafetyCheckStage(Stage): 当前只会检查文本的。 """ - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: config = ctx.astrbot_config["content_safety"] self.strategy_selector = StrategySelector(config) diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 8f1b87efc..b653b0a81 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -35,7 +35,7 @@ class RespondStage(Stage): Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情 } - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config self.platform_settings: dict = self.config.get("platform_settings", {}) diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 208f3a9f2..89d6b9745 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -19,7 +19,7 @@ @register_stage class ResultDecorateStage(Stage): - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"] self.reply_with_mention = ctx.astrbot_config["platform_settings"][ diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index b30c9ddb1..0fe939d1d 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -22,7 +22,7 @@ def __init__(self, context: PipelineContext) -> None: self.ctx = context # 上下文对象 self.stages = [] # 存储阶段实例 - async def initialize(self): + async def initialize(self) -> None: """初始化管道调度器时, 初始化所有阶段""" for stage_cls in registered_stages: stage_instance = stage_cls() # 创建实例 @@ -72,7 +72,7 @@ async def _process_stages(self, event: AstrMessageEvent, from_stage=0): logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") break - async def execute(self, event: AstrMessageEvent): + async def execute(self, event: AstrMessageEvent) -> None: """执行 pipeline Args: diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index b215fcd4c..d52304814 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -157,7 +157,7 @@ def get_sender_name(self) -> str: return self.message_obj.sender.nickname return "" - def set_extra(self, key, value): + def set_extra(self, key, value) -> None: """设置额外的信息。""" self._extras[key] = value @@ -167,7 +167,7 @@ def get_extra(self, key: str | None = None, default=None) -> Any: return self._extras return self._extras.get(key, default) - def clear_extra(self): + def clear_extra(self) -> None: """清除额外的信息。""" logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}") self._extras.clear() @@ -200,7 +200,7 @@ async def send_streaming( self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False, - ): + ) -> None: """发送流式消息到消息平台,使用异步生成器。 目前仅支持: telegram,qq official 私聊。 Fallback仅支持 aiocqhttp。 @@ -216,7 +216,7 @@ async def _pre_send(self): async def _post_send(self): """调度器会在执行 send() 后调用该方法 deprecated in v3.5.18""" - def set_result(self, result: MessageEventResult | str): + def set_result(self, result: MessageEventResult | str) -> None: """设置消息事件的结果。 Note: @@ -245,14 +245,14 @@ async def check_count(self, event: AstrMessageEvent): result.chain = [] self._result = result - def stop_event(self): + def stop_event(self) -> None: """终止事件传播。""" if self._result is None: self.set_result(MessageEventResult().stop_event()) else: self._result.stop_event() - def continue_event(self): + def continue_event(self) -> None: """继续事件传播。""" if self._result is None: self.set_result(MessageEventResult().continue_event()) @@ -265,7 +265,7 @@ def is_stopped(self) -> bool: return False # 默认是继续传播 return self._result.is_stopped() - def should_call_llm(self, call_llm: bool): + def should_call_llm(self, call_llm: bool) -> None: """是否在此消息事件中禁止默认的 LLM 请求。 只会阻止 AstrBot 默认的 LLM 请求链路,不会阻止插件中的 LLM 请求。 @@ -276,7 +276,7 @@ def get_result(self) -> MessageEventResult | None: """获取消息事件的结果。""" return self._result - def clear_result(self): + def clear_result(self) -> None: """清除消息事件的结果。""" self._result = None @@ -368,7 +368,7 @@ def request_llm( """平台适配器""" - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息到消息平台。 Args: @@ -387,7 +387,7 @@ async def send(self, message: MessageChain): ) self._has_send_oper = True - async def react(self, emoji: str): + async def react(self, emoji: str) -> None: """对消息添加表情回应。 默认实现为发送一条包含该表情的消息。 diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 92440c10f..9dfff0882 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -25,7 +25,7 @@ def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None: 约定整个项目中对 unique_session 的引用都从 default 的配置中获取""" self.event_queue = event_queue - async def initialize(self): + async def initialize(self) -> None: """初始化所有平台适配器""" for platform in self.platforms_config: try: @@ -43,7 +43,7 @@ async def initialize(self): ), ) - async def load_platform(self, platform_config: dict): + async def load_platform(self, platform_config: dict) -> None: """实例化一个平台""" # 动态导入 try: @@ -171,7 +171,7 @@ async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = No if platform: platform.record_error(error_msg, tb_str) - async def reload(self, platform_config: dict): + async def reload(self, platform_config: dict) -> None: await self.terminate_platform(platform_config["id"]) if platform_config["enable"]: await self.load_platform(platform_config) @@ -182,7 +182,7 @@ async def reload(self, platform_config: dict): if key not in config_ids: await self.terminate_platform(key) - async def terminate_platform(self, platform_id: str): + async def terminate_platform(self, platform_id: str) -> None: if platform_id in self._inst_map: logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...") @@ -204,7 +204,7 @@ async def terminate_platform(self, platform_id: str): if getattr(inst, "terminate", None): await inst.terminate() - async def terminate(self): + async def terminate(self) -> None: for inst in self.platform_insts: if getattr(inst, "terminate", None): await inst.terminate() diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index 8a698aeed..46e992543 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -69,12 +69,12 @@ def last_error(self) -> PlatformError | None: """获取最近的错误""" return self._errors[-1] if self._errors else None - def record_error(self, message: str, traceback_str: str | None = None): + def record_error(self, message: str, traceback_str: str | None = None) -> None: """记录一个错误""" self._errors.append(PlatformError(message=message, traceback=traceback_str)) self._status = PlatformStatus.ERROR - def clear_errors(self): + def clear_errors(self) -> None: """清除错误记录""" self._errors.clear() if self._status == PlatformStatus.ERROR: @@ -104,7 +104,7 @@ def run(self) -> Coroutine[Any, Any, None]: """得到一个平台的运行实例,需要返回一个协程对象。""" raise NotImplementedError - async def terminate(self): + async def terminate(self) -> None: """终止一个平台的运行实例。""" @abc.abstractmethod @@ -123,11 +123,11 @@ async def send_by_session( """ await Metric.upload(msg_event_tick=1, adapter_name=self.meta().name) - def commit_event(self, event: AstrMessageEvent): + def commit_event(self, event: AstrMessageEvent) -> None: """提交一个事件到事件队列。""" self._event_queue.put_nowait(event) - def get_client(self): + def get_client(self) -> None: """获取平台的客户端对象。""" async def webhook_callback(self, request: Any) -> Any: diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index ff7716b2c..87177d6fd 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -143,7 +143,7 @@ async def send_message( await cls._dispatch_send(bot, event, is_group, session_id, messages) await asyncio.sleep(0.5) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息""" event = getattr(self.message_obj, "raw_message", None) diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index b3c2229ab..6fc02d609 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -93,7 +93,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: is_group = session.message_type == MessageType.GROUP_MESSAGE if is_group: session_id = session.session_id.split("_")[-1] @@ -416,17 +416,17 @@ def run(self) -> Awaitable[Any]: self.shutdown_event = asyncio.Event() return coro - async def terminate(self): + async def terminate(self) -> None: self.shutdown_event.set() - async def shutdown_trigger_placeholder(self): + async def shutdown_trigger_placeholder(self) -> None: await self.shutdown_event.wait() logger.info("aiocqhttp 适配器已被优雅地关闭") def meta(self) -> PlatformMetadata: return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = AiocqhttpMessageEvent( message_str=message.message_str, message_obj=message, diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 8905698a5..8e06660ab 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -2,7 +2,7 @@ import os import threading import uuid -from typing import cast +from typing import cast, NoReturn import aiohttp import dingtalk_stream @@ -90,7 +90,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> NoReturn: raise NotImplementedError("钉钉机器人适配器不支持 send_by_session") def meta(self) -> PlatformMetadata: @@ -222,7 +222,7 @@ async def get_access_token(self) -> str: return "" return (await resp.json())["data"]["accessToken"] - async def handle_msg(self, abm: AstrBotMessage): + async def handle_msg(self, abm: AstrBotMessage) -> None: event = DingtalkMessageEvent( message_str=abm.message_str, message_obj=abm, @@ -233,7 +233,7 @@ async def handle_msg(self, abm: AstrBotMessage): self._event_queue.put_nowait(event) - async def run(self): + async def run(self) -> None: # await self.client_.start() # 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。 def start_client(loop: asyncio.AbstractEventLoop): @@ -252,7 +252,7 @@ def start_client(loop: asyncio.AbstractEventLoop): loop = asyncio.get_event_loop() await loop.run_in_executor(None, start_client, loop) - async def terminate(self): + async def terminate(self) -> None: def monkey_patch_close(): raise KeyboardInterrupt("Graceful shutdown") diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index eb07dffe9..db027125a 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -24,7 +24,7 @@ async def send_with_client( self, client: dingtalk_stream.ChatbotHandler, message: MessageChain, - ): + ) -> None: for segment in message.chain: if isinstance(segment, Comp.Plain): segment.text = segment.text.strip() @@ -64,7 +64,7 @@ async def send_with_client( logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送") continue - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: await self.send_with_client(self.client, message) await super().send(message) diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index b9b33e17b..ebd32c471 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -32,7 +32,7 @@ def __init__(self, token: str, proxy: str | None = None) -> None: self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None self._ready_once_fired = False - async def on_ready(self): + async def on_ready(self) -> None: """当机器人成功连接并准备就绪时触发""" if self.user is None: logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)") @@ -93,7 +93,7 @@ def _create_interaction_data(self, interaction: discord.Interaction) -> dict: "type": "interaction", } - async def on_message(self, message: discord.Message): + async def on_message(self, message: discord.Message) -> None: """当接收到消息时触发""" if message.author.bot: return @@ -130,12 +130,12 @@ def _extract_interaction_content(self, interaction: discord.Interaction) -> str: return str(interaction_data) - async def start_polling(self): + async def start_polling(self) -> None: """开始轮询消息,这是个阻塞方法""" await self.start(self.token) @override - async def close(self): + async def close(self) -> None: """关闭客户端""" if not self.is_closed(): await super().close() diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 50aa0fe6f..fb8ce972d 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -60,7 +60,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: """通过会话发送消息""" if self.client.user is None: logger.error( @@ -122,7 +122,7 @@ def meta(self) -> PlatformMetadata: ) @override - async def run(self): + async def run(self) -> None: """主要运行逻辑""" # 初始化回调函数 @@ -251,7 +251,7 @@ async def convert_message(self, data: dict) -> AstrBotMessage: # 由于 on_interaction 已被禁用,我们只处理普通消息 return self._convert_message_to_abm(data) - async def handle_msg(self, message: AstrBotMessage, followup_webhook=None): + async def handle_msg(self, message: AstrBotMessage, followup_webhook=None) -> None: """处理消息""" message_event = DiscordPlatformEvent( message_str=message.message_str, @@ -323,7 +323,7 @@ async def handle_msg(self, message: AstrBotMessage, followup_webhook=None): self.commit_event(message_event) @override - async def terminate(self): + async def terminate(self) -> None: """终止适配器""" logger.info("[Discord] 正在终止适配器... (step 1: cancel polling task)") self.shutdown_event.set() @@ -358,7 +358,7 @@ async def terminate(self): logger.warning(f"[Discord] 客户端关闭异常: {e}") logger.info("[Discord] 适配器已终止。") - def register_handler(self, handler_info): + def register_handler(self, handler_info) -> None: """注册处理器信息""" self.registered_handlers.append(handler_info) diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 42459c1b2..02d4dae86 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -46,7 +46,7 @@ def __init__( self.client = client self.interaction_followup_webhook = interaction_followup_webhook - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息到Discord平台""" # 解析消息链为 Discord 所需的对象 try: @@ -267,7 +267,7 @@ async def _parse_to_discord( content = content[:2000] return content, files, view, embeds, reference_message_id - async def react(self, emoji: str): + async def react(self, emoji: str) -> None: """对原消息添加反应""" try: if hasattr(self.message_obj, "raw_message") and hasattr( diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index 473be096f..472d34e7f 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -78,7 +78,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: if self.lark_api.im is None: logger.error("[Lark] API Client im 模块未初始化,无法发送消息") return @@ -127,7 +127,7 @@ def meta(self) -> PlatformMetadata: support_streaming_message=False, ) - async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): + async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1) -> None: if event.event is None: logger.debug("[Lark] 收到空事件(event.event is None)") return @@ -279,7 +279,7 @@ async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): logger.debug(abm) await self.handle_msg(abm) - async def handle_msg(self, abm: AstrBotMessage): + async def handle_msg(self, abm: AstrBotMessage) -> None: event = LarkMessageEvent( message_str=abm.message_str, message_obj=abm, @@ -290,11 +290,11 @@ async def handle_msg(self, abm: AstrBotMessage): self._event_queue.put_nowait(event) - async def run(self): + async def run(self) -> None: # self.client.start() await self.client._connect() - async def terminate(self): + async def terminate(self) -> None: await self.client._disconnect() logger.info("飞书(Lark) 适配器已被优雅地关闭") diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index dc952da45..b6c7d5258 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -110,7 +110,7 @@ async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> l ret.append(_stage) return ret - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: res = await LarkMessageEvent._convert_to_lark(message, self.bot) wrapped = { "zh_cn": { @@ -144,7 +144,7 @@ async def send(self, message: MessageChain): await super().send(message) - async def react(self, emoji: str): + async def react(self, emoji: str) -> None: if self.bot.im is None: logger.error("[Lark] API Client im 模块未初始化,无法发送表情") return diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index 7f3db3062..aab2fe802 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -123,7 +123,7 @@ def meta(self) -> PlatformMetadata: support_streaming_message=False, ) - async def run(self): + async def run(self) -> None: if not self.instance_url or not self.access_token: logger.error("[Misskey] 配置不完整,无法启动") return @@ -759,7 +759,7 @@ async def convert_room_message(self, raw_data: dict[str, Any]) -> AstrBotMessage ) return message - async def terminate(self): + async def terminate(self) -> None: self._running = False if self.api: await self.api.close() diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index 356e9bc3e..0a2acedda 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -90,7 +90,7 @@ async def connect(self) -> bool: self.is_connected = False return False - async def disconnect(self): + async def disconnect(self) -> None: self._running = False if self.websocket: await self.websocket.close() @@ -116,7 +116,7 @@ async def subscribe_channel( self.channels[channel_id] = channel_type return channel_id - async def unsubscribe_channel(self, channel_id: str): + async def unsubscribe_channel(self, channel_id: str) -> None: if ( not self.is_connected or not self.websocket @@ -136,10 +136,10 @@ def add_message_handler( self, event_type: str, handler: Callable[[dict], Awaitable[None]], - ): + ) -> None: self.message_handlers[event_type] = handler - async def listen(self): + async def listen(self) -> None: if not self.is_connected or not self.websocket: raise WebSocketError("WebSocket 未连接") diff --git a/astrbot/core/platform/sources/misskey/misskey_event.py b/astrbot/core/platform/sources/misskey/misskey_event.py index 77ed58376..068f7e7a2 100644 --- a/astrbot/core/platform/sources/misskey/misskey_event.py +++ b/astrbot/core/platform/sources/misskey/misskey_event.py @@ -40,7 +40,7 @@ def _is_system_command(self, message_str: str) -> bool: return any(message_trimmed.startswith(prefix) for prefix in system_prefixes) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息,使用适配器的完整上传和发送逻辑""" try: logger.debug( diff --git a/astrbot/core/platform/sources/misskey/misskey_utils.py b/astrbot/core/platform/sources/misskey/misskey_utils.py index 290acd64e..e0ed85d19 100644 --- a/astrbot/core/platform/sources/misskey/misskey_utils.py +++ b/astrbot/core/platform/sources/misskey/misskey_utils.py @@ -406,7 +406,7 @@ def cache_user_info( raw_data: dict[str, Any], client_self_id: str, is_chat: bool = False, -): +) -> None: """缓存用户信息""" if is_chat: user_cache_data = { @@ -432,7 +432,7 @@ def cache_room_info( user_cache: dict[str, Any], raw_data: dict[str, Any], client_self_id: str, -): +) -> None: """缓存房间信息""" room_data = raw_data.get("toRoom") room_id = raw_data.get("toRoomId") diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 2d1f9f46e..2da3cb4b7 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -37,7 +37,7 @@ def __init__( self.bot = bot self.send_buffer = None - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: self.send_buffer = message await self._post_send() diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 2a1bcda47..550c4b6fa 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -4,7 +4,7 @@ import logging import os import time -from typing import cast +from typing import cast, NoReturn import botpy import botpy.message @@ -35,11 +35,11 @@ # QQ 机器人官方框架 class botClient(Client): - def set_platform(self, platform: QQOfficialPlatformAdapter): + def set_platform(self, platform: QQOfficialPlatformAdapter) -> None: self.platform = platform # 收到群消息 - async def on_group_at_message_create(self, message: botpy.message.GroupMessage): + async def on_group_at_message_create(self, message: botpy.message.GroupMessage) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -52,7 +52,7 @@ async def on_group_at_message_create(self, message: botpy.message.GroupMessage): self._commit(abm) # 收到频道消息 - async def on_at_message_create(self, message: botpy.message.Message): + async def on_at_message_create(self, message: botpy.message.Message) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -63,7 +63,7 @@ async def on_at_message_create(self, message: botpy.message.Message): self._commit(abm) # 收到私聊消息 - async def on_direct_message_create(self, message: botpy.message.DirectMessage): + async def on_direct_message_create(self, message: botpy.message.DirectMessage) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -72,7 +72,7 @@ async def on_direct_message_create(self, message: botpy.message.DirectMessage): self._commit(abm) # 收到 C2C 消息 - async def on_c2c_message_create(self, message: botpy.message.C2CMessage): + async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -133,7 +133,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> NoReturn: raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") def meta(self) -> PlatformMetadata: @@ -226,6 +226,6 @@ def run(self): def get_client(self) -> botClient: return self.client - async def terminate(self): + async def terminate(self) -> None: await self.client.close() logger.info("QQ 官方机器人接口 适配器已被优雅地关闭") diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index 63b6726fe..634bc2de1 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Any, cast +from typing import Any, cast, NoReturn import botpy import botpy.message @@ -26,11 +26,11 @@ # QQ 机器人官方框架 class botClient(Client): - def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter"): + def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter") -> None: self.platform = platform # 收到群消息 - async def on_group_at_message_create(self, message: botpy.message.GroupMessage): + async def on_group_at_message_create(self, message: botpy.message.GroupMessage) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -43,7 +43,7 @@ async def on_group_at_message_create(self, message: botpy.message.GroupMessage): self._commit(abm) # 收到频道消息 - async def on_at_message_create(self, message: botpy.message.Message): + async def on_at_message_create(self, message: botpy.message.Message) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -54,7 +54,7 @@ async def on_at_message_create(self, message: botpy.message.Message): self._commit(abm) # 收到私聊消息 - async def on_direct_message_create(self, message: botpy.message.DirectMessage): + async def on_direct_message_create(self, message: botpy.message.DirectMessage) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -63,7 +63,7 @@ async def on_direct_message_create(self, message: botpy.message.DirectMessage): self._commit(abm) # 收到 C2C 消息 - async def on_c2c_message_create(self, message: botpy.message.C2CMessage): + async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -115,7 +115,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> NoReturn: raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") def meta(self) -> PlatformMetadata: @@ -125,7 +125,7 @@ def meta(self) -> PlatformMetadata: id=cast(str, self.config.get("id")), ) - async def run(self): + async def run(self) -> None: self.webhook_helper = QQOfficialWebhook( self.config, self._event_queue, @@ -153,7 +153,7 @@ async def webhook_callback(self, request: Any) -> Any: # 复用 webhook_helper 的回调处理逻辑 return await self.webhook_helper.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: if self.webhook_helper: self.webhook_helper.shutdown_event.set() await self.client.close() diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index f8ee930ff..643263726 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -38,7 +38,7 @@ def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Clien self.event_queue = event_queue self.shutdown_event = asyncio.Event() - async def initialize(self): + async def initialize(self) -> None: logger.info("正在登录到 QQ 官方机器人...") self.user = await self.http.login(self.token) logger.info(f"已登录 QQ 官方机器人账号: {self.user}") @@ -115,7 +115,7 @@ async def handle_callback(self, request) -> dict: return {"opcode": 12} - async def start_polling(self): + async def start_polling(self) -> None: logger.info( f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。", ) @@ -125,5 +125,5 @@ async def start_polling(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() diff --git a/astrbot/core/platform/sources/satori/satori_adapter.py b/astrbot/core/platform/sources/satori/satori_adapter.py index 46f9a4e0f..a1afbd898 100644 --- a/astrbot/core/platform/sources/satori/satori_adapter.py +++ b/astrbot/core/platform/sources/satori/satori_adapter.py @@ -73,7 +73,7 @@ async def send_by_session( self, session: MessageSession, message_chain: MessageChain, - ): + ) -> None: from .satori_event import SatoriPlatformEvent await SatoriPlatformEvent.send_with_adapter( @@ -99,7 +99,7 @@ def _is_websocket_closed(self, ws) -> bool: except AttributeError: return False - async def run(self): + async def run(self) -> None: self.running = True self.session = ClientSession(timeout=ClientTimeout(total=30)) @@ -133,7 +133,7 @@ async def run(self): if self.session: await self.session.close() - async def connect_websocket(self): + async def connect_websocket(self) -> None: logger.info(f"Satori 适配器正在连接到 WebSocket: {self.endpoint}") logger.info(f"Satori 适配器 HTTP API 地址: {self.api_base_url}") @@ -176,7 +176,7 @@ async def connect_websocket(self): except Exception as e: logger.error(f"Satori WebSocket 关闭异常: {e}") - async def send_identify(self): + async def send_identify(self) -> None: if not self.ws: raise Exception("WebSocket连接未建立") @@ -204,7 +204,7 @@ async def send_identify(self): logger.error(f"发送 IDENTIFY 信令失败: {e}") raise - async def heartbeat_loop(self): + async def heartbeat_loop(self) -> None: try: while self.running and self.ws: await asyncio.sleep(self.heartbeat_interval) @@ -229,7 +229,7 @@ async def heartbeat_loop(self): except Exception as e: logger.error(f"心跳任务异常: {e}") - async def handle_message(self, message: str): + async def handle_message(self, message: str) -> None: try: data = json.loads(message) op = data.get("op") @@ -270,7 +270,7 @@ async def handle_message(self, message: str): except Exception as e: logger.error(f"处理 WebSocket 消息异常: {e}") - async def handle_event(self, event_data: dict): + async def handle_event(self, event_data: dict) -> None: try: event_type = event_data.get("type") sn = event_data.get("sn") @@ -715,7 +715,7 @@ async def _parse_xml_node(self, node: ET.Element, elements: list) -> None: if child.tail and child.tail.strip(): elements.append(Plain(text=child.tail)) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: from .satori_event import SatoriPlatformEvent message_event = SatoriPlatformEvent( @@ -775,7 +775,7 @@ async def send_http_request( logger.error(f"Satori HTTP 请求异常: {e}") return {} - async def terminate(self): + async def terminate(self) -> None: self.running = False if self.heartbeat_task: diff --git a/astrbot/core/platform/sources/satori/satori_event.py b/astrbot/core/platform/sources/satori/satori_event.py index 722fe939d..021422283 100644 --- a/astrbot/core/platform/sources/satori/satori_event.py +++ b/astrbot/core/platform/sources/satori/satori_event.py @@ -110,7 +110,7 @@ async def send_with_adapter( logger.error(f"Satori 消息发送异常: {e}") return None - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: platform = getattr(self, "platform", None) user_id = getattr(self, "user_id", None) diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py index c0dfe969c..2ec6d2164 100644 --- a/astrbot/core/platform/sources/slack/client.py +++ b/astrbot/core/platform/sources/slack/client.py @@ -105,7 +105,7 @@ async def handle_callback(self, req): logger.error(f"处理 Slack 事件时出错: {e}") return Response("Internal Server Error", status=500) - async def start(self): + async def start(self) -> None: """启动 Webhook 服务器""" logger.info( f"Slack Webhook 服务器启动中,监听 {self.host}:{self.port}{self.path}...", @@ -118,10 +118,10 @@ async def start(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() - async def stop(self): + async def stop(self) -> None: """停止 Webhook 服务器""" self.shutdown_event.set() logger.info("Slack Webhook 服务器已停止") @@ -160,7 +160,7 @@ async def _handle_events( except Exception as e: logger.error(f"处理 Socket Mode 事件时出错: {e}") - async def start(self): + async def start(self) -> None: """启动 Socket Mode 连接""" self.socket_client = SocketModeClient( app_token=self.app_token, @@ -174,7 +174,7 @@ async def start(self): logger.info("Slack Socket Mode 客户端启动中...") await self.socket_client.connect() - async def stop(self): + async def stop(self) -> None: """停止 Socket Mode 连接""" if self.socket_client: await self.socket_client.disconnect() diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index 4621f8494..33342ee78 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -82,7 +82,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: blocks, text = await SlackMessageEvent._parse_slack_blocks( message_chain=message_chain, web_client=self.web_client, @@ -404,7 +404,7 @@ async def webhook_callback(self, request: Any) -> Any: return await self.webhook_client.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: if self.socket_client: await self.socket_client.stop() if self.webhook_client: @@ -414,7 +414,7 @@ async def terminate(self): def meta(self) -> PlatformMetadata: return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = SlackMessageEvent( message_str=message.message_str, message_obj=message, diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py index b08e05509..3f62690b5 100644 --- a/astrbot/core/platform/sources/slack/slack_event.py +++ b/astrbot/core/platform/sources/slack/slack_event.py @@ -126,7 +126,7 @@ async def _parse_slack_blocks( return blocks, "" if blocks else text_content - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: blocks, text = await SlackMessageEvent._parse_slack_blocks( message, self.web_client, diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index bca45ea8d..b327e8a4e 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -94,7 +94,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: from_username = session.session_id await TelegramPlatformEvent.send_with_client( self.client, @@ -109,7 +109,7 @@ def meta(self) -> PlatformMetadata: return PlatformMetadata(name="telegram", description="telegram 适配器", id=id_) @override - async def run(self): + async def run(self) -> None: await self.application.initialize() await self.application.start() @@ -134,7 +134,7 @@ async def run(self): logger.info("Telegram Platform Adapter is running.") await queue - async def register_commands(self): + async def register_commands(self) -> None: """收集所有注册的指令并注册到 Telegram""" try: commands = self.collect_commands() @@ -210,7 +210,7 @@ def _extract_command_info( description = description[:30] + "..." return cmd_name, description - async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: if not update.effective_chat: logger.warning( "Received a start command without an effective chat, skipping /start reply.", @@ -221,7 +221,7 @@ async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): text=self.config["start_message"], ) - async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: logger.debug(f"Telegram message: {update.message}") abm = await self.convert_message(update, context) if abm: @@ -397,7 +397,7 @@ async def convert_message( return message - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = TelegramPlatformEvent( message_str=message.message_str, message_obj=message, @@ -410,7 +410,7 @@ async def handle_msg(self, message: AstrBotMessage): def get_client(self) -> ExtBot: return self.client - async def terminate(self): + async def terminate(self) -> None: try: if self.scheduler.running: self.scheduler.shutdown() diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index bd518e9e9..108b5e547 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -134,14 +134,14 @@ async def send_with_client( path = await i.convert_to_file_path() await client.send_voice(voice=path, **cast(Any, payload)) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: if self.get_message_type() == MessageType.GROUP_MESSAGE: await self.send_with_client(self.client, message, self.message_obj.group_id) else: await self.send_with_client(self.client, message, self.get_sender_id()) await super().send(message) - async def react(self, emoji: str | None, big: bool = False): + async def react(self, emoji: str | None, big: bool = False) -> None: """给原消息添加 Telegram 反应: - 普通 emoji:传入 '👍'、'😂' 等 - 自定义表情:传入其 custom_emoji_id(纯数字字符串) diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 084d7860d..86f78606d 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -31,7 +31,7 @@ def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> No self.callback = callback self.running_tasks = set() - async def listen_to_queue(self, conversation_id: str): + async def listen_to_queue(self, conversation_id: str) -> None: """Listen to a specific conversation queue""" queue = self.webchat_queue_mgr.get_or_create_queue(conversation_id) while True: @@ -44,7 +44,7 @@ async def listen_to_queue(self, conversation_id: str): ) break - async def run(self): + async def run(self) -> None: """Monitor for new conversation queues and start listeners""" monitored_conversations = set() @@ -93,7 +93,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: await WebChatMessageEvent._send(message_chain, session.session_id) await super().send_by_session(session, message_chain) @@ -218,7 +218,7 @@ async def callback(data: tuple): def meta(self) -> PlatformMetadata: return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = WebChatMessageEvent( message_str=message.message_str, message_obj=message, @@ -235,6 +235,6 @@ async def handle_msg(self, message: AstrBotMessage): self.commit_event(message_event) - async def terminate(self): + async def terminate(self) -> None: # Do nothing pass diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index b220b002b..79c060022 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -101,11 +101,11 @@ async def _send( return data - async def send(self, message: MessageChain | None): + async def send(self, message: MessageChain | None) -> None: await WebChatMessageEvent._send(message, session_id=self.session_id) await super().send(MessageChain([])) - async def send_streaming(self, generator, use_fallback: bool = False): + async def send_streaming(self, generator, use_fallback: bool = False) -> None: final_data = "" reasoning_content = "" cid = self.session_id.split("!")[-1] diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py index 6c365cb3a..4824e2de9 100644 --- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py @@ -20,7 +20,7 @@ def get_or_create_back_queue(self, conversation_id: str) -> asyncio.Queue: self.back_queues[conversation_id] = asyncio.Queue() return self.back_queues[conversation_id] - def remove_queues(self, conversation_id: str): + def remove_queues(self, conversation_id: str) -> None: """Remove queues for the given conversation ID""" if conversation_id in self.queues: del self.queues[conversation_id] diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py index 4c9a9d36b..78c960534 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py @@ -150,7 +150,7 @@ def load_credentials(self): logger.error(f"加载 WeChatPadPro 凭据失败: {e}") return None - def save_credentials(self): + def save_credentials(self) -> None: """将 auth_key 和 wxid 保存到文件。""" credentials = { "auth_key": self.auth_key, @@ -165,7 +165,7 @@ def save_credentials(self): except Exception as e: logger.error(f"保存 WeChatPadPro 凭据失败: {e}") - async def check_online_status(self): + async def check_online_status(self) -> bool | None: """检查 WeChatPadPro 设备是否在线。""" if not self.auth_key: return False @@ -219,7 +219,7 @@ def _extract_auth_key(self, data): return data[0] return None - async def generate_auth_key(self): + async def generate_auth_key(self) -> None: """生成授权码。""" url = f"{self.base_url}/admin/GenAuthKey1" params = {"key": self.admin_key} @@ -292,7 +292,7 @@ async def get_login_qr_code(self): logger.error(f"获取登录二维码时发生错误: {e}") return None - async def check_login_status(self): + async def check_login_status(self) -> bool: """循环检测扫码状态。 尝试 6 次后跳出循环,添加倒计时。 返回 True 如果登录成功,否则返回 False。 @@ -359,7 +359,7 @@ async def check_login_status(self): logger.warning("登录检测超过最大尝试次数,退出检测。") return False - async def connect_websocket(self): + async def connect_websocket(self) -> None: """建立 WebSocket 连接并处理接收到的消息。""" os.environ["no_proxy"] = f"localhost,127.0.0.1,{self.host}" ws_url = f"ws://{self.host}:{self.port}/ws/GetSyncMsg?key={self.auth_key}" @@ -399,7 +399,7 @@ async def connect_websocket(self): ) await asyncio.sleep(5) - async def handle_websocket_message(self, message: str | bytes): + async def handle_websocket_message(self, message: str | bytes) -> None: """处理从 WebSocket 接收到的消息。""" logger.debug(f"收到 WebSocket 消息: {message}") try: @@ -837,7 +837,7 @@ async def _process_message_content( else: logger.warning(f"收到未处理的消息类型: {msg_type}。") - async def terminate(self): + async def terminate(self) -> None: """终止一个平台的运行实例。""" logger.info("终止 WeChatPadPro 适配器。") try: @@ -856,7 +856,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: dummy_message_obj = AstrBotMessage() dummy_message_obj.session_id = session.session_id # 根据 session_id 判断消息类型 diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py index 5a208a04c..101b8f2dd 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py @@ -37,7 +37,7 @@ def __init__( self.message_obj = message_obj # Save the full message object self.adapter = adapter # Save the adapter instance - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: async with aiohttp.ClientSession() as session: for comp in message.chain: await asyncio.sleep(1) diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 36a8713c5..66b5e9588 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -123,7 +123,7 @@ async def handle_callback(self, request) -> str: return "success" - async def start_polling(self): + async def start_polling(self) -> None: logger.info( f"将在 {self.callback_server_host}:{self.port} 端口启动 企业微信 适配器。", ) @@ -133,7 +133,7 @@ async def start_polling(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() @@ -214,7 +214,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: await super().send_by_session(session, message_chain) @override @@ -227,7 +227,7 @@ def meta(self) -> PlatformMetadata: ) @override - async def run(self): + async def run(self) -> None: loop = asyncio.get_event_loop() if self.kf_name: try: @@ -403,7 +403,7 @@ async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: return await self.handle_msg(abm) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = WecomPlatformEvent( message_str=message.message_str, message_obj=message, @@ -416,7 +416,7 @@ async def handle_msg(self, message: AstrBotMessage): def get_client(self) -> WeChatClient: return self.client - async def terminate(self): + async def terminate(self) -> None: self.server.shutdown_event.set() try: await self.server.server.shutdown() diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 2e2cdb751..d9ac60ba1 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -86,7 +86,7 @@ async def split_plain(self, plain: str) -> list[str]: return result - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: message_obj = self.message_obj is_wechat_kf = hasattr(self.client, "kf_message") diff --git a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py index 4e7714b0e..27b30f830 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py @@ -18,6 +18,7 @@ from Crypto.Cipher import AES from . import ierror +from typing import NoReturn """ 关于Crypto.Cipher模块,ImportError: No module named 'Crypto'解决方案 @@ -30,7 +31,7 @@ class FormatException(Exception): pass -def throw_exception(message, exception_class=FormatException): +def throw_exception(message, exception_class=FormatException) -> NoReturn: """My define raise exception function""" raise exception_class(message) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 70581e7ea..e4d34724a 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -53,7 +53,7 @@ def __init__( self.callback = callback self.running_tasks = set() - async def listen_to_queue(self, session_id: str): + async def listen_to_queue(self, session_id: str) -> None: """监听特定会话的队列""" queue = self.queue_mgr.get_or_create_queue(session_id) while True: @@ -64,7 +64,7 @@ async def listen_to_queue(self, session_id: str): logger.error(f"处理会话 {session_id} 消息时发生错误: {e}") break - async def run(self): + async def run(self) -> None: """监控新会话队列并启动监听器""" monitored_sessions = set() @@ -417,7 +417,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: """通过会话发送消息""" # 企业微信智能机器人主要通过回调响应,这里记录日志 logger.info("会话发送消息: %s -> %s", session.session_id, message_chain) @@ -453,7 +453,7 @@ async def webhook_callback(self, request: Any) -> Any: else: return await self.server.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: """终止适配器""" logger.info("企业微信智能机器人适配器正在关闭...") self.shutdown_event.set() @@ -463,7 +463,7 @@ def meta(self) -> PlatformMetadata: """获取平台元数据""" return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: """处理消息,创建消息事件并提交到事件队列""" try: message_event = WecomAIBotMessageEvent( diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py index 14dadcace..90a9e363b 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -90,7 +90,7 @@ async def _send( return data - async def send(self, message: MessageChain | None): + async def send(self, message: MessageChain | None) -> None: """发送消息""" raw = self.message_obj.raw_message assert isinstance(raw, dict), ( @@ -100,7 +100,7 @@ async def send(self, message: MessageChain | None): await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr) await super().send(MessageChain([])) - async def send_streaming(self, generator, use_fallback=False): + async def send_streaming(self, generator, use_fallback=False) -> None: """流式发送消息,参考webchat的send_streaming设计""" final_data = "" raw = self.message_obj.raw_message diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py index 3a982bdf7..1c6e9d7ef 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py @@ -52,7 +52,7 @@ def get_or_create_back_queue(self, session_id: str) -> asyncio.Queue: logger.debug(f"[WecomAI] 创建输出队列: {session_id}") return self.back_queues[session_id] - def remove_queues(self, session_id: str): + def remove_queues(self, session_id: str) -> None: """移除指定会话的所有队列 Args: @@ -95,7 +95,7 @@ def has_back_queue(self, session_id: str) -> bool: """ return session_id in self.back_queues - def set_pending_response(self, session_id: str, callback_params: dict[str, str]): + def set_pending_response(self, session_id: str, callback_params: dict[str, str]) -> None: """设置待处理的响应参数 Args: @@ -121,7 +121,7 @@ def get_pending_response(self, session_id: str) -> dict[str, Any] | None: """ return self.pending_responses.get(session_id) - def cleanup_expired_responses(self, max_age_seconds: int = 300): + def cleanup_expired_responses(self, max_age_seconds: int = 300) -> None: """清理过期的待处理响应 Args: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py index 3d7325c9d..82bfb8722 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -162,7 +162,7 @@ async def handle_callback(self, request): logger.error("处理消息时发生异常: %s", e) return "内部服务器错误", 500 - async def start_server(self): + async def start_server(self) -> None: """启动服务器""" logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port) @@ -176,11 +176,11 @@ async def start_server(self): logger.error("服务器运行异常: %s", e) raise - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: """关闭触发器""" await self.shutdown_event.wait() - async def shutdown(self): + async def shutdown(self) -> None: """关闭服务器""" logger.info("企业微信智能机器人服务器正在关闭...") self.shutdown_event.set() diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index 65fad22a0..26ce7febd 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -129,7 +129,7 @@ async def handle_callback(self, request) -> str: return "success" - async def start_polling(self): + async def start_polling(self) -> None: logger.info( f"将在 {self.callback_server_host}:{self.port} 端口启动 微信公众平台 适配器。", ) @@ -139,7 +139,7 @@ async def start_polling(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() @@ -218,7 +218,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: await super().send_by_session(session, message_chain) @override @@ -231,7 +231,7 @@ def meta(self) -> PlatformMetadata: ) @override - async def run(self): + async def run(self) -> None: # 如果启用统一 webhook 模式,则不启动独立服务器 webhook_uuid = self.config.get("webhook_uuid") if self.unified_webhook_mode and webhook_uuid: @@ -330,7 +330,7 @@ async def convert_message( logger.info(f"abm: {abm}") await self.handle_msg(abm) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = WeixinOfficialAccountPlatformEvent( message_str=message.message_str, message_obj=message, @@ -343,7 +343,7 @@ async def handle_msg(self, message: AstrBotMessage): def get_client(self) -> WeChatClient: return self.client - async def terminate(self): + async def terminate(self) -> None: self.server.shutdown_event.set() try: await self.server.server.shutdown() diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py index 66d003039..39ff7bc0e 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -84,7 +84,7 @@ async def split_plain(self, plain: str) -> list[str]: return result - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: message_obj = self.message_obj active_send_mode = cast(dict, message_obj.raw_message).get( "active_send_mode", False diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py index 3c22b641c..476e32cb1 100644 --- a/astrbot/core/platform_message_history_mgr.py +++ b/astrbot/core/platform_message_history_mgr.py @@ -40,7 +40,7 @@ async def get( history.reverse() return history - async def delete(self, platform_id: str, user_id: str, offset_sec: int = 86400): + async def delete(self, platform_id: str, user_id: str, offset_sec: int = 86400) -> None: """Delete platform message history records older than the specified offset.""" await self.db.delete_platform_message_offset( platform_id=platform_id, diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index a141bf73e..843b30a19 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -119,7 +119,7 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.__repr__() - def append_tool_calls_result(self, tool_calls_result: ToolCallsResult): + def append_tool_calls_result(self, tool_calls_result: ToolCallsResult) -> None: """添加工具调用结果到请求中""" if not self.tool_calls_result: self.tool_calls_result = [] diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 7054276d6..79b3fa519 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -88,7 +88,7 @@ async def set_provider( provider_id: str, provider_type: ProviderType, umo: str | None = None, - ): + ) -> None: """设置提供商。 Args: @@ -187,7 +187,7 @@ def get_using_provider( raise ValueError(f"Unknown provider type: {provider_type}") return provider - async def initialize(self): + async def initialize(self) -> None: # 逐个初始化提供商 for provider_config in self.providers_config: try: @@ -251,7 +251,7 @@ async def initialize(self): # 初始化 MCP Client 连接 asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients") - async def load_provider(self, provider_config: dict): + async def load_provider(self, provider_config: dict) -> None: if not provider_config["enable"]: logger.info(f"Provider {provider_config['id']} is disabled, skipping") return @@ -491,7 +491,7 @@ async def load_provider(self, provider_config: dict): f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}", ) - async def reload(self, provider_config: dict): + async def reload(self, provider_config: dict) -> None: async with self.reload_lock: await self.terminate_provider(provider_config["id"]) if provider_config["enable"]: @@ -536,7 +536,7 @@ async def reload(self, provider_config: dict): def get_insts(self): return self.provider_insts - async def terminate_provider(self, provider_id: str): + async def terminate_provider(self, provider_id: str) -> None: if provider_id in self.inst_map: logger.info( f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...", @@ -570,7 +570,7 @@ async def terminate_provider(self, provider_id: str): ) del self.inst_map[provider_id] - async def terminate(self): + async def terminate(self) -> None: for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): await provider_inst.terminate() # type: ignore diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 7f21a2ee1..54df05a34 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -2,7 +2,7 @@ import asyncio import os from collections.abc import AsyncGenerator -from typing import TypeAlias, Union +from typing import TypeAlias, Union, NoReturn from astrbot.core.agent.message import Message from astrbot.core.agent.tool import ToolSet @@ -32,7 +32,7 @@ def __init__(self, provider_config: dict) -> None: self.model_name = "" self.provider_config = provider_config - def set_model(self, model_name: str): + def set_model(self, model_name: str) -> None: """Set the current model name""" self.model_name = model_name @@ -54,7 +54,7 @@ def meta(self) -> ProviderMeta: ) return meta - async def test(self): + async def test(self) -> None: """test the provider is a raises: @@ -84,7 +84,7 @@ def get_keys(self) -> list[str]: return keys or [""] @abc.abstractmethod - def set_key(self, key: str): + def set_key(self, key: str) -> NoReturn: raise NotImplementedError @abc.abstractmethod @@ -155,7 +155,7 @@ async def text_chat_stream( yield None # type: ignore raise NotImplementedError() - async def pop_record(self, context: list): + async def pop_record(self, context: list) -> None: """弹出 context 第一条非系统提示词对话记录""" poped = 0 indexs_to_pop = [] @@ -186,7 +186,7 @@ def _ensure_message_to_dicts( return dicts - async def test(self, timeout: float = 45.0): + async def test(self, timeout: float = 45.0) -> None: await asyncio.wait_for( self.text_chat(prompt="REPLY `PONG` ONLY"), timeout=timeout, @@ -204,7 +204,7 @@ async def get_text(self, audio_url: str) -> str: """获取音频的文本""" raise NotImplementedError - async def test(self): + async def test(self) -> None: sample_audio_path = os.path.join( get_astrbot_path(), "samples", @@ -224,7 +224,7 @@ async def get_audio(self, text: str) -> str: """获取文本的音频,返回音频文件路径""" raise NotImplementedError - async def test(self): + async def test(self) -> None: await self.get_audio("hi") @@ -249,7 +249,7 @@ def get_dim(self) -> int: """获取向量的维度""" ... - async def test(self): + async def test(self) -> None: await self.get_embedding("astrbot") async def get_embeddings_batch( @@ -336,7 +336,7 @@ async def rerank( """获取查询和文档的重排序分数""" ... - async def test(self): + async def test(self) -> None: result = await self.rerank("Apple", documents=["apple", "banana"]) if not result: raise Exception("Rerank provider test failed, no results returned") diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index bd0f06fba..bd09ebae7 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -408,5 +408,5 @@ async def get_models(self) -> list[str]: models_str.append(model.id) return models_str - def set_key(self, key: str): + def set_key(self, key: str) -> None: self.chosen_api_key = key diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index e2efc6aab..18eb8db9f 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -742,7 +742,7 @@ def get_current_key(self) -> str: def get_keys(self) -> list[str]: return self.api_keys - def set_key(self, key): + def set_key(self, key) -> None: self.chosen_api_key = key self._init_client() @@ -782,5 +782,5 @@ async def encode_image_bs64(self, image_url: str) -> str: image_bs64 = base64.b64encode(f.read()).decode("utf-8") return "data:image/jpeg;base64," + image_bs64 - async def terminate(self): + async def terminate(self) -> None: logger.info("Google GenAI 适配器已终止。") diff --git a/astrbot/core/provider/sources/gsv_selfhosted_source.py b/astrbot/core/provider/sources/gsv_selfhosted_source.py index 7f8d39eac..cbc900b5d 100644 --- a/astrbot/core/provider/sources/gsv_selfhosted_source.py +++ b/astrbot/core/provider/sources/gsv_selfhosted_source.py @@ -39,7 +39,7 @@ def __init__( self.timeout = provider_config.get("timeout", 60) self._session: aiohttp.ClientSession | None = None - async def initialize(self): + async def initialize(self) -> None: """异步初始化:在 ProviderManager 中被调用""" self._session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=self.timeout), @@ -144,7 +144,7 @@ def build_synthesis_params(self, text: str) -> dict: # TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text) return params - async def terminate(self): + async def terminate(self) -> None: """终止释放资源:在 ProviderManager 中被调用""" if self._session and not self._session.closed: await self._session.close() diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 788b649a9..1a2767f67 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -602,7 +602,7 @@ def get_current_key(self) -> str: def get_keys(self) -> list[str]: return self.api_keys - def set_key(self, key): + def set_key(self, key) -> None: self.client.api_key = key async def assemble_context( diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index a41bd72fd..fd0615799 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -37,7 +37,7 @@ def __init__( self.model = None self.is_emotion = provider_config.get("is_emotion", False) - async def initialize(self): + async def initialize(self) -> None: logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...") # 将模型加载放到线程池中执行 diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index a14f93f14..68ef73e46 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -30,7 +30,7 @@ def __init__( self.set_model(provider_config["model"]) self.model = None - async def initialize(self): + async def initialize(self) -> None: loop = asyncio.get_event_loop() logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...") self.model = await loop.run_in_executor( diff --git a/astrbot/core/provider/sources/xinference_rerank_source.py b/astrbot/core/provider/sources/xinference_rerank_source.py index 960408550..9c3a77c15 100644 --- a/astrbot/core/provider/sources/xinference_rerank_source.py +++ b/astrbot/core/provider/sources/xinference_rerank_source.py @@ -37,7 +37,7 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: self.model: AsyncRESTfulRerankModelHandle | None = None self.model_uid = None - async def initialize(self): + async def initialize(self) -> None: if self.api_key: logger.info("Xinference Rerank: Using API key for authentication.") self.client = Client(self.base_url, api_key=self.api_key) diff --git a/astrbot/core/provider/sources/xinference_stt_provider.py b/astrbot/core/provider/sources/xinference_stt_provider.py index 9c69a0039..344655afd 100644 --- a/astrbot/core/provider/sources/xinference_stt_provider.py +++ b/astrbot/core/provider/sources/xinference_stt_provider.py @@ -37,7 +37,7 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: self.client = None self.model_uid = None - async def initialize(self): + async def initialize(self) -> None: if self.api_key: logger.info("Xinference STT: Using API key for authentication.") self.client = Client(self.base_url, api_key=self.api_key) diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index bdcc8785e..a3ae72e4c 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -51,10 +51,10 @@ async def html_render( options=options, ) - async def initialize(self): + async def initialize(self) -> None: """当插件被激活时会调用这个方法""" - async def terminate(self): + async def terminate(self) -> None: """当插件被禁用、重载插件时会调用这个方法""" def __del__(self) -> None: diff --git a/astrbot/core/star/config.py b/astrbot/core/star/config.py index a9af974c5..4068d8439 100644 --- a/astrbot/core/star/config.py +++ b/astrbot/core/star/config.py @@ -22,7 +22,7 @@ def load_config(namespace: str) -> dict | bool: return ret -def put_config(namespace: str, name: str, key: str, value, description: str): +def put_config(namespace: str, name: str, key: str, value, description: str) -> None: """将配置项写入以namespace为名字的配置文件,如果key不存在于目标配置文件中。当前 value 仅支持 str, int, float, bool, list 类型(暂不支持 dict)。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 name: str, 配置项的显示名字。 @@ -64,7 +64,7 @@ def put_config(namespace: str, name: str, key: str, value, description: str): f.flush() -def update_config(namespace: str, key: str, value): +def update_config(namespace: str, key: str, value) -> None: """更新配置文件中的配置项。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 key: str, 配置项的键。 diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index b6ff13b62..4049ea21c 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -398,7 +398,7 @@ def register_web_api( view_handler: Awaitable, methods: list, desc: str, - ): + ) -> None: for idx, api in enumerate(self.registered_web_apis): if api[0] == route and methods == api[2]: self.registered_web_apis[idx] = (route, view_handler, methods, desc) @@ -448,7 +448,7 @@ def get_db(self) -> BaseDatabase: """获取 AstrBot 数据库。""" return self._db - def register_provider(self, provider: Provider): + def register_provider(self, provider: Provider) -> None: """注册一个 LLM Provider(Chat_Completion 类型)。""" self.provider_manager.provider_insts.append(provider) @@ -493,7 +493,7 @@ def register_commands( awaitable: Callable[..., Awaitable[Any]], use_regex=False, ignore_prefix=False, - ): + ) -> None: """注册一个命令。 [Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。 @@ -522,6 +522,6 @@ def register_commands( ) star_handlers_registry.append(md) - def register_task(self, task: Awaitable, desc: str): + def register_task(self, task: Awaitable, desc: str) -> None: """[DEPRECATED]注册一个异步任务。""" self._register_tasks.append(task) diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index e42c07076..f73e1cc66 100755 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -62,7 +62,7 @@ def print_types(self): result = "".join(parts).rstrip(",") return result - def init_handler_md(self, handle_md: StarHandlerMetadata): + def init_handler_md(self, handle_md: StarHandlerMetadata) -> None: self.handler_md = handle_md signature = inspect.signature(self.handler_md.handler) self.handler_params = {} # 参数名 -> 参数类型,如果有默认值则为默认值 @@ -80,7 +80,7 @@ def init_handler_md(self, handle_md: StarHandlerMetadata): def get_handler_md(self) -> StarHandlerMetadata: return self.handler_md - def add_custom_filter(self, custom_filter: CustomFilter): + def add_custom_filter(self, custom_filter: CustomFilter) -> None: self.custom_filter_list.append(custom_filter) def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index 087f35445..b22d5031f 100755 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -28,10 +28,10 @@ def __init__( def add_sub_command_filter( self, sub_command_filter: CommandFilter | CommandGroupFilter, - ): + ) -> None: self.sub_command_filters.append(sub_command_filter) - def add_custom_filter(self, custom_filter: CustomFilter): + def add_custom_filter(self, custom_filter: CustomFilter) -> None: self.custom_filter_list.append(custom_filter) def get_complete_command_names(self) -> list[str]: diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index 663616f33..ed50ee5bb 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -16,7 +16,7 @@ def __init__(self) -> None: self.star_handlers_map: dict[str, StarHandlerMetadata] = {} self._handlers: list[StarHandlerMetadata] = [] - def append(self, handler: StarHandlerMetadata): + def append(self, handler: StarHandlerMetadata) -> None: """添加一个 Handler,并保持按优先级有序""" if "priority" not in handler.extras_configs: handler.extras_configs["priority"] = 0 @@ -154,11 +154,11 @@ def get_handlers_by_module_name( if handler.handler_module_path == module_name ] - def clear(self): + def clear(self) -> None: self.star_handlers_map.clear() self._handlers.clear() - def remove(self, handler: StarHandlerMetadata): + def remove(self, handler: StarHandlerMetadata) -> None: self.star_handlers_map.pop(handler.handler_full_name, None) self._handlers = [h for h in self._handlers if h != handler] diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index d73590722..20aa40fcc 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -685,7 +685,7 @@ async def uninstall_plugin( plugin_name: str, delete_config: bool = False, delete_data: bool = False, - ): + ) -> None: """卸载指定的插件。 Args: @@ -825,7 +825,7 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): is_reserved=plugin.reserved, ) - async def update_plugin(self, plugin_name: str, proxy=""): + async def update_plugin(self, plugin_name: str, proxy="") -> None: """升级一个插件""" plugin = self.context.get_registered_star(plugin_name) if not plugin: @@ -836,7 +836,7 @@ async def update_plugin(self, plugin_name: str, proxy=""): await self.updator.update(plugin, proxy=proxy) await self.reload(plugin_name) - async def turn_off_plugin(self, plugin_name: str): + async def turn_off_plugin(self, plugin_name: str) -> None: """禁用一个插件。 调用插件的 terminate() 方法, 将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。 @@ -898,7 +898,7 @@ async def _terminate_plugin(star_metadata: StarMetadata): elif "terminate" in star_metadata.star_cls_type.__dict__: await star_metadata.star_cls.terminate() - async def turn_on_plugin(self, plugin_name: str): + async def turn_on_plugin(self, plugin_name: str) -> None: plugin = self.context.get_registered_star(plugin_name) if plugin is None: raise Exception(f"插件 {plugin_name} 不存在。") diff --git a/astrbot/core/star/updator.py b/astrbot/core/star/updator.py index 8793ad505..1a0c5fc26 100644 --- a/astrbot/core/star/updator.py +++ b/astrbot/core/star/updator.py @@ -52,7 +52,7 @@ async def update(self, plugin: StarMetadata, proxy="") -> str: return plugin_path - def unzip_file(self, zip_path: str, target_dir: str): + def unzip_file(self, zip_path: str, target_dir: str) -> None: os.makedirs(target_dir, exist_ok=True) update_dir = "" logger.info(f"正在解压压缩包: {zip_path}") diff --git a/astrbot/core/umop_config_router.py b/astrbot/core/umop_config_router.py index e5d0d2071..049b96856 100644 --- a/astrbot/core/umop_config_router.py +++ b/astrbot/core/umop_config_router.py @@ -47,7 +47,7 @@ def get_conf_id_for_umop(self, umo: str) -> str | None: return conf_id return None - async def update_routing_data(self, new_routing: dict[str, str]): + async def update_routing_data(self, new_routing: dict[str, str]) -> None: """更新路由表 Args: @@ -67,7 +67,7 @@ async def update_routing_data(self, new_routing: dict[str, str]): self.umop_to_conf_id = new_routing await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) - async def update_route(self, umo: str, conf_id: str): + async def update_route(self, umo: str, conf_id: str) -> None: """更新一条路由 Args: @@ -86,7 +86,7 @@ async def update_route(self, umo: str, conf_id: str): self.umop_to_conf_id[umo] = conf_id await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) - async def delete_route(self, umo: str): + async def delete_route(self, umo: str) -> None: """删除一条路由 Args: diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index 0a7116a0d..7c880ede6 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -23,7 +23,7 @@ def __init__(self, repo_mirror: str = "") -> None: self.MAIN_PATH = get_astrbot_path() self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases" - def terminate_child_processes(self): + def terminate_child_processes(self) -> None: """终止当前进程的所有子进程 使用 psutil 库获取当前进程的所有子进程,并尝试终止它们 """ @@ -85,7 +85,7 @@ async def check_update( async def get_releases(self) -> list: return await self.fetch_release_info(self.ASTRBOT_RELEASE_API) - async def update(self, reboot=False, latest=True, version=None, proxy=""): + async def update(self, reboot=False, latest=True, version=None, proxy="") -> None: update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest) file_url = None diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index fcf5bb3c7..7aef3172d 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -19,7 +19,7 @@ logger = logging.getLogger("astrbot") -def on_error(func, path, exc_info): +def on_error(func, path, exc_info) -> None: """A callback of the rmtree function.""" import stat @@ -37,7 +37,7 @@ def remove_dir(file_path: str) -> bool: return True -def port_checker(port: int, host: str = "localhost"): +def port_checker(port: int, host: str = "localhost") -> bool | None: sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sk.settimeout(1) try: @@ -134,7 +134,7 @@ async def download_image_by_url( raise e -async def download_file(url: str, path: str, show_progress: bool = False): +async def download_file(url: str, path: str, show_progress: bool = False) -> None: """从指定 url 下载文件到指定路径 path""" try: ssl_context = ssl.create_default_context( diff --git a/astrbot/core/utils/log_pipe.py b/astrbot/core/utils/log_pipe.py index 077275c6d..6f40f0942 100644 --- a/astrbot/core/utils/log_pipe.py +++ b/astrbot/core/utils/log_pipe.py @@ -24,7 +24,7 @@ def __init__( def fileno(self): return self.fd_write - def run(self): + def run(self) -> None: for line in iter(self.reader.readline, ""): if self.callback: self.callback(line.strip()) @@ -32,5 +32,5 @@ def run(self): self.reader.close() - def close(self): + def close(self) -> None: os.close(self.fd_write) diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py index 0f1eb6583..689a1dc86 100644 --- a/astrbot/core/utils/pip_installer.py +++ b/astrbot/core/utils/pip_installer.py @@ -15,7 +15,7 @@ async def install( package_name: str | None = None, requirements_path: str | None = None, mirror: str | None = None, - ): + ) -> None: args = ["install"] if package_name: args.append(package_name) diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index 3e7f339f9..4241fc6f2 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -29,7 +29,7 @@ def __init__(self) -> None: self.history_chains: list[list[Comp.BaseMessageComponent]] = [] - def stop(self, error: Exception | None = None): + def stop(self, error: Exception | None = None) -> None: """立即结束这个会话""" if not self.future.done(): if error: @@ -37,7 +37,7 @@ def stop(self, error: Exception | None = None): else: self.future.set_result(None) - def keep(self, timeout: float = 0, reset_timeout=False): + def keep(self, timeout: float = 0, reset_timeout=False) -> None: """保持这个会话 Args: diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index cb3c9fcf2..ee22fa3db 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -118,7 +118,7 @@ async def global_get( return await self.range_get_async("global", "global", key) return await self.get_async("global", "global", key, default) - async def put_async(self, scope: str, scope_id: str, key: str, value: Any): + async def put_async(self, scope: str, scope_id: str, key: str, value: Any) -> None: """设置指定范围和键的偏好设置""" await self.db_helper.insert_preference_or_update( scope, @@ -127,24 +127,24 @@ async def put_async(self, scope: str, scope_id: str, key: str, value: Any): {"val": value}, ) - async def session_put(self, umo: str, key: str, value: Any): + async def session_put(self, umo: str, key: str, value: Any) -> None: await self.put_async("umo", umo, key, value) - async def global_put(self, key: str, value: Any): + async def global_put(self, key: str, value: Any) -> None: await self.put_async("global", "global", key, value) - async def remove_async(self, scope: str, scope_id: str, key: str): + async def remove_async(self, scope: str, scope_id: str, key: str) -> None: """删除指定范围和键的偏好设置""" await self.db_helper.remove_preference(scope, scope_id, key) - async def session_remove(self, umo: str, key: str): + async def session_remove(self, umo: str, key: str) -> None: await self.remove_async("umo", umo, key) - async def global_remove(self, key: str): + async def global_remove(self, key: str) -> None: """删除全局偏好设置""" await self.remove_async("global", "global", key) - async def clear_async(self, scope: str, scope_id: str): + async def clear_async(self, scope: str, scope_id: str) -> None: """清空指定范围的所有偏好设置""" await self.db_helper.clear_preferences(scope, scope_id) @@ -188,21 +188,21 @@ def range_get( return result - def put(self, key, value, scope: str | None = None, scope_id: str | None = None): + def put(self, key, value, scope: str | None = None, scope_id: str | None = None) -> None: """设置偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.put_async(scope or "unknown", scope_id or "unknown", key, value), self._sync_loop, ).result() - def remove(self, key, scope: str | None = None, scope_id: str | None = None): + def remove(self, key, scope: str | None = None, scope_id: str | None = None) -> None: """删除偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.remove_async(scope or "unknown", scope_id or "unknown", key), self._sync_loop, ).result() - def clear(self, scope: str | None = None, scope_id: str | None = None): + def clear(self, scope: str | None = None, scope_id: str | None = None) -> None: """清空偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.clear_async(scope or "unknown", scope_id or "unknown"), diff --git a/astrbot/core/utils/t2i/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py index 7ebba5669..2abb22917 100644 --- a/astrbot/core/utils/t2i/network_strategy.py +++ b/astrbot/core/utils/t2i/network_strategy.py @@ -28,7 +28,7 @@ def __init__(self, base_url: str | None = None) -> None: self.endpoints = [self.BASE_RENDER_URL] self.template_manager = TemplateManager() - async def initialize(self): + async def initialize(self) -> None: if self.BASE_RENDER_URL == ASTRBOT_T2I_DEFAULT_ENDPOINT: asyncio.create_task(self.get_official_endpoints()) @@ -36,7 +36,7 @@ async def get_template(self, name: str = "base") -> str: """通过名称获取文转图 HTML 模板""" return self.template_manager.get_template(name) - async def get_official_endpoints(self): + async def get_official_endpoints(self) -> None: """获取官方的 t2i 端点列表。""" try: async with aiohttp.ClientSession() as session: diff --git a/astrbot/core/utils/t2i/renderer.py b/astrbot/core/utils/t2i/renderer.py index b1c3ec384..e3118d7e8 100644 --- a/astrbot/core/utils/t2i/renderer.py +++ b/astrbot/core/utils/t2i/renderer.py @@ -11,7 +11,7 @@ def __init__(self, endpoint_url: str | None = None) -> None: self.network_strategy = NetworkRenderStrategy(endpoint_url) self.local_strategy = LocalRenderStrategy() - async def initialize(self): + async def initialize(self) -> None: await self.network_strategy.initialize() async def render_custom_template( diff --git a/astrbot/core/utils/t2i/template_manager.py b/astrbot/core/utils/t2i/template_manager.py index 8bbdb7e9e..3df4cab79 100644 --- a/astrbot/core/utils/t2i/template_manager.py +++ b/astrbot/core/utils/t2i/template_manager.py @@ -80,7 +80,7 @@ def get_template(self, name: str) -> str: raise FileNotFoundError("模板不存在。") - def create_template(self, name: str, content: str): + def create_template(self, name: str, content: str) -> None: """在用户目录中创建一个新的模板文件。""" path = self._get_user_template_path(name) if os.path.exists(path): @@ -88,7 +88,7 @@ def create_template(self, name: str, content: str): with open(path, "w", encoding="utf-8") as f: f.write(content) - def update_template(self, name: str, content: str): + def update_template(self, name: str, content: str) -> None: """更新一个模板。此操作始终写入用户目录。 如果更新的是一个内置模板,此操作实际上会在用户目录中创建一个修改后的副本, 从而实现对内置模板的“覆盖”。 @@ -97,7 +97,7 @@ def update_template(self, name: str, content: str): with open(path, "w", encoding="utf-8") as f: f.write(content) - def delete_template(self, name: str): + def delete_template(self, name: str) -> None: """仅删除用户目录中的模板文件。 如果删除的是一个覆盖了内置模板的用户模板,这将有效地“恢复”到内置版本。 """ @@ -106,6 +106,6 @@ def delete_template(self, name: str): raise FileNotFoundError("用户模板不存在,无法删除。") os.remove(path) - def reset_default_template(self): + def reset_default_template(self) -> None: """将核心模板从内置目录强制重置到用户目录。""" self._copy_core_templates(overwrite=True) diff --git a/astrbot/core/utils/webhook_utils.py b/astrbot/core/utils/webhook_utils.py index c56d00b37..93cb73d00 100644 --- a/astrbot/core/utils/webhook_utils.py +++ b/astrbot/core/utils/webhook_utils.py @@ -17,7 +17,7 @@ def _get_dashboard_port() -> int: return 6185 -def log_webhook_info(platform_name: str, webhook_uuid: str): +def log_webhook_info(platform_name: str, webhook_uuid: str) -> None: """打印美观的 webhook 信息日志 Args: diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py index 728dfdabb..9f7f2dde1 100644 --- a/astrbot/core/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -10,6 +10,7 @@ from astrbot.core import logger from astrbot.core.utils.io import download_file, on_error from astrbot.core.utils.version_comparator import VersionComparator +from typing import NoReturn class ReleaseInfo: @@ -101,10 +102,10 @@ def github_api_release_parser(self, releases: list) -> list: ) return ret - def unzip(self): + def unzip(self) -> NoReturn: raise NotImplementedError - async def update(self): + async def update(self) -> NoReturn: raise NotImplementedError def compare_version(self, v1: str, v2: str) -> int: @@ -148,7 +149,7 @@ async def check_update( body=f"{tag_name}\n\n{sel_release_data['body']}", ) - async def download_from_repo_url(self, target_path: str, repo_url: str, proxy=""): + async def download_from_repo_url(self, target_path: str, repo_url: str, proxy="") -> None: author, repo, branch = self.parse_github_url(repo_url) logger.info(f"正在下载更新 {repo} ...") @@ -203,7 +204,7 @@ def parse_github_url(self, url: str): return author, repo, branch raise ValueError("无效的 GitHub URL") - def unzip_file(self, zip_path: str, target_dir: str): + def unzip_file(self, zip_path: str, target_dir: str) -> None: """解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir""" os.makedirs(target_dir, exist_ok=True) update_dir = "" diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index e8f17cc99..642e4aefa 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -122,7 +122,7 @@ def validate(data: dict, metadata: dict = schema, path=""): return errors, data -def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False): +def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False) -> None: """验证并保存配置""" errors = None logger.info(f"Saving config, is_core={is_core}") diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index 72e359737..9fe3c307d 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -18,7 +18,7 @@ def __init__(self, context: RouteContext) -> None: self.app = context.app self.config = context.config - def register_routes(self): + def register_routes(self) -> None: def _add_rule(path, method, func): # 统一添加 /api 前缀 full_path = f"/api{path}" diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 09ec76b52..888da9009 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -238,6 +238,6 @@ def run(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() logger.info("AstrBot WebUI 已经被优雅地关闭") diff --git a/main.py b/main.py index 60879f065..339e3a728 100644 --- a/main.py +++ b/main.py @@ -25,7 +25,7 @@ """ -def check_env(): +def check_env() -> None: if not (sys.version_info.major == 3 and sys.version_info.minor >= 10): logger.error("请使用 Python3.10+ 运行本项目。") exit() diff --git a/packages/astrbot/long_term_memory.py b/packages/astrbot/long_term_memory.py index f1cb0ec02..52d015d4f 100644 --- a/packages/astrbot/long_term_memory.py +++ b/packages/astrbot/long_term_memory.py @@ -111,7 +111,7 @@ async def need_active_reply(self, event: AstrMessageEvent) -> bool: return False - async def handle_message(self, event: AstrMessageEvent): + async def handle_message(self, event: AstrMessageEvent) -> None: """仅支持群聊""" if event.get_message_type() == MessageType.GROUP_MESSAGE: datetime_str = datetime.datetime.now().strftime("%H:%M:%S") @@ -148,7 +148,7 @@ async def handle_message(self, event: AstrMessageEvent): if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: self.session_chats[event.unified_msg_origin].pop(0) - async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest): + async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None: """当触发 LLM 请求前,调用此方法修改 req""" if event.unified_msg_origin not in self.session_chats: return @@ -171,7 +171,7 @@ async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest): ) req.system_prompt += chats_str - async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse): + async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse) -> None: if event.unified_msg_origin not in self.session_chats: return diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 09859ab95..91e5537b2 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -89,7 +89,7 @@ async def on_message(self, event: AstrMessageEvent): logger.error(f"主动回复失败: {e}") @filter.on_llm_request() - async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): + async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest) -> None: """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" await self.proc_llm_req.process_llm_request(event, req) @@ -100,7 +100,7 @@ async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): logger.error(f"ltm: {e}") @filter.on_llm_response() - async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse): + async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse) -> None: """在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话""" umo = event.unified_msg_origin cfg = self.context.get_config(umo).get("provider_settings", {}) @@ -117,7 +117,7 @@ async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse): logger.error(f"ltm: {e}") @filter.after_message_sent() - async def after_message_sent(self, event: AstrMessageEvent): + async def after_message_sent(self, event: AstrMessageEvent) -> None: """消息发送后处理""" if self.ltm and self.ltm_enabled(event): try: diff --git a/packages/astrbot/process_llm_request.py b/packages/astrbot/process_llm_request.py index f11987748..3a097d76b 100644 --- a/packages/astrbot/process_llm_request.py +++ b/packages/astrbot/process_llm_request.py @@ -115,7 +115,7 @@ async def _request_img_caption( f"Cannot get image caption because provider `{provider_id}` is not exist.", ) - async def process_llm_request(self, event: AstrMessageEvent, req: ProviderRequest): + async def process_llm_request(self, event: AstrMessageEvent, req: ProviderRequest) -> None: """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" cfg: dict = self.ctx.get_config(umo=event.unified_msg_origin)[ "provider_settings" diff --git a/packages/builtin_commands/commands/admin.py b/packages/builtin_commands/commands/admin.py index 5725bd39a..11bfb20d3 100644 --- a/packages/builtin_commands/commands/admin.py +++ b/packages/builtin_commands/commands/admin.py @@ -8,7 +8,7 @@ class AdminCommands: def __init__(self, context: star.Context) -> None: self.context = context - async def op(self, event: AstrMessageEvent, admin_id: str = ""): + async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None: """授权管理员。op """ if not admin_id: event.set_result( @@ -21,7 +21,7 @@ async def op(self, event: AstrMessageEvent, admin_id: str = ""): self.context.get_config().save_config() event.set_result(MessageEventResult().message("授权成功。")) - async def deop(self, event: AstrMessageEvent, admin_id: str = ""): + async def deop(self, event: AstrMessageEvent, admin_id: str = "") -> None: """取消授权管理员。deop """ if not admin_id: event.set_result( @@ -39,7 +39,7 @@ async def deop(self, event: AstrMessageEvent, admin_id: str = ""): MessageEventResult().message("此用户 ID 不在管理员名单内。"), ) - async def wl(self, event: AstrMessageEvent, sid: str = ""): + async def wl(self, event: AstrMessageEvent, sid: str = "") -> None: """添加白名单。wl """ if not sid: event.set_result( @@ -53,7 +53,7 @@ async def wl(self, event: AstrMessageEvent, sid: str = ""): cfg.save_config() event.set_result(MessageEventResult().message("添加白名单成功。")) - async def dwl(self, event: AstrMessageEvent, sid: str = ""): + async def dwl(self, event: AstrMessageEvent, sid: str = "") -> None: """删除白名单。dwl """ if not sid: event.set_result( @@ -70,7 +70,7 @@ async def dwl(self, event: AstrMessageEvent, sid: str = ""): except ValueError: event.set_result(MessageEventResult().message("此 SID 不在白名单内。")) - async def update_dashboard(self, event: AstrMessageEvent): + async def update_dashboard(self, event: AstrMessageEvent) -> None: await event.send(MessageChain().message("正在尝试更新管理面板...")) await download_dashboard(version=f"v{VERSION}", latest=False) await event.send(MessageChain().message("管理面板更新完成。")) diff --git a/packages/builtin_commands/commands/alter_cmd.py b/packages/builtin_commands/commands/alter_cmd.py index 6f62e8c36..ba31c3326 100644 --- a/packages/builtin_commands/commands/alter_cmd.py +++ b/packages/builtin_commands/commands/alter_cmd.py @@ -14,7 +14,7 @@ class AlterCmdCommands(CommandParserMixin): def __init__(self, context: star.Context) -> None: self.context = context - async def update_reset_permission(self, scene_key: str, perm_type: str): + async def update_reset_permission(self, scene_key: str, perm_type: str) -> None: """更新reset命令在特定场景下的权限设置""" from astrbot.api import sp @@ -26,7 +26,7 @@ async def update_reset_permission(self, scene_key: str, perm_type: str): alter_cmd_cfg["astrbot"] = plugin_cfg await sp.global_put("alter_cmd", alter_cmd_cfg) - async def alter_cmd(self, event: AstrMessageEvent): + async def alter_cmd(self, event: AstrMessageEvent) -> None: token = self.parse_commands(event.message_str) if token.len < 3: await event.send( diff --git a/packages/builtin_commands/commands/conversation.py b/packages/builtin_commands/commands/conversation.py index 6f0bda4a8..eb8cfdefa 100644 --- a/packages/builtin_commands/commands/conversation.py +++ b/packages/builtin_commands/commands/conversation.py @@ -33,7 +33,7 @@ async def _get_current_persona_id(self, session_id): return None return conv.persona_id - async def reset(self, message: AstrMessageEvent): + async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" umo = message.unified_msg_origin cfg = self.context.get_config(umo=message.unified_msg_origin) @@ -98,7 +98,7 @@ async def reset(self, message: AstrMessageEvent): message.set_result(MessageEventResult().message(ret)) - async def his(self, message: AstrMessageEvent, page: int = 1): + async def his(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话记录""" if not self.context.get_using_provider(message.unified_msg_origin): message.set_result( @@ -141,7 +141,7 @@ async def his(self, message: AstrMessageEvent, page: int = 1): message.set_result(MessageEventResult().message(ret).use_t2i(False)) - async def convs(self, message: AstrMessageEvent, page: int = 1): + async def convs(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话列表""" cfg = self.context.get_config(umo=message.unified_msg_origin) agent_runner_type = cfg["provider_settings"]["agent_runner_type"] @@ -216,7 +216,7 @@ async def convs(self, message: AstrMessageEvent, page: int = 1): message.set_result(MessageEventResult().message(ret).use_t2i(False)) return - async def new_conv(self, message: AstrMessageEvent): + async def new_conv(self, message: AstrMessageEvent) -> None: """创建新对话""" cfg = self.context.get_config(umo=message.unified_msg_origin) agent_runner_type = cfg["provider_settings"]["agent_runner_type"] @@ -242,7 +242,7 @@ async def new_conv(self, message: AstrMessageEvent): MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"), ) - async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""): + async def groupnew_conv(self, message: AstrMessageEvent, sid: str = "") -> None: """创建新群聊对话""" if sid: session = str( @@ -273,7 +273,7 @@ async def switch_conv( self, message: AstrMessageEvent, index: int | None = None, - ): + ) -> None: """通过 /ls 前面的序号切换对话""" if not isinstance(index, int): message.set_result( @@ -308,7 +308,7 @@ async def switch_conv( ), ) - async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""): + async def rename_conv(self, message: AstrMessageEvent, new_name: str = "") -> None: """重命名对话""" if not new_name: message.set_result(MessageEventResult().message("请输入新的对话名称。")) @@ -319,7 +319,7 @@ async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""): ) message.set_result(MessageEventResult().message("重命名对话成功。")) - async def del_conv(self, message: AstrMessageEvent): + async def del_conv(self, message: AstrMessageEvent) -> None: """删除当前对话""" cfg = self.context.get_config(umo=message.unified_msg_origin) is_unique_session = cfg["platform_settings"]["unique_session"] diff --git a/packages/builtin_commands/commands/help.py b/packages/builtin_commands/commands/help.py index 8874328b1..7e7cdbb95 100644 --- a/packages/builtin_commands/commands/help.py +++ b/packages/builtin_commands/commands/help.py @@ -21,7 +21,7 @@ async def _query_astrbot_notice(self): except BaseException: return "" - async def help(self, event: AstrMessageEvent): + async def help(self, event: AstrMessageEvent) -> None: """查看帮助""" notice = "" try: diff --git a/packages/builtin_commands/commands/llm.py b/packages/builtin_commands/commands/llm.py index 7d599d6fa..ba9ba5c9b 100644 --- a/packages/builtin_commands/commands/llm.py +++ b/packages/builtin_commands/commands/llm.py @@ -6,7 +6,7 @@ class LLMCommands: def __init__(self, context: star.Context) -> None: self.context = context - async def llm(self, event: AstrMessageEvent): + async def llm(self, event: AstrMessageEvent) -> None: """开启/关闭 LLM""" cfg = self.context.get_config(umo=event.unified_msg_origin) enable = cfg["provider_settings"].get("enable", True) diff --git a/packages/builtin_commands/commands/persona.py b/packages/builtin_commands/commands/persona.py index 155ac333a..1a5ddb848 100644 --- a/packages/builtin_commands/commands/persona.py +++ b/packages/builtin_commands/commands/persona.py @@ -8,7 +8,7 @@ class PersonaCommands: def __init__(self, context: star.Context) -> None: self.context = context - async def persona(self, message: AstrMessageEvent): + async def persona(self, message: AstrMessageEvent) -> None: l = message.message_str.split(" ") # noqa: E741 umo = message.unified_msg_origin diff --git a/packages/builtin_commands/commands/plugin.py b/packages/builtin_commands/commands/plugin.py index 327606ced..49bee9462 100644 --- a/packages/builtin_commands/commands/plugin.py +++ b/packages/builtin_commands/commands/plugin.py @@ -11,7 +11,7 @@ class PluginCommands: def __init__(self, context: star.Context) -> None: self.context = context - async def plugin_ls(self, event: AstrMessageEvent): + async def plugin_ls(self, event: AstrMessageEvent) -> None: """获取已经安装的插件列表。""" parts = ["已加载的插件:\n"] for plugin in self.context.get_all_stars(): @@ -30,7 +30,7 @@ async def plugin_ls(self, event: AstrMessageEvent): MessageEventResult().message(f"{plugin_list_info}").use_t2i(False), ) - async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """禁用插件""" if DEMO_MODE: event.set_result(MessageEventResult().message("演示模式下无法禁用插件。")) @@ -43,7 +43,7 @@ async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): await self.context._star_manager.turn_off_plugin(plugin_name) # type: ignore event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。")) - async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """启用插件""" if DEMO_MODE: event.set_result(MessageEventResult().message("演示模式下无法启用插件。")) @@ -56,7 +56,7 @@ async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""): await self.context._star_manager.turn_on_plugin(plugin_name) # type: ignore event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。")) - async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""): + async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None: """安装插件""" if DEMO_MODE: event.set_result(MessageEventResult().message("演示模式下无法安装插件。")) @@ -77,7 +77,7 @@ async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""): event.set_result(MessageEventResult().message(f"安装插件失败: {e}")) return - async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """获取插件帮助""" if not plugin_name: event.set_result( diff --git a/packages/builtin_commands/commands/provider.py b/packages/builtin_commands/commands/provider.py index 09f685dca..408511d1a 100644 --- a/packages/builtin_commands/commands/provider.py +++ b/packages/builtin_commands/commands/provider.py @@ -49,7 +49,7 @@ async def provider( event: AstrMessageEvent, idx: str | int | None = None, idx2: int | None = None, - ): + ) -> None: """查看或者切换 LLM Provider""" umo = event.unified_msg_origin cfg = self.context.get_config(umo).get("provider_settings", {}) @@ -226,7 +226,7 @@ async def model_ls( self, message: AstrMessageEvent, idx_or_name: int | str | None = None, - ): + ) -> None: """查看或者切换模型""" prov = self.context.get_using_provider(message.unified_msg_origin) if not prov: @@ -291,7 +291,7 @@ async def model_ls( MessageEventResult().message(f"切换模型到 {prov.get_model()}。"), ) - async def key(self, message: AstrMessageEvent, index: int | None = None): + async def key(self, message: AstrMessageEvent, index: int | None = None) -> None: prov = self.context.get_using_provider(message.unified_msg_origin) if not prov: message.set_result( diff --git a/packages/builtin_commands/commands/setunset.py b/packages/builtin_commands/commands/setunset.py index 06a45ff6c..096698844 100644 --- a/packages/builtin_commands/commands/setunset.py +++ b/packages/builtin_commands/commands/setunset.py @@ -6,7 +6,7 @@ class SetUnsetCommands: def __init__(self, context: star.Context) -> None: self.context = context - async def set_variable(self, event: AstrMessageEvent, key: str, value: str): + async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None: """设置会话变量""" uid = event.unified_msg_origin session_var = await sp.session_get(uid, "session_variables", {}) @@ -19,7 +19,7 @@ async def set_variable(self, event: AstrMessageEvent, key: str, value: str): ), ) - async def unset_variable(self, event: AstrMessageEvent, key: str): + async def unset_variable(self, event: AstrMessageEvent, key: str) -> None: """移除会话变量""" uid = event.unified_msg_origin session_var = await sp.session_get(uid, "session_variables", {}) diff --git a/packages/builtin_commands/commands/sid.py b/packages/builtin_commands/commands/sid.py index 35525ccdd..e8bdbffb1 100644 --- a/packages/builtin_commands/commands/sid.py +++ b/packages/builtin_commands/commands/sid.py @@ -10,7 +10,7 @@ class SIDCommand: def __init__(self, context: star.Context) -> None: self.context = context - async def sid(self, event: AstrMessageEvent): + async def sid(self, event: AstrMessageEvent) -> None: """获取消息来源信息""" sid = event.unified_msg_origin user_id = str(event.get_sender_id()) diff --git a/packages/builtin_commands/commands/t2i.py b/packages/builtin_commands/commands/t2i.py index 76e5acf3f..78d6b0df7 100644 --- a/packages/builtin_commands/commands/t2i.py +++ b/packages/builtin_commands/commands/t2i.py @@ -10,7 +10,7 @@ class T2ICommand: def __init__(self, context: star.Context) -> None: self.context = context - async def t2i(self, event: AstrMessageEvent): + async def t2i(self, event: AstrMessageEvent) -> None: """开关文本转图片""" config = self.context.get_config(umo=event.unified_msg_origin) if config["t2i"]: diff --git a/packages/builtin_commands/commands/tool.py b/packages/builtin_commands/commands/tool.py index bec9b53ac..09b239b8c 100644 --- a/packages/builtin_commands/commands/tool.py +++ b/packages/builtin_commands/commands/tool.py @@ -6,25 +6,25 @@ class ToolCommands: def __init__(self, context: star.Context) -> None: self.context = context - async def tool_ls(self, event: AstrMessageEvent): + async def tool_ls(self, event: AstrMessageEvent) -> None: """查看函数工具列表""" event.set_result( MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), ) - async def tool_on(self, event: AstrMessageEvent, tool_name: str = ""): + async def tool_on(self, event: AstrMessageEvent, tool_name: str = "") -> None: """启用一个函数工具""" event.set_result( MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), ) - async def tool_off(self, event: AstrMessageEvent, tool_name: str = ""): + async def tool_off(self, event: AstrMessageEvent, tool_name: str = "") -> None: """停用一个函数工具""" event.set_result( MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), ) - async def tool_all_off(self, event: AstrMessageEvent): + async def tool_all_off(self, event: AstrMessageEvent) -> None: """停用所有函数工具""" event.set_result( MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), diff --git a/packages/builtin_commands/commands/tts.py b/packages/builtin_commands/commands/tts.py index 22af09238..5c245ed26 100644 --- a/packages/builtin_commands/commands/tts.py +++ b/packages/builtin_commands/commands/tts.py @@ -11,7 +11,7 @@ class TTSCommand: def __init__(self, context: star.Context) -> None: self.context = context - async def tts(self, event: AstrMessageEvent): + async def tts(self, event: AstrMessageEvent) -> None: """开关文本转语音(会话级别)""" umo = event.unified_msg_origin ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo) diff --git a/packages/builtin_commands/main.py b/packages/builtin_commands/main.py index 291bed456..57a7217f8 100644 --- a/packages/builtin_commands/main.py +++ b/packages/builtin_commands/main.py @@ -37,108 +37,108 @@ def __init__(self, context: star.Context) -> None: self.sid_c = SIDCommand(self.context) @filter.command("help") - async def help(self, event: AstrMessageEvent): + async def help(self, event: AstrMessageEvent) -> None: """查看帮助""" await self.help_c.help(event) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("llm") - async def llm(self, event: AstrMessageEvent): + async def llm(self, event: AstrMessageEvent) -> None: """开启/关闭 LLM""" await self.llm_c.llm(event) @filter.command_group("tool") - def tool(self): + def tool(self) -> None: pass @tool.command("ls") - async def tool_ls(self, event: AstrMessageEvent): + async def tool_ls(self, event: AstrMessageEvent) -> None: """查看函数工具列表""" await self.tool_c.tool_ls(event) @tool.command("on") - async def tool_on(self, event: AstrMessageEvent, tool_name: str): + async def tool_on(self, event: AstrMessageEvent, tool_name: str) -> None: """启用一个函数工具""" await self.tool_c.tool_on(event, tool_name) @tool.command("off") - async def tool_off(self, event: AstrMessageEvent, tool_name: str): + async def tool_off(self, event: AstrMessageEvent, tool_name: str) -> None: """停用一个函数工具""" await self.tool_c.tool_off(event, tool_name) @tool.command("off_all") - async def tool_all_off(self, event: AstrMessageEvent): + async def tool_all_off(self, event: AstrMessageEvent) -> None: """停用所有函数工具""" await self.tool_c.tool_all_off(event) @filter.command_group("plugin") - def plugin(self): + def plugin(self) -> None: pass @plugin.command("ls") - async def plugin_ls(self, event: AstrMessageEvent): + async def plugin_ls(self, event: AstrMessageEvent) -> None: """获取已经安装的插件列表。""" await self.plugin_c.plugin_ls(event) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("off") - async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """禁用插件""" await self.plugin_c.plugin_off(event, plugin_name) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("on") - async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """启用插件""" await self.plugin_c.plugin_on(event, plugin_name) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("get") - async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""): + async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None: """安装插件""" await self.plugin_c.plugin_get(event, plugin_repo) @plugin.command("help") - async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """获取插件帮助""" await self.plugin_c.plugin_help(event, plugin_name) @filter.command("t2i") - async def t2i(self, event: AstrMessageEvent): + async def t2i(self, event: AstrMessageEvent) -> None: """开关文本转图片""" await self.t2i_c.t2i(event) @filter.command("tts") - async def tts(self, event: AstrMessageEvent): + async def tts(self, event: AstrMessageEvent) -> None: """开关文本转语音(会话级别)""" await self.tts_c.tts(event) @filter.command("sid") - async def sid(self, event: AstrMessageEvent): + async def sid(self, event: AstrMessageEvent) -> None: """获取会话 ID 和 管理员 ID""" await self.sid_c.sid(event) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("op") - async def op(self, event: AstrMessageEvent, admin_id: str = ""): + async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None: """授权管理员。op """ await self.admin_c.op(event, admin_id) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("deop") - async def deop(self, event: AstrMessageEvent, admin_id: str): + async def deop(self, event: AstrMessageEvent, admin_id: str) -> None: """取消授权管理员。deop """ await self.admin_c.deop(event, admin_id) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("wl") - async def wl(self, event: AstrMessageEvent, sid: str = ""): + async def wl(self, event: AstrMessageEvent, sid: str = "") -> None: """添加白名单。wl """ await self.admin_c.wl(event, sid) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("dwl") - async def dwl(self, event: AstrMessageEvent, sid: str): + async def dwl(self, event: AstrMessageEvent, sid: str) -> None: """删除白名单。dwl """ await self.admin_c.dwl(event, sid) @@ -149,12 +149,12 @@ async def provider( event: AstrMessageEvent, idx: str | int | None = None, idx2: int | None = None, - ): + ) -> None: """查看或者切换 LLM Provider""" await self.provider_c.provider(event, idx, idx2) @filter.command("reset") - async def reset(self, message: AstrMessageEvent): + async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" await self.conversation_c.reset(message) @@ -164,73 +164,73 @@ async def model_ls( self, message: AstrMessageEvent, idx_or_name: int | str | None = None, - ): + ) -> None: """查看或者切换模型""" await self.provider_c.model_ls(message, idx_or_name) @filter.command("history") - async def his(self, message: AstrMessageEvent, page: int = 1): + async def his(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话记录""" await self.conversation_c.his(message, page) @filter.command("ls") - async def convs(self, message: AstrMessageEvent, page: int = 1): + async def convs(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话列表""" await self.conversation_c.convs(message, page) @filter.command("new") - async def new_conv(self, message: AstrMessageEvent): + async def new_conv(self, message: AstrMessageEvent) -> None: """创建新对话""" await self.conversation_c.new_conv(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("groupnew") - async def groupnew_conv(self, message: AstrMessageEvent, sid: str): + async def groupnew_conv(self, message: AstrMessageEvent, sid: str) -> None: """创建新群聊对话""" await self.conversation_c.groupnew_conv(message, sid) @filter.command("switch") - async def switch_conv(self, message: AstrMessageEvent, index: int | None = None): + async def switch_conv(self, message: AstrMessageEvent, index: int | None = None) -> None: """通过 /ls 前面的序号切换对话""" await self.conversation_c.switch_conv(message, index) @filter.command("rename") - async def rename_conv(self, message: AstrMessageEvent, new_name: str): + async def rename_conv(self, message: AstrMessageEvent, new_name: str) -> None: """重命名对话""" await self.conversation_c.rename_conv(message, new_name) @filter.command("del") - async def del_conv(self, message: AstrMessageEvent): + async def del_conv(self, message: AstrMessageEvent) -> None: """删除当前对话""" await self.conversation_c.del_conv(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("key") - async def key(self, message: AstrMessageEvent, index: int | None = None): + async def key(self, message: AstrMessageEvent, index: int | None = None) -> None: """查看或者切换 Key""" await self.provider_c.key(message, index) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("persona") - async def persona(self, message: AstrMessageEvent): + async def persona(self, message: AstrMessageEvent) -> None: """查看或者切换 Persona""" await self.persona_c.persona(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("dashboard_update") - async def update_dashboard(self, event: AstrMessageEvent): + async def update_dashboard(self, event: AstrMessageEvent) -> None: await self.admin_c.update_dashboard(event) @filter.command("set") - async def set_variable(self, event: AstrMessageEvent, key: str, value: str): + async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None: await self.setunset_c.set_variable(event, key, value) @filter.command("unset") - async def unset_variable(self, event: AstrMessageEvent, key: str): + async def unset_variable(self, event: AstrMessageEvent, key: str) -> None: await self.setunset_c.unset_variable(event, key) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("alter_cmd", alias={"alter"}) - async def alter_cmd(self, event: AstrMessageEvent): + async def alter_cmd(self, event: AstrMessageEvent) -> None: """修改命令权限""" await self.alter_cmd_c.alter_cmd(event) diff --git a/packages/python_interpreter/main.py b/packages/python_interpreter/main.py index 98496157a..8646f8f3a 100644 --- a/packages/python_interpreter/main.py +++ b/packages/python_interpreter/main.py @@ -124,7 +124,7 @@ def __init__(self, context: star.Context) -> None: with open(PATH) as f: self.config = json.load(f) - async def initialize(self): + async def initialize(self) -> None: ok = await self.is_docker_available() if not ok: logger.info( @@ -240,7 +240,7 @@ async def on_message(self, event: AstrMessageEvent): del self.user_waiting[uid] @filter.on_llm_request() - async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest): + async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest) -> None: if event.get_session_id() in self.user_file_msg_buffer: files = self.user_file_msg_buffer[event.get_session_id()] if not request.prompt: @@ -248,7 +248,7 @@ async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest): request.prompt += f"\nUser provided files: {files}" @filter.command_group("pi") - def pi(self): + def pi(self) -> None: pass @pi.command("absdir") diff --git a/packages/python_interpreter/shared/api.py b/packages/python_interpreter/shared/api.py index 287773fb0..f4cf51ff2 100644 --- a/packages/python_interpreter/shared/api.py +++ b/packages/python_interpreter/shared/api.py @@ -6,17 +6,17 @@ def _get_magic_code(): return os.getenv("MAGIC_CODE") -def send_text(text: str): +def send_text(text: str) -> None: print(f"[ASTRBOT_TEXT_OUTPUT#{_get_magic_code()}]: {text}") -def send_image(image_path: str): +def send_image(image_path: str) -> None: if not os.path.exists(image_path): raise Exception(f"Image file not found: {image_path}") print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}") -def send_file(file_path: str): +def send_file(file_path: str) -> None: if not os.path.exists(file_path): raise Exception(f"File not found: {file_path}") print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}") diff --git a/packages/reminder/main.py b/packages/reminder/main.py index 8f61e02fe..fcc849aa9 100644 --- a/packages/reminder/main.py +++ b/packages/reminder/main.py @@ -178,7 +178,7 @@ async def reminder_tool( ) @filter.command_group("reminder") - def reminder(self): + def reminder(self) -> None: """The command group of the reminder.""" async def get_upcoming_reminders(self, unified_msg_origin: str): @@ -260,7 +260,7 @@ async def _reminder_callback(self, unified_msg_origin: str, d: dict): ), ) - async def terminate(self): + async def terminate(self) -> None: self.scheduler.shutdown() await self._save_data() logger.info("Reminder plugin terminated.") diff --git a/packages/session_controller/main.py b/packages/session_controller/main.py index 3e2106a9e..b1aecc766 100644 --- a/packages/session_controller/main.py +++ b/packages/session_controller/main.py @@ -21,7 +21,7 @@ def __init__(self, context: Context) -> None: super().__init__(context) @filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize) - async def handle_session_control_agent(self, event: AstrMessageEvent): + async def handle_session_control_agent(self, event: AstrMessageEvent) -> None: """会话控制代理""" for session_filter in FILTERS: session_id = session_filter.filter(event) diff --git a/packages/web_searcher/main.py b/packages/web_searcher/main.py index 118ef2483..7d5f14149 100644 --- a/packages/web_searcher/main.py +++ b/packages/web_searcher/main.py @@ -184,7 +184,7 @@ async def _extract_tavily(self, cfg: AstrBotConfig, payload: dict) -> list[dict] return results @filter.command("websearch") - async def websearch(self, event: AstrMessageEvent, oper: str | None = None): + async def websearch(self, event: AstrMessageEvent, oper: str | None = None) -> None: event.set_result( MessageEventResult().message( "此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。", @@ -230,7 +230,7 @@ async def search_from_search_engine( return ret - async def ensure_baidu_ai_search_mcp(self, umo: str | None = None): + async def ensure_baidu_ai_search_mcp(self, umo: str | None = None) -> None: if self.baidu_initialized: return cfg = self.context.get_config(umo=umo) @@ -378,7 +378,7 @@ async def edit_web_search_tools( self, event: AstrMessageEvent, req: ProviderRequest, - ): + ) -> None: """Get the session conversation for the given event.""" cfg = self.context.get_config(umo=event.unified_msg_origin) prov_settings = cfg.get("provider_settings", {}) diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index f5439e9d5..d53984814 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -61,7 +61,7 @@ async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecyc @pytest.mark.asyncio -async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): +async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle) -> None: """Tests the login functionality with both wrong and correct credentials.""" test_client = app.test_client() response = await test_client.post( @@ -83,7 +83,7 @@ async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): @pytest.mark.asyncio -async def test_get_stat(app: Quart, authenticated_header: dict): +async def test_get_stat(app: Quart, authenticated_header: dict) -> None: test_client = app.test_client() response = await test_client.get("/api/stat/get") assert response.status_code == 401 @@ -94,7 +94,7 @@ async def test_get_stat(app: Quart, authenticated_header: dict): @pytest.mark.asyncio -async def test_plugins(app: Quart, authenticated_header: dict): +async def test_plugins(app: Quart, authenticated_header: dict) -> None: test_client = app.test_client() # 已经安装的插件 response = await test_client.get("/api/plugin/get", headers=authenticated_header) @@ -161,7 +161,7 @@ async def test_plugins(app: Quart, authenticated_header: dict): @pytest.mark.asyncio -async def test_check_update(app: Quart, authenticated_header: dict): +async def test_check_update(app: Quart, authenticated_header: dict) -> None: test_client = app.test_client() response = await test_client.get("/api/update/check", headers=authenticated_header) assert response.status_code == 200 @@ -176,7 +176,7 @@ async def test_do_update( core_lifecycle_td: AstrBotCoreLifecycle, monkeypatch, tmp_path_factory, -): +) -> None: test_client = app.test_client() # Use a temporary path for the mock update to avoid side effects diff --git a/tests/test_main.py b/tests/test_main.py index c70fe6865..d84cd44c9 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -17,7 +17,7 @@ def __init__(self, major, minor) -> None: self.minor = minor -def test_check_env(monkeypatch): +def test_check_env(monkeypatch) -> None: version_info_correct = _version_info(3, 10) version_info_wrong = _version_info(3, 9) monkeypatch.setattr(sys, "version_info", version_info_correct) @@ -33,7 +33,7 @@ def test_check_env(monkeypatch): @pytest.mark.asyncio -async def test_check_dashboard_files_not_exists(monkeypatch): +async def test_check_dashboard_files_not_exists(monkeypatch) -> None: """Tests dashboard download when files do not exist.""" monkeypatch.setattr(os.path, "exists", lambda x: False) @@ -43,7 +43,7 @@ async def test_check_dashboard_files_not_exists(monkeypatch): @pytest.mark.asyncio -async def test_check_dashboard_files_exists_and_version_match(monkeypatch): +async def test_check_dashboard_files_exists_and_version_match(monkeypatch) -> None: """Tests that dashboard is not downloaded when it exists and version matches.""" # Mock os.path.exists to return True monkeypatch.setattr(os.path, "exists", lambda x: True) @@ -62,7 +62,7 @@ async def test_check_dashboard_files_exists_and_version_match(monkeypatch): @pytest.mark.asyncio -async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch): +async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch) -> None: """Tests that a warning is logged when dashboard version mismatches.""" monkeypatch.setattr(os.path, "exists", lambda x: True) @@ -77,7 +77,7 @@ async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch): @pytest.mark.asyncio -async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch): +async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch) -> None: """Tests that providing a valid webui_dir skips all checks.""" valid_dir = "/tmp/my-custom-webui" monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir) diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 1e4cd866a..6e3cbc2c9 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -59,21 +59,21 @@ def plugin_manager_pm(tmp_path): return manager -def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): +def test_plugin_manager_initialization(plugin_manager_pm: PluginManager) -> None: assert plugin_manager_pm is not None assert plugin_manager_pm.context is not None assert plugin_manager_pm.config is not None @pytest.mark.asyncio -async def test_plugin_manager_reload(plugin_manager_pm: PluginManager): +async def test_plugin_manager_reload(plugin_manager_pm: PluginManager) -> None: success, err_message = await plugin_manager_pm.reload() assert success is True assert err_message is None @pytest.mark.asyncio -async def test_install_plugin(plugin_manager_pm: PluginManager): +async def test_install_plugin(plugin_manager_pm: PluginManager) -> None: """Tests successful plugin installation in an isolated environment.""" test_repo = "https://github.com/Soulter/astrbot_plugin_essential" plugin_info = await plugin_manager_pm.install_plugin(test_repo) @@ -90,7 +90,7 @@ async def test_install_plugin(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager): +async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager) -> None: """Tests that installing a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.install_plugin( @@ -99,7 +99,7 @@ async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_update_plugin(plugin_manager_pm: PluginManager): +async def test_update_plugin(plugin_manager_pm: PluginManager) -> None: """Tests updating an existing plugin in an isolated environment.""" # First, install the plugin test_repo = "https://github.com/Soulter/astrbot_plugin_essential" @@ -110,14 +110,14 @@ async def test_update_plugin(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager): +async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager) -> None: """Tests that updating a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.update_plugin("non_existent_plugin") @pytest.mark.asyncio -async def test_uninstall_plugin(plugin_manager_pm: PluginManager): +async def test_uninstall_plugin(plugin_manager_pm: PluginManager) -> None: """Tests successful plugin uninstallation in an isolated environment.""" # First, install the plugin test_repo = "https://github.com/Soulter/astrbot_plugin_essential" @@ -144,7 +144,7 @@ async def test_uninstall_plugin(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager): +async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager) -> None: """Tests that uninstalling a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.uninstall_plugin("non_existent_plugin") diff --git a/tests/test_security_fixes.py b/tests/test_security_fixes.py index d4e455541..06939ebf4 100644 --- a/tests/test_security_fixes.py +++ b/tests/test_security_fixes.py @@ -10,7 +10,7 @@ import pytest -def test_wecom_crypto_uses_secrets(): +def test_wecom_crypto_uses_secrets() -> None: """Test that WXBizJsonMsgCrypt uses secrets module instead of random.""" from astrbot.core.platform.sources.wecom_ai_bot.WXBizJsonMsgCrypt import Prpcrypt @@ -33,7 +33,7 @@ def test_wecom_crypto_uses_secrets(): assert 1000000000000000 <= int(decoded) <= 9999999999999999 -def test_wecomai_utils_uses_secrets(): +def test_wecomai_utils_uses_secrets() -> None: """Test that wecomai_utils uses secrets module for random string generation.""" from astrbot.core.platform.sources.wecom_ai_bot.wecomai_utils import ( generate_random_string, @@ -53,7 +53,7 @@ def test_wecomai_utils_uses_secrets(): assert len(set(random_strings)) >= 19 # Allow for 1 collision in 20 (very unlikely) -def test_azure_tts_signature_uses_secrets(): +def test_azure_tts_signature_uses_secrets() -> None: """Test that Azure TTS signature generation uses secrets module.""" import asyncio @@ -94,7 +94,7 @@ async def test_nonce_generation(): asyncio.run(test_nonce_generation()) -def test_ssl_context_fallback_explicit(): +def test_ssl_context_fallback_explicit() -> None: """Test that SSL context fallback is properly configured.""" # This test verifies the SSL context configuration # We can't easily test the full io.py functions without network calls, @@ -113,7 +113,7 @@ def test_ssl_context_fallback_explicit(): # The actual code only uses this when certificate validation fails -def test_io_module_has_ssl_imports(): +def test_io_module_has_ssl_imports() -> None: """Verify that io.py properly imports ssl module.""" from astrbot.core.utils import io @@ -124,7 +124,7 @@ def test_io_module_has_ssl_imports(): assert hasattr(io.ssl, "CERT_NONE") -def test_secrets_module_randomness_quality(): +def test_secrets_module_randomness_quality() -> None: """Test that secrets module provides high-quality randomness.""" import secrets From 8d766f0cb989b61afd30b598dfac56fcf9d8659a Mon Sep 17 00:00:00 2001 From: Dale Null Date: Thu, 11 Dec 2025 15:55:34 +0800 Subject: [PATCH 26/44] =?UTF-8?q?refactor:=20=E4=B8=BA=20save=5Fmcp=5Fconf?= =?UTF-8?q?ig=20=E6=B7=BB=E5=8A=A0=E8=BF=94=E5=9B=9E=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E6=B3=A8=E8=A7=A3=20bool?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/func_tool_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 0f5f26f35..106b42cc5 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -500,7 +500,7 @@ def load_mcp_config(self): logger.error(f"加载 MCP 配置失败: {e}") return DEFAULT_MCP_CONFIG - def save_mcp_config(self, config: dict): + def save_mcp_config(self, config: dict) -> bool: try: with open(self.mcp_config_path, "w", encoding="utf-8") as f: json.dump(config, f, ensure_ascii=False, indent=4) From 1f2a16cdb3d1e318fc2f46a3894b20b9872e3c26 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Thu, 11 Dec 2025 15:56:31 +0800 Subject: [PATCH 27/44] =?UTF-8?q?refactor:=20=E7=BB=9F=E4=B8=80=20typing?= =?UTF-8?q?=20=E5=AF=BC=E5=85=A5=E9=A1=BA=E5=BA=8F=E5=B9=B6=E5=B0=86=20NoR?= =?UTF-8?q?eturn=20=E5=AF=BC=E5=85=A5=E6=8F=90=E5=89=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py | 2 +- .../platform/sources/qqofficial/qqofficial_platform_adapter.py | 2 +- .../platform/sources/qqofficial_webhook/qo_webhook_adapter.py | 2 +- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py | 2 +- astrbot/core/provider/provider.py | 2 +- astrbot/core/zip_updator.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 8e06660ab..090df74d7 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -2,7 +2,7 @@ import os import threading import uuid -from typing import cast, NoReturn +from typing import NoReturn, cast import aiohttp import dingtalk_stream diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 550c4b6fa..4a2bd202d 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -4,7 +4,7 @@ import logging import os import time -from typing import cast, NoReturn +from typing import NoReturn, cast import botpy import botpy.message diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index 634bc2de1..61caee5aa 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Any, cast, NoReturn +from typing import Any, NoReturn, cast import botpy import botpy.message diff --git a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py index 27b30f830..260b950d1 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py @@ -14,11 +14,11 @@ import socket import struct import time +from typing import NoReturn from Crypto.Cipher import AES from . import ierror -from typing import NoReturn """ 关于Crypto.Cipher模块,ImportError: No module named 'Crypto'解决方案 diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 54df05a34..00a150cc6 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -2,7 +2,7 @@ import asyncio import os from collections.abc import AsyncGenerator -from typing import TypeAlias, Union, NoReturn +from typing import NoReturn, TypeAlias, Union from astrbot.core.agent.message import Message from astrbot.core.agent.tool import ToolSet diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py index 9f7f2dde1..c5bf5b77f 100644 --- a/astrbot/core/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -3,6 +3,7 @@ import shutil import ssl import zipfile +from typing import NoReturn import aiohttp import certifi @@ -10,7 +11,6 @@ from astrbot.core import logger from astrbot.core.utils.io import download_file, on_error from astrbot.core.utils.version_comparator import VersionComparator -from typing import NoReturn class ReleaseInfo: From 1b564f3507b1767ea5181066959243539cac8696 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Thu, 11 Dec 2025 16:03:56 +0800 Subject: [PATCH 28/44] =?UTF-8?q?refactor:=20=E4=B8=BA=E5=A4=9A=E5=A4=84?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E6=B7=BB=E5=8A=A0=E8=BF=94=E5=9B=9E=E5=80=BC?= =?UTF-8?q?/=E5=8F=82=E6=95=B0=E7=B1=BB=E5=9E=8B=E6=B3=A8=E8=A7=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/agent/mcp_client.py | 4 ++-- astrbot/core/astrbot_config_mgr.py | 2 +- astrbot/core/config/astrbot_config.py | 2 +- astrbot/core/db/migration/shared_preferences_v3.py | 2 +- astrbot/core/db/migration/sqlite_v3.py | 2 +- astrbot/core/db/sqlite.py | 6 +++--- astrbot/core/event_bus.py | 2 +- astrbot/core/file_token_service.py | 2 +- astrbot/core/knowledge_base/kb_helper.py | 2 +- astrbot/core/knowledge_base/kb_mgr.py | 2 +- astrbot/core/log.py | 8 ++++---- astrbot/core/message/components.py | 4 ++-- .../method/agent_sub_stages/internal.py | 12 ++++++------ astrbot/core/pipeline/respond/stage.py | 2 +- astrbot/core/pipeline/scheduler.py | 2 +- astrbot/core/platform/astr_message_event.py | 4 ++-- astrbot/core/platform/astrbot_message.py | 2 +- astrbot/core/platform/manager.py | 2 +- astrbot/core/platform/platform.py | 2 +- .../sources/aiocqhttp/aiocqhttp_message_event.py | 4 ++-- .../sources/aiocqhttp/aiocqhttp_platform_adapter.py | 10 +++++----- .../platform/sources/dingtalk/dingtalk_adapter.py | 4 ++-- .../sources/discord/discord_platform_adapter.py | 8 ++++---- astrbot/core/platform/sources/lark/lark_adapter.py | 4 ++-- .../core/platform/sources/misskey/misskey_adapter.py | 12 ++++++------ astrbot/core/platform/sources/misskey/misskey_api.py | 6 +++--- .../qqofficial/qqofficial_platform_adapter.py | 2 +- .../sources/qqofficial_webhook/qo_webhook_adapter.py | 2 +- .../sources/qqofficial_webhook/qo_webhook_server.py | 2 +- astrbot/core/platform/sources/slack/client.py | 4 ++-- astrbot/core/platform/sources/slack/slack_adapter.py | 4 ++-- astrbot/core/platform/sources/telegram/tg_event.py | 2 +- .../core/platform/sources/webchat/webchat_adapter.py | 2 +- .../sources/wechatpadpro/wechatpadpro_adapter.py | 4 ++-- .../wechatpadpro/wechatpadpro_message_event.py | 10 +++++----- astrbot/core/platform/sources/wecom/wecom_adapter.py | 2 +- astrbot/core/platform/sources/wecom/wecom_event.py | 2 +- .../platform/sources/wecom_ai_bot/wecomai_adapter.py | 6 +++--- .../platform/sources/wecom_ai_bot/wecomai_server.py | 2 +- .../weixin_official_account/weixin_offacc_event.py | 2 +- astrbot/core/provider/entities.py | 2 +- astrbot/core/provider/provider.py | 2 +- astrbot/core/provider/sources/azure_tts_source.py | 4 ++-- .../core/provider/sources/gsv_selfhosted_source.py | 2 +- astrbot/core/provider/sources/openai_source.py | 2 +- .../provider/sources/sensevoice_selfhosted_source.py | 2 +- astrbot/core/provider/sources/whisper_api_source.py | 2 +- .../provider/sources/whisper_selfhosted_source.py | 2 +- astrbot/core/star/star_handler.py | 2 +- astrbot/core/star/star_manager.py | 12 ++++++------ astrbot/core/star/star_tools.py | 2 +- astrbot/core/umop_config_router.py | 2 +- astrbot/core/updator.py | 2 +- astrbot/core/utils/metrics.py | 2 +- astrbot/core/utils/session_waiter.py | 6 +++--- astrbot/core/utils/t2i/template_manager.py | 4 ++-- astrbot/dashboard/routes/chat.py | 2 +- astrbot/dashboard/routes/config.py | 8 ++++---- astrbot/dashboard/routes/knowledge_base.py | 8 ++++---- astrbot/dashboard/routes/platform.py | 2 +- astrbot/dashboard/routes/plugin.py | 2 +- astrbot/dashboard/routes/route.py | 2 +- astrbot/dashboard/routes/static_file.py | 2 +- astrbot/dashboard/server.py | 2 +- packages/astrbot/process_llm_request.py | 4 ++-- packages/builtin_commands/commands/provider.py | 2 +- packages/python_interpreter/main.py | 2 +- packages/reminder/main.py | 6 +++--- packages/session_controller/main.py | 2 +- packages/web_searcher/engines/__init__.py | 3 ++- tests/test_dashboard.py | 6 +++--- tests/test_security_fixes.py | 2 +- 72 files changed, 132 insertions(+), 131 deletions(-) diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 84b6e83c9..18f4d47e0 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -144,7 +144,7 @@ async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: cfg = _prepare_config(mcp_server_config.copy()) - def logging_callback(msg: str): + def logging_callback(msg: str) -> None: # Handle MCP service error logs print(f"MCP Server {name} Error: {msg}") self.server_errlogs.append(msg) @@ -214,7 +214,7 @@ def logging_callback(msg: str): **cfg, ) - def callback(msg: str): + def callback(msg: str) -> None: # Handle MCP service error logs self.server_errlogs.append(msg) diff --git a/astrbot/core/astrbot_config_mgr.py b/astrbot/core/astrbot_config_mgr.py index 5cc4adb69..c2bfb1c37 100644 --- a/astrbot/core/astrbot_config_mgr.py +++ b/astrbot/core/astrbot_config_mgr.py @@ -56,7 +56,7 @@ def _get_abconf_data(self) -> dict: ) return self.abconf_data - def _load_all_configs(self): + def _load_all_configs(self) -> None: """Load all configurations from the shared preferences.""" abconf_data = self._get_abconf_data() self.abconf_data = abconf_data diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index aa2ad4fee..875b1647e 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -66,7 +66,7 @@ def _config_schema_to_default_config(self, schema: dict) -> dict: """将 Schema 转换成 Config""" conf = {} - def _parse_schema(schema: dict, conf: dict): + def _parse_schema(schema: dict, conf: dict) -> None: for k, v in schema.items(): if v["type"] not in DEFAULT_VALUE_MAP: raise TypeError( diff --git a/astrbot/core/db/migration/shared_preferences_v3.py b/astrbot/core/db/migration/shared_preferences_v3.py index 3d1771b8b..05b514583 100644 --- a/astrbot/core/db/migration/shared_preferences_v3.py +++ b/astrbot/core/db/migration/shared_preferences_v3.py @@ -23,7 +23,7 @@ def _load_preferences(self): os.remove(self.path) return {} - def _save_preferences(self): + def _save_preferences(self) -> None: with open(self.path, "w") as f: json.dump(self._data, f, indent=4, ensure_ascii=False) f.flush() diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index faa7088f2..59aab736c 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -127,7 +127,7 @@ def _get_conn(self, db_path: str) -> sqlite3.Connection: conn.text_factory = str return conn - def _exec_sql(self, sql: str, params: tuple | None = None): + def _exec_sql(self, sql: str, params: tuple | None = None) -> None: conn = self.conn try: c = self.conn.cursor() diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 02ef5ea6d..5dfaf3d67 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -699,7 +699,7 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) @@ -722,7 +722,7 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) @@ -757,7 +757,7 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 773940c59..44cdccb83 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -47,7 +47,7 @@ async def dispatch(self) -> None: continue asyncio.create_task(scheduler.execute(event)) - def _print_event(self, event: AstrMessageEvent, conf_name: str): + def _print_event(self, event: AstrMessageEvent, conf_name: str) -> None: """用于记录事件信息 Args: diff --git a/astrbot/core/file_token_service.py b/astrbot/core/file_token_service.py index c807f44e9..42fbd23df 100644 --- a/astrbot/core/file_token_service.py +++ b/astrbot/core/file_token_service.py @@ -14,7 +14,7 @@ def __init__(self, default_timeout: float = 300) -> None: self.staged_files = {} # token: (file_path, expire_time) self.default_timeout = default_timeout - async def _cleanup_expired_tokens(self): + async def _cleanup_expired_tokens(self) -> None: """清理过期的令牌""" now = time.time() expired_tokens = [ diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index e8ead1a5e..1e9127d72 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -293,7 +293,7 @@ async def upload_document( await progress_callback("chunking", 100, 100) # 阶段3: 生成向量(带进度回调) - async def embedding_progress_callback(current, total): + async def embedding_progress_callback(current, total) -> None: if progress_callback: await progress_callback("embedding", current, total) diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index 9777fea0c..b8584a02c 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -58,7 +58,7 @@ async def initialize(self) -> None: logger.error(f"知识库模块初始化失败: {e}") logger.error(traceback.format_exc()) - async def _init_kb_database(self): + async def _init_kb_database(self) -> None: self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix()) await self.kb_db.initialize() await self.kb_db.migrate_to_v1() diff --git a/astrbot/core/log.py b/astrbot/core/log.py index db093f02d..5b3d290cc 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -193,7 +193,7 @@ def GetLogger(cls, log_name: str = "default"): class PluginFilter(logging.Filter): """插件过滤器类, 用于标记日志来源是插件还是核心组件""" - def filter(self, record): + def filter(self, record) -> bool: record.plugin_tag = ( "[Plug]" if is_plugin_path(record.pathname) else "[Core]" ) @@ -205,7 +205,7 @@ class FileNameFilter(logging.Filter): """ # 获取这个文件和父文件夹的名字:. 并且去除 .py - def filter(self, record): + def filter(self, record) -> bool: dirname = os.path.dirname(record.pathname) record.filename = ( os.path.basename(dirname) @@ -218,7 +218,7 @@ class LevelNameFilter(logging.Filter): """短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写""" # 添加短日志级别名称 - def filter(self, record): + def filter(self, record) -> bool: record.short_levelname = get_short_level_name(record.levelname) return True @@ -232,7 +232,7 @@ def filter(self, record): return logger @classmethod - def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker): + def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker) -> None: """设置队列处理器, 用于将日志消息发送到 LogBroker Args: diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index e8b7f38a1..ff2092b97 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -687,7 +687,7 @@ def file(self) -> str: return "" @file.setter - def file(self, value: str): + def file(self, value: str) -> None: """向前兼容, 设置file属性, 传入的参数可能是文件路径或URL Args: @@ -722,7 +722,7 @@ async def get_file(self, allow_return_url: bool = False) -> str: return "" - async def _download_file(self): + async def _download_file(self) -> None: """下载文件""" if not self.url: raise ValueError("Download failed: No URL provided in File component.") diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 7e3305f55..9178752e5 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -98,7 +98,7 @@ async def _apply_kb( self, event: AstrMessageEvent, req: ProviderRequest, - ): + ) -> None: """Apply knowledge base context to the provider request""" if not self.kb_agentic_mode: if req.prompt is None: @@ -126,7 +126,7 @@ async def _apply_file_extract( self, event: AstrMessageEvent, req: ProviderRequest, - ): + ) -> None: """Apply file extract to the provider request""" file_paths = [] file_names = [] @@ -198,7 +198,7 @@ def _modalities_fix( self, provider: Provider, req: ProviderRequest, - ): + ) -> None: """检查提供商的模态能力,清理请求中的不支持内容""" if req.image_urls: provider_cfg = provider.provider_config.get("modalities", ["image"]) @@ -218,7 +218,7 @@ def _plugin_tool_fix( self, event: AstrMessageEvent, req: ProviderRequest, - ): + ) -> None: """根据事件中的插件设置,过滤请求中的工具列表""" if event.plugins_name is not None and req.func_tool: new_tool_set = ToolSet() @@ -238,7 +238,7 @@ async def _handle_webchat( event: AstrMessageEvent, req: ProviderRequest, prov: Provider, - ): + ) -> None: """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title""" if not req.conversation: return @@ -294,7 +294,7 @@ async def _save_to_history( event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse | None, - ): + ) -> None: if ( not req or not req.conversation diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index b653b0a81..41e4a68c4 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -91,7 +91,7 @@ async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float: # random return random.uniform(self.interval[0], self.interval[1]) - async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]): + async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bool: """检查消息链是否为空 Args: diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 0fe939d1d..5361fa321 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -29,7 +29,7 @@ async def initialize(self) -> None: await stage_instance.initialize(self.ctx) self.stages.append(stage_instance) - async def _process_stages(self, event: AstrMessageEvent, from_stage=0): + async def _process_stages(self, event: AstrMessageEvent, from_stage=0) -> None: """依次执行各个阶段 Args: diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index d52304814..c1aa2e291 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -210,10 +210,10 @@ async def send_streaming( ) self._has_send_oper = True - async def _pre_send(self): + async def _pre_send(self) -> None: """调度器会在执行 send() 前调用该方法 deprecated in v3.5.18""" - async def _post_send(self): + async def _post_send(self) -> None: """调度器会在执行 send() 后调用该方法 deprecated in v3.5.18""" def set_result(self, result: MessageEventResult | str) -> None: diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 7e8127649..3db53fd48 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -78,7 +78,7 @@ def group_id(self) -> str: return "" @group_id.setter - def group_id(self, value: str | None): + def group_id(self, value: str | None) -> None: """设置 group_id""" if value: if self.group: diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 9dfff0882..3c2fa8036 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -149,7 +149,7 @@ async def load_platform(self, platform_config: dict) -> None: except Exception: logger.error(traceback.format_exc()) - async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None): + async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None) -> None: # 设置平台状态为运行中 if platform: platform.status = PlatformStatus.RUNNING diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index 46e992543..1263496e1 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -53,7 +53,7 @@ def status(self) -> PlatformStatus: return self._status @status.setter - def status(self, value: PlatformStatus): + def status(self, value: PlatformStatus) -> None: """设置平台运行状态""" self._status = value if value == PlatformStatus.RUNNING and self._started_at is None: diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 87177d6fd..1ad26c265 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -72,7 +72,7 @@ async def _dispatch_send( is_group: bool, session_id: str | None, messages: list[dict], - ): + ) -> None: # session_id 必须是纯数字字符串 session_id_int = ( int(session_id) if session_id and session_id.isdigit() else None @@ -97,7 +97,7 @@ async def send_message( event: Event | None = None, is_group: bool = False, session_id: str | None = None, - ): + ) -> None: """发送消息至 QQ 协议端(aiocqhttp)。 Args: diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index 6fc02d609..c8b53abba 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -62,31 +62,31 @@ def __init__( ) @self.bot.on_request() - async def request(event: Event): + async def request(event: Event) -> None: abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_notice() - async def notice(event: Event): + async def notice(event: Event) -> None: abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_message("group") - async def group(event: Event): + async def group(event: Event) -> None: abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_message("private") - async def private(event: Event): + async def private(event: Event) -> None: abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_websocket_connection - def on_websocket_connection(_): + def on_websocket_connection(_) -> None: logger.info("aiocqhttp(OneBot v11) 适配器已连接。") async def send_by_session( diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 090df74d7..78e61324b 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -236,7 +236,7 @@ async def handle_msg(self, abm: AstrBotMessage) -> None: async def run(self) -> None: # await self.client_.start() # 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。 - def start_client(loop: asyncio.AbstractEventLoop): + def start_client(loop: asyncio.AbstractEventLoop) -> None: try: self._shutdown_event = threading.Event() task = loop.create_task(self.client_.start()) @@ -253,7 +253,7 @@ def start_client(loop: asyncio.AbstractEventLoop): await loop.run_in_executor(None, start_client, loop) async def terminate(self) -> None: - def monkey_patch_close(): + def monkey_patch_close() -> NoReturn: raise KeyboardInterrupt("Graceful shutdown") if self.client_.websocket is not None: diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index fb8ce972d..ce25da419 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -126,7 +126,7 @@ async def run(self) -> None: """主要运行逻辑""" # 初始化回调函数 - async def on_received(message_data): + async def on_received(message_data) -> None: logger.debug(f"[Discord] 收到消息: {message_data}") if self.client_self_id is None: self.client_self_id = message_data.get("bot_id") @@ -143,7 +143,7 @@ async def on_received(message_data): self.client = DiscordBotClient(token, proxy) self.client.on_message_received = on_received - async def callback(): + async def callback() -> None: if self.enable_command_register: await self._collect_and_register_commands() if self.activity_name: @@ -362,7 +362,7 @@ def register_handler(self, handler_info) -> None: """注册处理器信息""" self.registered_handlers.append(handler_info) - async def _collect_and_register_commands(self): + async def _collect_and_register_commands(self) -> None: """收集所有指令并注册到Discord""" logger.info("[Discord] 开始收集并注册斜杠指令...") registered_commands = [] @@ -418,7 +418,7 @@ def _create_dynamic_callback(self, cmd_name: str): async def dynamic_callback( ctx: discord.ApplicationContext, params: str | None = None - ): + ) -> None: # 将平台特定的前缀'/'剥离,以适配通用的CommandFilter logger.debug(f"[Discord] 回调函数触发: {cmd_name}") logger.debug(f"[Discord] 回调函数参数: {ctx}") diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index 472d34e7f..def0f2cad 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -50,10 +50,10 @@ def __init__( if not self.bot_name: logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。") - async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1): + async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1) -> None: await self.convert_msg(event) - def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1): + def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1) -> None: asyncio.create_task(on_msg_event_recv(event)) self.event_handler = ( diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index aab2fe802..23418a91e 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -152,7 +152,7 @@ async def run(self) -> None: await self._start_websocket_connection() - def _register_event_handlers(self, streaming): + def _register_event_handlers(self, streaming) -> None: """注册事件处理器""" streaming.add_message_handler("notification", self._handle_notification) streaming.add_message_handler("main:notification", self._handle_notification) @@ -196,7 +196,7 @@ def _process_poll_data( message: AstrBotMessage, poll: dict[str, Any], message_parts: list[str], - ): + ) -> None: """处理投票数据,将其添加到消息中""" try: if not isinstance(message.raw_message, dict): @@ -235,7 +235,7 @@ def _extract_additional_fields(self, session, message_chain) -> dict[str, Any]: return fields - async def _start_websocket_connection(self): + async def _start_websocket_connection(self) -> None: backoff_delay = 1.0 max_backoff = 300.0 backoff_multiplier = 1.5 @@ -283,7 +283,7 @@ async def _start_websocket_connection(self): await asyncio.sleep(sleep_time) backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff) - async def _handle_notification(self, data: dict[str, Any]): + async def _handle_notification(self, data: dict[str, Any]) -> None: try: notification_type = data.get("type") logger.debug( @@ -307,7 +307,7 @@ async def _handle_notification(self, data: dict[str, Any]): except Exception as e: logger.error(f"[Misskey] 处理通知失败: {e}") - async def _handle_chat_message(self, data: dict[str, Any]): + async def _handle_chat_message(self, data: dict[str, Any]) -> None: try: sender_id = str( data.get("fromUserId", "") or data.get("fromUser", {}).get("id", ""), @@ -342,7 +342,7 @@ async def _handle_chat_message(self, data: dict[str, Any]): except Exception as e: logger.error(f"[Misskey] 处理聊天消息失败: {e}") - async def _debug_handler(self, data: dict[str, Any]): + async def _debug_handler(self, data: dict[str, Any]) -> None: event_type = data.get("type", "unknown") logger.debug( f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}", diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index 0a2acedda..86636b12c 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -3,7 +3,7 @@ import random import uuid from collections.abc import Awaitable, Callable -from typing import Any +from typing import Any, NoReturn try: import aiohttp @@ -187,7 +187,7 @@ async def listen(self) -> None: except Exception: pass - async def _handle_message(self, data: dict[str, Any]): + async def _handle_message(self, data: dict[str, Any]) -> None: message_type = data.get("type") body = data.get("body", {}) @@ -375,7 +375,7 @@ def session(self) -> aiohttp.ClientSession: self._session = aiohttp.ClientSession(headers=headers) return self._session - def _handle_response_status(self, status: int, endpoint: str): + def _handle_response_status(self, status: int, endpoint: str) -> NoReturn: """处理 HTTP 响应状态码""" if status == 400: logger.error(f"[Misskey API] 请求参数错误: {endpoint} (HTTP {status})") diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 4a2bd202d..23de9b83c 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -80,7 +80,7 @@ async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None abm.session_id = abm.sender.user_id self._commit(abm) - def _commit(self, abm: AstrBotMessage): + def _commit(self, abm: AstrBotMessage) -> None: self.platform.commit_event( QQOfficialMessageEvent( abm.message_str, diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index 61caee5aa..d34e43740 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -71,7 +71,7 @@ async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None abm.session_id = abm.sender.user_id self._commit(abm) - def _commit(self, abm: AstrBotMessage): + def _commit(self, abm: AstrBotMessage) -> None: self.platform.commit_event( QQOfficialWebhookMessageEvent( abm.message_str, diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 643263726..86bbf55b2 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -46,7 +46,7 @@ async def initialize(self) -> None: self.client.api = self.api self.client.http = self.http - async def bot_connect(): + async def bot_connect() -> None: pass self._connection = ConnectionSession( diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py index 2ec6d2164..efd7a6f3d 100644 --- a/astrbot/core/platform/sources/slack/client.py +++ b/astrbot/core/platform/sources/slack/client.py @@ -44,7 +44,7 @@ def __init__( self.shutdown_event = asyncio.Event() - def _setup_routes(self): + def _setup_routes(self) -> None: """设置路由""" @self.app.route(self.path, methods=["POST"]) @@ -143,7 +143,7 @@ def __init__( async def _handle_events( self, _: AsyncBaseSocketModeClient, req: SocketModeRequest - ): + ) -> None: """处理 Socket Mode 事件""" try: if self.socket_client is None: diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index 33342ee78..3184f5186 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -288,7 +288,7 @@ def _parse_blocks(self, blocks: list) -> list: return message_components - async def _handle_socket_event(self, req: SocketModeRequest): + async def _handle_socket_event(self, req: SocketModeRequest) -> None: """处理 Socket Mode 事件""" if req.type == "events_api": # 事件 API @@ -377,7 +377,7 @@ async def run(self) -> None: f"不支持的连接模式: {self.connection_mode},请使用 'socket' 或 'webhook'", ) - async def _handle_webhook_event(self, event_data: dict): + async def _handle_webhook_event(self, event_data: dict) -> None: """处理 Webhook 事件""" event = event_data.get("event", {}) diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 108b5e547..80f0cd6af 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -73,7 +73,7 @@ async def send_with_client( client: ExtBot, message: MessageChain, user_name: str, - ): + ) -> None: image_path = None has_reply = False diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 86f78606d..47dbb22ec 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -208,7 +208,7 @@ async def convert_message(self, data: tuple) -> AstrBotMessage: return abm def run(self) -> Coroutine[Any, Any, None]: - async def callback(data: tuple): + async def callback(data: tuple) -> None: abm = await self.convert_message(data) await self.handle_msg(abm) diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py index 78c960534..d2012d9a7 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py @@ -487,7 +487,7 @@ async def _process_chat_type( to_user_name: str, content: str, push_content: str, - ): + ) -> bool: """判断消息是群聊还是私聊,并设置 AstrBotMessage 的基本属性。""" if from_user_name == "weixin": return False @@ -640,7 +640,7 @@ async def _process_message_content( raw_message: dict, msg_type: int, content: str, - ): + ) -> None: """根据消息类型处理消息内容,填充 AstrBotMessage 的 message 列表。""" if msg_type == 1: # 文本消息 abm.message_str = content diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py index 101b8f2dd..5ee8d7c07 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py @@ -66,7 +66,7 @@ async def send_streaming( await self.send(buffer) return await super().send_streaming(generator, use_fallback) - async def _send_image(self, session: aiohttp.ClientSession, comp: Image): + async def _send_image(self, session: aiohttp.ClientSession, comp: Image) -> None: b64 = await comp.convert_to_base64() raw = self._validate_base64(b64) b64c = self._compress_image(raw) @@ -78,7 +78,7 @@ async def _send_image(self, session: aiohttp.ClientSession, comp: Image): url = f"{self.adapter.base_url}/message/SendImageNewMessage" await self._post(session, url, payload) - async def _send_text(self, session: aiohttp.ClientSession, text: str): + async def _send_text(self, session: aiohttp.ClientSession, text: str) -> None: if ( self.message_obj.type == MessageType.GROUP_MESSAGE # 确保是群聊消息 and self.adapter.settings.get( @@ -114,7 +114,7 @@ async def _send_text(self, session: aiohttp.ClientSession, text: str): url = f"{self.adapter.base_url}/message/SendTextMessage" await self._post(session, url, payload) - async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji): + async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji) -> None: payload = { "EmojiList": [ { @@ -127,7 +127,7 @@ async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji): url = f"{self.adapter.base_url}/message/SendEmojiMessage" await self._post(session, url, payload) - async def _send_voice(self, session: aiohttp.ClientSession, comp: Record): + async def _send_voice(self, session: aiohttp.ClientSession, comp: Record) -> None: record_path = await comp.convert_to_file_path() # 默认已经存在 data/temp 中 b64, duration = await audio_to_tencent_silk_base64(record_path) @@ -157,7 +157,7 @@ def _compress_image(data: bytes) -> str: # logger.info("图片处理完成!!!") return base64.b64encode(buf.getvalue()).decode() - async def _post(self, session, url, payload): + async def _post(self, session, url, payload) -> None: params = {"key": self.adapter.auth_key} try: async with session.post(url, params=params, json=payload) as resp: diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 66b5e9588..fdea8423b 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -182,7 +182,7 @@ def __init__( self.client.__setattr__("API_BASE_URL", self.api_base_url) - async def callback(msg: BaseMessage): + async def callback(msg: BaseMessage) -> None: if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event": def get_latest_msg_item() -> dict | None: diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index d9ac60ba1..865a14234 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -37,7 +37,7 @@ async def send_with_client( client: WeChatClient, message: MessageChain, user_name: str, - ): + ) -> None: pass async def split_plain(self, plain: str) -> list[str]: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index e4d34724a..08a07ed1b 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -153,7 +153,7 @@ def __init__( self._handle_queued_message, ) - async def _handle_queued_message(self, data: dict): + async def _handle_queued_message(self, data: dict) -> None: """处理队列中的消息,类似webchat的callback""" try: abm = await self.convert_message(data) @@ -313,7 +313,7 @@ async def _enqueue_message( callback_params: dict[str, str], stream_id: str, session_id: str, - ): + ) -> None: """将消息放入队列进行异步处理""" input_queue = self.queue_mgr.get_or_create_queue(stream_id) _ = self.queue_mgr.get_or_create_back_queue(stream_id) @@ -426,7 +426,7 @@ async def send_by_session( def run(self) -> Awaitable[Any]: """运行适配器,同时启动HTTP服务器和队列监听器""" - async def run_both(): + async def run_both() -> None: # 如果启用统一 webhook 模式,则不启动独立服务器 webhook_uuid = self.config.get("webhook_uuid") if self.unified_webhook_mode and webhook_uuid: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py index 82bfb8722..80ec5179e 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -43,7 +43,7 @@ def __init__( self.shutdown_event = asyncio.Event() - def _setup_routes(self): + def _setup_routes(self) -> None: """设置 Quart 路由""" # 使用 Quart 的 add_url_rule 方法添加路由 self.app.add_url_rule( diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py index 39ff7bc0e..995b16690 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -35,7 +35,7 @@ async def send_with_client( client: WeChatClient, message: MessageChain, user_name: str, - ): + ) -> None: pass async def split_plain(self, plain: str) -> list[str]: diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 843b30a19..9d0adeb82 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -279,7 +279,7 @@ def completion_text(self): return self._completion_text @completion_text.setter - def completion_text(self, value): + def completion_text(self, value) -> None: if self.result_chain: self.result_chain.chain = [ comp diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 00a150cc6..2979a8609 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -279,7 +279,7 @@ async def get_embeddings_batch( completed_count = 0 total_count = len(texts) - async def process_batch(batch_idx: int, batch_texts: list[str]): + async def process_batch(batch_idx: int, batch_texts: list[str]) -> None: nonlocal completed_count async with semaphore: for attempt in range(max_retries): diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py index 3a0a79c93..08180222a 100644 --- a/astrbot/core/provider/sources/azure_tts_source.py +++ b/astrbot/core/provider/sources/azure_tts_source.py @@ -48,7 +48,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._client.aclose() self._client = None - async def _sync_time(self): + async def _sync_time(self) -> None: try: response = await self.client.get(self.auth_time_url) response.raise_for_status() @@ -149,7 +149,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._client.aclose() self._client = None - async def _refresh_token(self): + async def _refresh_token(self) -> None: token_url = ( f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken" ) diff --git a/astrbot/core/provider/sources/gsv_selfhosted_source.py b/astrbot/core/provider/sources/gsv_selfhosted_source.py index cbc900b5d..029f6af10 100644 --- a/astrbot/core/provider/sources/gsv_selfhosted_source.py +++ b/astrbot/core/provider/sources/gsv_selfhosted_source.py @@ -85,7 +85,7 @@ async def _make_request( logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}") raise - async def _set_model_weights(self): + async def _set_model_weights(self) -> None: """设置模型路径""" try: if self.gpt_weights_path: diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 1a2767f67..996a5ab24 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -74,7 +74,7 @@ def __init__(self, provider_config, provider_settings) -> None: self.reasoning_key = "reasoning_content" - def _maybe_inject_xai_search(self, payloads: dict, **kwargs): + def _maybe_inject_xai_search(self, payloads: dict, **kwargs) -> None: """当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。 - 仅在 provider_config.xai_native_search 为 True 时生效 diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index fd0615799..965b83a5a 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -52,7 +52,7 @@ async def get_timestamped_path(self) -> str: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") return os.path.join("data", "temp", f"{timestamp}") - async def _is_silk_file(self, file_path): + async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" with open(file_path, "rb") as f: file_header = f.read(8) diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index fa69206ef..9532ca5b9 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -38,7 +38,7 @@ def __init__( self.set_model(provider_config["model"]) - async def _get_audio_format(self, file_path): + async def _get_audio_format(self, file_path) -> str | None: # 定义要检测的头部字节 silk_header = b"SILK" amr_header = b"#!AMR" diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index 68ef73e46..d5d2dc340 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -40,7 +40,7 @@ async def initialize(self) -> None: ) logger.info("Whisper 模型加载完成。") - async def _is_silk_file(self, file_path): + async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" with open(file_path, "rb") as f: file_header = f.read(8) diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index ed50ee5bb..0d617d46a 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -25,7 +25,7 @@ def append(self, handler: StarHandlerMetadata) -> None: self._handlers.append(handler) self._handlers.sort(key=lambda h: -h.extras_configs["priority"]) - def _print_handlers(self): + def _print_handlers(self) -> None: for handler in self._handlers: print(handler.handler_full_name) diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 20aa40fcc..e054aa5a7 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -65,7 +65,7 @@ def __init__(self, context: Context, config: AstrBotConfig) -> None: if os.getenv("ASTRBOT_RELOAD", "0") == "1": asyncio.create_task(self._watch_plugins_changes()) - async def _watch_plugins_changes(self): + async def _watch_plugins_changes(self) -> None: """监视插件文件变化""" try: async for changes in awatch( @@ -82,7 +82,7 @@ async def _watch_plugins_changes(self): logger.error(f"插件热重载监视任务异常: {e!s}") logger.error(traceback.format_exc()) - async def _handle_file_changes(self, changes): + async def _handle_file_changes(self, changes) -> None: """处理文件变化""" logger.info(f"检测到文件变化: {changes}") plugins_to_check = [] @@ -166,7 +166,7 @@ def _get_plugin_modules(self) -> list[dict]: plugins.extend(_p) return plugins - async def _check_plugin_dept_update(self, target_plugin: str | None = None): + async def _check_plugin_dept_update(self, target_plugin: str | None = None) -> bool | None: """检查插件的依赖 如果 target_plugin 为 None,则检查所有插件的依赖 """ @@ -263,7 +263,7 @@ def _purge_modules( module_patterns: list[str] | None = None, root_dir_name: str | None = None, is_reserved: bool = False, - ): + ) -> None: """从 sys.modules 中移除指定的模块 可以基于模块名模式或插件目录名移除模块,用于清理插件相关的模块缓存 @@ -774,7 +774,7 @@ async def uninstall_plugin( except Exception as e: logger.warning(f"删除插件持久化数据失败 (plugins_data): {e!s}") - async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): + async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str) -> None: """解绑并移除一个插件。 Args: @@ -878,7 +878,7 @@ async def turn_off_plugin(self, plugin_name: str) -> None: plugin.activated = False @staticmethod - async def _terminate_plugin(star_metadata: StarMetadata): + async def _terminate_plugin(star_metadata: StarMetadata) -> None: """终止插件,调用插件的 terminate() 和 __del__() 方法""" logger.info(f"正在终止插件 {star_metadata.name} ...") diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 7a66449b4..4d85131fc 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -89,7 +89,7 @@ async def send_message_by_id( id: str, message_chain: MessageChain, platform: str = "aiocqhttp", - ): + ) -> None: """根据 id(例如qq号, 群号等) 直接, 主动地发送消息 Args: diff --git a/astrbot/core/umop_config_router.py b/astrbot/core/umop_config_router.py index 049b96856..dad17815f 100644 --- a/astrbot/core/umop_config_router.py +++ b/astrbot/core/umop_config_router.py @@ -11,7 +11,7 @@ def __init__(self, sp: SharedPreferences) -> None: self._load_routing_table() - def _load_routing_table(self): + def _load_routing_table(self) -> None: """加载路由表""" # 从 SharedPreferences 中加载 umop_to_conf_id 映射 sp_data = self.sp.get( diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index 7c880ede6..9e4a0be84 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -44,7 +44,7 @@ def terminate_child_processes(self) -> None: except psutil.NoSuchProcess: pass - def _reboot(self, delay: int = 3): + def _reboot(self, delay: int = 3) -> None: """重启当前程序 在指定的延迟后,终止所有子进程并重新启动程序 这里只能使用 os.exec* 来重启程序 diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index f12019e3c..06acdba4b 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -40,7 +40,7 @@ def get_installation_id(): return "null" @staticmethod - async def upload(**kwargs): + async def upload(**kwargs) -> None: """上传相关非敏感的指标以更好地了解 AstrBot 的使用情况。上传的指标不会包含任何有关消息文本、用户信息等敏感信息。 Powered by TickStats. diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index 4241fc6f2..b327a6184 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -71,7 +71,7 @@ def keep(self, timeout: float = 0, reset_timeout=False) -> None: asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep - async def _holding(self, event: asyncio.Event, timeout: float): + async def _holding(self, event: asyncio.Event, timeout: float) -> None: """等待事件结束或超时""" try: await asyncio.wait_for(event.wait(), timeout) @@ -141,7 +141,7 @@ async def register_wait( finally: self._cleanup() - def _cleanup(self, error: Exception | None = None): + def _cleanup(self, error: Exception | None = None) -> None: """清理会话""" USER_SESSIONS.pop(self.session_id, None) try: @@ -151,7 +151,7 @@ def _cleanup(self, error: Exception | None = None): self.session_controller.stop(error) @classmethod - async def trigger(cls, session_id: str, event: AstrMessageEvent): + async def trigger(cls, session_id: str, event: AstrMessageEvent) -> None: """外部输入触发会话处理""" session = USER_SESSIONS.get(session_id) if not session or session.session_controller.future.done(): diff --git a/astrbot/core/utils/t2i/template_manager.py b/astrbot/core/utils/t2i/template_manager.py index 3df4cab79..b3eb0c9ff 100644 --- a/astrbot/core/utils/t2i/template_manager.py +++ b/astrbot/core/utils/t2i/template_manager.py @@ -28,7 +28,7 @@ def __init__(self) -> None: os.makedirs(self.user_template_dir, exist_ok=True) self._initialize_user_templates() - def _copy_core_templates(self, overwrite: bool = False): + def _copy_core_templates(self, overwrite: bool = False) -> None: """从内置目录复制核心模板到用户目录。""" for filename in self.CORE_TEMPLATES: src = os.path.join(self.builtin_template_dir, filename) @@ -36,7 +36,7 @@ def _copy_core_templates(self, overwrite: bool = False): if os.path.exists(src) and (overwrite or not os.path.exists(dst)): shutil.copyfile(src, dst) - def _initialize_user_templates(self): + def _initialize_user_templates(self) -> None: """如果用户目录下缺少核心模板,则进行复制。""" self._copy_core_templates(overwrite=False) diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index cfb750803..ff9f78908 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -510,7 +510,7 @@ def _extract_attachment_ids(self, history_list) -> list[str]: attachment_ids.append(part["attachment_id"]) return attachment_ids - async def _delete_attachments(self, attachment_ids: list[str]): + async def _delete_attachments(self, attachment_ids: list[str]) -> None: """删除附件(包括数据库记录和磁盘文件)""" try: attachments = await self.db.get_attachments(attachment_ids) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 642e4aefa..2b34c5063 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -49,7 +49,7 @@ def try_cast(value: Any, type_: str): def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]: errors = [] - def validate(data: dict, metadata: dict = schema, path=""): + def validate(data: dict, metadata: dict = schema, path="") -> None: for key, value in data.items(): if key not in metadata: continue @@ -677,7 +677,7 @@ async def get_llm_tools(self): tools = tool_mgr.get_func_desc_openai_style() return Response().ok(tools).__dict__ - async def _register_platform_logo(self, platform, platform_default_tmpl): + async def _register_platform_logo(self, platform, platform_default_tmpl) -> None: """注册平台logo文件并生成访问令牌""" if not platform.logo_path: return @@ -800,7 +800,7 @@ async def _get_plugin_config(self, plugin_name: str): async def _save_astrbot_configs( self, post_configs: dict, conf_id: str | None = None - ): + ) -> None: try: if conf_id not in self.acm.confs: raise ValueError(f"配置文件 {conf_id} 不存在") @@ -816,7 +816,7 @@ async def _save_astrbot_configs( except Exception as e: raise e - async def _save_plugin_configs(self, post_configs: dict, plugin_name: str): + async def _save_plugin_configs(self, post_configs: dict, plugin_name: str) -> None: md = None for plugin_md in star_registry: if plugin_md.name == plugin_name: diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index d7db42c40..f5564d6b2 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -76,7 +76,7 @@ async def _background_upload_task( batch_size: int, tasks_limit: int, max_retries: int, - ): + ) -> None: """后台上传任务""" try: # 初始化任务状态 @@ -112,7 +112,7 @@ async def _background_upload_task( ) # 创建进度回调函数 - async def progress_callback(stage, current, total): + async def progress_callback(stage, current, total) -> None: if task_id in self.upload_progress: self.upload_progress[task_id].update( { @@ -1013,7 +1013,7 @@ async def _background_upload_from_url_task( max_retries: int, enable_cleaning: bool, cleaning_provider_id: str | None, - ): + ) -> None: """后台上传URL任务""" try: # 初始化任务状态 @@ -1033,7 +1033,7 @@ async def _background_upload_from_url_task( } # 创建进度回调函数 - async def progress_callback(stage, current, total): + async def progress_callback(stage, current, total) -> None: if task_id in self.upload_progress: self.upload_progress[task_id].update( { diff --git a/astrbot/dashboard/routes/platform.py b/astrbot/dashboard/routes/platform.py index 5b709a628..0b50c930e 100644 --- a/astrbot/dashboard/routes/platform.py +++ b/astrbot/dashboard/routes/platform.py @@ -26,7 +26,7 @@ def __init__( self._register_webhook_routes() - def _register_webhook_routes(self): + def _register_webhook_routes(self) -> None: """注册 webhook 路由""" # 统一 webhook 入口,支持 GET 和 POST self.app.add_url_rule( diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index c249b07b7..ff8bc5474 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -256,7 +256,7 @@ def _load_plugin_cache(self, cache_file: str): logger.warning(f"加载插件市场缓存失败: {e}") return None - def _save_plugin_cache(self, cache_file: str, data, md5: str | None = None): + def _save_plugin_cache(self, cache_file: str, data, md5: str | None = None) -> None: """保存插件市场数据到本地缓存""" try: # 确保目录存在 diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index 9fe3c307d..53c623443 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -19,7 +19,7 @@ def __init__(self, context: RouteContext) -> None: self.config = context.config def register_routes(self) -> None: - def _add_rule(path, method, func): + def _add_rule(path, method, func) -> None: # 统一添加 /api 前缀 full_path = f"/api{path}" self.app.add_url_rule(full_path, view_func=func, methods=[method]) diff --git a/astrbot/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py index 3d3d0ca51..e056b6c5a 100644 --- a/astrbot/dashboard/routes/static_file.py +++ b/astrbot/dashboard/routes/static_file.py @@ -30,7 +30,7 @@ def __init__(self, context: RouteContext) -> None: self.app.add_url_rule(i, view_func=self.index) @self.app.errorhandler(404) - async def page_not_found(e): + async def page_not_found(e) -> str: return "404 Not found。如果你初次使用打开面板发现 404, 请参考文档: https://astrbot.app/faq.html。如果你正在测试回调地址可达性,显示这段文字说明测试成功了。" async def index(self): diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 888da9009..bfc3df0b3 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -168,7 +168,7 @@ def get_process_using_port(self, port: int) -> str: except Exception as e: return f"获取进程信息失败: {e!s}" - def _init_jwt_secret(self): + def _init_jwt_secret(self) -> None: if not self.config.get("dashboard", {}).get("jwt_secret", None): # 如果没有设置 JWT 密钥,则生成一个新的密钥 jwt_secret = os.urandom(32).hex() diff --git a/packages/astrbot/process_llm_request.py b/packages/astrbot/process_llm_request.py index 3a097d76b..8250233ed 100644 --- a/packages/astrbot/process_llm_request.py +++ b/packages/astrbot/process_llm_request.py @@ -21,7 +21,7 @@ def __init__(self, context: star.Context) -> None: else: logger.info(f"Timezone set to: {self.timezone}") - async def _ensure_persona(self, req: ProviderRequest, cfg: dict, umo: str): + async def _ensure_persona(self, req: ProviderRequest, cfg: dict, umo: str) -> None: """确保用户人格已加载""" if not req.conversation: return @@ -77,7 +77,7 @@ async def _ensure_img_caption( req: ProviderRequest, cfg: dict, img_cap_prov_id: str, - ): + ) -> None: try: caption = await self._request_img_caption( img_cap_prov_id, diff --git a/packages/builtin_commands/commands/provider.py b/packages/builtin_commands/commands/provider.py index 408511d1a..ede51a7c8 100644 --- a/packages/builtin_commands/commands/provider.py +++ b/packages/builtin_commands/commands/provider.py @@ -17,7 +17,7 @@ def _log_reachability_failure( provider_capability_type: ProviderType | None, err_code: str, err_reason: str, - ): + ) -> None: """记录不可达原因到日志。""" meta = provider.meta() logger.warning( diff --git a/packages/python_interpreter/main.py b/packages/python_interpreter/main.py index 8646f8f3a..1bb75a7ae 100644 --- a/packages/python_interpreter/main.py +++ b/packages/python_interpreter/main.py @@ -171,7 +171,7 @@ async def get_image_name(self) -> str: return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}" return self.config["sandbox"]["image"] - def _save_config(self): + def _save_config(self) -> None: with open(PATH, "w") as f: json.dump(self.config, f) diff --git a/packages/reminder/main.py b/packages/reminder/main.py index fcc849aa9..e0395d457 100644 --- a/packages/reminder/main.py +++ b/packages/reminder/main.py @@ -38,7 +38,7 @@ def __init__(self, context: star.Context) -> None: self._init_scheduler() self.scheduler.start() - def _init_scheduler(self): + def _init_scheduler(self) -> None: """Initialize the scheduler.""" for group in self.reminder_data: for reminder in self.reminder_data[group]: @@ -82,7 +82,7 @@ def check_is_outdated(self, reminder: dict): return reminder_time < datetime.datetime.now(self.timezone) return False - async def _save_data(self): + async def _save_data(self) -> None: """Save the reminder data.""" reminder_file = os.path.join(get_astrbot_data_path(), "astrbot-reminder.json") with open(reminder_file, "w", encoding="utf-8") as f: @@ -246,7 +246,7 @@ async def reminder_rm(self, event: AstrMessageEvent, index: int): await self._save_data() yield event.plain_result("成功删除待办事项:\n" + reminder["text"]) - async def _reminder_callback(self, unified_msg_origin: str, d: dict): + async def _reminder_callback(self, unified_msg_origin: str, d: dict) -> None: """The callback function of the reminder.""" logger.info(f"Reminder Activated: {d['text']}, created by {unified_msg_origin}") await self.context.send_message( diff --git a/packages/session_controller/main.py b/packages/session_controller/main.py index b1aecc766..fdc295313 100644 --- a/packages/session_controller/main.py +++ b/packages/session_controller/main.py @@ -91,7 +91,7 @@ async def handle_empty_mention(self, event: AstrMessageEvent): async def empty_mention_waiter( controller: SessionController, event: AstrMessageEvent, - ): + ) -> None: event.message_obj.message.insert( 0, Comp.At(qq=event.get_self_id(), name=event.get_self_id()), diff --git a/packages/web_searcher/engines/__init__.py b/packages/web_searcher/engines/__init__.py index 699438602..ad893d5c2 100644 --- a/packages/web_searcher/engines/__init__.py +++ b/packages/web_searcher/engines/__init__.py @@ -1,6 +1,7 @@ import random import urllib.parse from dataclasses import dataclass +from typing import NoReturn from aiohttp import ClientSession from bs4 import BeautifulSoup, Tag @@ -48,7 +49,7 @@ def __init__(self) -> None: def _set_selector(self, selector: str) -> str: raise NotImplementedError - def _get_next_page(self, query: str): + def _get_next_page(self, query: str) -> NoReturn: raise NotImplementedError async def _get_html(self, url: str, data: dict | None = None) -> str: diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index d53984814..1ae2576d5 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -183,15 +183,15 @@ async def test_do_update( temp_release_dir = tmp_path_factory.mktemp("release") release_path = temp_release_dir / "astrbot" - async def mock_update(*args, **kwargs): + async def mock_update(*args, **kwargs) -> None: """Mocks the update process by creating a directory in the temp path.""" os.makedirs(release_path, exist_ok=True) - async def mock_download_dashboard(*args, **kwargs): + async def mock_download_dashboard(*args, **kwargs) -> None: """Mocks the dashboard download to prevent network access.""" return - async def mock_pip_install(*args, **kwargs): + async def mock_pip_install(*args, **kwargs) -> None: """Mocks pip install to prevent actual installation.""" return diff --git a/tests/test_security_fixes.py b/tests/test_security_fixes.py index 06939ebf4..8e5cf7f6d 100644 --- a/tests/test_security_fixes.py +++ b/tests/test_security_fixes.py @@ -66,7 +66,7 @@ def test_azure_tts_signature_uses_secrets() -> None: "OTTS_AUTH_TIME": "https://example.com/api/time", } - async def test_nonce_generation(): + async def test_nonce_generation() -> None: async with OTTSProvider(config) as provider: # Mock time sync to avoid actual API calls provider.time_offset = 0 From dc873ae0bcc719e56aa2d1ecb5c026195e3fb69d Mon Sep 17 00:00:00 2001 From: Dale Null Date: Thu, 11 Dec 2025 16:23:06 +0800 Subject: [PATCH 29/44] =?UTF-8?q?feat:=20=E4=B8=BA=20build=5Fplug=5Flist?= =?UTF-8?q?=20=E6=B7=BB=E5=8A=A0=20PluginInfo=20TypedDict=20=E4=B8=8E?= =?UTF-8?q?=E7=B1=BB=E5=9E=8B=E6=B3=A8=E8=A7=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: aider (openai/gemini-3-pro-high) --- astrbot/cli/utils/plugin.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index 81f59e0bf..94dd84003 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -3,6 +3,7 @@ from enum import Enum from io import BytesIO from pathlib import Path +from typing import Optional, TypedDict from zipfile import ZipFile import click @@ -19,6 +20,16 @@ class PluginStatus(str, Enum): NOT_PUBLISHED = "未发布" +class PluginInfo(TypedDict): + name: str + desc: str + version: str + author: str + repo: str + status: PluginStatus + local_path: Optional[str] + + def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None: """从 Git 仓库下载代码并解压到指定路径""" temp_dir = Path(tempfile.mkdtemp()) @@ -102,18 +113,18 @@ def load_yaml_metadata(plugin_dir: Path) -> dict: return {} -def build_plug_list(plugins_dir: Path) -> list: +def build_plug_list(plugins_dir: Path) -> list[PluginInfo]: """构建插件列表,包含本地和在线插件信息 Args: plugins_dir (Path): 插件目录路径 Returns: - list: 包含插件信息的字典列表 + list[PluginInfo]: 包含插件信息的字典列表 """ # 获取本地插件信息 - result = [] + result: list[PluginInfo] = [] if plugins_dir.exists(): for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]: plugin_dir = plugins_dir / plugin_name @@ -141,7 +152,7 @@ def build_plug_list(plugins_dir: Path) -> list: ) # 获取在线插件列表 - online_plugins = [] + online_plugins: list[PluginInfo] = [] try: with httpx.Client() as client: resp = client.get("https://api.soulter.top/astrbot/plugins") From c0babb0166e139b806a1730ed3dc1ad9b1dbbdcb Mon Sep 17 00:00:00 2001 From: Dale Null Date: Thu, 11 Dec 2025 16:26:57 +0800 Subject: [PATCH 30/44] =?UTF-8?q?refactor:=20=E4=B8=BA=20display=5Fplugins?= =?UTF-8?q?=20=E5=A2=9E=E5=8A=A0=E5=8F=82=E6=95=B0=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E6=B3=A8=E8=A7=A3=E5=B9=B6=E5=AF=BC=E5=85=A5=20PluginInfo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: aider (openai/gemini-3-pro-high) --- astrbot/cli/commands/cmd_plug.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/astrbot/cli/commands/cmd_plug.py b/astrbot/cli/commands/cmd_plug.py index 9cf94365a..d08677b3e 100644 --- a/astrbot/cli/commands/cmd_plug.py +++ b/astrbot/cli/commands/cmd_plug.py @@ -5,6 +5,7 @@ import click from ..utils import ( + PluginInfo, PluginStatus, build_plug_list, check_astrbot_root, @@ -28,7 +29,11 @@ def _get_data_path() -> Path: return (base / "data").resolve() -def display_plugins(plugins, title=None, color=None) -> None: +def display_plugins( + plugins: list[PluginInfo], + title: str | None = None, + color: str | None = None, +) -> None: if title: click.echo(click.style(title, fg=color, bold=True)) From c13f5aba199ede2c8ec4cdc02aaa59ea53e6a769 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Thu, 11 Dec 2025 18:45:44 +0800 Subject: [PATCH 31/44] =?UTF-8?q?style:=20=E7=BB=9F=E4=B8=80=E5=A4=9A?= =?UTF-8?q?=E5=A4=84=E5=87=BD=E6=95=B0=E5=8F=82=E6=95=B0=E6=8D=A2=E8=A1=8C?= =?UTF-8?q?=E4=B8=8E=E7=B1=BB=E5=9E=8B=E6=B3=A8=E9=87=8A=E9=A3=8E=E6=A0=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/pipeline/result_decorate/stage.py | 2 +- .../sources/qqofficial_webhook/qo_webhook_adapter.py | 8 ++++++-- .../sources/qqofficial_webhook/qo_webhook_server.py | 4 +++- astrbot/core/platform/sources/telegram/tg_adapter.py | 4 +++- .../sources/wechatpadpro/wechatpadpro_message_event.py | 4 +++- astrbot/core/platform_message_history_mgr.py | 4 +++- astrbot/core/star/filter/permission.py | 4 +++- astrbot/core/star/star_manager.py | 4 +++- astrbot/core/utils/shared_preferences.py | 8 ++++++-- astrbot/core/zip_updator.py | 4 +++- astrbot/dashboard/routes/config.py | 4 +++- astrbot/dashboard/routes/t2i.py | 4 +++- packages/astrbot/long_term_memory.py | 4 +++- packages/astrbot/main.py | 8 ++++++-- packages/astrbot/process_llm_request.py | 4 +++- packages/builtin_commands/main.py | 4 +++- packages/python_interpreter/main.py | 4 +++- 17 files changed, 58 insertions(+), 20 deletions(-) diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 80869d3d8..d2e371634 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -140,7 +140,7 @@ async def process( if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage): async for _ in self.content_safe_check_stage.process( event, - check_text=text, + check_text=text, # type:ignore ): yield diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index d34e43740..ef0c562e6 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -30,7 +30,9 @@ def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter") -> None: self.platform = platform # 收到群消息 - async def on_group_at_message_create(self, message: botpy.message.GroupMessage) -> None: + async def on_group_at_message_create( + self, message: botpy.message.GroupMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -54,7 +56,9 @@ async def on_at_message_create(self, message: botpy.message.Message) -> None: self._commit(abm) # 收到私聊消息 - async def on_direct_message_create(self, message: botpy.message.DirectMessage) -> None: + async def on_direct_message_create( + self, message: botpy.message.DirectMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 86bbf55b2..5f35471ee 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -14,7 +14,9 @@ class QQOfficialWebhook: - def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Client) -> None: + def __init__( + self, config: dict, event_queue: asyncio.Queue, botpy_client: Client + ) -> None: self.appid = config["appid"] self.secret = config["secret"] self.port = config.get("port", 6196) diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index b327e8a4e..7e6d16c10 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -221,7 +221,9 @@ async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> Non text=self.config["start_message"], ) - async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: + async def message_handler( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: logger.debug(f"Telegram message: {update.message}") abm = await self.convert_message(update, context) if abm: diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py index 5ee8d7c07..8ecc4c64f 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py @@ -114,7 +114,9 @@ async def _send_text(self, session: aiohttp.ClientSession, text: str) -> None: url = f"{self.adapter.base_url}/message/SendTextMessage" await self._post(session, url, payload) - async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji) -> None: + async def _send_emoji( + self, session: aiohttp.ClientSession, comp: WechatEmoji + ) -> None: payload = { "EmojiList": [ { diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py index 476e32cb1..ad8bb44f6 100644 --- a/astrbot/core/platform_message_history_mgr.py +++ b/astrbot/core/platform_message_history_mgr.py @@ -40,7 +40,9 @@ async def get( history.reverse() return history - async def delete(self, platform_id: str, user_id: str, offset_sec: int = 86400) -> None: + async def delete( + self, platform_id: str, user_id: str, offset_sec: int = 86400 + ) -> None: """Delete platform message history records older than the specified offset.""" await self.db.delete_platform_message_offset( platform_id=platform_id, diff --git a/astrbot/core/star/filter/permission.py b/astrbot/core/star/filter/permission.py index 0017489cc..a70299fa9 100644 --- a/astrbot/core/star/filter/permission.py +++ b/astrbot/core/star/filter/permission.py @@ -14,7 +14,9 @@ class PermissionType(enum.Flag): class PermissionTypeFilter(HandlerFilter): - def __init__(self, permission_type: PermissionType, raise_error: bool = True) -> None: + def __init__( + self, permission_type: PermissionType, raise_error: bool = True + ) -> None: self.permission_type = permission_type self.raise_error = raise_error diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index e054aa5a7..ff78c6b85 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -166,7 +166,9 @@ def _get_plugin_modules(self) -> list[dict]: plugins.extend(_p) return plugins - async def _check_plugin_dept_update(self, target_plugin: str | None = None) -> bool | None: + async def _check_plugin_dept_update( + self, target_plugin: str | None = None + ) -> bool | None: """检查插件的依赖 如果 target_plugin 为 None,则检查所有插件的依赖 """ diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index ee22fa3db..a4f69100a 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -188,14 +188,18 @@ def range_get( return result - def put(self, key, value, scope: str | None = None, scope_id: str | None = None) -> None: + def put( + self, key, value, scope: str | None = None, scope_id: str | None = None + ) -> None: """设置偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.put_async(scope or "unknown", scope_id or "unknown", key, value), self._sync_loop, ).result() - def remove(self, key, scope: str | None = None, scope_id: str | None = None) -> None: + def remove( + self, key, scope: str | None = None, scope_id: str | None = None + ) -> None: """删除偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.remove_async(scope or "unknown", scope_id or "unknown", key), diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py index c5bf5b77f..6cea6b38d 100644 --- a/astrbot/core/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -149,7 +149,9 @@ async def check_update( body=f"{tag_name}\n\n{sel_release_data['body']}", ) - async def download_from_repo_url(self, target_path: str, repo_url: str, proxy="") -> None: + async def download_from_repo_url( + self, target_path: str, repo_url: str, proxy="" + ) -> None: author, repo, branch = self.parse_github_url(repo_url) logger.info(f"正在下载更新 {repo} ...") diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 2b34c5063..746f1cb85 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -122,7 +122,9 @@ def validate(data: dict, metadata: dict = schema, path="") -> None: return errors, data -def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False) -> None: +def save_config( + post_config: dict, config: AstrBotConfig, is_core: bool = False +) -> None: """验证并保存配置""" errors = None logger.info(f"Saving config, is_core={is_core}") diff --git a/astrbot/dashboard/routes/t2i.py b/astrbot/dashboard/routes/t2i.py index df8fbbaa7..8d06826be 100644 --- a/astrbot/dashboard/routes/t2i.py +++ b/astrbot/dashboard/routes/t2i.py @@ -12,7 +12,9 @@ class T2iRoute(Route): - def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle) -> None: + def __init__( + self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle + ) -> None: super().__init__(context) self.core_lifecycle = core_lifecycle self.config = core_lifecycle.astrbot_config diff --git a/packages/astrbot/long_term_memory.py b/packages/astrbot/long_term_memory.py index 52d015d4f..e08cdc515 100644 --- a/packages/astrbot/long_term_memory.py +++ b/packages/astrbot/long_term_memory.py @@ -171,7 +171,9 @@ async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> Non ) req.system_prompt += chats_str - async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse) -> None: + async def after_req_llm( + self, event: AstrMessageEvent, llm_resp: LLMResponse + ) -> None: if event.unified_msg_origin not in self.session_chats: return diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 91e5537b2..18a3f05cd 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -89,7 +89,9 @@ async def on_message(self, event: AstrMessageEvent): logger.error(f"主动回复失败: {e}") @filter.on_llm_request() - async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest) -> None: + async def decorate_llm_req( + self, event: AstrMessageEvent, req: ProviderRequest + ) -> None: """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" await self.proc_llm_req.process_llm_request(event, req) @@ -100,7 +102,9 @@ async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest) logger.error(f"ltm: {e}") @filter.on_llm_response() - async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse) -> None: + async def inject_reasoning( + self, event: AstrMessageEvent, resp: LLMResponse + ) -> None: """在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话""" umo = event.unified_msg_origin cfg = self.context.get_config(umo).get("provider_settings", {}) diff --git a/packages/astrbot/process_llm_request.py b/packages/astrbot/process_llm_request.py index 8250233ed..cec405926 100644 --- a/packages/astrbot/process_llm_request.py +++ b/packages/astrbot/process_llm_request.py @@ -115,7 +115,9 @@ async def _request_img_caption( f"Cannot get image caption because provider `{provider_id}` is not exist.", ) - async def process_llm_request(self, event: AstrMessageEvent, req: ProviderRequest) -> None: + async def process_llm_request( + self, event: AstrMessageEvent, req: ProviderRequest + ) -> None: """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" cfg: dict = self.ctx.get_config(umo=event.unified_msg_origin)[ "provider_settings" diff --git a/packages/builtin_commands/main.py b/packages/builtin_commands/main.py index 57a7217f8..12078d70d 100644 --- a/packages/builtin_commands/main.py +++ b/packages/builtin_commands/main.py @@ -190,7 +190,9 @@ async def groupnew_conv(self, message: AstrMessageEvent, sid: str) -> None: await self.conversation_c.groupnew_conv(message, sid) @filter.command("switch") - async def switch_conv(self, message: AstrMessageEvent, index: int | None = None) -> None: + async def switch_conv( + self, message: AstrMessageEvent, index: int | None = None + ) -> None: """通过 /ls 前面的序号切换对话""" await self.conversation_c.switch_conv(message, index) diff --git a/packages/python_interpreter/main.py b/packages/python_interpreter/main.py index 1bb75a7ae..33479f7ba 100644 --- a/packages/python_interpreter/main.py +++ b/packages/python_interpreter/main.py @@ -240,7 +240,9 @@ async def on_message(self, event: AstrMessageEvent): del self.user_waiting[uid] @filter.on_llm_request() - async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest) -> None: + async def on_llm_req( + self, event: AstrMessageEvent, request: ProviderRequest + ) -> None: if event.get_session_id() in self.user_file_msg_buffer: files = self.user_file_msg_buffer[event.get_session_id()] if not request.prompt: From 1256a667807895906a1ead3b36ddb4ba214b5bb6 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Fri, 12 Dec 2025 11:21:36 +0800 Subject: [PATCH 32/44] =?UTF-8?q?refactor:=20=E5=A2=9E=E5=BC=BA=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E6=B3=A8=E8=A7=A3=E5=B9=B6=E4=BF=AE=E6=AD=A3=20MCP=20?= =?UTF-8?q?=E5=AE=A2=E6=88=B7=E7=AB=AF=E4=B8=8E=E6=B6=88=E6=81=AF=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E7=AD=BE=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/agent/mcp_client.py | 15 +++++++++++---- astrbot/core/agent/message.py | 13 +++++++------ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 18f4d47e0..edf93be8a 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -2,7 +2,7 @@ import logging from contextlib import AsyncExitStack from datetime import timedelta -from typing import Generic +from typing import Any, Generic from tenacity import ( before_sleep_log, @@ -11,6 +11,7 @@ stop_after_attempt, wait_exponential, ) +from mcp.types import CallToolResult from astrbot import logger from astrbot.core.agent.run_context import ContextWrapper @@ -322,7 +323,7 @@ async def call_tool_with_reconnect( before_sleep=before_sleep_log(logger, logging.WARNING), reraise=True, ) - async def _call_with_retry(): + async def _call_with_retry() -> CallToolResult: if not self.session: raise ValueError("MCP session is not available for MCP function tools.") @@ -364,7 +365,11 @@ class MCPTool(FunctionTool, Generic[TContext]): """A function tool that calls an MCP service.""" def __init__( - self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs + self, + mcp_tool: mcp.Tool, + mcp_client: MCPClient, + mcp_server_name: str, + **kwargs: Any, # noqa: ANN401 ) -> None: super().__init__( name=mcp_tool.name, @@ -376,7 +381,9 @@ def __init__( self.mcp_server_name = mcp_server_name async def call( - self, context: ContextWrapper[TContext], **kwargs + self, + context: ContextWrapper[TContext], + **kwargs: Any, # noqa: ANN401 ) -> mcp.types.CallToolResult: return await self.mcp_client.call_tool_with_reconnect( tool_name=self.mcp_tool.name, diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py index 6e3a4012d..a7d27abe4 100644 --- a/astrbot/core/agent/message.py +++ b/astrbot/core/agent/message.py @@ -1,6 +1,7 @@ # Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation. # License: Apache License 2.0 +import builtins from typing import Any, ClassVar, Literal, cast from pydantic import BaseModel, GetCoreSchemaHandler, model_validator @@ -14,7 +15,7 @@ class ContentPart(BaseModel): type: str - def __init_subclass__(cls, **kwargs: Any) -> None: + def __init_subclass__(cls, **kwargs: Any) -> None: # noqa:ANN401 super().__init_subclass__(**kwargs) invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`" @@ -27,15 +28,15 @@ def __init_subclass__(cls, **kwargs: Any) -> None: @classmethod def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler + cls, source_type: builtins.type[BaseModel], handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: # If we're dealing with the base ContentPart class, use custom validation if cls.__name__ == "ContentPart": - def validate_content_part(value: Any) -> Any: + def validate_content_part(value: object) -> "ContentPart": # if it's already an instance of a ContentPart subclass, return it if hasattr(value, "__class__") and issubclass(value.__class__, cls): - return value + return cast("ContentPart", value) # if it's a dict with a type field, dispatch to the appropriate subclass if isinstance(value, dict) and "type" in value: @@ -122,7 +123,7 @@ class FunctionBody(BaseModel): extra_content: dict[str, Any] | None = None """Extra metadata for the tool call.""" - def model_dump(self, **kwargs: Any) -> dict[str, Any]: + def model_dump(self, **kwargs: Any) -> dict[str, Any]: # noqa:ANN401 if self.extra_content is None: kwargs.setdefault("exclude", set()).add("extra_content") return super().model_dump(**kwargs) @@ -155,7 +156,7 @@ class Message(BaseModel): """The ID of the tool call.""" @model_validator(mode="after") - def check_content_required(self): + def check_content_required(self) -> "Message": # assistant + tool_calls is not None: allow content to be None if self.role == "assistant" and self.tool_calls is not None: return self From ee5eb7b6d2e8c435cddad536423a6d1a8a0a602d Mon Sep 17 00:00:00 2001 From: Dale Null Date: Fri, 12 Dec 2025 11:36:05 +0800 Subject: [PATCH 33/44] =?UTF-8?q?fix:=20=E7=BB=9F=E4=B8=80=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D=20ANN=20=E6=B3=A8=E8=A7=A3=E9=97=AE=E9=A2=98=EF=BC=8C?= =?UTF-8?q?=E6=9B=BF=E6=8D=A2=20Any=20=E4=B8=BA=20object=EF=BC=8C=E5=AE=8C?= =?UTF-8?q?=E5=96=84=E8=BF=94=E5=9B=9E=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: aider (openai/gemini-3-pro-high) --- .../agent/runners/coze/coze_agent_runner.py | 4 ++-- .../agent/runners/coze/coze_api_client.py | 6 +++--- .../dashscope/dashscope_agent_runner.py | 20 +++++++++---------- .../agent/runners/dify/dify_agent_runner.py | 6 +++--- .../agent/runners/dify/dify_api_client.py | 8 ++++---- astrbot/core/agent/tool.py | 4 +++- astrbot/core/agent/tool_executor.py | 2 +- 7 files changed, 26 insertions(+), 24 deletions(-) diff --git a/astrbot/core/agent/runners/coze/coze_agent_runner.py b/astrbot/core/agent/runners/coze/coze_agent_runner.py index a8300bb71..c2e8a948e 100644 --- a/astrbot/core/agent/runners/coze/coze_agent_runner.py +++ b/astrbot/core/agent/runners/coze/coze_agent_runner.py @@ -70,7 +70,7 @@ async def reset( self.file_id_cache: dict[str, dict[str, str]] = {} @override - async def step(self): + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: """ 执行 Coze Agent 的一个步骤 """ @@ -113,7 +113,7 @@ async def step_until_done( async for resp in self.step(): yield resp - async def _execute_coze_request(self): + async def _execute_coze_request(self) -> T.AsyncGenerator[AgentResponse, None]: """执行 Coze 请求的核心逻辑""" prompt = self.req.prompt or "" session_id = self.req.session_id or "unknown" diff --git a/astrbot/core/agent/runners/coze/coze_api_client.py b/astrbot/core/agent/runners/coze/coze_api_client.py index f5799dfbb..a5e62520b 100644 --- a/astrbot/core/agent/runners/coze/coze_api_client.py +++ b/astrbot/core/agent/runners/coze/coze_api_client.py @@ -15,7 +15,7 @@ def __init__(self, api_key: str, api_base: str = "https://api.coze.cn") -> None: self.api_base = api_base self.session = None - async def _ensure_session(self): + async def _ensure_session(self) -> aiohttp.ClientSession: """确保HTTP session存在""" if self.session is None: connector = aiohttp.TCPConnector( @@ -208,7 +208,7 @@ async def chat_messages( except Exception as e: raise Exception(f"Coze API 流式请求失败: {e!s}") - async def clear_context(self, conversation_id: str): + async def clear_context(self, conversation_id: str) -> dict: """清空会话上下文 Args: @@ -247,7 +247,7 @@ async def get_message_list( order: str = "desc", limit: int = 10, offset: int = 0, - ): + ) -> dict: """获取消息列表 Args: diff --git a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py index 1aaf6e3b9..f14fae0ca 100644 --- a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +++ b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py @@ -82,7 +82,7 @@ def has_rag_options(self) -> bool: return False @override - async def step(self): + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: """ 执行 Dashscope Agent 的一个步骤 """ @@ -124,7 +124,7 @@ async def step_until_done( yield resp def _consume_sync_generator( - self, response: T.Any, response_queue: queue.Queue + self, response: T.Iterable[object], response_queue: queue.Queue ) -> None: """在线程中消费同步generator,将结果放入队列 @@ -278,7 +278,7 @@ async def _build_request_payload( return payload async def _handle_streaming_response( - self, response: T.Any, session_id: str + self, response: object, session_id: str ) -> T.AsyncGenerator[AgentResponse, None]: """处理流式响应 @@ -292,7 +292,7 @@ async def _handle_streaming_response( response_queue = queue.Queue() consumer_thread = threading.Thread( target=self._consume_sync_generator, - args=(response, response_queue), + args=(T.cast(T.Iterable[object], response), response_queue), daemon=True, ) consumer_thread.start() @@ -319,14 +319,14 @@ async def _handle_streaming_response( ( output_text, chunk_doc_refs, - response, + response_obj, ) = await self._process_stream_chunk(chunk, output_text) - if response: - if response.type == "err": - yield response + if response_obj: + if response_obj.type == "err": + yield response_obj return - yield response + yield response_obj if chunk_doc_refs: doc_references = chunk_doc_refs @@ -366,7 +366,7 @@ async def _handle_streaming_response( data=AgentResponseData(chain=chain), ) - async def _execute_dashscope_request(self): + async def _execute_dashscope_request(self) -> T.AsyncGenerator[AgentResponse, None]: """执行 Dashscope 请求的核心逻辑""" prompt = self.req.prompt or "" session_id = self.req.session_id or "unknown" diff --git a/astrbot/core/agent/runners/dify/dify_agent_runner.py b/astrbot/core/agent/runners/dify/dify_agent_runner.py index d9a8b7cd6..64f028dd9 100644 --- a/astrbot/core/agent/runners/dify/dify_agent_runner.py +++ b/astrbot/core/agent/runners/dify/dify_agent_runner.py @@ -63,7 +63,7 @@ async def reset( self.api_client = DifyAPIClient(self.api_key, self.api_base) @override - async def step(self): + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: """ 执行 Dify Agent 的一个步骤 """ @@ -106,7 +106,7 @@ async def step_until_done( async for resp in self.step(): yield resp - async def _execute_dify_request(self): + async def _execute_dify_request(self) -> T.AsyncGenerator[AgentResponse, None]: """执行 Dify 请求的核心逻辑""" prompt = self.req.prompt or "" session_id = self.req.session_id or "unknown" @@ -285,7 +285,7 @@ async def parse_dify_result(self, chunk: dict | str) -> MessageChain: # Chat return MessageChain(chain=[Comp.Plain(chunk)]) - async def parse_file(item: dict): + async def parse_file(item: dict) -> object: match item["type"]: case "image": return Comp.Image(file=item["url"], url=item["url"]) diff --git a/astrbot/core/agent/runners/dify/dify_api_client.py b/astrbot/core/agent/runners/dify/dify_api_client.py index 26da6dfe9..9e683d257 100644 --- a/astrbot/core/agent/runners/dify/dify_api_client.py +++ b/astrbot/core/agent/runners/dify/dify_api_client.py @@ -77,7 +77,7 @@ async def workflow_run( response_mode: str = "streaming", files: list[dict[str, Any]] | None = None, timeout: float = 60, - ): + ) -> AsyncGenerator[dict[str, Any], None]: if files is None: files = [] url = f"{self.api_base}/workflows/run" @@ -158,7 +158,7 @@ async def file_upload( async def close(self) -> None: await self.session.close() - async def get_chat_convs(self, user: str, limit: int = 20): + async def get_chat_convs(self, user: str, limit: int = 20) -> dict: # conversations. GET url = f"{self.api_base}/conversations" payload = { @@ -168,7 +168,7 @@ async def get_chat_convs(self, user: str, limit: int = 20): async with self.session.get(url, params=payload, headers=self.headers) as resp: return await resp.json() - async def delete_chat_conv(self, user: str, conversation_id: str): + async def delete_chat_conv(self, user: str, conversation_id: str) -> dict: # conversation. DELETE url = f"{self.api_base}/conversations/{conversation_id}" payload = { @@ -183,7 +183,7 @@ async def rename( name: str, user: str, auto_generate: bool = False, - ): + ) -> dict: # /conversations/:conversation_id/name url = f"{self.api_base}/conversations/{conversation_id}/name" payload = { diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 160563e26..03679a245 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -61,7 +61,9 @@ class FunctionTool(ToolSchema, Generic[TContext]): def __repr__(self) -> str: return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})" - async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult: + async def call( + self, context: ContextWrapper[TContext], **kwargs: object + ) -> ToolExecResult: """Run the tool with the given arguments. The handler field has priority.""" raise NotImplementedError( "FunctionTool.call() must be implemented by subclasses or set a handler." diff --git a/astrbot/core/agent/tool_executor.py b/astrbot/core/agent/tool_executor.py index 2704119d4..b35c34cad 100644 --- a/astrbot/core/agent/tool_executor.py +++ b/astrbot/core/agent/tool_executor.py @@ -13,5 +13,5 @@ async def execute( cls, tool: FunctionTool, run_context: ContextWrapper[TContext], - **tool_args, + **tool_args: object, ) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ... From c5db8a4862b460b4f4dc826cfcb6c1a8b8e6aaca Mon Sep 17 00:00:00 2001 From: Dale Null Date: Fri, 12 Dec 2025 11:42:35 +0800 Subject: [PATCH 34/44] =?UTF-8?q?fix:=20=E5=B0=86=20KnowledgeBaseQueryTool?= =?UTF-8?q?.call=20=E7=9A=84=20kwargs=20=E6=A0=87=E6=B3=A8=E4=B8=BA=20obje?= =?UTF-8?q?ct=EF=BC=8C=E5=B9=B6=E5=B0=86=20query=20=E8=BD=AC=E4=B8=BA=20st?= =?UTF-8?q?r=20=E4=BC=A0=E5=85=A5=E6=A3=80=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: aider (openai/gemini-3-pro-high) --- astrbot/core/pipeline/process_stage/utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/utils.py b/astrbot/core/pipeline/process_stage/utils.py index 24e052e1e..09965d7f3 100644 --- a/astrbot/core/pipeline/process_stage/utils.py +++ b/astrbot/core/pipeline/process_stage/utils.py @@ -31,13 +31,17 @@ class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]): ) async def call( - self, context: ContextWrapper[AstrAgentContext], **kwargs + self, context: ContextWrapper[AstrAgentContext], **kwargs: object ) -> ToolExecResult: query = kwargs.get("query", "") if not query: return "error: Query parameter is empty." + + # 显式转换为 str,解决类型检查报错 "object cannot be assigned to str" + query_str = str(query) + result = await retrieve_knowledge_base( - query=kwargs.get("query", ""), + query=query_str, umo=context.context.event.unified_msg_origin, context=context.context.context, ) From 9c8400714d614d9b09dfe8ba918bfe9bf94fad1e Mon Sep 17 00:00:00 2001 From: Dale Null Date: Fri, 12 Dec 2025 13:08:07 +0800 Subject: [PATCH 35/44] =?UTF-8?q?refactor:=20=E5=B0=86=E5=A4=9A=E5=A4=84?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E5=8F=82=E6=95=B0=E7=9A=84=20Any=20=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=20object=EF=BC=8C=E7=BB=9F=E4=B8=80=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E6=B3=A8=E8=A7=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/agent/handoff.py | 2 +- astrbot/core/agent/mcp_client.py | 4 +- astrbot/core/agent/message.py | 4 +- astrbot/core/agent/runners/base.py | 2 +- astrbot/core/conversation_mgr.py | 2 +- astrbot/core/db/__init__.py | 2 +- astrbot/core/db/sqlite.py | 2 +- astrbot/core/knowledge_base/chunking/base.py | 2 +- .../knowledge_base/chunking/fixed_size.py | 2 +- .../core/knowledge_base/chunking/recursive.py | 2 +- astrbot/core/message/components.py | 58 +++++++++---------- astrbot/core/pipeline/context_utils.py | 8 +-- astrbot/core/pipeline/process_stage/utils.py | 4 +- 13 files changed, 47 insertions(+), 47 deletions(-) diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index 37203b131..cf3cced38 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -12,7 +12,7 @@ def __init__( self, agent: Agent[TContext], parameters: dict | None = None, - **kwargs: Any, # noqa: ANN401 + **kwargs: object, ) -> None: self.agent = agent super().__init__( diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index edf93be8a..1f6c3cdb4 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -369,7 +369,7 @@ def __init__( mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, - **kwargs: Any, # noqa: ANN401 + **kwargs: object, ) -> None: super().__init__( name=mcp_tool.name, @@ -383,7 +383,7 @@ def __init__( async def call( self, context: ContextWrapper[TContext], - **kwargs: Any, # noqa: ANN401 + **kwargs: object, ) -> mcp.types.CallToolResult: return await self.mcp_client.call_tool_with_reconnect( tool_name=self.mcp_tool.name, diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py index a7d27abe4..aba9ef264 100644 --- a/astrbot/core/agent/message.py +++ b/astrbot/core/agent/message.py @@ -15,7 +15,7 @@ class ContentPart(BaseModel): type: str - def __init_subclass__(cls, **kwargs: Any) -> None: # noqa:ANN401 + def __init_subclass__(cls, **kwargs: object) -> None: super().__init_subclass__(**kwargs) invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`" @@ -123,7 +123,7 @@ class FunctionBody(BaseModel): extra_content: dict[str, Any] | None = None """Extra metadata for the tool call.""" - def model_dump(self, **kwargs: Any) -> dict[str, Any]: # noqa:ANN401 + def model_dump(self, **kwargs: Any) -> dict[str, Any]: if self.extra_content is None: kwargs.setdefault("exclude", set()).add("extra_content") return super().model_dump(**kwargs) diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py index 2b1466495..a9ddc0e35 100644 --- a/astrbot/core/agent/runners/base.py +++ b/astrbot/core/agent/runners/base.py @@ -25,7 +25,7 @@ async def reset( self, run_context: ContextWrapper[TContext], agent_hooks: BaseAgentRunHooks[TContext], - **kwargs: T.Any, # noqa: ANN401 + **kwargs: object, ) -> None: """Reset the agent to its initial state. This method should be called before starting a new run. diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 59f9e02e3..3548f0eb4 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -226,7 +226,7 @@ async def get_filtered_conversations( page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", - **kwargs: Any, # noqa: ANN401 + **kwargs: object, ) -> tuple[list[Conversation], int]: """获取过滤后的对话列表. diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 693b2a7a8..1dd204993 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -123,7 +123,7 @@ async def get_filtered_conversations( page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", - **kwargs: T.Any, # noqa: ANN401 + **kwargs: object, ) -> tuple[list[ConversationV2], int]: """Get conversations filtered by platform IDs and search query.""" ... diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 066b18e15..77dd0f784 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -161,7 +161,7 @@ async def get_filtered_conversations( page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", - **kwargs: T.Any, # noqa:ANN401 + **kwargs: object, ) -> tuple[list[ConversationV2], int]: async with self.get_db() as session: session: AsyncSession diff --git a/astrbot/core/knowledge_base/chunking/base.py b/astrbot/core/knowledge_base/chunking/base.py index a59fce3c0..a7915d737 100644 --- a/astrbot/core/knowledge_base/chunking/base.py +++ b/astrbot/core/knowledge_base/chunking/base.py @@ -14,7 +14,7 @@ class BaseChunker(ABC): """ @abstractmethod - async def chunk(self, text: str, **kwargs: Any) -> list[str]: # noqa:ANN401 + async def chunk(self, text: str, **kwargs: object) -> list[str]: """将文本分块 Args: diff --git a/astrbot/core/knowledge_base/chunking/fixed_size.py b/astrbot/core/knowledge_base/chunking/fixed_size.py index 058b49cd1..fb46603ee 100644 --- a/astrbot/core/knowledge_base/chunking/fixed_size.py +++ b/astrbot/core/knowledge_base/chunking/fixed_size.py @@ -25,7 +25,7 @@ def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50) -> None: self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap - async def chunk(self, text: str, **kwargs: Any) -> list[str]: # noqa:ANN401 + async def chunk(self, text: str, **kwargs: object) -> list[str]: """固定大小分块 Args: diff --git a/astrbot/core/knowledge_base/chunking/recursive.py b/astrbot/core/knowledge_base/chunking/recursive.py index ba71e99d7..611fcbefa 100644 --- a/astrbot/core/knowledge_base/chunking/recursive.py +++ b/astrbot/core/knowledge_base/chunking/recursive.py @@ -40,7 +40,7 @@ def __init__( "", # 字符 ] - async def chunk(self, text: str, **kwargs: Any) -> list[str]: # noqa:ANN401 + async def chunk(self, text: str, **kwargs: object) -> list[str]: """递归地将文本分割成块 Args: diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index a8ddb1e6a..4dfadfc2f 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -90,7 +90,7 @@ class Plain(BaseMessageComponent): text: str convert: bool | None = True - def __init__(self, text: str, convert: bool = True, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, text: str, convert: bool = True, **_: object) -> None: super().__init__(text=text, convert=convert, **_) def toDict(self): @@ -104,7 +104,7 @@ class Face(BaseMessageComponent): type = ComponentType.Face id: int - def __init__(self, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -119,7 +119,7 @@ class Record(BaseMessageComponent): # 额外 path: str | None - def __init__(self, file: str | None, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, file: str | None, **_: object) -> None: for k in _: if k == "url": pass @@ -127,17 +127,17 @@ def __init__(self, file: str | None, **_: T.Any) -> None: # noqa: ANN401 super().__init__(file=file, **_) @staticmethod - def fromFileSystem(path: str, **_: T.Any) -> "Record": # noqa: ANN401 + def fromFileSystem(path: str, **_: object) -> "Record": return Record(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromURL(url: str, **_: T.Any) -> "Record": # noqa: ANN401 + def fromURL(url: str, **_: object) -> "Record": if url.startswith("http://") or url.startswith("https://"): return Record(file=url, **_) raise Exception("not a valid url") @staticmethod - def fromBase64(bs64_data: str, **_: T.Any) -> "Record": # noqa: ANN401 + def fromBase64(bs64_data: str, **_: object) -> "Record": return Record(file=f"base64://{bs64_data}", **_) async def convert_to_file_path(self) -> str: @@ -222,15 +222,15 @@ class Video(BaseMessageComponent): # 额外 path: str | None = "" - def __init__(self, file: str, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, file: str, **_: object) -> None: super().__init__(file=file, **_) @staticmethod - def fromFileSystem(path: str, **_: T.Any) -> "Video": # noqa: ANN401 + def fromFileSystem(path: str, **_: object) -> "Video": return Video(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromURL(url: str, **_: T.Any) -> "Video": # noqa: ANN401 + def fromURL(url: str, **_: object) -> "Video": if url.startswith("http://") or url.startswith("https://"): return Video(file=url, **_) raise Exception("not a valid url") @@ -304,7 +304,7 @@ class At(BaseMessageComponent): qq: int | str # 此处str为all时代表所有人 name: str | None = "" - def __init__(self, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, **_: object) -> None: super().__init__(**_) def toDict(self) -> dict: @@ -317,28 +317,28 @@ def toDict(self) -> dict: class AtAll(At): qq: str = "all" - def __init__(self, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, **_: object) -> None: super().__init__(**_) class RPS(BaseMessageComponent): # TODO type = ComponentType.RPS - def __init__(self, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, **_: object) -> None: super().__init__(**_) class Dice(BaseMessageComponent): # TODO type = ComponentType.Dice - def __init__(self, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, **_: object) -> None: super().__init__(**_) class Shake(BaseMessageComponent): # TODO type = ComponentType.Shake - def __init__(self, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -349,7 +349,7 @@ class Share(BaseMessageComponent): content: str | None = "" image: str | None = "" - def __init__(self, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -358,7 +358,7 @@ class Contact(BaseMessageComponent): # TODO _type: str # type 字段冲突 id: int | None = 0 - def __init__(self, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -369,7 +369,7 @@ class Location(BaseMessageComponent): # TODO title: str | None = "" content: str | None = "" - def __init__(self, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -383,7 +383,7 @@ class Music(BaseMessageComponent): content: str | None = "" image: str | None = "" - def __init__(self, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, **_: object) -> None: # for k in _.keys(): # if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]: # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") @@ -403,21 +403,21 @@ class Image(BaseMessageComponent): path: str | None = "" file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识 - def __init__(self, file: str | None, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, file: str | None, **_: object) -> None: super().__init__(file=file, **_) @staticmethod - def fromURL(url: str, **_: T.Any) -> "Image": # noqa: ANN401 + def fromURL(url: str, **_: object) -> "Image": if url.startswith("http://") or url.startswith("https://"): return Image(file=url, **_) raise Exception("not a valid url") @staticmethod - def fromFileSystem(path: str, **_: T.Any) -> "Image": # noqa: ANN401 + def fromFileSystem(path: str, **_: object) -> "Image": return Image(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromBase64(base64: str, **_: T.Any) -> "Image": # noqa: ANN401 + def fromBase64(base64: str, **_: object) -> "Image": return Image(f"base64://{base64}", **_) @staticmethod @@ -526,7 +526,7 @@ class Reply(BaseMessageComponent): seq: int | None = 0 """deprecated""" - def __init__(self, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -535,7 +535,7 @@ class Poke(BaseMessageComponent): id: int | None = 0 qq: int | None = 0 - def __init__(self, type: str, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, type: str, **_: object) -> None: type = f"Poke:{type}" super().__init__(type=type, **_) @@ -544,7 +544,7 @@ class Forward(BaseMessageComponent): type = ComponentType.Forward id: str - def __init__(self, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -559,7 +559,7 @@ class Node(BaseMessageComponent): seq: str | list | None = "" # 忽略 time: int | None = 0 # 忽略 - def __init__(self, content: list[BaseMessageComponent], **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, content: list[BaseMessageComponent], **_: object) -> None: if isinstance(content, Node): # back content = [content] @@ -606,7 +606,7 @@ class Nodes(BaseMessageComponent): type = ComponentType.Nodes nodes: list[Node] - def __init__(self, nodes: list[Node], **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, nodes: list[Node], **_: object) -> None: super().__init__(nodes=nodes, **_) def toDict(self) -> dict: @@ -633,7 +633,7 @@ class Json(BaseMessageComponent): data: str | dict resid: int | None = 0 - def __init__(self, data: str | dict, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, data: str | dict, **_: object) -> None: if isinstance(data, dict): data = json.dumps(data) super().__init__(data=data, **_) @@ -788,7 +788,7 @@ class WechatEmoji(BaseMessageComponent): md5_len: int | None = 0 cdnurl: str | None = "" - def __init__(self, **_: T.Any) -> None: # noqa: ANN401 + def __init__(self, **_: object) -> None: super().__init__(**_) diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 27da6a27f..49fe3be3a 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -12,8 +12,8 @@ async def call_handler( event: AstrMessageEvent, handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]], - *args: T.Any, # noqa: ANN401 - **kwargs: T.Any, # noqa: ANN401 + *args: object, + **kwargs: object, ) -> T.AsyncGenerator[T.Any, None]: """执行事件处理函数并处理其返回结果 @@ -75,8 +75,8 @@ async def call_handler( async def call_event_hook( event: AstrMessageEvent, hook_type: EventType, - *args: T.Any, # noqa: ANN401 - **kwargs: T.Any, # noqa: ANN401 + *args: object, + **kwargs: object, ) -> bool: """调用事件钩子函数 diff --git a/astrbot/core/pipeline/process_stage/utils.py b/astrbot/core/pipeline/process_stage/utils.py index 09965d7f3..d6ad8f2c1 100644 --- a/astrbot/core/pipeline/process_stage/utils.py +++ b/astrbot/core/pipeline/process_stage/utils.py @@ -36,10 +36,10 @@ async def call( query = kwargs.get("query", "") if not query: return "error: Query parameter is empty." - + # 显式转换为 str,解决类型检查报错 "object cannot be assigned to str" query_str = str(query) - + result = await retrieve_knowledge_base( query=query_str, umo=context.context.event.unified_msg_origin, From eda5909d753818313f893539f0990a0d5ddcf505 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 17 Dec 2025 09:55:34 +0800 Subject: [PATCH 36/44] =?UTF-8?q?fix:=20=E8=AE=A9=20ContentPart.=5F=5Finit?= =?UTF-8?q?=5Fsubclass=5F=5F=20=E4=BD=BF=E7=94=A8=20Unpack[ConfigDict]=20?= =?UTF-8?q?=E4=B8=8E=E7=B1=BB=E5=9E=8B=E6=A3=80=E6=9F=A5=E5=AF=B9=E9=BD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: aider (openai/gpt-5) --- astrbot/core/agent/message.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py index aba9ef264..8565ff58d 100644 --- a/astrbot/core/agent/message.py +++ b/astrbot/core/agent/message.py @@ -6,6 +6,8 @@ from pydantic import BaseModel, GetCoreSchemaHandler, model_validator from pydantic_core import core_schema +from typing_extensions import Unpack +from pydantic.config import ConfigDict class ContentPart(BaseModel): @@ -15,7 +17,7 @@ class ContentPart(BaseModel): type: str - def __init_subclass__(cls, **kwargs: object) -> None: + def __init_subclass__(cls, **kwargs: Unpack[ConfigDict]) -> None: super().__init_subclass__(**kwargs) invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`" From 1539a3d3c38d18f08814e7205c436d23e716fbb0 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 17 Dec 2025 10:06:59 +0800 Subject: [PATCH 37/44] =?UTF-8?q?refactor:=20=E7=94=A8=20TypedDict+Unpack?= =?UTF-8?q?=20=E9=87=8D=E6=9E=84=20HandoffTool=20=5F=5Finit=5F=5F=20?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: aider (openai/gpt-5) --- astrbot/core/agent/handoff.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index cf3cced38..1ca10c818 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -1,8 +1,19 @@ +from __future__ import annotations + +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any, Generic +from typing_extensions import TypedDict, Unpack +from astrbot.core.message.message_event_result import MessageEventResult from .agent import Agent from .run_context import TContext -from .tool import FunctionTool +from .tool import FunctionTool, ParametersType + + +class HandoffInitKwargs(TypedDict, total=False): + handler: Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]] | None + handler_module_path: str | None + active: bool class HandoffTool(FunctionTool, Generic[TContext]): @@ -11,8 +22,8 @@ class HandoffTool(FunctionTool, Generic[TContext]): def __init__( self, agent: Agent[TContext], - parameters: dict | None = None, - **kwargs: object, + parameters: ParametersType | None = None, + **kwargs: Unpack[HandoffInitKwargs], ) -> None: self.agent = agent super().__init__( From a230bbdfddf1d602996317b6ca362e68d7c96a1d Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 17 Dec 2025 10:12:47 +0800 Subject: [PATCH 38/44] =?UTF-8?q?chore:=20=E7=A7=BB=E9=99=A4=E6=9C=AA?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E7=9A=84=20Any=20=E5=AF=BC=E5=85=A5=E5=B9=B6?= =?UTF-8?q?=E6=95=B4=E7=90=86=E5=AF=BC=E5=85=A5=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/agent/handoff.py | 8 ++++++-- astrbot/core/agent/mcp_client.py | 4 ++-- astrbot/core/agent/message.py | 2 +- astrbot/core/conversation_mgr.py | 1 - astrbot/core/knowledge_base/chunking/base.py | 1 - astrbot/core/knowledge_base/chunking/fixed_size.py | 2 -- astrbot/core/knowledge_base/chunking/recursive.py | 1 - 7 files changed, 9 insertions(+), 10 deletions(-) diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index 1ca10c818..755cc45a6 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -1,17 +1,21 @@ from __future__ import annotations from collections.abc import AsyncGenerator, Awaitable, Callable -from typing import Any, Generic +from typing import Generic + from typing_extensions import TypedDict, Unpack from astrbot.core.message.message_event_result import MessageEventResult + from .agent import Agent from .run_context import TContext from .tool import FunctionTool, ParametersType class HandoffInitKwargs(TypedDict, total=False): - handler: Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]] | None + handler: ( + Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult]] | None + ) handler_module_path: str | None active: bool diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index 1f6c3cdb4..13ac2d7de 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -2,8 +2,9 @@ import logging from contextlib import AsyncExitStack from datetime import timedelta -from typing import Any, Generic +from typing import Generic +from mcp.types import CallToolResult from tenacity import ( before_sleep_log, retry, @@ -11,7 +12,6 @@ stop_after_attempt, wait_exponential, ) -from mcp.types import CallToolResult from astrbot import logger from astrbot.core.agent.run_context import ContextWrapper diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py index 8565ff58d..8eff62c05 100644 --- a/astrbot/core/agent/message.py +++ b/astrbot/core/agent/message.py @@ -5,9 +5,9 @@ from typing import Any, ClassVar, Literal, cast from pydantic import BaseModel, GetCoreSchemaHandler, model_validator +from pydantic.config import ConfigDict from pydantic_core import core_schema from typing_extensions import Unpack -from pydantic.config import ConfigDict class ContentPart(BaseModel): diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 3548f0eb4..afdbd24f7 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -6,7 +6,6 @@ import json from collections.abc import Awaitable, Callable -from typing import Any from astrbot.core import sp from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment diff --git a/astrbot/core/knowledge_base/chunking/base.py b/astrbot/core/knowledge_base/chunking/base.py index a7915d737..11ae0caba 100644 --- a/astrbot/core/knowledge_base/chunking/base.py +++ b/astrbot/core/knowledge_base/chunking/base.py @@ -4,7 +4,6 @@ """ from abc import ABC, abstractmethod -from typing import Any class BaseChunker(ABC): diff --git a/astrbot/core/knowledge_base/chunking/fixed_size.py b/astrbot/core/knowledge_base/chunking/fixed_size.py index fb46603ee..90e1e3dba 100644 --- a/astrbot/core/knowledge_base/chunking/fixed_size.py +++ b/astrbot/core/knowledge_base/chunking/fixed_size.py @@ -3,8 +3,6 @@ 按照固定的字符数将文本分块,支持重叠区域。 """ -from typing import Any - from .base import BaseChunker diff --git a/astrbot/core/knowledge_base/chunking/recursive.py b/astrbot/core/knowledge_base/chunking/recursive.py index 611fcbefa..8e0dd7e9d 100644 --- a/astrbot/core/knowledge_base/chunking/recursive.py +++ b/astrbot/core/knowledge_base/chunking/recursive.py @@ -1,5 +1,4 @@ from collections.abc import Callable -from typing import Any from .base import BaseChunker From 80a7e3782926003ce5d7bfd8d4b1b85aff1c69ee Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 17 Dec 2025 11:37:42 +0800 Subject: [PATCH 39/44] =?UTF-8?q?refactor:=20=E4=BD=BF=E7=94=A8=20TypedDic?= =?UTF-8?q?t+Unpack=20=E9=87=8D=E6=9E=84=20get=5Ffiltered=5Fconversations?= =?UTF-8?q?=20=E7=9A=84=20kwargs=20=E7=B1=BB=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: aider (openai/gpt-5) --- astrbot/core/db/sqlite.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index b7a5754d9..d960e257d 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -31,6 +31,12 @@ NOT_GIVEN = T.TypeVar("NOT_GIVEN") TxResult = T.TypeVar("TxResult") +class FilterKwargs(T.TypedDict, total=False): + message_types: list[str] + platforms: list[str] + exclude_ids: list[str] + exclude_platforms: list[str] + class SQLiteDatabase(BaseDatabase): def __init__(self, db_path: str) -> None: @@ -165,7 +171,7 @@ async def get_filtered_conversations( page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", - **kwargs: object, + **kwargs: T.Unpack[FilterKwargs], ) -> tuple[list[ConversationV2], int]: async with self.get_db() as session: session: AsyncSession From 14de5842bbc2713de293c0812a314bf2b8af408c Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 17 Dec 2025 11:40:23 +0800 Subject: [PATCH 40/44] =?UTF-8?q?style:=20=E5=9C=A8=20NOT=5FGIVEN=20?= =?UTF-8?q?=E4=B8=8E=20TxResult=20=E4=B9=8B=E9=97=B4=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E7=A9=BA=E8=A1=8C=E4=BB=A5=E6=94=B9=E5=96=84=E5=8F=AF=E8=AF=BB?= =?UTF-8?q?=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/db/sqlite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index d960e257d..c06dcd7c3 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -31,6 +31,7 @@ NOT_GIVEN = T.TypeVar("NOT_GIVEN") TxResult = T.TypeVar("TxResult") + class FilterKwargs(T.TypedDict, total=False): message_types: list[str] platforms: list[str] From 8fc7f3c07fb9e6b84192f6140b44b11b0a1630eb Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 17 Dec 2025 11:45:18 +0800 Subject: [PATCH 41/44] =?UTF-8?q?fix:=20=E5=85=BC=E5=AE=B9=20Unpack=20?= =?UTF-8?q?=E7=9A=84=20typing=20=E5=AF=BC=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: aider (openai/gpt-5) --- astrbot/core/db/sqlite.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index c06dcd7c3..79b559ee7 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,6 +1,11 @@ import asyncio import threading import typing as T + +try: + from typing import Unpack # type: ignore[attr-defined] +except ImportError: + from typing_extensions import Unpack from collections.abc import Awaitable, Callable from datetime import datetime, timedelta, timezone @@ -172,7 +177,7 @@ async def get_filtered_conversations( page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", - **kwargs: T.Unpack[FilterKwargs], + **kwargs: Unpack[FilterKwargs], ) -> tuple[list[ConversationV2], int]: async with self.get_db() as session: session: AsyncSession From 280a3ed03ba709eab287c29497bdc5c998738dac Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 17 Dec 2025 14:31:01 +0800 Subject: [PATCH 42/44] =?UTF-8?q?refactor:=20=E5=B0=86=20chunk=20=E5=87=BD?= =?UTF-8?q?=E6=95=B0=E6=94=B9=E4=B8=BA=E5=85=B3=E9=94=AE=E5=AD=97=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E5=B9=B6=E5=8A=A0=E5=85=A5=E8=BE=B9=E7=95=8C=E5=A4=84?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: aider (openai/gpt-5) --- astrbot/core/knowledge_base/chunking/recursive.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/astrbot/core/knowledge_base/chunking/recursive.py b/astrbot/core/knowledge_base/chunking/recursive.py index 8e0dd7e9d..040f2ffd0 100644 --- a/astrbot/core/knowledge_base/chunking/recursive.py +++ b/astrbot/core/knowledge_base/chunking/recursive.py @@ -39,7 +39,13 @@ def __init__( "", # 字符 ] - async def chunk(self, text: str, **kwargs: object) -> list[str]: + async def chunk( + self, + text: str, + *, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + ) -> list[str]: """递归地将文本分割成块 Args: @@ -54,8 +60,11 @@ async def chunk(self, text: str, **kwargs: object) -> list[str]: if not text: return [] - overlap = kwargs.get("chunk_overlap", self.chunk_overlap) - chunk_size = kwargs.get("chunk_size", self.chunk_size) + overlap = self.chunk_overlap if chunk_overlap is None else chunk_overlap + chunk_size = self.chunk_size if chunk_size is None else chunk_size + if chunk_size <= 0: + return [text] + overlap = max(0, min(overlap, chunk_size - 1)) text_length = self.length_function(text) if text_length <= chunk_size: From 7f891482fc4501e257c740884401b6cf15ca3902 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 17 Dec 2025 14:32:42 +0800 Subject: [PATCH 43/44] =?UTF-8?q?fix:=20=E8=B0=83=E6=95=B4=20FixedSizeChun?= =?UTF-8?q?ker.chunk=20=E7=AD=BE=E5=90=8D=E5=B9=B6=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E5=85=B3=E9=94=AE=E5=AD=97=E5=8F=82=E6=95=B0=E4=B8=8E=E8=BE=B9?= =?UTF-8?q?=E7=95=8C=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: aider (openai/gpt-5) --- .../core/knowledge_base/chunking/fixed_size.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/astrbot/core/knowledge_base/chunking/fixed_size.py b/astrbot/core/knowledge_base/chunking/fixed_size.py index 90e1e3dba..cd146f137 100644 --- a/astrbot/core/knowledge_base/chunking/fixed_size.py +++ b/astrbot/core/knowledge_base/chunking/fixed_size.py @@ -23,7 +23,13 @@ def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50) -> None: self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap - async def chunk(self, text: str, **kwargs: object) -> list[str]: + async def chunk( + self, + text: str, + *, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + ) -> list[str]: """固定大小分块 Args: @@ -35,8 +41,11 @@ async def chunk(self, text: str, **kwargs: object) -> list[str]: list[str]: 分块后的文本列表 """ - chunk_size = kwargs.get("chunk_size", self.chunk_size) - chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap) + chunk_size = self.chunk_size if chunk_size is None else chunk_size + chunk_overlap = self.chunk_overlap if chunk_overlap is None else chunk_overlap + if chunk_size <= 0: + return [text] + chunk_overlap = max(0, min(chunk_overlap, chunk_size - 1)) chunks = [] start = 0 From 681b6d18f6a4c217eb53289917392911b4056718 Mon Sep 17 00:00:00 2001 From: Dale Null Date: Wed, 17 Dec 2025 14:37:32 +0800 Subject: [PATCH 44/44] =?UTF-8?q?fix:=20=E5=B0=86=20any=20=E6=9B=BF?= =?UTF-8?q?=E6=8D=A2=E4=B8=BA=20Any=EF=BC=8C=E5=B9=B6=E5=AF=BC=E5=85=A5=20?= =?UTF-8?q?typing=20=E7=9A=84=20Any?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/dashboard/routes/knowledge_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index fd9eff1a9..72ffd5fd3 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -4,6 +4,7 @@ import os import traceback import uuid +from typing import Any import aiofiles from quart import request @@ -75,7 +76,7 @@ def _init_task(self, task_id: str, status: str = "pending") -> None: } def _set_task_result( - self, task_id: str, status: str, result: any = None, error: str | None = None + self, task_id: str, status: str, result: Any = None, error: str | None = None ) -> None: self.upload_tasks[task_id] = { "status": status,