diff --git a/astrbot/cli/commands/cmd_conf.py b/astrbot/cli/commands/cmd_conf.py index a9bd40f00..64adf6fc4 100644 --- a/astrbot/cli/commands/cmd_conf.py +++ b/astrbot/cli/commands/cmd_conf.py @@ -104,7 +104,7 @@ def _save_config(config: dict[str, Any]) -> None: ) -def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None: +def _set_nested_item(obj: dict[str, Any], path: str, value: object) -> None: """设置嵌套字典中的值""" parts = path.split(".") for part in parts[:-1]: @@ -118,7 +118,7 @@ def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None: obj[parts[-1]] = value -def _get_nested_item(obj: dict[str, Any], path: str) -> Any: +def _get_nested_item(obj: dict[str, Any], path: str) -> object: """获取嵌套字典中的值""" parts = path.split(".") for part in parts: @@ -127,7 +127,7 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any: @click.group(name="conf") -def conf(): +def conf() -> None: """配置管理命令 支持的配置项: @@ -149,7 +149,7 @@ def conf(): @conf.command(name="set") @click.argument("key") @click.argument("value") -def set_config(key: str, value: str): +def set_config(key: str, value: str) -> None: """设置配置项的值""" if key not in CONFIG_VALIDATORS: raise click.ClickException(f"不支持的配置项: {key}") @@ -178,7 +178,7 @@ def set_config(key: str, value: str): @conf.command(name="get") @click.argument("key", required=False) -def get_config(key: str | None = None): +def get_config(key: str | None = None) -> None: """获取配置项的值,不提供key则显示所有可配置项""" config = _load_config() diff --git a/astrbot/cli/commands/cmd_plug.py b/astrbot/cli/commands/cmd_plug.py index a1099de1d..aeaedc70b 100644 --- a/astrbot/cli/commands/cmd_plug.py +++ b/astrbot/cli/commands/cmd_plug.py @@ -5,6 +5,7 @@ import click from ..utils import ( + PluginInfo, PluginStatus, build_plug_list, check_astrbot_root, @@ -15,7 +16,7 @@ @click.group() -def plug(): +def plug() -> None: """插件管理""" @@ -28,7 +29,11 @@ def _get_data_path() -> Path: return (base / "data").resolve() -def display_plugins(plugins, title=None, color=None): +def display_plugins( + plugins: list[PluginInfo], + title: str | None = None, + color: int | tuple[int, int, int] | str | None = None, +) -> None: if title: click.echo(click.style(title, fg=color, bold=True)) @@ -45,7 +50,7 @@ def display_plugins(plugins, title=None, color=None): @plug.command() @click.argument("name") -def new(name: str): +def new(name: str) -> None: """创建新插件""" base_path = _get_data_path() plug_path = base_path / "plugins" / name @@ -100,7 +105,7 @@ def new(name: str): @plug.command() @click.option("--all", "-a", is_flag=True, help="列出未安装的插件") -def list(all: bool): +def list(all: bool) -> None: """列出插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") @@ -141,7 +146,7 @@ def list(all: bool): @plug.command() @click.argument("name") @click.option("--proxy", help="代理服务器地址") -def install(name: str, proxy: str | None): +def install(name: str, proxy: str | None) -> None: """安装插件""" base_path = _get_data_path() plug_path = base_path / "plugins" @@ -164,13 +169,13 @@ def install(name: str, proxy: str | None): @plug.command() @click.argument("name") -def remove(name: str): +def remove(name: str) -> None: """卸载插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") plugin = next((p for p in plugins if p["name"] == name), None) - if not plugin or not plugin.get("local_path"): + if not plugin or not plugin["local_path"]: raise click.ClickException(f"插件 {name} 不存在或未安装") plugin_path = plugin["local_path"] @@ -187,7 +192,7 @@ def remove(name: str): @plug.command() @click.argument("name", required=False) @click.option("--proxy", help="Github代理地址") -def update(name: str, proxy: str | None): +def update(name: str, proxy: str | None) -> None: """更新插件""" base_path = _get_data_path() plug_path = base_path / "plugins" @@ -225,7 +230,7 @@ def update(name: str, proxy: str | None): @plug.command() @click.argument("query") -def search(query: str): +def search(query: str) -> None: """搜索插件""" base_path = _get_data_path() plugins = build_plug_list(base_path / "plugins") diff --git a/astrbot/cli/commands/cmd_run.py b/astrbot/cli/commands/cmd_run.py index 9333f1b87..23665dff3 100644 --- a/astrbot/cli/commands/cmd_run.py +++ b/astrbot/cli/commands/cmd_run.py @@ -10,7 +10,7 @@ from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root -async def run_astrbot(astrbot_root: Path): +async def run_astrbot(astrbot_root: Path) -> None: """运行 AstrBot""" from astrbot.core import LogBroker, LogManager, db_helper, logger from astrbot.core.initial_loader import InitialLoader diff --git a/astrbot/cli/utils/__init__.py b/astrbot/cli/utils/__init__.py index 3830682f0..6d6b02bf2 100644 --- a/astrbot/cli/utils/__init__.py +++ b/astrbot/cli/utils/__init__.py @@ -3,10 +3,17 @@ check_dashboard, get_astrbot_root, ) -from .plugin import PluginStatus, build_plug_list, get_git_repo, manage_plugin +from .plugin import ( + PluginInfo, + PluginStatus, + build_plug_list, + get_git_repo, + manage_plugin, +) from .version_comparator import VersionComparator __all__ = [ + "PluginInfo", "PluginStatus", "VersionComparator", "build_plug_list", diff --git a/astrbot/cli/utils/plugin.py b/astrbot/cli/utils/plugin.py index cd76a07c8..44542c33d 100644 --- a/astrbot/cli/utils/plugin.py +++ b/astrbot/cli/utils/plugin.py @@ -3,6 +3,7 @@ from enum import Enum from io import BytesIO from pathlib import Path +from typing import TypedDict from zipfile import ZipFile import click @@ -19,7 +20,17 @@ class PluginStatus(str, Enum): NOT_PUBLISHED = "未发布" -def get_git_repo(url: str, target_path: Path, proxy: str | None = None): +class PluginInfo(TypedDict): + name: str + desc: str + version: str + author: str + repo: str + status: PluginStatus + local_path: str | None + + +def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None: """从 Git 仓库下载代码并解压到指定路径""" temp_dir = Path(tempfile.mkdtemp()) try: @@ -102,18 +113,18 @@ def load_yaml_metadata(plugin_dir: Path) -> dict: return {} -def build_plug_list(plugins_dir: Path) -> list: +def build_plug_list(plugins_dir: Path) -> list[PluginInfo]: """构建插件列表,包含本地和在线插件信息 Args: plugins_dir (Path): 插件目录路径 Returns: - list: 包含插件信息的字典列表 + list[PluginInfo]: 包含插件信息的字典列表 """ # 获取本地插件信息 - result = [] + result: list[PluginInfo] = [] if plugins_dir.exists(): for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]: plugin_dir = plugins_dir / plugin_name @@ -141,7 +152,7 @@ def build_plug_list(plugins_dir: Path) -> list: ) # 获取在线插件列表 - online_plugins = [] + online_plugins: list[PluginInfo] = [] try: with httpx.Client() as client: resp = client.get("https://api.soulter.top/astrbot/plugins") @@ -191,7 +202,7 @@ def build_plug_list(plugins_dir: Path) -> list: def manage_plugin( - plugin: dict, + plugin: PluginInfo, plugins_dir: Path, is_update: bool = False, proxy: str | None = None, @@ -209,7 +220,7 @@ def manage_plugin( repo_url = plugin["repo"] # 如果是更新且有本地路径,直接使用本地路径 - if is_update and plugin.get("local_path"): + if is_update and plugin["local_path"]: target_path = Path(plugin["local_path"]) else: target_path = plugins_dir / plugin_name diff --git a/astrbot/cli/utils/version_comparator.py b/astrbot/cli/utils/version_comparator.py index 0aaf8dcab..fbf15a612 100644 --- a/astrbot/cli/utils/version_comparator.py +++ b/astrbot/cli/utils/version_comparator.py @@ -15,7 +15,7 @@ def compare_version(v1: str, v2: str) -> int: v1 = v1.lower().replace("v", "") v2 = v2.lower().replace("v", "") - def split_version(version): + def split_version(version: str) -> tuple[list[int], list[int | str] | None]: match = re.match( r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$", version, @@ -77,7 +77,7 @@ def split_version(version): return 0 # 数字部分和预发布标签都相同 @staticmethod - def _split_prerelease(prerelease): + def _split_prerelease(prerelease: str) -> list[int | str] | None: if not prerelease: return None parts = prerelease.split(".") diff --git a/astrbot/core/agent/handoff.py b/astrbot/core/agent/handoff.py index 85276540b..755cc45a6 100644 --- a/astrbot/core/agent/handoff.py +++ b/astrbot/core/agent/handoff.py @@ -1,8 +1,23 @@ +from __future__ import annotations + +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Generic +from typing_extensions import TypedDict, Unpack + +from astrbot.core.message.message_event_result import MessageEventResult + from .agent import Agent from .run_context import TContext -from .tool import FunctionTool +from .tool import FunctionTool, ParametersType + + +class HandoffInitKwargs(TypedDict, total=False): + handler: ( + Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult]] | None + ) + handler_module_path: str | None + active: bool class HandoffTool(FunctionTool, Generic[TContext]): @@ -11,9 +26,9 @@ class HandoffTool(FunctionTool, Generic[TContext]): def __init__( self, agent: Agent[TContext], - parameters: dict | None = None, - **kwargs, - ): + parameters: ParametersType | None = None, + **kwargs: Unpack[HandoffInitKwargs], + ) -> None: self.agent = agent super().__init__( name=f"transfer_to_{agent.name}", diff --git a/astrbot/core/agent/hooks.py b/astrbot/core/agent/hooks.py index d834240b7..74ca6335b 100644 --- a/astrbot/core/agent/hooks.py +++ b/astrbot/core/agent/hooks.py @@ -9,22 +9,22 @@ class BaseAgentRunHooks(Generic[TContext]): - async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ... + async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ... async def on_tool_start( self, run_context: ContextWrapper[TContext], tool: FunctionTool, tool_args: dict | None, - ): ... + ) -> None: ... async def on_tool_end( self, run_context: ContextWrapper[TContext], tool: FunctionTool, tool_args: dict | None, tool_result: mcp.types.CallToolResult | None, - ): ... + ) -> None: ... async def on_agent_done( self, run_context: ContextWrapper[TContext], llm_response: LLMResponse, - ): ... + ) -> None: ... diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index c5ff123b2..13ac2d7de 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -4,6 +4,7 @@ from datetime import timedelta from typing import Generic +from mcp.types import CallToolResult from tenacity import ( before_sleep_log, retry, @@ -108,7 +109,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: class MCPClient: - def __init__(self): + def __init__(self) -> None: # Initialize session and client objects self.session: mcp.ClientSession | None = None self.exit_stack = AsyncExitStack() @@ -126,7 +127,7 @@ def __init__(self): self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection self._reconnecting: bool = False # For logging and debugging - async def connect_to_server(self, mcp_server_config: dict, name: str): + async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: """Connect to MCP server If `url` parameter exists: @@ -144,7 +145,7 @@ async def connect_to_server(self, mcp_server_config: dict, name: str): cfg = _prepare_config(mcp_server_config.copy()) - def logging_callback(msg: str): + def logging_callback(msg: str) -> None: # Handle MCP service error logs print(f"MCP Server {name} Error: {msg}") self.server_errlogs.append(msg) @@ -214,7 +215,7 @@ def logging_callback(msg: str): **cfg, ) - def callback(msg: str): + def callback(msg: str) -> None: # Handle MCP service error logs self.server_errlogs.append(msg) @@ -322,7 +323,7 @@ async def call_tool_with_reconnect( before_sleep=before_sleep_log(logger, logging.WARNING), reraise=True, ) - async def _call_with_retry(): + async def _call_with_retry() -> CallToolResult: if not self.session: raise ValueError("MCP session is not available for MCP function tools.") @@ -343,7 +344,7 @@ async def _call_with_retry(): return await _call_with_retry() - async def cleanup(self): + async def cleanup(self) -> None: """Clean up resources including old exit stacks from reconnections""" # Close current exit stack try: @@ -364,8 +365,12 @@ class MCPTool(FunctionTool, Generic[TContext]): """A function tool that calls an MCP service.""" def __init__( - self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs - ): + self, + mcp_tool: mcp.Tool, + mcp_client: MCPClient, + mcp_server_name: str, + **kwargs: object, + ) -> None: super().__init__( name=mcp_tool.name, description=mcp_tool.description or "", @@ -376,7 +381,9 @@ def __init__( self.mcp_server_name = mcp_server_name async def call( - self, context: ContextWrapper[TContext], **kwargs + self, + context: ContextWrapper[TContext], + **kwargs: object, ) -> mcp.types.CallToolResult: return await self.mcp_client.call_tool_with_reconnect( tool_name=self.mcp_tool.name, diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py index 6e3a4012d..8eff62c05 100644 --- a/astrbot/core/agent/message.py +++ b/astrbot/core/agent/message.py @@ -1,10 +1,13 @@ # Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation. # License: Apache License 2.0 +import builtins from typing import Any, ClassVar, Literal, cast from pydantic import BaseModel, GetCoreSchemaHandler, model_validator +from pydantic.config import ConfigDict from pydantic_core import core_schema +from typing_extensions import Unpack class ContentPart(BaseModel): @@ -14,7 +17,7 @@ class ContentPart(BaseModel): type: str - def __init_subclass__(cls, **kwargs: Any) -> None: + def __init_subclass__(cls, **kwargs: Unpack[ConfigDict]) -> None: super().__init_subclass__(**kwargs) invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`" @@ -27,15 +30,15 @@ def __init_subclass__(cls, **kwargs: Any) -> None: @classmethod def __get_pydantic_core_schema__( - cls, source_type: Any, handler: GetCoreSchemaHandler + cls, source_type: builtins.type[BaseModel], handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: # If we're dealing with the base ContentPart class, use custom validation if cls.__name__ == "ContentPart": - def validate_content_part(value: Any) -> Any: + def validate_content_part(value: object) -> "ContentPart": # if it's already an instance of a ContentPart subclass, return it if hasattr(value, "__class__") and issubclass(value.__class__, cls): - return value + return cast("ContentPart", value) # if it's a dict with a type field, dispatch to the appropriate subclass if isinstance(value, dict) and "type" in value: @@ -155,7 +158,7 @@ class Message(BaseModel): """The ID of the tool call.""" @model_validator(mode="after") - def check_content_required(self): + def check_content_required(self) -> "Message": # assistant + tool_calls is not None: allow content to be None if self.role == "assistant" and self.tool_calls is not None: return self diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py index 21e796433..a9ddc0e35 100644 --- a/astrbot/core/agent/runners/base.py +++ b/astrbot/core/agent/runners/base.py @@ -25,7 +25,7 @@ async def reset( self, run_context: ContextWrapper[TContext], agent_hooks: BaseAgentRunHooks[TContext], - **kwargs: T.Any, + **kwargs: object, ) -> None: """Reset the agent to its initial state. This method should be called before starting a new run. diff --git a/astrbot/core/agent/runners/coze/coze_agent_runner.py b/astrbot/core/agent/runners/coze/coze_agent_runner.py index a8300bb71..c2e8a948e 100644 --- a/astrbot/core/agent/runners/coze/coze_agent_runner.py +++ b/astrbot/core/agent/runners/coze/coze_agent_runner.py @@ -70,7 +70,7 @@ async def reset( self.file_id_cache: dict[str, dict[str, str]] = {} @override - async def step(self): + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: """ 执行 Coze Agent 的一个步骤 """ @@ -113,7 +113,7 @@ async def step_until_done( async for resp in self.step(): yield resp - async def _execute_coze_request(self): + async def _execute_coze_request(self) -> T.AsyncGenerator[AgentResponse, None]: """执行 Coze 请求的核心逻辑""" prompt = self.req.prompt or "" session_id = self.req.session_id or "unknown" diff --git a/astrbot/core/agent/runners/coze/coze_api_client.py b/astrbot/core/agent/runners/coze/coze_api_client.py index e8f3a1e24..a5e62520b 100644 --- a/astrbot/core/agent/runners/coze/coze_api_client.py +++ b/astrbot/core/agent/runners/coze/coze_api_client.py @@ -10,12 +10,12 @@ class CozeAPIClient: - def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"): + def __init__(self, api_key: str, api_base: str = "https://api.coze.cn") -> None: self.api_key = api_key self.api_base = api_base self.session = None - async def _ensure_session(self): + async def _ensure_session(self) -> aiohttp.ClientSession: """确保HTTP session存在""" if self.session is None: connector = aiohttp.TCPConnector( @@ -208,7 +208,7 @@ async def chat_messages( except Exception as e: raise Exception(f"Coze API 流式请求失败: {e!s}") - async def clear_context(self, conversation_id: str): + async def clear_context(self, conversation_id: str) -> dict: """清空会话上下文 Args: @@ -247,7 +247,7 @@ async def get_message_list( order: str = "desc", limit: int = 10, offset: int = 0, - ): + ) -> dict: """获取消息列表 Args: @@ -277,7 +277,7 @@ async def get_message_list( logger.error(f"获取Coze消息列表失败: {e!s}") raise Exception(f"获取Coze消息列表失败: {e!s}") - async def close(self): + async def close(self) -> None: """关闭会话""" if self.session: await self.session.close() @@ -288,7 +288,7 @@ async def close(self): import asyncio import os - async def test_coze_api_client(): + async def test_coze_api_client() -> None: api_key = os.getenv("COZE_API_KEY", "") bot_id = os.getenv("COZE_BOT_ID", "") client = CozeAPIClient(api_key=api_key) diff --git a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py index 7a095a60b..f14fae0ca 100644 --- a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +++ b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py @@ -67,7 +67,7 @@ async def reset( if isinstance(self.timeout, str): self.timeout = int(self.timeout) - def has_rag_options(self): + def has_rag_options(self) -> bool: """判断是否有 RAG 选项 Returns: @@ -82,7 +82,7 @@ def has_rag_options(self): return False @override - async def step(self): + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: """ 执行 Dashscope Agent 的一个步骤 """ @@ -124,7 +124,7 @@ async def step_until_done( yield resp def _consume_sync_generator( - self, response: T.Any, response_queue: queue.Queue + self, response: T.Iterable[object], response_queue: queue.Queue ) -> None: """在线程中消费同步generator,将结果放入队列 @@ -278,7 +278,7 @@ async def _build_request_payload( return payload async def _handle_streaming_response( - self, response: T.Any, session_id: str + self, response: object, session_id: str ) -> T.AsyncGenerator[AgentResponse, None]: """处理流式响应 @@ -292,7 +292,7 @@ async def _handle_streaming_response( response_queue = queue.Queue() consumer_thread = threading.Thread( target=self._consume_sync_generator, - args=(response, response_queue), + args=(T.cast(T.Iterable[object], response), response_queue), daemon=True, ) consumer_thread.start() @@ -319,14 +319,14 @@ async def _handle_streaming_response( ( output_text, chunk_doc_refs, - response, + response_obj, ) = await self._process_stream_chunk(chunk, output_text) - if response: - if response.type == "err": - yield response + if response_obj: + if response_obj.type == "err": + yield response_obj return - yield response + yield response_obj if chunk_doc_refs: doc_references = chunk_doc_refs @@ -366,7 +366,7 @@ async def _handle_streaming_response( data=AgentResponseData(chain=chain), ) - async def _execute_dashscope_request(self): + async def _execute_dashscope_request(self) -> T.AsyncGenerator[AgentResponse, None]: """执行 Dashscope 请求的核心逻辑""" prompt = self.req.prompt or "" session_id = self.req.session_id or "unknown" diff --git a/astrbot/core/agent/runners/dify/dify_agent_runner.py b/astrbot/core/agent/runners/dify/dify_agent_runner.py index d9a8b7cd6..64f028dd9 100644 --- a/astrbot/core/agent/runners/dify/dify_agent_runner.py +++ b/astrbot/core/agent/runners/dify/dify_agent_runner.py @@ -63,7 +63,7 @@ async def reset( self.api_client = DifyAPIClient(self.api_key, self.api_base) @override - async def step(self): + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: """ 执行 Dify Agent 的一个步骤 """ @@ -106,7 +106,7 @@ async def step_until_done( async for resp in self.step(): yield resp - async def _execute_dify_request(self): + async def _execute_dify_request(self) -> T.AsyncGenerator[AgentResponse, None]: """执行 Dify 请求的核心逻辑""" prompt = self.req.prompt or "" session_id = self.req.session_id or "unknown" @@ -285,7 +285,7 @@ async def parse_dify_result(self, chunk: dict | str) -> MessageChain: # Chat return MessageChain(chain=[Comp.Plain(chunk)]) - async def parse_file(item: dict): + async def parse_file(item: dict) -> object: match item["type"]: case "image": return Comp.Image(file=item["url"], url=item["url"]) diff --git a/astrbot/core/agent/runners/dify/dify_api_client.py b/astrbot/core/agent/runners/dify/dify_api_client.py index d9c6556cf..9e683d257 100644 --- a/astrbot/core/agent/runners/dify/dify_api_client.py +++ b/astrbot/core/agent/runners/dify/dify_api_client.py @@ -31,7 +31,7 @@ async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]: class DifyAPIClient: - def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"): + def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1") -> None: self.api_key = api_key self.api_base = api_base self.session = ClientSession(trust_env=True) @@ -77,7 +77,7 @@ async def workflow_run( response_mode: str = "streaming", files: list[dict[str, Any]] | None = None, timeout: float = 60, - ): + ) -> AsyncGenerator[dict[str, Any], None]: if files is None: files = [] url = f"{self.api_base}/workflows/run" @@ -155,10 +155,10 @@ async def file_upload( raise Exception(f"Dify 文件上传失败:{resp.status}. {text}") return await resp.json() # {"id": "xxx", ...} - async def close(self): + async def close(self) -> None: await self.session.close() - async def get_chat_convs(self, user: str, limit: int = 20): + async def get_chat_convs(self, user: str, limit: int = 20) -> dict: # conversations. GET url = f"{self.api_base}/conversations" payload = { @@ -168,7 +168,7 @@ async def get_chat_convs(self, user: str, limit: int = 20): async with self.session.get(url, params=payload, headers=self.headers) as resp: return await resp.json() - async def delete_chat_conv(self, user: str, conversation_id: str): + async def delete_chat_conv(self, user: str, conversation_id: str) -> dict: # conversation. DELETE url = f"{self.api_base}/conversations/{conversation_id}" payload = { @@ -183,7 +183,7 @@ async def rename( name: str, user: str, auto_generate: bool = False, - ): + ) -> dict: # /conversations/:conversation_id/name url = f"{self.api_base}/conversations/{conversation_id}/name" payload = { diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 450e4dbcb..920cd3101 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -79,7 +79,7 @@ async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: yield await self.provider.text_chat(**self.req.__dict__) @override - async def step(self): + async def step(self) -> T.AsyncGenerator[AgentResponse, None]: """Process a single step of the agent. This method should return the result of the step. """ diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 7f30f44ef..03679a245 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -1,4 +1,4 @@ -from collections.abc import AsyncGenerator, Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable, Iterator from typing import Any, Generic import jsonschema @@ -58,10 +58,12 @@ class FunctionTool(ToolSchema, Generic[TContext]): You can ignore it when integrating with other frameworks. """ - def __repr__(self): + def __repr__(self) -> str: return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})" - async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult: + async def call( + self, context: ContextWrapper[TContext], **kwargs: object + ) -> ToolExecResult: """Run the tool with the given arguments. The handler field has priority.""" raise NotImplementedError( "FunctionTool.call() must be implemented by subclasses or set a handler." @@ -82,7 +84,7 @@ def empty(self) -> bool: """Check if the tool set is empty.""" return len(self.tools) == 0 - def add_tool(self, tool: FunctionTool): + def add_tool(self, tool: FunctionTool) -> None: """Add a tool to the set.""" # 检查是否已存在同名工具 for i, existing_tool in enumerate(self.tools): @@ -91,7 +93,7 @@ def add_tool(self, tool: FunctionTool): return self.tools.append(tool) - def remove_tool(self, name: str): + def remove_tool(self, name: str) -> None: """Remove a tool by its name.""" self.tools = [tool for tool in self.tools if tool.name != name] @@ -109,7 +111,7 @@ def add_func( func_args: list, desc: str, handler: Callable[..., Awaitable[Any]], - ): + ) -> None: """Add a function tool to the set.""" params = { "type": "object", # hard-coded here @@ -129,7 +131,7 @@ def add_func( self.add_tool(_func) @deprecated(reason="Use remove_tool() instead", version="4.0.0") - def remove_func(self, name: str): + def remove_func(self, name: str) -> None: """Remove a function tool by its name.""" self.remove_tool(name) @@ -259,32 +261,34 @@ def convert_schema(schema: dict) -> dict: return declarations @deprecated(reason="Use openai_schema() instead", version="4.0.0") - def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False): + def get_func_desc_openai_style( + self, omit_empty_parameter_field: bool = False + ) -> list[dict]: return self.openai_schema(omit_empty_parameter_field) @deprecated(reason="Use anthropic_schema() instead", version="4.0.0") - def get_func_desc_anthropic_style(self): + def get_func_desc_anthropic_style(self) -> list[dict]: return self.anthropic_schema() @deprecated(reason="Use google_schema() instead", version="4.0.0") - def get_func_desc_google_genai_style(self): + def get_func_desc_google_genai_style(self) -> dict: return self.google_schema() def names(self) -> list[str]: """获取所有工具的名称列表""" return [tool.name for tool in self.tools] - def __len__(self): + def __len__(self) -> int: return len(self.tools) - def __bool__(self): + def __bool__(self) -> bool: return len(self.tools) > 0 - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.tools) - def __repr__(self): + def __repr__(self) -> str: return f"ToolSet(tools={self.tools})" - def __str__(self): + def __str__(self) -> str: return f"ToolSet(tools={self.tools})" diff --git a/astrbot/core/agent/tool_executor.py b/astrbot/core/agent/tool_executor.py index 2704119d4..b35c34cad 100644 --- a/astrbot/core/agent/tool_executor.py +++ b/astrbot/core/agent/tool_executor.py @@ -13,5 +13,5 @@ async def execute( cls, tool: FunctionTool, run_context: ContextWrapper[TContext], - **tool_args, + **tool_args: object, ) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: ... diff --git a/astrbot/core/astr_agent_hooks.py b/astrbot/core/astr_agent_hooks.py index f394fc947..39e247c6d 100644 --- a/astrbot/core/astr_agent_hooks.py +++ b/astrbot/core/astr_agent_hooks.py @@ -11,7 +11,7 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): - async def on_agent_done(self, run_context, llm_response): + async def on_agent_done(self, run_context, llm_response) -> None: # 执行事件钩子 await call_event_hook( run_context.context.event, @@ -25,7 +25,7 @@ async def on_tool_end( tool: FunctionTool[Any], tool_args: dict | None, tool_result: CallToolResult | None, - ): + ) -> None: run_context.context.event.clear_result() diff --git a/astrbot/core/astrbot_config_mgr.py b/astrbot/core/astrbot_config_mgr.py index 3a1353ce5..c2bfb1c37 100644 --- a/astrbot/core/astrbot_config_mgr.py +++ b/astrbot/core/astrbot_config_mgr.py @@ -36,7 +36,7 @@ def __init__( default_config: AstrBotConfig, ucr: UmopConfigRouter, sp: SharedPreferences, - ): + ) -> None: self.sp = sp self.ucr = ucr self.confs: dict[str, AstrBotConfig] = {} @@ -56,7 +56,7 @@ def _get_abconf_data(self) -> dict: ) return self.abconf_data - def _load_all_configs(self): + def _load_all_configs(self) -> None: """Load all configurations from the shared preferences.""" abconf_data = self._get_abconf_data() self.abconf_data = abconf_data diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 9477eabaa..fd458e933 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -33,9 +33,8 @@ def __init__( config_path: str = ASTRBOT_CONFIG_PATH, default_config: dict = DEFAULT_CONFIG, schema: dict | None = None, - ): + ) -> None: super().__init__() - # 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件 object.__setattr__(self, "config_path", config_path) object.__setattr__(self, "default_config", default_config) @@ -66,7 +65,7 @@ def _config_schema_to_default_config(self, schema: dict) -> dict: """将 Schema 转换成 Config""" conf = {} - def _parse_schema(schema: dict, conf: dict): + def _parse_schema(schema: dict, conf: dict) -> None: for k, v in schema.items(): if v["type"] not in DEFAULT_VALUE_MAP: raise TypeError( @@ -87,7 +86,9 @@ def _parse_schema(schema: dict, conf: dict): return conf - def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): + def check_config_integrity( + self, refer_conf: dict, conf: dict, path: str = "" + ) -> bool: """检查配置完整性,如果有新的配置项或顺序不一致则返回 True""" has_new = False @@ -146,7 +147,7 @@ def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): return has_new - def save_config(self, replace_config: dict | None = None): + def save_config(self, replace_config: dict | None = None) -> None: """将配置写入文件 如果传入 replace_config,则将配置替换为 replace_config @@ -156,20 +157,20 @@ def save_config(self, replace_config: dict | None = None): with open(self.config_path, "w", encoding="utf-8-sig") as f: json.dump(self, f, indent=2, ensure_ascii=False) - def __getattr__(self, item): + def __getattr__(self, item: str) -> object: try: return self[item] except KeyError: return None - def __delattr__(self, key): + def __delattr__(self, key: str) -> None: try: del self[key] self.save_config() except KeyError: raise AttributeError(f"没有找到 Key: '{key}'") - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: object) -> None: self[key] = value def check_exist(self) -> bool: diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 287fe03c4..afdbd24f7 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -16,7 +16,7 @@ class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" - def __init__(self, db_helper: BaseDatabase): + def __init__(self, db_helper: BaseDatabase) -> None: self.session_conversations: dict[str, str] = {} self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 @@ -105,7 +105,9 @@ async def new_conversation( await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id) return conv.conversation_id - async def switch_conversation(self, unified_msg_origin: str, conversation_id: str): + async def switch_conversation( + self, unified_msg_origin: str, conversation_id: str + ) -> None: """切换会话的对话 Args: @@ -120,7 +122,7 @@ async def delete_conversation( self, unified_msg_origin: str, conversation_id: str | None = None, - ): + ) -> None: """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话 Args: @@ -137,7 +139,7 @@ async def delete_conversation( self.session_conversations.pop(unified_msg_origin, None) await sp.session_remove(unified_msg_origin, "sel_conv_id") - async def delete_conversations_by_user_id(self, unified_msg_origin: str): + async def delete_conversations_by_user_id(self, unified_msg_origin: str) -> None: """删除会话的所有对话 Args: @@ -223,7 +225,7 @@ async def get_filtered_conversations( page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", - **kwargs, + **kwargs: object, ) -> tuple[list[Conversation], int]: """获取过滤后的对话列表. diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 192c7b263..3c491372a 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -39,7 +39,7 @@ def __init__(self) -> None: expire_on_commit=False, ) - async def initialize(self): + async def initialize(self) -> None: """初始化数据库连接""" @asynccontextmanager @@ -105,7 +105,7 @@ async def get_conversations( ... @abc.abstractmethod - async def get_conversation_by_id(self, cid: str) -> ConversationV2: + async def get_conversation_by_id(self, cid: str) -> ConversationV2 | None: """Get a specific conversation by its ID.""" ... @@ -125,7 +125,7 @@ async def get_filtered_conversations( page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", - **kwargs, + **kwargs: object, ) -> tuple[list[ConversationV2], int]: """Get conversations filtered by platform IDs and search query.""" ... @@ -152,7 +152,7 @@ async def update_conversation( title: str | None = None, persona_id: str | None = None, content: list[dict] | None = None, - ) -> None: + ) -> ConversationV2 | None: """Update a conversation's history.""" ... @@ -213,12 +213,12 @@ async def insert_attachment( path: str, type: str, mime_type: str, - ): + ) -> Attachment: """Insert a new attachment record.""" ... @abc.abstractmethod - async def get_attachment_by_id(self, attachment_id: str) -> Attachment: + async def get_attachment_by_id(self, attachment_id: str) -> Attachment | None: """Get an attachment by its ID.""" ... @@ -255,7 +255,7 @@ async def insert_persona( ... @abc.abstractmethod - async def get_persona_by_id(self, persona_id: str) -> Persona: + async def get_persona_by_id(self, persona_id: str) -> Persona | None: """Get a persona by its ID.""" ... @@ -292,7 +292,9 @@ async def insert_preference_or_update( ... @abc.abstractmethod - async def get_preference(self, scope: str, scope_id: str, key: str) -> Preference: + async def get_preference( + self, scope: str, scope_id: str, key: str + ) -> Preference | None: """Get a preference by scope ID and key.""" ... diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index 66b72d5cb..727d97b29 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -43,7 +43,7 @@ def get_platform_type( async def migration_conversation_table( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) @@ -101,7 +101,7 @@ async def migration_conversation_table( async def migration_platform_table( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), ) @@ -180,7 +180,7 @@ async def migration_platform_table( async def migration_webchat_data( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: """迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中""" db_helper_v3 = SQLiteV3DatabaseV3( db_path=DB_PATH.replace("data_v4.db", "data_v3.db"), @@ -236,7 +236,7 @@ async def migration_webchat_data( async def migration_persona_data( db_helper: BaseDatabase, astrbot_config: AstrBotConfig, -): +) -> None: """迁移 Persona 数据到新的表中。 旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。 """ @@ -279,7 +279,7 @@ async def migration_persona_data( async def migration_preferences( db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]], -): +) -> None: # 1. global scope migration keys = [ "inactivated_llm_tools", diff --git a/astrbot/core/db/migration/migra_45_to_46.py b/astrbot/core/db/migration/migra_45_to_46.py index dc70026f9..58736ab51 100644 --- a/astrbot/core/db/migration/migra_45_to_46.py +++ b/astrbot/core/db/migration/migra_45_to_46.py @@ -3,7 +3,7 @@ from astrbot.core.umop_config_router import UmopConfigRouter -async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter): +async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter) -> None: abconf_data = acm.abconf_data if not isinstance(abconf_data, dict): diff --git a/astrbot/core/db/migration/migra_webchat_session.py b/astrbot/core/db/migration/migra_webchat_session.py index ff0b5ca6f..46025fc64 100644 --- a/astrbot/core/db/migration/migra_webchat_session.py +++ b/astrbot/core/db/migration/migra_webchat_session.py @@ -17,7 +17,7 @@ from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession -async def migrate_webchat_session(db_helper: BaseDatabase): +async def migrate_webchat_session(db_helper: BaseDatabase) -> None: """Create PlatformSession records from platform_message_history. This migration extracts all unique user_ids from platform_message_history diff --git a/astrbot/core/db/migration/shared_preferences_v3.py b/astrbot/core/db/migration/shared_preferences_v3.py index 3abcb1a66..72802f057 100644 --- a/astrbot/core/db/migration/shared_preferences_v3.py +++ b/astrbot/core/db/migration/shared_preferences_v3.py @@ -8,13 +8,13 @@ class SharedPreferences: - def __init__(self, path=None): + def __init__(self, path: str | None = None) -> None: if path is None: path = os.path.join(get_astrbot_data_path(), "shared_preferences.json") self.path = path self._data = self._load_preferences() - def _load_preferences(self): + def _load_preferences(self) -> dict: if os.path.exists(self.path): try: with open(self.path) as f: @@ -23,24 +23,24 @@ def _load_preferences(self): os.remove(self.path) return {} - def _save_preferences(self): + def _save_preferences(self) -> None: with open(self.path, "w") as f: json.dump(self._data, f, indent=4, ensure_ascii=False) f.flush() - def get(self, key, default: _VT = None) -> _VT: + def get(self, key: str, default: _VT = None) -> _VT: return self._data.get(key, default) - def put(self, key, value): + def put(self, key: str, value: object) -> None: self._data[key] = value self._save_preferences() - def remove(self, key): + def remove(self, key: str) -> None: if key in self._data: del self._data[key] self._save_preferences() - def clear(self): + def clear(self) -> None: self._data.clear() self._save_preferences() diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index b1a780d48..b326ebb44 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -127,7 +127,7 @@ def _get_conn(self, db_path: str) -> sqlite3.Connection: conn.text_factory = str return conn - def _exec_sql(self, sql: str, params: tuple | None = None): + def _exec_sql(self, sql: str, params: tuple | None = None) -> None: conn = self.conn try: c = self.conn.cursor() @@ -144,7 +144,7 @@ def _exec_sql(self, sql: str, params: tuple | None = None): conn.commit() - def insert_platform_metrics(self, metrics: dict): + def insert_platform_metrics(self, metrics: dict) -> None: for k, v in metrics.items(): self._exec_sql( """ @@ -153,7 +153,7 @@ def insert_platform_metrics(self, metrics: dict): (k, v, int(time.time())), ) - def insert_llm_metrics(self, metrics: dict): + def insert_llm_metrics(self, metrics: dict) -> None: for k, v in metrics.items(): self._exec_sql( """ @@ -249,7 +249,7 @@ def get_conversation_by_user_id( return Conversation(*res) - def new_conversation(self, user_id: str, cid: str): + def new_conversation(self, user_id: str, cid: str) -> None: history = "[]" updated_at = int(time.time()) created_at = updated_at @@ -287,7 +287,7 @@ def get_conversations(self, user_id: str) -> list[Conversation]: ) return conversations - def update_conversation(self, user_id: str, cid: str, history: str): + def update_conversation(self, user_id: str, cid: str, history: str) -> None: """更新对话,并且同时更新时间""" updated_at = int(time.time()) self._exec_sql( @@ -297,7 +297,7 @@ def update_conversation(self, user_id: str, cid: str, history: str): (history, updated_at, user_id, cid), ) - def update_conversation_title(self, user_id: str, cid: str, title: str): + def update_conversation_title(self, user_id: str, cid: str, title: str) -> None: self._exec_sql( """ UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? @@ -305,7 +305,9 @@ def update_conversation_title(self, user_id: str, cid: str, title: str): (title, user_id, cid), ) - def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str): + def update_conversation_persona_id( + self, user_id: str, cid: str, persona_id: str + ) -> None: self._exec_sql( """ UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? @@ -313,7 +315,7 @@ def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str (persona_id, user_id, cid), ) - def delete_conversation(self, user_id: str, cid: str): + def delete_conversation(self, user_id: str, cid: str) -> None: self._exec_sql( """ DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index fa3ca9a76..79b559ee7 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -1,6 +1,11 @@ import asyncio import threading import typing as T + +try: + from typing import Unpack # type: ignore[attr-defined] +except ImportError: + from typing_extensions import Unpack from collections.abc import Awaitable, Callable from datetime import datetime, timedelta, timezone @@ -32,6 +37,13 @@ TxResult = T.TypeVar("TxResult") +class FilterKwargs(T.TypedDict, total=False): + message_types: list[str] + platforms: list[str] + exclude_ids: list[str] + exclude_platforms: list[str] + + class SQLiteDatabase(BaseDatabase): def __init__(self, db_path: str) -> None: self.db_path = db_path @@ -57,10 +69,10 @@ async def initialize(self) -> None: async def insert_platform_stats( self, - platform_id, - platform_type, - count=1, - timestamp=None, + platform_id: str, + platform_type: str, + count: int = 1, + timestamp: datetime | None = None, ) -> None: """Insert a new platform statistic record.""" async with self.get_db() as session: @@ -121,7 +133,9 @@ async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat # Conversation Management # ==== - async def get_conversations(self, user_id=None, platform_id=None): + async def get_conversations( + self, user_id: str | None = None, platform_id: str | None = None + ) -> list[ConversationV2]: async with self.get_db() as session: session: AsyncSession query = select(ConversationV2) @@ -134,16 +148,18 @@ async def get_conversations(self, user_id=None, platform_id=None): query = query.order_by(desc(ConversationV2.created_at)) result = await session.execute(query) - return result.scalars().all() + return result.scalars().all() # type: ignore - async def get_conversation_by_id(self, cid): + async def get_conversation_by_id(self, cid: str) -> ConversationV2 | None: async with self.get_db() as session: session: AsyncSession query = select(ConversationV2).where(ConversationV2.conversation_id == cid) result = await session.execute(query) return result.scalar_one_or_none() - async def get_all_conversations(self, page=1, page_size=20): + async def get_all_conversations( + self, page: int = 1, page_size: int = 20 + ) -> list[ConversationV2]: async with self.get_db() as session: session: AsyncSession offset = (page - 1) * page_size @@ -153,16 +169,16 @@ async def get_all_conversations(self, page=1, page_size=20): .offset(offset) .limit(page_size), ) - return result.scalars().all() + return result.scalars().all() # type:ignore async def get_filtered_conversations( self, - page=1, - page_size=20, - platform_ids=None, - search_query="", - **kwargs, - ): + page: int = 1, + page_size: int = 20, + platform_ids: list[str] | None = None, + search_query: str = "", + **kwargs: Unpack[FilterKwargs], + ) -> tuple[list[ConversationV2], int]: async with self.get_db() as session: session: AsyncSession # Build the base query with filters @@ -207,19 +223,19 @@ async def get_filtered_conversations( result = await session.execute(result_query) conversations = result.scalars().all() - return conversations, total + return conversations, total # type:ignore async def create_conversation( self, - user_id, - platform_id, - content=None, - title=None, - persona_id=None, - cid=None, - created_at=None, - updated_at=None, - ): + user_id: str, + platform_id: str, + content: list[dict] | None = None, + title: str | None = None, + persona_id: str | None = None, + cid: str | None = None, + created_at: datetime | None = None, + updated_at: datetime | None = None, + ) -> ConversationV2: kwargs = {} if cid: kwargs["conversation_id"] = cid @@ -241,7 +257,13 @@ async def create_conversation( session.add(new_conversation) return new_conversation - async def update_conversation(self, cid, title=None, persona_id=None, content=None): + async def update_conversation( + self, + cid: str, + title: str | None = None, + persona_id: str | None = None, + content: list[dict] | None = None, + ) -> ConversationV2 | None: async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -261,7 +283,7 @@ async def update_conversation(self, cid, title=None, persona_id=None, content=No await session.execute(query) return await self.get_conversation_by_id(cid) - async def delete_conversation(self, cid): + async def delete_conversation(self, cid: str) -> None: async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -283,10 +305,10 @@ async def delete_conversations_by_user_id(self, user_id: str) -> None: async def get_session_conversations( self, - page=1, - page_size=20, - search_query=None, - platform=None, + page: int = 1, + page_size: int = 20, + search_query: str | None = None, + platform: str | None = None, ) -> tuple[list[dict], int]: """Get paginated session conversations with joined conversation and persona details.""" async with self.get_db() as session: @@ -392,12 +414,12 @@ async def get_session_conversations( async def insert_platform_message_history( self, - platform_id, - user_id, - content, - sender_id=None, - sender_name=None, - ): + platform_id: str, + user_id: str, + content: dict, + sender_id: str | None = None, + sender_name: str | None = None, + ) -> PlatformMessageHistory: """Insert a new platform message history record.""" async with self.get_db() as session: session: AsyncSession @@ -413,12 +435,9 @@ async def insert_platform_message_history( return new_history async def delete_platform_message_offset( - self, - platform_id, - user_id, - offset_sec=86400, - ): - """Delete platform message history records newer than the specified offset.""" + self, platform_id: str, user_id: str, offset_sec: int = 86400 + ) -> None: + """Delete platform message history records older than the specified offset.""" async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -433,12 +452,8 @@ async def delete_platform_message_offset( ) async def get_platform_message_history( - self, - platform_id, - user_id, - page=1, - page_size=20, - ): + self, platform_id: str, user_id: str, page: int = 1, page_size: int = 20 + ) -> list[PlatformMessageHistory]: """Get platform message history records.""" async with self.get_db() as session: session: AsyncSession @@ -452,7 +467,7 @@ async def get_platform_message_history( .order_by(desc(PlatformMessageHistory.created_at)) ) result = await session.execute(query.offset(offset).limit(page_size)) - return result.scalars().all() + return result.scalars().all() # type:ignore async def get_platform_message_history_by_id( self, message_id: int @@ -479,7 +494,7 @@ async def insert_attachment(self, path, type, mime_type): session.add(new_attachment) return new_attachment - async def get_attachment_by_id(self, attachment_id): + async def get_attachment_by_id(self, attachment_id: str) -> Attachment | None: """Get an attachment by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -531,11 +546,11 @@ async def delete_attachments(self, attachment_ids: list[str]) -> int: async def insert_persona( self, - persona_id, - system_prompt, - begin_dialogs=None, - tools=None, - ): + persona_id: str, + system_prompt: str, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, + ) -> Persona: """Insert a new persona record.""" async with self.get_db() as session: session: AsyncSession @@ -549,7 +564,7 @@ async def insert_persona( session.add(new_persona) return new_persona - async def get_persona_by_id(self, persona_id): + async def get_persona_by_id(self, persona_id: str) -> Persona | None: """Get a persona by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -557,21 +572,21 @@ async def get_persona_by_id(self, persona_id): result = await session.execute(query) return result.scalar_one_or_none() - async def get_personas(self): + async def get_personas(self) -> list[Persona]: """Get all personas for a specific bot.""" async with self.get_db() as session: session: AsyncSession query = select(Persona) result = await session.execute(query) - return result.scalars().all() + return result.scalars().all() # type:ignore async def update_persona( self, - persona_id, - system_prompt=None, - begin_dialogs=None, - tools=NOT_GIVEN, - ): + persona_id: str, + system_prompt: str | None = None, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, + ) -> Persona | None: """Update a persona's system prompt or begin dialogs.""" async with self.get_db() as session: session: AsyncSession @@ -590,7 +605,7 @@ async def update_persona( await session.execute(query) return await self.get_persona_by_id(persona_id) - async def delete_persona(self, persona_id): + async def delete_persona(self, persona_id: str) -> None: """Delete a persona by its ID.""" async with self.get_db() as session: session: AsyncSession @@ -599,7 +614,9 @@ async def delete_persona(self, persona_id): delete(Persona).where(col(Persona.persona_id) == persona_id), ) - async def insert_preference_or_update(self, scope, scope_id, key, value): + async def insert_preference_or_update( + self, scope: str, scope_id: str, key: str, value: dict + ) -> Preference: """Insert a new preference record or update if it exists.""" async with self.get_db() as session: session: AsyncSession @@ -623,7 +640,9 @@ async def insert_preference_or_update(self, scope, scope_id, key, value): session.add(new_preference) return existing_preference or new_preference - async def get_preference(self, scope, scope_id, key): + async def get_preference( + self, scope: str, scope_id: str, key: str + ) -> Preference | None: """Get a preference by key.""" async with self.get_db() as session: session: AsyncSession @@ -635,7 +654,9 @@ async def get_preference(self, scope, scope_id, key): result = await session.execute(query) return result.scalar_one_or_none() - async def get_preferences(self, scope, scope_id=None, key=None): + async def get_preferences( + self, scope: str, scope_id: str | None = None, key: str | None = None + ) -> list[Preference]: """Get all preferences for a specific scope ID or key.""" async with self.get_db() as session: session: AsyncSession @@ -645,9 +666,9 @@ async def get_preferences(self, scope, scope_id=None, key=None): if key is not None: query = query.where(Preference.key == key) result = await session.execute(query) - return result.scalars().all() + return result.scalars().all() # type:ignore - async def remove_preference(self, scope, scope_id, key): + async def remove_preference(self, scope: str, scope_id: str, key: str) -> None: """Remove a preference by scope ID and key.""" async with self.get_db() as session: session: AsyncSession @@ -661,7 +682,7 @@ async def remove_preference(self, scope, scope_id, key): ) await session.commit() - async def clear_preferences(self, scope, scope_id): + async def clear_preferences(self, scope: str, scope_id: str) -> None: """Clear all preferences for a specific scope ID.""" async with self.get_db() as session: session: AsyncSession @@ -914,10 +935,10 @@ async def _op(session: AsyncSession) -> None: # Deprecated Methods # ==== - def get_base_stats(self, offset_sec=86400): + def get_base_stats(self, offset_sec: int = 86400) -> DeprecatedStats: """Get base statistics within the specified offset in seconds.""" - async def _inner(): + async def _inner() -> DeprecatedStats: async with self.get_db() as session: session: AsyncSession now = datetime.now() @@ -939,19 +960,19 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() - return result + return result # type:ignore - def get_total_message_count(self): + def get_total_message_count(self) -> int: """Get the total message count from platform statistics.""" - async def _inner(): + async def _inner() -> int: async with self.get_db() as session: session: AsyncSession result = await session.execute( @@ -962,18 +983,18 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() - return result + return result # type:ignore - def get_grouped_base_stats(self, offset_sec=86400): + def get_grouped_base_stats(self, offset_sec: int = 86400) -> DeprecatedStats: # group by platform_id - async def _inner(): + async def _inner() -> DeprecatedStats: async with self.get_db() as session: session: AsyncSession now = datetime.now() @@ -997,14 +1018,14 @@ async def _inner(): result = None - def runner(): + def runner() -> None: nonlocal result result = asyncio.run(_inner()) t = threading.Thread(target=runner) t.start() t.join() - return result + return result # type:ignore # ==== # Platform Session Management diff --git a/astrbot/core/db/vec_db/base.py b/astrbot/core/db/vec_db/base.py index 7440b6f2a..0ac3c608d 100644 --- a/astrbot/core/db/vec_db/base.py +++ b/astrbot/core/db/vec_db/base.py @@ -1,4 +1,5 @@ import abc +from collections.abc import Awaitable, Callable from dataclasses import dataclass @@ -9,7 +10,7 @@ class Result: class BaseVecDB: - async def initialize(self): + async def initialize(self) -> None: """初始化向量数据库""" @abc.abstractmethod @@ -31,7 +32,7 @@ async def insert_batch( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: Callable[[int, int], Awaitable[None]] | None = None, ) -> int: """批量插入文本和其对应向量,自动生成 ID 并保持一致性。 @@ -70,4 +71,4 @@ async def delete(self, doc_id: str) -> bool: ... @abc.abstractmethod - async def close(self): ... + async def close(self) -> None: ... diff --git a/astrbot/core/db/vec_db/faiss_impl/document_storage.py b/astrbot/core/db/vec_db/faiss_impl/document_storage.py index e27eb6fe8..db275006f 100644 --- a/astrbot/core/db/vec_db/faiss_impl/document_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/document_storage.py @@ -1,5 +1,6 @@ import json import os +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from datetime import datetime @@ -33,7 +34,7 @@ class Document(BaseDocModel, table=True): class DocumentStorage: - def __init__(self, db_path: str): + def __init__(self, db_path: str) -> None: self.db_path = db_path self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" self.engine: AsyncEngine | None = None @@ -43,7 +44,7 @@ def __init__(self, db_path: str): "sqlite_init.sql", ) - async def initialize(self): + async def initialize(self) -> None: """Initialize the SQLite database and create the documents table if it doesn't exist.""" await self.connect() async with self.engine.begin() as conn: # type: ignore @@ -80,7 +81,7 @@ async def initialize(self): await conn.commit() - async def connect(self): + async def connect(self) -> None: """Connect to the SQLite database.""" if self.engine is None: self.engine = create_async_engine( @@ -95,7 +96,7 @@ async def connect(self): ) # type: ignore @asynccontextmanager - async def get_session(self): + async def get_session(self) -> AsyncGenerator[AsyncSession, None]: """Context manager for database sessions.""" async with self.async_session_maker() as session: # type: ignore yield session @@ -211,7 +212,7 @@ async def insert_documents_batch( await session.flush() # Flush to get all IDs return [doc.id for doc in documents] # type: ignore - async def delete_document_by_doc_id(self, doc_id: str): + async def delete_document_by_doc_id(self, doc_id: str) -> None: """Delete a document by its doc_id. Args: @@ -228,7 +229,7 @@ async def delete_document_by_doc_id(self, doc_id: str): if document: await session.delete(document) - async def get_document_by_doc_id(self, doc_id: str): + async def get_document_by_doc_id(self, doc_id: str) -> dict | None: """Retrieve a document by its doc_id. Args: @@ -249,7 +250,7 @@ async def get_document_by_doc_id(self, doc_id: str): return self._document_to_dict(document) return None - async def update_document_by_doc_id(self, doc_id: str, new_text: str): + async def update_document_by_doc_id(self, doc_id: str, new_text: str) -> None: """Update a document by its doc_id. Args: @@ -269,7 +270,7 @@ async def update_document_by_doc_id(self, doc_id: str, new_text: str): document.updated_at = datetime.now() session.add(document) - async def delete_documents(self, metadata_filters: dict): + async def delete_documents(self, metadata_filters: dict) -> None: """Delete documents by their metadata filters. Args: @@ -363,7 +364,7 @@ def _document_to_dict(self, document: Document) -> dict: else document.updated_at, } - async def tuple_to_dict(self, row): + async def tuple_to_dict(self, row: tuple) -> dict: """Convert a tuple to a dictionary. Args: @@ -384,7 +385,7 @@ async def tuple_to_dict(self, row): "updated_at": row[5], } - async def close(self): + async def close(self) -> None: """Close the connection to the SQLite database.""" if self.engine: await self.engine.dispose() diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index 564454cb1..dc6977cf8 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -10,7 +10,7 @@ class EmbeddingStorage: - def __init__(self, dimension: int, path: str | None = None): + def __init__(self, dimension: int, path: str | None = None) -> None: self.dimension = dimension self.path = path self.index = None @@ -20,7 +20,7 @@ def __init__(self, dimension: int, path: str | None = None): base_index = faiss.IndexFlatL2(dimension) self.index = faiss.IndexIDMap(base_index) - async def insert(self, vector: np.ndarray, id: int): + async def insert(self, vector: np.ndarray, id: int) -> None: """插入向量 Args: @@ -38,7 +38,7 @@ async def insert(self, vector: np.ndarray, id: int): self.index.add_with_ids(vector.reshape(1, -1), np.array([id])) await self.save_index() - async def insert_batch(self, vectors: np.ndarray, ids: list[int]): + async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None: """批量插入向量 Args: @@ -71,7 +71,7 @@ async def search(self, vector: np.ndarray, k: int) -> tuple: distances, indices = self.index.search(vector, k) return distances, indices - async def delete(self, ids: list[int]): + async def delete(self, ids: list[int]) -> None: """删除向量 Args: @@ -83,7 +83,7 @@ async def delete(self, ids: list[int]): self.index.remove_ids(id_array) await self.save_index() - async def save_index(self): + async def save_index(self) -> None: """保存索引 Args: diff --git a/astrbot/core/db/vec_db/faiss_impl/vec_db.py b/astrbot/core/db/vec_db/faiss_impl/vec_db.py index 14221f1e8..173c4c7f6 100644 --- a/astrbot/core/db/vec_db/faiss_impl/vec_db.py +++ b/astrbot/core/db/vec_db/faiss_impl/vec_db.py @@ -1,5 +1,6 @@ import time import uuid +from collections.abc import Awaitable, Callable import numpy as np @@ -20,7 +21,7 @@ def __init__( index_store_path: str, embedding_provider: EmbeddingProvider, rerank_provider: RerankProvider | None = None, - ): + ) -> None: self.doc_store_path = doc_store_path self.index_store_path = index_store_path self.embedding_provider = embedding_provider @@ -32,7 +33,7 @@ def __init__( self.embedding_provider = embedding_provider self.rerank_provider = rerank_provider - async def initialize(self): + async def initialize(self) -> None: await self.document_storage.initialize() async def insert( @@ -63,7 +64,7 @@ async def insert_batch( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: Callable[[str, int, int], Awaitable[None]] | None = None, ) -> list[int]: """批量插入文本和其对应向量,自动生成 ID 并保持一致性。 @@ -165,7 +166,7 @@ async def retrieve( return top_k_results - async def delete(self, doc_id: str): + async def delete(self, doc_id: str) -> None: """删除一条文档块(chunk)""" # 获得对应的 int id result = await self.document_storage.get_document_by_doc_id(doc_id) @@ -177,7 +178,7 @@ async def delete(self, doc_id: str): await self.document_storage.delete_document_by_doc_id(doc_id) await self.embedding_storage.delete([int_id]) - async def close(self): + async def close(self) -> None: await self.document_storage.close() async def count_documents(self, metadata_filter: dict | None = None) -> int: @@ -192,7 +193,7 @@ async def count_documents(self, metadata_filter: dict | None = None) -> int: ) return count - async def delete_documents(self, metadata_filters: dict): + async def delete_documents(self, metadata_filters: dict) -> None: """根据元数据过滤器删除文档""" docs = await self.document_storage.get_documents( metadata_filters=metadata_filters, diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 0017e65fa..44cdccb83 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -28,13 +28,13 @@ def __init__( event_queue: Queue, pipeline_scheduler_mapping: dict[str, PipelineScheduler], astrbot_config_mgr: AstrBotConfigManager, - ): + ) -> None: self.event_queue = event_queue # 事件队列 # abconf uuid -> scheduler self.pipeline_scheduler_mapping = pipeline_scheduler_mapping self.astrbot_config_mgr = astrbot_config_mgr - async def dispatch(self): + async def dispatch(self) -> None: while True: event: AstrMessageEvent = await self.event_queue.get() conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) @@ -47,7 +47,7 @@ async def dispatch(self): continue asyncio.create_task(scheduler.execute(event)) - def _print_event(self, event: AstrMessageEvent, conf_name: str): + def _print_event(self, event: AstrMessageEvent, conf_name: str) -> None: """用于记录事件信息 Args: diff --git a/astrbot/core/file_token_service.py b/astrbot/core/file_token_service.py index ea97759c1..42fbd23df 100644 --- a/astrbot/core/file_token_service.py +++ b/astrbot/core/file_token_service.py @@ -9,12 +9,12 @@ class FileTokenService: """维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。""" - def __init__(self, default_timeout: float = 300): + def __init__(self, default_timeout: float = 300) -> None: self.lock = asyncio.Lock() self.staged_files = {} # token: (file_path, expire_time) self.default_timeout = default_timeout - async def _cleanup_expired_tokens(self): + async def _cleanup_expired_tokens(self) -> None: """清理过期的令牌""" now = time.time() expired_tokens = [ diff --git a/astrbot/core/initial_loader.py b/astrbot/core/initial_loader.py index f54d18641..3f836a4c4 100644 --- a/astrbot/core/initial_loader.py +++ b/astrbot/core/initial_loader.py @@ -17,13 +17,13 @@ class InitialLoader: """AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。""" - def __init__(self, db: BaseDatabase, log_broker: LogBroker): + def __init__(self, db: BaseDatabase, log_broker: LogBroker) -> None: self.db = db self.logger = logger self.log_broker = log_broker self.webui_dir: str | None = None - async def start(self): + async def start(self) -> None: core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db) try: diff --git a/astrbot/core/knowledge_base/chunking/base.py b/astrbot/core/knowledge_base/chunking/base.py index a45d86ad1..11ae0caba 100644 --- a/astrbot/core/knowledge_base/chunking/base.py +++ b/astrbot/core/knowledge_base/chunking/base.py @@ -13,7 +13,7 @@ class BaseChunker(ABC): """ @abstractmethod - async def chunk(self, text: str, **kwargs) -> list[str]: + async def chunk(self, text: str, **kwargs: object) -> list[str]: """将文本分块 Args: diff --git a/astrbot/core/knowledge_base/chunking/fixed_size.py b/astrbot/core/knowledge_base/chunking/fixed_size.py index 5439f070f..cd146f137 100644 --- a/astrbot/core/knowledge_base/chunking/fixed_size.py +++ b/astrbot/core/knowledge_base/chunking/fixed_size.py @@ -12,7 +12,7 @@ class FixedSizeChunker(BaseChunker): 按照固定的字符数分块,并支持块之间的重叠。 """ - def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50): + def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50) -> None: """初始化分块器 Args: @@ -23,7 +23,13 @@ def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50): self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap - async def chunk(self, text: str, **kwargs) -> list[str]: + async def chunk( + self, + text: str, + *, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + ) -> list[str]: """固定大小分块 Args: @@ -35,8 +41,11 @@ async def chunk(self, text: str, **kwargs) -> list[str]: list[str]: 分块后的文本列表 """ - chunk_size = kwargs.get("chunk_size", self.chunk_size) - chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap) + chunk_size = self.chunk_size if chunk_size is None else chunk_size + chunk_overlap = self.chunk_overlap if chunk_overlap is None else chunk_overlap + if chunk_size <= 0: + return [text] + chunk_overlap = max(0, min(chunk_overlap, chunk_size - 1)) chunks = [] start = 0 diff --git a/astrbot/core/knowledge_base/chunking/recursive.py b/astrbot/core/knowledge_base/chunking/recursive.py index 3f4aabb57..040f2ffd0 100644 --- a/astrbot/core/knowledge_base/chunking/recursive.py +++ b/astrbot/core/knowledge_base/chunking/recursive.py @@ -11,7 +11,7 @@ def __init__( length_function: Callable[[str], int] = len, is_separator_regex: bool = False, separators: list[str] | None = None, - ): + ) -> None: """初始化递归字符文本分割器 Args: @@ -39,7 +39,13 @@ def __init__( "", # 字符 ] - async def chunk(self, text: str, **kwargs) -> list[str]: + async def chunk( + self, + text: str, + *, + chunk_size: int | None = None, + chunk_overlap: int | None = None, + ) -> list[str]: """递归地将文本分割成块 Args: @@ -54,8 +60,11 @@ async def chunk(self, text: str, **kwargs) -> list[str]: if not text: return [] - overlap = kwargs.get("chunk_overlap", self.chunk_overlap) - chunk_size = kwargs.get("chunk_size", self.chunk_size) + overlap = self.chunk_overlap if chunk_overlap is None else chunk_overlap + chunk_size = self.chunk_size if chunk_size is None else chunk_size + if chunk_size <= 0: + return [text] + overlap = max(0, min(overlap, chunk_size - 1)) text_length = self.length_function(text) if text_length <= chunk_size: diff --git a/astrbot/core/knowledge_base/kb_db_sqlite.py b/astrbot/core/knowledge_base/kb_db_sqlite.py index 5e1db842f..fa8809cb8 100644 --- a/astrbot/core/knowledge_base/kb_db_sqlite.py +++ b/astrbot/core/knowledge_base/kb_db_sqlite.py @@ -1,3 +1,4 @@ +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from pathlib import Path @@ -46,7 +47,7 @@ def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None: ) @asynccontextmanager - async def get_db(self): + async def get_db(self) -> AsyncGenerator[AsyncSession, None]: """获取数据库会话 用法: @@ -253,7 +254,7 @@ async def get_document_with_metadata(self, doc_id: str) -> dict | None: "knowledge_base": row[1], } - async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB): + async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None: """删除单个文档及其相关数据""" # 在知识库表中删除 async with self.get_db() as session, session.begin(): diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 4adfb60b8..37f849bc6 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -3,6 +3,7 @@ import re import time import uuid +from collections.abc import Awaitable, Callable from pathlib import Path import aiofiles @@ -31,7 +32,7 @@ class RateLimiter: """一个简单的速率限制器""" - def __init__(self, max_rpm: int): + def __init__(self, max_rpm: int) -> None: self.max_per_minute = max_rpm self.interval = 60.0 / max_rpm if max_rpm > 0 else 0 self.last_call_time = 0 @@ -116,7 +117,7 @@ def __init__( provider_manager: ProviderManager, kb_root_dir: str, chunker: BaseChunker, - ): + ) -> None: self.kb_db = kb_db self.kb = kb self.prov_mgr = provider_manager @@ -130,7 +131,7 @@ def __init__( self.kb_medias_dir.mkdir(parents=True, exist_ok=True) self.kb_files_dir.mkdir(parents=True, exist_ok=True) - async def initialize(self): + async def initialize(self) -> None: await self._ensure_vec_db() async def get_ep(self) -> EmbeddingProvider: @@ -174,7 +175,7 @@ async def _ensure_vec_db(self) -> FaissVecDB: self.vec_db = vec_db return vec_db - async def delete_vec_db(self): + async def delete_vec_db(self) -> None: """删除知识库的向量数据库和所有相关文件""" import shutil @@ -182,7 +183,7 @@ async def delete_vec_db(self): if self.kb_dir.exists(): shutil.rmtree(self.kb_dir) - async def terminate(self): + async def terminate(self) -> None: if self.vec_db: await self.vec_db.close() @@ -196,7 +197,7 @@ async def upload_document( batch_size: int = 32, tasks_limit: int = 3, max_retries: int = 3, - progress_callback=None, + progress_callback: Callable[[str, int, int], Awaitable[None]] | None = None, pre_chunked_text: list[str] | None = None, ) -> KBDocument: """上传并处理文档(带原子性保证和失败清理) @@ -293,7 +294,7 @@ async def upload_document( await progress_callback("chunking", 100, 100) # 阶段3: 生成向量(带进度回调) - async def embedding_progress_callback(current, total): + async def embedding_progress_callback(current: int, total: int) -> None: if progress_callback: await progress_callback("embedding", current, total) @@ -360,7 +361,7 @@ async def get_document(self, doc_id: str) -> KBDocument | None: doc = await self.kb_db.get_document_by_id(doc_id) return doc - async def delete_document(self, doc_id: str): + async def delete_document(self, doc_id: str) -> None: """删除单个文档及其相关数据""" await self.kb_db.delete_document_by_id( doc_id=doc_id, @@ -372,7 +373,7 @@ async def delete_document(self, doc_id: str): ) await self.refresh_kb() - async def delete_chunk(self, chunk_id: str, doc_id: str): + async def delete_chunk(self, chunk_id: str, doc_id: str) -> None: """删除单个文本块及其相关数据""" vec_db: FaissVecDB = self.vec_db # type: ignore await vec_db.delete(chunk_id) @@ -383,7 +384,7 @@ async def delete_chunk(self, chunk_id: str, doc_id: str): await self.refresh_kb() await self.refresh_document(doc_id) - async def refresh_kb(self): + async def refresh_kb(self) -> None: if self.kb: kb = await self.kb_db.get_kb_by_id(self.kb.kb_id) if kb: diff --git a/astrbot/core/knowledge_base/kb_mgr.py b/astrbot/core/knowledge_base/kb_mgr.py index 2219cc00b..b8584a02c 100644 --- a/astrbot/core/knowledge_base/kb_mgr.py +++ b/astrbot/core/knowledge_base/kb_mgr.py @@ -26,14 +26,14 @@ class KnowledgeBaseManager: def __init__( self, provider_manager: ProviderManager, - ): + ) -> None: Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True) self.provider_manager = provider_manager self._session_deleted_callback_registered = False self.kb_insts: dict[str, KBHelper] = {} - async def initialize(self): + async def initialize(self) -> None: """初始化知识库模块""" try: logger.info("正在初始化知识库模块...") @@ -58,13 +58,13 @@ async def initialize(self): logger.error(f"知识库模块初始化失败: {e}") logger.error(traceback.format_exc()) - async def _init_kb_database(self): + async def _init_kb_database(self) -> None: self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix()) await self.kb_db.initialize() await self.kb_db.migrate_to_v1() logger.info(f"KnowledgeBase database initialized: {DB_PATH}") - async def load_kbs(self): + async def load_kbs(self) -> None: """加载所有知识库实例""" kb_records = await self.kb_db.list_kbs() for record in kb_records: @@ -268,7 +268,7 @@ def _format_context(self, results: list[RetrievalResult]) -> str: return "\n".join(lines) - async def terminate(self): + async def terminate(self) -> None: """终止所有知识库实例,关闭数据库连接""" for kb_id, kb_helper in self.kb_insts.items(): try: diff --git a/astrbot/core/knowledge_base/parsers/url_parser.py b/astrbot/core/knowledge_base/parsers/url_parser.py index f68e2e0c4..2867164a9 100644 --- a/astrbot/core/knowledge_base/parsers/url_parser.py +++ b/astrbot/core/knowledge_base/parsers/url_parser.py @@ -6,7 +6,7 @@ class URLExtractor: """URL 内容提取器,封装了 Tavily API 调用和密钥管理""" - def __init__(self, tavily_keys: list[str]): + def __init__(self, tavily_keys: list[str]) -> None: """ 初始化 URL 提取器 diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 746406e90..a90cbef11 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -44,7 +44,7 @@ def __init__( sparse_retriever: SparseRetriever, rank_fusion: RankFusion, kb_db: KBSQLiteDatabase, - ): + ) -> None: """初始化检索管理器 Args: @@ -195,7 +195,7 @@ async def _dense_retrieve( query: str, kb_ids: list[str], kb_options: dict, - ): + ) -> list[Result]: """稠密检索 (向量相似度) 为每个知识库使用独立的向量数据库进行检索,然后合并结果。 diff --git a/astrbot/core/knowledge_base/retrieval/rank_fusion.py b/astrbot/core/knowledge_base/retrieval/rank_fusion.py index 26203f94b..40afd9748 100644 --- a/astrbot/core/knowledge_base/retrieval/rank_fusion.py +++ b/astrbot/core/knowledge_base/retrieval/rank_fusion.py @@ -31,7 +31,7 @@ class RankFusion: - 使用 Reciprocal Rank Fusion (RRF) 算法 """ - def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60): + def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60) -> None: """初始化结果融合器 Args: diff --git a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py index ea5da1c9e..d453251d1 100644 --- a/astrbot/core/knowledge_base/retrieval/sparse_retriever.py +++ b/astrbot/core/knowledge_base/retrieval/sparse_retriever.py @@ -34,7 +34,7 @@ class SparseRetriever: - 使用 BM25 算法计算相关度 """ - def __init__(self, kb_db: KBSQLiteDatabase): + def __init__(self, kb_db: KBSQLiteDatabase) -> None: """初始化稀疏检索器 Args: diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 806ebcebb..ba2ce0d65 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -87,7 +87,7 @@ class LogBroker: 发布-订阅模式 """ - def __init__(self): + def __init__(self) -> None: self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志 self.subscribers: list[Queue] = [] # 订阅者列表 @@ -102,7 +102,7 @@ def register(self) -> Queue: self.subscribers.append(q) return q - def unregister(self, q: Queue): + def unregister(self, q: Queue) -> None: """取消订阅 Args: @@ -111,7 +111,7 @@ def unregister(self, q: Queue): """ self.subscribers.remove(q) - def publish(self, log_entry: dict): + def publish(self, log_entry: dict) -> None: """发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统 Args: @@ -133,11 +133,11 @@ class LogQueueHandler(logging.Handler): 继承自 logging.Handler """ - def __init__(self, log_broker: LogBroker): + def __init__(self, log_broker: LogBroker) -> None: super().__init__() self.log_broker = log_broker - def emit(self, record): + def emit(self, record) -> None: """日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布 这个方法会在每次日志记录时被调用 @@ -194,7 +194,7 @@ def GetLogger(cls, log_name: str = "default"): class PluginFilter(logging.Filter): """插件过滤器类, 用于标记日志来源是插件还是核心组件""" - def filter(self, record): + def filter(self, record) -> bool: record.plugin_tag = ( "[Plug]" if is_plugin_path(record.pathname) else "[Core]" ) @@ -206,7 +206,7 @@ class FileNameFilter(logging.Filter): """ # 获取这个文件和父文件夹的名字:. 并且去除 .py - def filter(self, record): + def filter(self, record) -> bool: dirname = os.path.dirname(record.pathname) record.filename = ( os.path.basename(dirname) @@ -219,7 +219,7 @@ class LevelNameFilter(logging.Filter): """短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写""" # 添加短日志级别名称 - def filter(self, record): + def filter(self, record) -> bool: record.short_levelname = get_short_level_name(record.levelname) return True @@ -233,7 +233,7 @@ def filter(self, record): return logger @classmethod - def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker): + def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker) -> None: """设置队列处理器, 用于将日志消息发送到 LogBroker Args: diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 0e7b3bab6..4dfadfc2f 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -25,6 +25,7 @@ import base64 import json import os +import typing as T import uuid from enum import Enum @@ -66,10 +67,10 @@ class ComponentType(str, Enum): class BaseMessageComponent(BaseModel): type: ComponentType - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - def toDict(self): + def toDict(self) -> dict: data = {} for k, v in self.__dict__.items(): if k == "type" or v is None: @@ -89,13 +90,13 @@ class Plain(BaseMessageComponent): text: str convert: bool | None = True - def __init__(self, text: str, convert: bool = True, **_): + def __init__(self, text: str, convert: bool = True, **_: object) -> None: super().__init__(text=text, convert=convert, **_) def toDict(self): return {"type": "text", "data": {"text": self.text.strip()}} - async def to_dict(self): + async def to_dict(self) -> dict: return {"type": "text", "data": {"text": self.text}} @@ -103,7 +104,7 @@ class Face(BaseMessageComponent): type = ComponentType.Face id: int - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -118,7 +119,7 @@ class Record(BaseMessageComponent): # 额外 path: str | None - def __init__(self, file: str | None, **_): + def __init__(self, file: str | None, **_: object) -> None: for k in _: if k == "url": pass @@ -126,17 +127,17 @@ def __init__(self, file: str | None, **_): super().__init__(file=file, **_) @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: object) -> "Record": return Record(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromURL(url: str, **_): + def fromURL(url: str, **_: object) -> "Record": if url.startswith("http://") or url.startswith("https://"): return Record(file=url, **_) raise Exception("not a valid url") @staticmethod - def fromBase64(bs64_data: str, **_): + def fromBase64(bs64_data: str, **_: object) -> "Record": return Record(file=f"base64://{bs64_data}", **_) async def convert_to_file_path(self) -> str: @@ -221,15 +222,15 @@ class Video(BaseMessageComponent): # 额外 path: str | None = "" - def __init__(self, file: str, **_): + def __init__(self, file: str, **_: object) -> None: super().__init__(file=file, **_) @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: object) -> "Video": return Video(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromURL(url: str, **_): + def fromURL(url: str, **_: object) -> "Video": if url.startswith("http://") or url.startswith("https://"): return Video(file=url, **_) raise Exception("not a valid url") @@ -255,7 +256,7 @@ async def convert_to_file_path(self) -> str: return os.path.abspath(url) raise Exception(f"not a valid file: {url}") - async def register_to_file_service(self): + async def register_to_file_service(self) -> str: """将视频注册到文件服务。 Returns: @@ -278,7 +279,7 @@ async def register_to_file_service(self): return f"{callback_host}/api/file/{token}" - async def to_dict(self): + async def to_dict(self) -> dict: """需要和 toDict 区分开,toDict 是同步方法""" url_or_path = self.file if url_or_path.startswith("http"): @@ -303,10 +304,10 @@ class At(BaseMessageComponent): qq: int | str # 此处str为all时代表所有人 name: str | None = "" - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) - def toDict(self): + def toDict(self) -> dict: return { "type": "at", "data": {"qq": str(self.qq)}, @@ -316,28 +317,28 @@ def toDict(self): class AtAll(At): qq: str = "all" - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) class RPS(BaseMessageComponent): # TODO type = ComponentType.RPS - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) class Dice(BaseMessageComponent): # TODO type = ComponentType.Dice - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) class Shake(BaseMessageComponent): # TODO type = ComponentType.Shake - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -348,7 +349,7 @@ class Share(BaseMessageComponent): content: str | None = "" image: str | None = "" - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -357,7 +358,7 @@ class Contact(BaseMessageComponent): # TODO _type: str # type 字段冲突 id: int | None = 0 - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -368,7 +369,7 @@ class Location(BaseMessageComponent): # TODO title: str | None = "" content: str | None = "" - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -382,7 +383,7 @@ class Music(BaseMessageComponent): content: str | None = "" image: str | None = "" - def __init__(self, **_): + def __init__(self, **_: object) -> None: # for k in _.keys(): # if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]: # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") @@ -402,29 +403,29 @@ class Image(BaseMessageComponent): path: str | None = "" file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识 - def __init__(self, file: str | None, **_): + def __init__(self, file: str | None, **_: object) -> None: super().__init__(file=file, **_) @staticmethod - def fromURL(url: str, **_): + def fromURL(url: str, **_: object) -> "Image": if url.startswith("http://") or url.startswith("https://"): return Image(file=url, **_) raise Exception("not a valid url") @staticmethod - def fromFileSystem(path, **_): + def fromFileSystem(path: str, **_: object) -> "Image": return Image(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod - def fromBase64(base64: str, **_): + def fromBase64(base64: str, **_: object) -> "Image": return Image(f"base64://{base64}", **_) @staticmethod - def fromBytes(byte: bytes): + def fromBytes(byte: bytes) -> "Image": return Image.fromBase64(base64.b64encode(byte).decode()) @staticmethod - def fromIO(IO): + def fromIO(IO: T.BinaryIO) -> "Image": return Image.fromBytes(IO.read()) async def convert_to_file_path(self) -> str: @@ -525,16 +526,16 @@ class Reply(BaseMessageComponent): seq: int | None = 0 """deprecated""" - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) class Poke(BaseMessageComponent): - type: str = ComponentType.Poke + type = ComponentType.Poke id: int | None = 0 qq: int | None = 0 - def __init__(self, type: str, **_): + def __init__(self, type: str, **_: object) -> None: type = f"Poke:{type}" super().__init__(type=type, **_) @@ -543,7 +544,7 @@ class Forward(BaseMessageComponent): type = ComponentType.Forward id: str - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) @@ -558,13 +559,13 @@ class Node(BaseMessageComponent): seq: str | list | None = "" # 忽略 time: int | None = 0 # 忽略 - def __init__(self, content: list[BaseMessageComponent], **_): + def __init__(self, content: list[BaseMessageComponent], **_: object) -> None: if isinstance(content, Node): # back content = [content] super().__init__(content=content, **_) - async def to_dict(self): + async def to_dict(self) -> dict: data_content = [] for comp in self.content: if isinstance(comp, (Image, Record)): @@ -605,10 +606,10 @@ class Nodes(BaseMessageComponent): type = ComponentType.Nodes nodes: list[Node] - def __init__(self, nodes: list[Node], **_): + def __init__(self, nodes: list[Node], **_: object) -> None: super().__init__(nodes=nodes, **_) - def toDict(self): + def toDict(self) -> dict: """Deprecated. Use to_dict instead""" ret = { "messages": [], @@ -632,7 +633,7 @@ class Json(BaseMessageComponent): data: str | dict resid: int | None = 0 - def __init__(self, data, **_): + def __init__(self, data: str | dict, **_: object) -> None: if isinstance(data, dict): data = json.dumps(data) super().__init__(data=data, **_) @@ -651,7 +652,7 @@ class File(BaseMessageComponent): file_: str | None = "" # 本地路径 url: str | None = "" # url - def __init__(self, name: str, file: str = "", url: str = ""): + def __init__(self, name: str, file: str = "", url: str = "") -> None: """文件消息段。""" super().__init__(name=name, file_=file, url=url) @@ -687,7 +688,7 @@ def file(self) -> str: return "" @file.setter - def file(self, value: str): + def file(self, value: str) -> None: """向前兼容, 设置file属性, 传入的参数可能是文件路径或URL Args: @@ -722,7 +723,7 @@ async def get_file(self, allow_return_url: bool = False) -> str: return "" - async def _download_file(self): + async def _download_file(self) -> None: """下载文件""" if not self.url: raise ValueError("Download failed: No URL provided in File component.") @@ -737,7 +738,7 @@ async def _download_file(self): await download_file(self.url, file_path) self.file_ = os.path.abspath(file_path) - async def register_to_file_service(self): + async def register_to_file_service(self) -> str: """将文件注册到文件服务。 Returns: @@ -760,7 +761,7 @@ async def register_to_file_service(self): return f"{callback_host}/api/file/{token}" - async def to_dict(self): + async def to_dict(self) -> dict: """需要和 toDict 区分开,toDict 是同步方法""" url_or_path = await self.get_file(allow_return_url=True) if url_or_path.startswith("http"): @@ -787,7 +788,7 @@ class WechatEmoji(BaseMessageComponent): md5_len: int | None = 0 cdnurl: str | None = "" - def __init__(self, **_): + def __init__(self, **_: object) -> None: super().__init__(**_) diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index ed4e25f43..2e1527e27 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -2,7 +2,7 @@ from collections.abc import AsyncGenerator from dataclasses import dataclass, field -from typing_extensions import deprecated +from typing_extensions import Self, deprecated from astrbot.core.message.components import ( At, @@ -29,7 +29,7 @@ class MessageChain: type: str | None = None """消息链承载的消息的类型。可选,用于让消息平台区分不同业务场景的消息链。""" - def message(self, message: str): + def message(self, message: str) -> Self: """添加一条文本消息到消息链 `chain` 中。 Example: @@ -40,7 +40,7 @@ def message(self, message: str): self.chain.append(Plain(message)) return self - def at(self, name: str, qq: str | int): + def at(self, name: str, qq: str | int) -> Self: """添加一条 At 消息到消息链 `chain` 中。 Example: @@ -51,7 +51,7 @@ def at(self, name: str, qq: str | int): self.chain.append(At(name=name, qq=qq)) return self - def at_all(self): + def at_all(self) -> Self: """添加一条 AtAll 消息到消息链 `chain` 中。 Example: @@ -63,7 +63,7 @@ def at_all(self): return self @deprecated("请使用 message 方法代替。") - def error(self, message: str): + def error(self, message: str) -> Self: """添加一条错误消息到消息链 `chain` 中 Example: @@ -73,7 +73,7 @@ def error(self, message: str): self.chain.append(Plain(message)) return self - def url_image(self, url: str): + def url_image(self, url: str) -> Self: """添加一条图片消息(https 链接)到消息链 `chain` 中。 Note: @@ -86,7 +86,7 @@ def url_image(self, url: str): self.chain.append(Image.fromURL(url)) return self - def file_image(self, path: str): + def file_image(self, path: str) -> Self: """添加一条图片消息(本地文件路径)到消息链 `chain` 中。 Note: @@ -98,7 +98,7 @@ def file_image(self, path: str): self.chain.append(Image.fromFileSystem(path)) return self - def base64_image(self, base64_str: str): + def base64_image(self, base64_str: str) -> Self: """添加一条图片消息(base64 编码字符串)到消息链 `chain` 中。 Example: @@ -107,7 +107,7 @@ def base64_image(self, base64_str: str): self.chain.append(Image.fromBase64(base64_str)) return self - def use_t2i(self, use_t2i: bool): + def use_t2i(self, use_t2i: bool) -> Self: """设置是否使用文本转图片服务。 Args: @@ -121,7 +121,7 @@ def get_plain_text(self) -> str: """获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) - def squash_plain(self): + def squash_plain(self) -> Self | None: """将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。""" if not self.chain: return None @@ -195,12 +195,12 @@ class MessageEventResult(MessageChain): async_stream: AsyncGenerator | None = None """异步流""" - def stop_event(self) -> "MessageEventResult": + def stop_event(self) -> Self: """终止事件传播。""" self.result_type = EventResultType.STOP return self - def continue_event(self) -> "MessageEventResult": + def continue_event(self) -> Self: """继续事件传播。""" self.result_type = EventResultType.CONTINUE return self @@ -209,12 +209,12 @@ def is_stopped(self) -> bool: """是否终止事件传播。""" return self.result_type == EventResultType.STOP - def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult": + def set_async_stream(self, stream: AsyncGenerator) -> Self: """设置异步流。""" self.async_stream = stream return self - def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult": + def set_result_content_type(self, typ: ResultContentType) -> Self: """设置事件处理的结果类型。 Args: diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py index b2d2c6be1..8b32a1d0e 100644 --- a/astrbot/core/persona_mgr.py +++ b/astrbot/core/persona_mgr.py @@ -16,7 +16,7 @@ class PersonaManager: - def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager): + def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager) -> None: self.db = db_helper self.acm = acm default_ps = acm.default_conf.get("provider_settings", {}) @@ -28,7 +28,7 @@ def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager): self.selected_default_persona_v3: Personality | None = None self.persona_v3_config: list[dict] = [] - async def initialize(self): + async def initialize(self) -> None: self.personas = await self.get_all_personas() self.get_v3_persona_data() logger.info(f"已加载 {len(self.personas)} 个人格。") @@ -57,7 +57,7 @@ async def get_default_persona_v3( except Exception: return DEFAULT_PERSONALITY - async def delete_persona(self, persona_id: str): + async def delete_persona(self, persona_id: str) -> None: """删除指定 persona""" if not await self.db.get_persona_by_id(persona_id): raise ValueError(f"Persona with ID {persona_id} does not exist.") diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index b089c48e0..19037eb08 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -16,7 +16,7 @@ class ContentSafetyCheckStage(Stage): 当前只会检查文本的。 """ - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: config = ctx.astrbot_config["content_safety"] self.strategy_selector = StrategySelector(config) diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 1f5ba43a0..49fe3be3a 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -12,8 +12,8 @@ async def call_handler( event: AstrMessageEvent, handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]], - *args, - **kwargs, + *args: object, + **kwargs: object, ) -> T.AsyncGenerator[T.Any, None]: """执行事件处理函数并处理其返回结果 @@ -75,8 +75,8 @@ async def call_handler( async def call_event_hook( event: AstrMessageEvent, hook_type: EventType, - *args, - **kwargs, + *args: object, + **kwargs: object, ) -> bool: """调用事件钩子函数 diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 7e3305f55..9178752e5 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -98,7 +98,7 @@ async def _apply_kb( self, event: AstrMessageEvent, req: ProviderRequest, - ): + ) -> None: """Apply knowledge base context to the provider request""" if not self.kb_agentic_mode: if req.prompt is None: @@ -126,7 +126,7 @@ async def _apply_file_extract( self, event: AstrMessageEvent, req: ProviderRequest, - ): + ) -> None: """Apply file extract to the provider request""" file_paths = [] file_names = [] @@ -198,7 +198,7 @@ def _modalities_fix( self, provider: Provider, req: ProviderRequest, - ): + ) -> None: """检查提供商的模态能力,清理请求中的不支持内容""" if req.image_urls: provider_cfg = provider.provider_config.get("modalities", ["image"]) @@ -218,7 +218,7 @@ def _plugin_tool_fix( self, event: AstrMessageEvent, req: ProviderRequest, - ): + ) -> None: """根据事件中的插件设置,过滤请求中的工具列表""" if event.plugins_name is not None and req.func_tool: new_tool_set = ToolSet() @@ -238,7 +238,7 @@ async def _handle_webchat( event: AstrMessageEvent, req: ProviderRequest, prov: Provider, - ): + ) -> None: """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title""" if not req.conversation: return @@ -294,7 +294,7 @@ async def _save_to_history( event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse | None, - ): + ) -> None: if ( not req or not req.conversation diff --git a/astrbot/core/pipeline/process_stage/utils.py b/astrbot/core/pipeline/process_stage/utils.py index 24e052e1e..d6ad8f2c1 100644 --- a/astrbot/core/pipeline/process_stage/utils.py +++ b/astrbot/core/pipeline/process_stage/utils.py @@ -31,13 +31,17 @@ class KnowledgeBaseQueryTool(FunctionTool[AstrAgentContext]): ) async def call( - self, context: ContextWrapper[AstrAgentContext], **kwargs + self, context: ContextWrapper[AstrAgentContext], **kwargs: object ) -> ToolExecResult: query = kwargs.get("query", "") if not query: return "error: Query parameter is empty." + + # 显式转换为 str,解决类型检查报错 "object cannot be assigned to str" + query_str = str(query) + result = await retrieve_knowledge_base( - query=kwargs.get("query", ""), + query=query_str, umo=context.context.event.unified_msg_origin, context=context.context.context, ) diff --git a/astrbot/core/pipeline/rate_limit_check/stage.py b/astrbot/core/pipeline/rate_limit_check/stage.py index 64e21dd7e..392bceff3 100644 --- a/astrbot/core/pipeline/rate_limit_check/stage.py +++ b/astrbot/core/pipeline/rate_limit_check/stage.py @@ -19,7 +19,7 @@ class RateLimitStage(Stage): 如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。 """ - def __init__(self): + def __init__(self) -> None: # 存储每个会话的请求时间队列 self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque) # 为每个会话设置一个锁,避免并发冲突 diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index bfbcaf33a..97fd184bb 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -35,7 +35,7 @@ class RespondStage(Stage): Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情 } - async def initialize(self, ctx: PipelineContext): + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.config = ctx.astrbot_config self.platform_settings: dict = self.config.get("platform_settings", {}) @@ -91,7 +91,7 @@ async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float: # random return random.uniform(self.interval[0], self.interval[1]) - async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]): + async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bool: """检查消息链是否为空 Args: @@ -136,7 +136,7 @@ def _extract_comp( raw_chain: list[BaseMessageComponent], extract_types: set[ComponentType], modify_raw_chain: bool = True, - ): + ) -> list[BaseMessageComponent]: extracted = [] if modify_raw_chain: remaining = [] diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 7647ef022..8b6ff630d 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -3,6 +3,7 @@ import time import traceback from collections.abc import AsyncGenerator +from typing import cast from astrbot.core import file_token_service, html_renderer, logger from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply @@ -20,7 +21,9 @@ @register_stage class ResultDecorateStage(Stage): - async def initialize(self, ctx: PipelineContext): + content_safe_check_stage: ContentSafetyCheckStage | None + + async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"] self.reply_with_mention = ctx.astrbot_config["platform_settings"][ @@ -95,7 +98,9 @@ async def initialize(self, ctx: PipelineContext): if self.content_safe_check_reply: for stage_cls in registered_stages: if stage_cls.__name__ == "ContentSafetyCheckStage": - self.content_safe_check_stage = stage_cls() + self.content_safe_check_stage = cast( + ContentSafetyCheckStage, stage_cls() + ) await self.content_safe_check_stage.initialize(ctx) def _split_text_by_words(self, text: str) -> list[str]: @@ -148,7 +153,7 @@ async def process( if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage): async for _ in self.content_safe_check_stage.process( event, - check_text=text, + check_text=text, # type:ignore ): yield diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 5fb3034f5..dea31dcb6 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -15,21 +15,23 @@ class PipelineScheduler: """管道调度器,负责调度各个阶段的执行""" - def __init__(self, context: PipelineContext): + def __init__(self, context: PipelineContext) -> None: registered_stages.sort( key=lambda x: STAGES_ORDER.index(x.__name__), ) # 按照顺序排序 self.ctx = context # 上下文对象 self.stages = [] # 存储阶段实例 - async def initialize(self): + async def initialize(self) -> None: """初始化管道调度器时, 初始化所有阶段""" for stage_cls in registered_stages: stage_instance = stage_cls() # 创建实例 await stage_instance.initialize(self.ctx) self.stages.append(stage_instance) - async def _process_stages(self, event: AstrMessageEvent, from_stage=0): + async def _process_stages( + self, event: AstrMessageEvent, from_stage: int = 0 + ) -> None: """依次执行各个阶段 Args: @@ -72,7 +74,7 @@ async def _process_stages(self, event: AstrMessageEvent, from_stage=0): logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") break - async def execute(self, event: AstrMessageEvent): + async def execute(self, event: AstrMessageEvent) -> None: """执行 pipeline Args: diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index 74aca4ef1..e3a91e1be 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -10,7 +10,7 @@ registered_stages: list[type[Stage]] = [] # 维护了所有已注册的 Stage 实现类类型 -def register_stage(cls): +def register_stage(cls: type[Stage]) -> type[Stage]: """一个简单的装饰器,用于注册 pipeline 包下的 Stage 实现类""" registered_stages.append(cls) return cls @@ -33,7 +33,7 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: + ) -> None | AsyncGenerator[None]: """处理事件 Args: diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index f6eda07a9..c1aa2e291 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -35,7 +35,7 @@ def __init__( message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, - ): + ) -> None: self.message_str = message_str """纯文本的消息""" self.message_obj = message_obj @@ -157,7 +157,7 @@ def get_sender_name(self) -> str: return self.message_obj.sender.nickname return "" - def set_extra(self, key, value): + def set_extra(self, key, value) -> None: """设置额外的信息。""" self._extras[key] = value @@ -167,7 +167,7 @@ def get_extra(self, key: str | None = None, default=None) -> Any: return self._extras return self._extras.get(key, default) - def clear_extra(self): + def clear_extra(self) -> None: """清除额外的信息。""" logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}") self._extras.clear() @@ -200,7 +200,7 @@ async def send_streaming( self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False, - ): + ) -> None: """发送流式消息到消息平台,使用异步生成器。 目前仅支持: telegram,qq official 私聊。 Fallback仅支持 aiocqhttp。 @@ -210,13 +210,13 @@ async def send_streaming( ) self._has_send_oper = True - async def _pre_send(self): + async def _pre_send(self) -> None: """调度器会在执行 send() 前调用该方法 deprecated in v3.5.18""" - async def _post_send(self): + async def _post_send(self) -> None: """调度器会在执行 send() 后调用该方法 deprecated in v3.5.18""" - def set_result(self, result: MessageEventResult | str): + def set_result(self, result: MessageEventResult | str) -> None: """设置消息事件的结果。 Note: @@ -245,14 +245,14 @@ async def check_count(self, event: AstrMessageEvent): result.chain = [] self._result = result - def stop_event(self): + def stop_event(self) -> None: """终止事件传播。""" if self._result is None: self.set_result(MessageEventResult().stop_event()) else: self._result.stop_event() - def continue_event(self): + def continue_event(self) -> None: """继续事件传播。""" if self._result is None: self.set_result(MessageEventResult().continue_event()) @@ -265,7 +265,7 @@ def is_stopped(self) -> bool: return False # 默认是继续传播 return self._result.is_stopped() - def should_call_llm(self, call_llm: bool): + def should_call_llm(self, call_llm: bool) -> None: """是否在此消息事件中禁止默认的 LLM 请求。 只会阻止 AstrBot 默认的 LLM 请求链路,不会阻止插件中的 LLM 请求。 @@ -276,7 +276,7 @@ def get_result(self) -> MessageEventResult | None: """获取消息事件的结果。""" return self._result - def clear_result(self): + def clear_result(self) -> None: """清除消息事件的结果。""" self._result = None @@ -368,7 +368,7 @@ def request_llm( """平台适配器""" - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息到消息平台。 Args: @@ -387,7 +387,7 @@ async def send(self, message: MessageChain): ) self._has_send_oper = True - async def react(self, emoji: str): + async def react(self, emoji: str) -> None: """对消息添加表情回应。 默认实现为发送一条包含该表情的消息。 diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 253963322..3db53fd48 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -11,7 +11,7 @@ class MessageMember: user_id: str # 发送者id nickname: str | None = None - def __str__(self): + def __str__(self) -> str: # 使用 f-string 来构建返回的字符串表示形式 return ( f"User ID: {self.user_id}," @@ -34,7 +34,7 @@ class Group: members: list[MessageMember] | None = None """所有群成员""" - def __str__(self): + def __str__(self) -> str: # 使用 f-string 来构建返回的字符串表示形式 return ( f"Group ID: {self.group_id}\n" @@ -78,7 +78,7 @@ def group_id(self) -> str: return "" @group_id.setter - def group_id(self, value: str | None): + def group_id(self, value: str | None) -> None: """设置 group_id""" if value: if self.group: diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index f4313f642..7af97745f 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -13,7 +13,7 @@ class PlatformManager: - def __init__(self, config: AstrBotConfig, event_queue: Queue): + def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None: self.platform_insts: list[Platform] = [] """加载的 Platform 的实例""" @@ -27,7 +27,7 @@ def __init__(self, config: AstrBotConfig, event_queue: Queue): 约定整个项目中对 unique_session 的引用都从 default 的配置中获取""" self.event_queue = event_queue - async def initialize(self): + async def initialize(self) -> None: """初始化所有平台适配器""" for platform in self.platforms_config: try: @@ -47,7 +47,7 @@ async def initialize(self): ), ) - async def load_platform(self, platform_config: dict): + async def load_platform(self, platform_config: dict) -> None: """实例化一个平台""" # 动态导入 try: @@ -153,7 +153,9 @@ async def load_platform(self, platform_config: dict): except Exception: logger.error(traceback.format_exc()) - async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None): + async def _task_wrapper( + self, task: asyncio.Task, platform: Platform | None = None + ) -> None: # 设置平台状态为运行中 if platform: platform.status = PlatformStatus.RUNNING @@ -175,7 +177,7 @@ async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = No if platform: platform.record_error(error_msg, tb_str) - async def reload(self, platform_config: dict): + async def reload(self, platform_config: dict) -> None: await self.terminate_platform(platform_config["id"]) if platform_config["enable"]: await self.load_platform(platform_config) @@ -186,7 +188,7 @@ async def reload(self, platform_config: dict): if key not in config_ids: await self.terminate_platform(key) - async def terminate_platform(self, platform_id: str): + async def terminate_platform(self, platform_id: str) -> None: if platform_id in self._inst_map: logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...") @@ -208,7 +210,7 @@ async def terminate_platform(self, platform_id: str): if getattr(inst, "terminate", None): await inst.terminate() - async def terminate(self): + async def terminate(self) -> None: for inst in self.platform_insts: if getattr(inst, "terminate", None): await inst.terminate() diff --git a/astrbot/core/platform/message_session.py b/astrbot/core/platform/message_session.py index bca5300b8..87c397af4 100644 --- a/astrbot/core/platform/message_session.py +++ b/astrbot/core/platform/message_session.py @@ -15,7 +15,7 @@ class MessageSession: session_id: str platform_id: str | None = None - def __str__(self): + def __str__(self) -> str: return f"{self.platform_id}:{self.message_type.value}:{self.session_id}" def __post_init__(self): diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index c2e55fb63..3bcb6fdde 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -34,7 +34,7 @@ class PlatformError: class Platform(abc.ABC): - def __init__(self, config: dict, event_queue: Queue): + def __init__(self, config: dict, event_queue: Queue) -> None: super().__init__() # 平台配置 self.config = config @@ -53,7 +53,7 @@ def status(self) -> PlatformStatus: return self._status @status.setter - def status(self, value: PlatformStatus): + def status(self, value: PlatformStatus) -> None: """设置平台运行状态""" self._status = value if value == PlatformStatus.RUNNING and self._started_at is None: @@ -69,12 +69,12 @@ def last_error(self) -> PlatformError | None: """获取最近的错误""" return self._errors[-1] if self._errors else None - def record_error(self, message: str, traceback_str: str | None = None): + def record_error(self, message: str, traceback_str: str | None = None) -> None: """记录一个错误""" self._errors.append(PlatformError(message=message, traceback=traceback_str)) self._status = PlatformStatus.ERROR - def clear_errors(self): + def clear_errors(self) -> None: """清除错误记录""" self._errors.clear() if self._status == PlatformStatus.ERROR: @@ -112,7 +112,7 @@ def run(self) -> Coroutine[Any, Any, None]: """得到一个平台的运行实例,需要返回一个协程对象。""" raise NotImplementedError - async def terminate(self): + async def terminate(self) -> None: """终止一个平台的运行实例。""" @abc.abstractmethod @@ -131,11 +131,11 @@ async def send_by_session( """ await Metric.upload(msg_event_tick=1, adapter_name=self.meta().name) - def commit_event(self, event: AstrMessageEvent): + def commit_event(self, event: AstrMessageEvent) -> None: """提交一个事件到事件队列。""" self._event_queue.put_nowait(event) - def get_client(self): + def get_client(self) -> None: """获取平台的客户端对象。""" async def webhook_callback(self, request: Any) -> Any: diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 293b462d3..1ad26c265 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -26,7 +26,7 @@ def __init__( platform_meta, session_id, bot: CQHttp, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot @@ -72,7 +72,7 @@ async def _dispatch_send( is_group: bool, session_id: str | None, messages: list[dict], - ): + ) -> None: # session_id 必须是纯数字字符串 session_id_int = ( int(session_id) if session_id and session_id.isdigit() else None @@ -97,7 +97,7 @@ async def send_message( event: Event | None = None, is_group: bool = False, session_id: str | None = None, - ): + ) -> None: """发送消息至 QQ 协议端(aiocqhttp)。 Args: @@ -143,7 +143,7 @@ async def send_message( await cls._dispatch_send(bot, event, is_group, session_id, messages) await asyncio.sleep(0.5) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息""" event = getattr(self.message_obj, "raw_message", None) diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index 52dd21d56..b8ac5ceb6 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -62,38 +62,38 @@ def __init__( ) @self.bot.on_request() - async def request(event: Event): + async def request(event: Event) -> None: abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_notice() - async def notice(event: Event): + async def notice(event: Event) -> None: abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_message("group") - async def group(event: Event): + async def group(event: Event) -> None: abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_message("private") - async def private(event: Event): + async def private(event: Event) -> None: abm = await self.convert_message(event) if abm: await self.handle_msg(abm) @self.bot.on_websocket_connection - def on_websocket_connection(_): + def on_websocket_connection(_) -> None: logger.info("aiocqhttp(OneBot v11) 适配器已连接。") async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: is_group = session.message_type == MessageType.GROUP_MESSAGE if is_group: session_id = session.session_id.split("_")[-1] @@ -416,17 +416,17 @@ def run(self) -> Awaitable[Any]: self.shutdown_event = asyncio.Event() return coro - async def terminate(self): + async def terminate(self) -> None: self.shutdown_event.set() - async def shutdown_trigger_placeholder(self): + async def shutdown_trigger_placeholder(self) -> None: await self.shutdown_event.wait() logger.info("aiocqhttp 适配器已被关闭") def meta(self) -> PlatformMetadata: return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = AiocqhttpMessageEvent( message_str=message.message_str, message_obj=message, diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 6f9e25df4..25de33808 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -2,7 +2,7 @@ import os import threading import uuid -from typing import cast +from typing import NoReturn, cast import aiohttp import dingtalk_stream @@ -90,7 +90,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> NoReturn: raise NotImplementedError("钉钉机器人适配器不支持 send_by_session") def meta(self) -> PlatformMetadata: @@ -222,7 +222,7 @@ async def get_access_token(self) -> str: return "" return (await resp.json())["data"]["accessToken"] - async def handle_msg(self, abm: AstrBotMessage): + async def handle_msg(self, abm: AstrBotMessage) -> None: event = DingtalkMessageEvent( message_str=abm.message_str, message_obj=abm, @@ -233,10 +233,10 @@ async def handle_msg(self, abm: AstrBotMessage): self._event_queue.put_nowait(event) - async def run(self): + async def run(self) -> None: # await self.client_.start() # 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。 - def start_client(loop: asyncio.AbstractEventLoop): + def start_client(loop: asyncio.AbstractEventLoop) -> None: try: self._shutdown_event = threading.Event() task = loop.create_task(self.client_.start()) @@ -252,8 +252,8 @@ def start_client(loop: asyncio.AbstractEventLoop): loop = asyncio.get_event_loop() await loop.run_in_executor(None, start_client, loop) - async def terminate(self): - def monkey_patch_close(): + async def terminate(self) -> None: + def monkey_patch_close() -> NoReturn: raise KeyboardInterrupt("Graceful shutdown") if self.client_.websocket is not None: diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index d520189d8..db027125a 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -16,7 +16,7 @@ def __init__( platform_meta, session_id, client: dingtalk_stream.ChatbotHandler, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -24,7 +24,7 @@ async def send_with_client( self, client: dingtalk_stream.ChatbotHandler, message: MessageChain, - ): + ) -> None: for segment in message.chain: if isinstance(segment, Comp.Plain): segment.text = segment.text.strip() @@ -64,7 +64,7 @@ async def send_with_client( logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送") continue - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: await self.send_with_client(self.client, message) await super().send(message) diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index ac0610f2a..ebd32c471 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -15,7 +15,7 @@ class DiscordBotClient(discord.Bot): """Discord客户端封装""" - def __init__(self, token: str, proxy: str | None = None): + def __init__(self, token: str, proxy: str | None = None) -> None: self.token = token self.proxy = proxy @@ -32,7 +32,7 @@ def __init__(self, token: str, proxy: str | None = None): self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None self._ready_once_fired = False - async def on_ready(self): + async def on_ready(self) -> None: """当机器人成功连接并准备就绪时触发""" if self.user is None: logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)") @@ -93,7 +93,7 @@ def _create_interaction_data(self, interaction: discord.Interaction) -> dict: "type": "interaction", } - async def on_message(self, message: discord.Message): + async def on_message(self, message: discord.Message) -> None: """当接收到消息时触发""" if message.author.bot: return @@ -130,12 +130,12 @@ def _extract_interaction_content(self, interaction: discord.Interaction) -> str: return str(interaction_data) - async def start_polling(self): + async def start_polling(self) -> None: """开始轮询消息,这是个阻塞方法""" await self.start(self.token) @override - async def close(self): + async def close(self) -> None: """关闭客户端""" if not self.is_closed(): await super().close() diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index f875652a0..433509f5e 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -19,7 +19,7 @@ def __init__( image: str | None = None, footer: str | None = None, fields: list[dict] | None = None, - ): + ) -> None: self.title = title self.description = description self.color = color @@ -71,7 +71,7 @@ def __init__( emoji: str | None = None, url: str | None = None, disabled: bool = False, - ): + ) -> None: self.label = label self.custom_id = custom_id self.style = style @@ -85,7 +85,7 @@ class DiscordReference(BaseMessageComponent): type: str = "discord_reference" - def __init__(self, message_id: str, channel_id: str): + def __init__(self, message_id: str, channel_id: str) -> None: self.message_id = message_id self.channel_id = channel_id @@ -99,7 +99,7 @@ def __init__( self, components: list[BaseMessageComponent] | None = None, timeout: float | None = None, - ): + ) -> None: self.components = components or [] self.timeout = timeout diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 50aa0fe6f..ce25da419 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -60,7 +60,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: """通过会话发送消息""" if self.client.user is None: logger.error( @@ -122,11 +122,11 @@ def meta(self) -> PlatformMetadata: ) @override - async def run(self): + async def run(self) -> None: """主要运行逻辑""" # 初始化回调函数 - async def on_received(message_data): + async def on_received(message_data) -> None: logger.debug(f"[Discord] 收到消息: {message_data}") if self.client_self_id is None: self.client_self_id = message_data.get("bot_id") @@ -143,7 +143,7 @@ async def on_received(message_data): self.client = DiscordBotClient(token, proxy) self.client.on_message_received = on_received - async def callback(): + async def callback() -> None: if self.enable_command_register: await self._collect_and_register_commands() if self.activity_name: @@ -251,7 +251,7 @@ async def convert_message(self, data: dict) -> AstrBotMessage: # 由于 on_interaction 已被禁用,我们只处理普通消息 return self._convert_message_to_abm(data) - async def handle_msg(self, message: AstrBotMessage, followup_webhook=None): + async def handle_msg(self, message: AstrBotMessage, followup_webhook=None) -> None: """处理消息""" message_event = DiscordPlatformEvent( message_str=message.message_str, @@ -323,7 +323,7 @@ async def handle_msg(self, message: AstrBotMessage, followup_webhook=None): self.commit_event(message_event) @override - async def terminate(self): + async def terminate(self) -> None: """终止适配器""" logger.info("[Discord] 正在终止适配器... (step 1: cancel polling task)") self.shutdown_event.set() @@ -358,11 +358,11 @@ async def terminate(self): logger.warning(f"[Discord] 客户端关闭异常: {e}") logger.info("[Discord] 适配器已终止。") - def register_handler(self, handler_info): + def register_handler(self, handler_info) -> None: """注册处理器信息""" self.registered_handlers.append(handler_info) - async def _collect_and_register_commands(self): + async def _collect_and_register_commands(self) -> None: """收集所有指令并注册到Discord""" logger.info("[Discord] 开始收集并注册斜杠指令...") registered_commands = [] @@ -418,7 +418,7 @@ def _create_dynamic_callback(self, cmd_name: str): async def dynamic_callback( ctx: discord.ApplicationContext, params: str | None = None - ): + ) -> None: # 将平台特定的前缀'/'剥离,以适配通用的CommandFilter logger.debug(f"[Discord] 回调函数触发: {cmd_name}") logger.debug(f"[Discord] 回调函数参数: {ctx}") diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 053018225..02d4dae86 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -28,7 +28,7 @@ class DiscordViewComponent(BaseMessageComponent): type: str = "discord_view" - def __init__(self, view: discord.ui.View): + def __init__(self, view: discord.ui.View) -> None: self.view = view @@ -41,12 +41,12 @@ def __init__( session_id: str, client: DiscordBotClient, interaction_followup_webhook: discord.Webhook | None = None, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client self.interaction_followup_webhook = interaction_followup_webhook - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息到Discord平台""" # 解析消息链为 Discord 所需的对象 try: @@ -267,7 +267,7 @@ async def _parse_to_discord( content = content[:2000] return content, files, view, embeds, reference_message_id - async def react(self, emoji: str): + async def react(self, emoji: str) -> None: """对原消息添加反应""" try: if hasattr(self.message_obj, "raw_message") and hasattr( diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index 08df1f359..c2830f748 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -58,10 +58,10 @@ def __init__( logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。") # 初始化 WebSocket 长连接相关配置 - async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1): + async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1) -> None: await self.convert_msg(event) - def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1): + def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1) -> None: asyncio.create_task(on_msg_event_recv(event)) self.event_handler = ( @@ -126,7 +126,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: if self.lark_api.im is None: logger.error("[Lark] API Client im 模块未初始化,无法发送消息") return @@ -175,7 +175,7 @@ def meta(self) -> PlatformMetadata: support_streaming_message=False, ) - async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): + async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1) -> None: if event.event is None: logger.debug("[Lark] 收到空事件(event.event is None)") return @@ -331,7 +331,7 @@ async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): logger.debug(abm) await self.handle_msg(abm) - async def handle_msg(self, abm: AstrBotMessage): + async def handle_msg(self, abm: AstrBotMessage) -> None: event = LarkMessageEvent( message_str=abm.message_str, message_obj=abm, @@ -364,7 +364,7 @@ async def handle_webhook_event(self, event_data: dict): except Exception as e: logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True) - async def run(self): + async def run(self) -> None: if self.connection_mode == "webhook": # Webhook 模式 if self.webhook_server is None: @@ -387,7 +387,7 @@ async def webhook_callback(self, request: Any) -> Any: return await self.webhook_server.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: if self.connection_mode == "socket": await self.client._disconnect() logger.info("飞书(Lark) 适配器已关闭") diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index 7b7d20b38..b6c7d5258 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -31,7 +31,7 @@ def __init__( platform_meta, session_id, bot: lark.Client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot @@ -110,7 +110,7 @@ async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> l ret.append(_stage) return ret - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: res = await LarkMessageEvent._convert_to_lark(message, self.bot) wrapped = { "zh_cn": { @@ -144,7 +144,7 @@ async def send(self, message: MessageChain): await super().send(message) - async def react(self, emoji: str): + async def react(self, emoji: str) -> None: if self.bot.im is None: logger.error("[Lark] API Client im 模块未初始化,无法发送表情") return diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index 7f3db3062..23418a91e 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -123,7 +123,7 @@ def meta(self) -> PlatformMetadata: support_streaming_message=False, ) - async def run(self): + async def run(self) -> None: if not self.instance_url or not self.access_token: logger.error("[Misskey] 配置不完整,无法启动") return @@ -152,7 +152,7 @@ async def run(self): await self._start_websocket_connection() - def _register_event_handlers(self, streaming): + def _register_event_handlers(self, streaming) -> None: """注册事件处理器""" streaming.add_message_handler("notification", self._handle_notification) streaming.add_message_handler("main:notification", self._handle_notification) @@ -196,7 +196,7 @@ def _process_poll_data( message: AstrBotMessage, poll: dict[str, Any], message_parts: list[str], - ): + ) -> None: """处理投票数据,将其添加到消息中""" try: if not isinstance(message.raw_message, dict): @@ -235,7 +235,7 @@ def _extract_additional_fields(self, session, message_chain) -> dict[str, Any]: return fields - async def _start_websocket_connection(self): + async def _start_websocket_connection(self) -> None: backoff_delay = 1.0 max_backoff = 300.0 backoff_multiplier = 1.5 @@ -283,7 +283,7 @@ async def _start_websocket_connection(self): await asyncio.sleep(sleep_time) backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff) - async def _handle_notification(self, data: dict[str, Any]): + async def _handle_notification(self, data: dict[str, Any]) -> None: try: notification_type = data.get("type") logger.debug( @@ -307,7 +307,7 @@ async def _handle_notification(self, data: dict[str, Any]): except Exception as e: logger.error(f"[Misskey] 处理通知失败: {e}") - async def _handle_chat_message(self, data: dict[str, Any]): + async def _handle_chat_message(self, data: dict[str, Any]) -> None: try: sender_id = str( data.get("fromUserId", "") or data.get("fromUser", {}).get("id", ""), @@ -342,7 +342,7 @@ async def _handle_chat_message(self, data: dict[str, Any]): except Exception as e: logger.error(f"[Misskey] 处理聊天消息失败: {e}") - async def _debug_handler(self, data: dict[str, Any]): + async def _debug_handler(self, data: dict[str, Any]) -> None: event_type = data.get("type", "unknown") logger.debug( f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}", @@ -759,7 +759,7 @@ async def convert_room_message(self, raw_data: dict[str, Any]) -> AstrBotMessage ) return message - async def terminate(self): + async def terminate(self) -> None: self._running = False if self.api: await self.api.close() diff --git a/astrbot/core/platform/sources/misskey/misskey_api.py b/astrbot/core/platform/sources/misskey/misskey_api.py index 06dc6304d..86636b12c 100644 --- a/astrbot/core/platform/sources/misskey/misskey_api.py +++ b/astrbot/core/platform/sources/misskey/misskey_api.py @@ -3,7 +3,7 @@ import random import uuid from collections.abc import Awaitable, Callable -from typing import Any +from typing import Any, NoReturn try: import aiohttp @@ -43,7 +43,7 @@ class WebSocketError(APIError): class StreamingClient: - def __init__(self, instance_url: str, access_token: str): + def __init__(self, instance_url: str, access_token: str) -> None: self.instance_url = instance_url.rstrip("/") self.access_token = access_token self.websocket: Any | None = None @@ -90,7 +90,7 @@ async def connect(self) -> bool: self.is_connected = False return False - async def disconnect(self): + async def disconnect(self) -> None: self._running = False if self.websocket: await self.websocket.close() @@ -116,7 +116,7 @@ async def subscribe_channel( self.channels[channel_id] = channel_type return channel_id - async def unsubscribe_channel(self, channel_id: str): + async def unsubscribe_channel(self, channel_id: str) -> None: if ( not self.is_connected or not self.websocket @@ -136,10 +136,10 @@ def add_message_handler( self, event_type: str, handler: Callable[[dict], Awaitable[None]], - ): + ) -> None: self.message_handlers[event_type] = handler - async def listen(self): + async def listen(self) -> None: if not self.is_connected or not self.websocket: raise WebSocketError("WebSocket 未连接") @@ -187,7 +187,7 @@ async def listen(self): except Exception: pass - async def _handle_message(self, data: dict[str, Any]): + async def _handle_message(self, data: dict[str, Any]) -> None: message_type = data.get("type") body = data.get("body", {}) @@ -334,7 +334,7 @@ def __init__( download_timeout: int = 15, chunk_size: int = 64 * 1024, max_download_bytes: int | None = None, - ): + ) -> None: self.instance_url = instance_url.rstrip("/") self.access_token = access_token self._session: aiohttp.ClientSession | None = None @@ -375,7 +375,7 @@ def session(self) -> aiohttp.ClientSession: self._session = aiohttp.ClientSession(headers=headers) return self._session - def _handle_response_status(self, status: int, endpoint: str): + def _handle_response_status(self, status: int, endpoint: str) -> NoReturn: """处理 HTTP 响应状态码""" if status == 400: logger.error(f"[Misskey API] 请求参数错误: {endpoint} (HTTP {status})") diff --git a/astrbot/core/platform/sources/misskey/misskey_event.py b/astrbot/core/platform/sources/misskey/misskey_event.py index 7975f0ec7..068f7e7a2 100644 --- a/astrbot/core/platform/sources/misskey/misskey_event.py +++ b/astrbot/core/platform/sources/misskey/misskey_event.py @@ -26,7 +26,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -40,7 +40,7 @@ def _is_system_command(self, message_str: str) -> bool: return any(message_trimmed.startswith(prefix) for prefix in system_prefixes) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: """发送消息,使用适配器的完整上传和发送逻辑""" try: logger.debug( diff --git a/astrbot/core/platform/sources/misskey/misskey_utils.py b/astrbot/core/platform/sources/misskey/misskey_utils.py index 290acd64e..e0ed85d19 100644 --- a/astrbot/core/platform/sources/misskey/misskey_utils.py +++ b/astrbot/core/platform/sources/misskey/misskey_utils.py @@ -406,7 +406,7 @@ def cache_user_info( raw_data: dict[str, Any], client_self_id: str, is_chat: bool = False, -): +) -> None: """缓存用户信息""" if is_chat: user_cache_data = { @@ -432,7 +432,7 @@ def cache_room_info( user_cache: dict[str, Any], raw_data: dict[str, Any], client_self_id: str, -): +) -> None: """缓存房间信息""" room_data = raw_data.get("toRoom") room_id = raw_data.get("toRoomId") diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index d693c4206..2da3cb4b7 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -32,12 +32,12 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, bot: Client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot self.send_buffer = None - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: self.send_buffer = message await self._post_send() diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 2a1bcda47..f6654b6d6 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -4,7 +4,7 @@ import logging import os import time -from typing import cast +from typing import NoReturn, cast import botpy import botpy.message @@ -35,11 +35,13 @@ # QQ 机器人官方框架 class botClient(Client): - def set_platform(self, platform: QQOfficialPlatformAdapter): + def set_platform(self, platform: QQOfficialPlatformAdapter) -> None: self.platform = platform # 收到群消息 - async def on_group_at_message_create(self, message: botpy.message.GroupMessage): + async def on_group_at_message_create( + self, message: botpy.message.GroupMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -52,7 +54,7 @@ async def on_group_at_message_create(self, message: botpy.message.GroupMessage): self._commit(abm) # 收到频道消息 - async def on_at_message_create(self, message: botpy.message.Message): + async def on_at_message_create(self, message: botpy.message.Message) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -63,7 +65,9 @@ async def on_at_message_create(self, message: botpy.message.Message): self._commit(abm) # 收到私聊消息 - async def on_direct_message_create(self, message: botpy.message.DirectMessage): + async def on_direct_message_create( + self, message: botpy.message.DirectMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -72,7 +76,7 @@ async def on_direct_message_create(self, message: botpy.message.DirectMessage): self._commit(abm) # 收到 C2C 消息 - async def on_c2c_message_create(self, message: botpy.message.C2CMessage): + async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -80,7 +84,7 @@ async def on_c2c_message_create(self, message: botpy.message.C2CMessage): abm.session_id = abm.sender.user_id self._commit(abm) - def _commit(self, abm: AstrBotMessage): + def _commit(self, abm: AstrBotMessage) -> None: self.platform.commit_event( QQOfficialMessageEvent( abm.message_str, @@ -133,7 +137,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> NoReturn: raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") def meta(self) -> PlatformMetadata: @@ -226,6 +230,6 @@ def run(self): def get_client(self) -> botClient: return self.client - async def terminate(self): + async def terminate(self) -> None: await self.client.close() logger.info("QQ 官方机器人接口 适配器已被优雅地关闭") diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index 63b6726fe..ef0c562e6 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Any, cast +from typing import Any, NoReturn, cast import botpy import botpy.message @@ -26,11 +26,13 @@ # QQ 机器人官方框架 class botClient(Client): - def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter"): + def set_platform(self, platform: "QQOfficialWebhookPlatformAdapter") -> None: self.platform = platform # 收到群消息 - async def on_group_at_message_create(self, message: botpy.message.GroupMessage): + async def on_group_at_message_create( + self, message: botpy.message.GroupMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -43,7 +45,7 @@ async def on_group_at_message_create(self, message: botpy.message.GroupMessage): self._commit(abm) # 收到频道消息 - async def on_at_message_create(self, message: botpy.message.Message): + async def on_at_message_create(self, message: botpy.message.Message) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.GROUP_MESSAGE, @@ -54,7 +56,9 @@ async def on_at_message_create(self, message: botpy.message.Message): self._commit(abm) # 收到私聊消息 - async def on_direct_message_create(self, message: botpy.message.DirectMessage): + async def on_direct_message_create( + self, message: botpy.message.DirectMessage + ) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -63,7 +67,7 @@ async def on_direct_message_create(self, message: botpy.message.DirectMessage): self._commit(abm) # 收到 C2C 消息 - async def on_c2c_message_create(self, message: botpy.message.C2CMessage): + async def on_c2c_message_create(self, message: botpy.message.C2CMessage) -> None: abm = QQOfficialPlatformAdapter._parse_from_qqofficial( message, MessageType.FRIEND_MESSAGE, @@ -71,7 +75,7 @@ async def on_c2c_message_create(self, message: botpy.message.C2CMessage): abm.session_id = abm.sender.user_id self._commit(abm) - def _commit(self, abm: AstrBotMessage): + def _commit(self, abm: AstrBotMessage) -> None: self.platform.commit_event( QQOfficialWebhookMessageEvent( abm.message_str, @@ -115,7 +119,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> NoReturn: raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session") def meta(self) -> PlatformMetadata: @@ -125,7 +129,7 @@ def meta(self) -> PlatformMetadata: id=cast(str, self.config.get("id")), ) - async def run(self): + async def run(self) -> None: self.webhook_helper = QQOfficialWebhook( self.config, self._event_queue, @@ -153,7 +157,7 @@ async def webhook_callback(self, request: Any) -> Any: # 复用 webhook_helper 的回调处理逻辑 return await self.webhook_helper.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: if self.webhook_helper: self.webhook_helper.shutdown_event.set() await self.client.close() diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py index 306db5e56..5ceeb2c70 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py @@ -13,5 +13,5 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, bot: Client, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id, bot) diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 2eda11a6c..5f35471ee 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -14,7 +14,9 @@ class QQOfficialWebhook: - def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Client): + def __init__( + self, config: dict, event_queue: asyncio.Queue, botpy_client: Client + ) -> None: self.appid = config["appid"] self.secret = config["secret"] self.port = config.get("port", 6196) @@ -38,7 +40,7 @@ def __init__(self, config: dict, event_queue: asyncio.Queue, botpy_client: Clien self.event_queue = event_queue self.shutdown_event = asyncio.Event() - async def initialize(self): + async def initialize(self) -> None: logger.info("正在登录到 QQ 官方机器人...") self.user = await self.http.login(self.token) logger.info(f"已登录 QQ 官方机器人账号: {self.user}") @@ -46,7 +48,7 @@ async def initialize(self): self.client.api = self.api self.client.http = self.http - async def bot_connect(): + async def bot_connect() -> None: pass self._connection = ConnectionSession( @@ -115,7 +117,7 @@ async def handle_callback(self, request) -> dict: return {"opcode": 12} - async def start_polling(self): + async def start_polling(self) -> None: logger.info( f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。", ) @@ -125,5 +127,5 @@ async def start_polling(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() diff --git a/astrbot/core/platform/sources/satori/satori_adapter.py b/astrbot/core/platform/sources/satori/satori_adapter.py index 46f9a4e0f..a1afbd898 100644 --- a/astrbot/core/platform/sources/satori/satori_adapter.py +++ b/astrbot/core/platform/sources/satori/satori_adapter.py @@ -73,7 +73,7 @@ async def send_by_session( self, session: MessageSession, message_chain: MessageChain, - ): + ) -> None: from .satori_event import SatoriPlatformEvent await SatoriPlatformEvent.send_with_adapter( @@ -99,7 +99,7 @@ def _is_websocket_closed(self, ws) -> bool: except AttributeError: return False - async def run(self): + async def run(self) -> None: self.running = True self.session = ClientSession(timeout=ClientTimeout(total=30)) @@ -133,7 +133,7 @@ async def run(self): if self.session: await self.session.close() - async def connect_websocket(self): + async def connect_websocket(self) -> None: logger.info(f"Satori 适配器正在连接到 WebSocket: {self.endpoint}") logger.info(f"Satori 适配器 HTTP API 地址: {self.api_base_url}") @@ -176,7 +176,7 @@ async def connect_websocket(self): except Exception as e: logger.error(f"Satori WebSocket 关闭异常: {e}") - async def send_identify(self): + async def send_identify(self) -> None: if not self.ws: raise Exception("WebSocket连接未建立") @@ -204,7 +204,7 @@ async def send_identify(self): logger.error(f"发送 IDENTIFY 信令失败: {e}") raise - async def heartbeat_loop(self): + async def heartbeat_loop(self) -> None: try: while self.running and self.ws: await asyncio.sleep(self.heartbeat_interval) @@ -229,7 +229,7 @@ async def heartbeat_loop(self): except Exception as e: logger.error(f"心跳任务异常: {e}") - async def handle_message(self, message: str): + async def handle_message(self, message: str) -> None: try: data = json.loads(message) op = data.get("op") @@ -270,7 +270,7 @@ async def handle_message(self, message: str): except Exception as e: logger.error(f"处理 WebSocket 消息异常: {e}") - async def handle_event(self, event_data: dict): + async def handle_event(self, event_data: dict) -> None: try: event_type = event_data.get("type") sn = event_data.get("sn") @@ -715,7 +715,7 @@ async def _parse_xml_node(self, node: ET.Element, elements: list) -> None: if child.tail and child.tail.strip(): elements.append(Plain(text=child.tail)) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: from .satori_event import SatoriPlatformEvent message_event = SatoriPlatformEvent( @@ -775,7 +775,7 @@ async def send_http_request( logger.error(f"Satori HTTP 请求异常: {e}") return {} - async def terminate(self): + async def terminate(self) -> None: self.running = False if self.heartbeat_task: diff --git a/astrbot/core/platform/sources/satori/satori_event.py b/astrbot/core/platform/sources/satori/satori_event.py index 81a0d222c..021422283 100644 --- a/astrbot/core/platform/sources/satori/satori_event.py +++ b/astrbot/core/platform/sources/satori/satori_event.py @@ -28,7 +28,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, adapter: "SatoriPlatformAdapter", - ): + ) -> None: # 更新平台元数据 if adapter and hasattr(adapter, "logins") and adapter.logins: current_login = adapter.logins[0] @@ -110,7 +110,7 @@ async def send_with_adapter( logger.error(f"Satori 消息发送异常: {e}") return None - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: platform = getattr(self, "platform", None) user_id = getattr(self, "user_id", None) diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py index fbdc71759..efd7a6f3d 100644 --- a/astrbot/core/platform/sources/slack/client.py +++ b/astrbot/core/platform/sources/slack/client.py @@ -27,7 +27,7 @@ def __init__( port: int = 3000, path: str = "/slack/events", event_handler: Callable | None = None, - ): + ) -> None: self.web_client = web_client self.signing_secret = signing_secret self.host = host @@ -44,7 +44,7 @@ def __init__( self.shutdown_event = asyncio.Event() - def _setup_routes(self): + def _setup_routes(self) -> None: """设置路由""" @self.app.route(self.path, methods=["POST"]) @@ -105,7 +105,7 @@ async def handle_callback(self, req): logger.error(f"处理 Slack 事件时出错: {e}") return Response("Internal Server Error", status=500) - async def start(self): + async def start(self) -> None: """启动 Webhook 服务器""" logger.info( f"Slack Webhook 服务器启动中,监听 {self.host}:{self.port}{self.path}...", @@ -118,10 +118,10 @@ async def start(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() - async def stop(self): + async def stop(self) -> None: """停止 Webhook 服务器""" self.shutdown_event.set() logger.info("Slack Webhook 服务器已停止") @@ -135,7 +135,7 @@ def __init__( web_client: AsyncWebClient, app_token: str, event_handler: Callable | None = None, - ): + ) -> None: self.web_client = web_client self.app_token = app_token self.event_handler = event_handler @@ -143,7 +143,7 @@ def __init__( async def _handle_events( self, _: AsyncBaseSocketModeClient, req: SocketModeRequest - ): + ) -> None: """处理 Socket Mode 事件""" try: if self.socket_client is None: @@ -160,7 +160,7 @@ async def _handle_events( except Exception as e: logger.error(f"处理 Socket Mode 事件时出错: {e}") - async def start(self): + async def start(self) -> None: """启动 Socket Mode 连接""" self.socket_client = SocketModeClient( app_token=self.app_token, @@ -174,7 +174,7 @@ async def start(self): logger.info("Slack Socket Mode 客户端启动中...") await self.socket_client.connect() - async def stop(self): + async def stop(self) -> None: """停止 Socket Mode 连接""" if self.socket_client: await self.socket_client.disconnect() diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index ed838b0a9..16a80f2e8 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -82,7 +82,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: blocks, text = await SlackMessageEvent._parse_slack_blocks( message_chain=message_chain, web_client=self.web_client, @@ -288,7 +288,7 @@ def _parse_blocks(self, blocks: list) -> list: return message_components - async def _handle_socket_event(self, req: SocketModeRequest): + async def _handle_socket_event(self, req: SocketModeRequest) -> None: """处理 Socket Mode 事件""" if req.type == "events_api": # 事件 API @@ -377,7 +377,7 @@ async def run(self) -> None: f"不支持的连接模式: {self.connection_mode},请使用 'socket' 或 'webhook'", ) - async def _handle_webhook_event(self, event_data: dict): + async def _handle_webhook_event(self, event_data: dict) -> None: """处理 Webhook 事件""" event = event_data.get("event", {}) @@ -404,7 +404,7 @@ async def webhook_callback(self, request: Any) -> Any: return await self.webhook_client.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: if self.socket_client: await self.socket_client.stop() if self.webhook_client: @@ -414,7 +414,7 @@ async def terminate(self): def meta(self) -> PlatformMetadata: return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = SlackMessageEvent( message_str=message.message_str, message_obj=message, diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py index 822e6fdeb..3f62690b5 100644 --- a/astrbot/core/platform/sources/slack/slack_event.py +++ b/astrbot/core/platform/sources/slack/slack_event.py @@ -24,7 +24,7 @@ def __init__( platform_meta, session_id, web_client: AsyncWebClient, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.web_client = web_client @@ -126,7 +126,7 @@ async def _parse_slack_blocks( return blocks, "" if blocks else text_content - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: blocks, text = await SlackMessageEvent._parse_slack_blocks( message, self.web_client, diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 218d13bdc..2254216b0 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -94,7 +94,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: from_username = session.session_id await TelegramPlatformEvent.send_with_client( self.client, @@ -109,7 +109,7 @@ def meta(self) -> PlatformMetadata: return PlatformMetadata(name="telegram", description="telegram 适配器", id=id_) @override - async def run(self): + async def run(self) -> None: await self.application.initialize() await self.application.start() @@ -134,7 +134,7 @@ async def run(self): logger.info("Telegram Platform Adapter is running.") await queue - async def register_commands(self): + async def register_commands(self) -> None: """收集所有注册的指令并注册到 Telegram""" try: commands = self.collect_commands() @@ -210,7 +210,7 @@ def _extract_command_info( description = description[:30] + "..." return cmd_name, description - async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: if not update.effective_chat: logger.warning( "Received a start command without an effective chat, skipping /start reply.", @@ -221,7 +221,9 @@ async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE): text=self.config["start_message"], ) - async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + async def message_handler( + self, update: Update, context: ContextTypes.DEFAULT_TYPE + ) -> None: logger.debug(f"Telegram message: {update.message}") abm = await self.convert_message(update, context) if abm: @@ -397,7 +399,7 @@ async def convert_message( return message - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = TelegramPlatformEvent( message_str=message.message_str, message_obj=message, @@ -410,7 +412,7 @@ async def handle_msg(self, message: AstrBotMessage): def get_client(self) -> ExtBot: return self.client - async def terminate(self): + async def terminate(self) -> None: try: if self.scheduler.running: self.scheduler.shutdown() diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 37f60e65a..80f0cd6af 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -38,7 +38,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: ExtBot, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -73,7 +73,7 @@ async def send_with_client( client: ExtBot, message: MessageChain, user_name: str, - ): + ) -> None: image_path = None has_reply = False @@ -134,14 +134,14 @@ async def send_with_client( path = await i.convert_to_file_path() await client.send_voice(voice=path, **cast(Any, payload)) - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: if self.get_message_type() == MessageType.GROUP_MESSAGE: await self.send_with_client(self.client, message, self.message_obj.group_id) else: await self.send_with_client(self.client, message, self.get_sender_id()) await super().send(message) - async def react(self, emoji: str | None, big: bool = False): + async def react(self, emoji: str | None, big: bool = False) -> None: """给原消息添加 Telegram 反应: - 普通 emoji:传入 '👍'、'😂' 等 - 自定义表情:传入其 custom_emoji_id(纯数字字符串) diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 084d7860d..47dbb22ec 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -31,7 +31,7 @@ def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> No self.callback = callback self.running_tasks = set() - async def listen_to_queue(self, conversation_id: str): + async def listen_to_queue(self, conversation_id: str) -> None: """Listen to a specific conversation queue""" queue = self.webchat_queue_mgr.get_or_create_queue(conversation_id) while True: @@ -44,7 +44,7 @@ async def listen_to_queue(self, conversation_id: str): ) break - async def run(self): + async def run(self) -> None: """Monitor for new conversation queues and start listeners""" monitored_conversations = set() @@ -93,7 +93,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: await WebChatMessageEvent._send(message_chain, session.session_id) await super().send_by_session(session, message_chain) @@ -208,7 +208,7 @@ async def convert_message(self, data: tuple) -> AstrBotMessage: return abm def run(self) -> Coroutine[Any, Any, None]: - async def callback(data: tuple): + async def callback(data: tuple) -> None: abm = await self.convert_message(data) await self.handle_msg(abm) @@ -218,7 +218,7 @@ async def callback(data: tuple): def meta(self) -> PlatformMetadata: return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = WebChatMessageEvent( message_str=message.message_str, message_obj=message, @@ -235,6 +235,6 @@ async def handle_msg(self, message: AstrBotMessage): self.commit_event(message_event) - async def terminate(self): + async def terminate(self) -> None: # Do nothing pass diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 9f1a6d059..79c060022 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -14,7 +14,7 @@ class WebChatMessageEvent(AstrMessageEvent): - def __init__(self, message_str, message_obj, platform_meta, session_id): + def __init__(self, message_str, message_obj, platform_meta, session_id) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) os.makedirs(imgs_dir, exist_ok=True) @@ -101,11 +101,11 @@ async def _send( return data - async def send(self, message: MessageChain | None): + async def send(self, message: MessageChain | None) -> None: await WebChatMessageEvent._send(message, session_id=self.session_id) await super().send(MessageChain([])) - async def send_streaming(self, generator, use_fallback: bool = False): + async def send_streaming(self, generator, use_fallback: bool = False) -> None: final_data = "" reasoning_content = "" cid = self.session_id.split("!")[-1] diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py index 6c365cb3a..4824e2de9 100644 --- a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py @@ -20,7 +20,7 @@ def get_or_create_back_queue(self, conversation_id: str) -> asyncio.Queue: self.back_queues[conversation_id] = asyncio.Queue() return self.back_queues[conversation_id] - def remove_queues(self, conversation_id: str): + def remove_queues(self, conversation_id: str) -> None: """Remove queues for the given conversation ID""" if conversation_id in self.queues: del self.queues[conversation_id] diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py index 4c9a9d36b..d2012d9a7 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py @@ -150,7 +150,7 @@ def load_credentials(self): logger.error(f"加载 WeChatPadPro 凭据失败: {e}") return None - def save_credentials(self): + def save_credentials(self) -> None: """将 auth_key 和 wxid 保存到文件。""" credentials = { "auth_key": self.auth_key, @@ -165,7 +165,7 @@ def save_credentials(self): except Exception as e: logger.error(f"保存 WeChatPadPro 凭据失败: {e}") - async def check_online_status(self): + async def check_online_status(self) -> bool | None: """检查 WeChatPadPro 设备是否在线。""" if not self.auth_key: return False @@ -219,7 +219,7 @@ def _extract_auth_key(self, data): return data[0] return None - async def generate_auth_key(self): + async def generate_auth_key(self) -> None: """生成授权码。""" url = f"{self.base_url}/admin/GenAuthKey1" params = {"key": self.admin_key} @@ -292,7 +292,7 @@ async def get_login_qr_code(self): logger.error(f"获取登录二维码时发生错误: {e}") return None - async def check_login_status(self): + async def check_login_status(self) -> bool: """循环检测扫码状态。 尝试 6 次后跳出循环,添加倒计时。 返回 True 如果登录成功,否则返回 False。 @@ -359,7 +359,7 @@ async def check_login_status(self): logger.warning("登录检测超过最大尝试次数,退出检测。") return False - async def connect_websocket(self): + async def connect_websocket(self) -> None: """建立 WebSocket 连接并处理接收到的消息。""" os.environ["no_proxy"] = f"localhost,127.0.0.1,{self.host}" ws_url = f"ws://{self.host}:{self.port}/ws/GetSyncMsg?key={self.auth_key}" @@ -399,7 +399,7 @@ async def connect_websocket(self): ) await asyncio.sleep(5) - async def handle_websocket_message(self, message: str | bytes): + async def handle_websocket_message(self, message: str | bytes) -> None: """处理从 WebSocket 接收到的消息。""" logger.debug(f"收到 WebSocket 消息: {message}") try: @@ -487,7 +487,7 @@ async def _process_chat_type( to_user_name: str, content: str, push_content: str, - ): + ) -> bool: """判断消息是群聊还是私聊,并设置 AstrBotMessage 的基本属性。""" if from_user_name == "weixin": return False @@ -640,7 +640,7 @@ async def _process_message_content( raw_message: dict, msg_type: int, content: str, - ): + ) -> None: """根据消息类型处理消息内容,填充 AstrBotMessage 的 message 列表。""" if msg_type == 1: # 文本消息 abm.message_str = content @@ -837,7 +837,7 @@ async def _process_message_content( else: logger.warning(f"收到未处理的消息类型: {msg_type}。") - async def terminate(self): + async def terminate(self) -> None: """终止一个平台的运行实例。""" logger.info("终止 WeChatPadPro 适配器。") try: @@ -856,7 +856,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: dummy_message_obj = AstrBotMessage() dummy_message_obj.session_id = session.session_id # 根据 session_id 判断消息类型 diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py index 08ab27013..8ecc4c64f 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py @@ -32,12 +32,12 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, adapter: "WeChatPadProAdapter", # 传递适配器实例 - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.message_obj = message_obj # Save the full message object self.adapter = adapter # Save the adapter instance - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: async with aiohttp.ClientSession() as session: for comp in message.chain: await asyncio.sleep(1) @@ -66,7 +66,7 @@ async def send_streaming( await self.send(buffer) return await super().send_streaming(generator, use_fallback) - async def _send_image(self, session: aiohttp.ClientSession, comp: Image): + async def _send_image(self, session: aiohttp.ClientSession, comp: Image) -> None: b64 = await comp.convert_to_base64() raw = self._validate_base64(b64) b64c = self._compress_image(raw) @@ -78,7 +78,7 @@ async def _send_image(self, session: aiohttp.ClientSession, comp: Image): url = f"{self.adapter.base_url}/message/SendImageNewMessage" await self._post(session, url, payload) - async def _send_text(self, session: aiohttp.ClientSession, text: str): + async def _send_text(self, session: aiohttp.ClientSession, text: str) -> None: if ( self.message_obj.type == MessageType.GROUP_MESSAGE # 确保是群聊消息 and self.adapter.settings.get( @@ -114,7 +114,9 @@ async def _send_text(self, session: aiohttp.ClientSession, text: str): url = f"{self.adapter.base_url}/message/SendTextMessage" await self._post(session, url, payload) - async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji): + async def _send_emoji( + self, session: aiohttp.ClientSession, comp: WechatEmoji + ) -> None: payload = { "EmojiList": [ { @@ -127,7 +129,7 @@ async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji): url = f"{self.adapter.base_url}/message/SendEmojiMessage" await self._post(session, url, payload) - async def _send_voice(self, session: aiohttp.ClientSession, comp: Record): + async def _send_voice(self, session: aiohttp.ClientSession, comp: Record) -> None: record_path = await comp.convert_to_file_path() # 默认已经存在 data/temp 中 b64, duration = await audio_to_tencent_silk_base64(record_path) @@ -157,7 +159,7 @@ def _compress_image(data: bytes) -> str: # logger.info("图片处理完成!!!") return base64.b64encode(buf.getvalue()).decode() - async def _post(self, session, url, payload): + async def _post(self, session, url, payload) -> None: params = {"key": self.adapter.auth_key} try: async with session.post(url, params=params, json=payload) as resp: diff --git a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py index 09924edb6..cf23c6b54 100644 --- a/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +++ b/astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py @@ -20,7 +20,7 @@ def __init__( cached_images=None, raw_message: dict | None = None, downloader=None, - ): + ) -> None: self._xml = None self.content = content self.is_private_chat = is_private_chat diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 44ed75117..369b389e7 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -39,7 +39,7 @@ class WecomServer: - def __init__(self, event_queue: asyncio.Queue, config: dict): + def __init__(self, event_queue: asyncio.Queue, config: dict) -> None: self.server = quart.Quart(__name__) self.port = int(cast(str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") @@ -123,7 +123,7 @@ async def handle_callback(self, request) -> str: return "success" - async def start_polling(self): + async def start_polling(self) -> None: logger.info( f"将在 {self.callback_server_host}:{self.port} 端口启动 企业微信 适配器。", ) @@ -133,7 +133,7 @@ async def start_polling(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() @@ -182,7 +182,7 @@ def __init__( self.client.__setattr__("API_BASE_URL", self.api_base_url) - async def callback(msg: BaseMessage): + async def callback(msg: BaseMessage) -> None: if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event": def get_latest_msg_item() -> dict | None: @@ -214,7 +214,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: await super().send_by_session(session, message_chain) @override @@ -227,7 +227,7 @@ def meta(self) -> PlatformMetadata: ) @override - async def run(self): + async def run(self) -> None: loop = asyncio.get_event_loop() if self.kf_name: try: @@ -403,7 +403,7 @@ async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: return await self.handle_msg(abm) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = WecomPlatformEvent( message_str=message.message_str, message_obj=message, @@ -416,7 +416,7 @@ async def handle_msg(self, message: AstrBotMessage): def get_client(self) -> WeChatClient: return self.client - async def terminate(self): + async def terminate(self) -> None: self.server.shutdown_event.set() try: await self.server.server.shutdown() diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 0b5dae272..865a14234 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -28,7 +28,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: WeChatClient, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -37,7 +37,7 @@ async def send_with_client( client: WeChatClient, message: MessageChain, user_name: str, - ): + ) -> None: pass async def split_plain(self, plain: str) -> list[str]: @@ -86,7 +86,7 @@ async def split_plain(self, plain: str) -> list[str]: return result - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: message_obj = self.message_obj is_wechat_kf = hasattr(self.client, "kf_message") diff --git a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py index 2df09a763..260b950d1 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py @@ -14,6 +14,7 @@ import socket import struct import time +from typing import NoReturn from Crypto.Cipher import AES @@ -30,7 +31,7 @@ class FormatException(Exception): pass -def throw_exception(message, exception_class=FormatException): +def throw_exception(message, exception_class=FormatException) -> NoReturn: """My define raise exception function""" raise exception_class(message) @@ -145,7 +146,7 @@ class Prpcrypt: MIN_RANDOM_VALUE = 1000000000000000 # 最小值: 1000000000000000 (16位) RANDOM_RANGE = 9000000000000000 # 范围大小: 确保最大值为 9999999999999999 (16位) - def __init__(self, key): + def __init__(self, key) -> None: # self.key = base64.b64decode(key+"=") self.key = key # 设置加解密模式为AES的CBC模式 @@ -220,7 +221,7 @@ def get_random_str(self): class WXBizJsonMsgCrypt: # 构造函数 - def __init__(self, sToken, sEncodingAESKey, sReceiveId): + def __init__(self, sToken, sEncodingAESKey, sReceiveId) -> None: try: self.key = base64.b64decode(sEncodingAESKey + "=") assert len(self.key) == 32 diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 70581e7ea..08a07ed1b 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -53,7 +53,7 @@ def __init__( self.callback = callback self.running_tasks = set() - async def listen_to_queue(self, session_id: str): + async def listen_to_queue(self, session_id: str) -> None: """监听特定会话的队列""" queue = self.queue_mgr.get_or_create_queue(session_id) while True: @@ -64,7 +64,7 @@ async def listen_to_queue(self, session_id: str): logger.error(f"处理会话 {session_id} 消息时发生错误: {e}") break - async def run(self): + async def run(self) -> None: """监控新会话队列并启动监听器""" monitored_sessions = set() @@ -153,7 +153,7 @@ def __init__( self._handle_queued_message, ) - async def _handle_queued_message(self, data: dict): + async def _handle_queued_message(self, data: dict) -> None: """处理队列中的消息,类似webchat的callback""" try: abm = await self.convert_message(data) @@ -313,7 +313,7 @@ async def _enqueue_message( callback_params: dict[str, str], stream_id: str, session_id: str, - ): + ) -> None: """将消息放入队列进行异步处理""" input_queue = self.queue_mgr.get_or_create_queue(stream_id) _ = self.queue_mgr.get_or_create_back_queue(stream_id) @@ -417,7 +417,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: """通过会话发送消息""" # 企业微信智能机器人主要通过回调响应,这里记录日志 logger.info("会话发送消息: %s -> %s", session.session_id, message_chain) @@ -426,7 +426,7 @@ async def send_by_session( def run(self) -> Awaitable[Any]: """运行适配器,同时启动HTTP服务器和队列监听器""" - async def run_both(): + async def run_both() -> None: # 如果启用统一 webhook 模式,则不启动独立服务器 webhook_uuid = self.config.get("webhook_uuid") if self.unified_webhook_mode and webhook_uuid: @@ -453,7 +453,7 @@ async def webhook_callback(self, request: Any) -> Any: else: return await self.server.handle_callback(request) - async def terminate(self): + async def terminate(self) -> None: """终止适配器""" logger.info("企业微信智能机器人适配器正在关闭...") self.shutdown_event.set() @@ -463,7 +463,7 @@ def meta(self) -> PlatformMetadata: """获取平台元数据""" return self.metadata - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: """处理消息,创建消息事件并提交到事件队列""" try: message_event = WecomAIBotMessageEvent( diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py index 6c448a97e..97831fbb2 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py @@ -19,7 +19,7 @@ class WecomAIBotAPIClient: """企业微信智能机器人 API 客户端""" - def __init__(self, token: str, encoding_aes_key: str): + def __init__(self, token: str, encoding_aes_key: str) -> None: """初始化 API 客户端 Args: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py index fd11d7ceb..90a9e363b 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -22,7 +22,7 @@ def __init__( session_id: str, api_client: WecomAIBotAPIClient, queue_mgr: WecomAIQueueMgr, - ): + ) -> None: """初始化消息事件 Args: @@ -90,7 +90,7 @@ async def _send( return data - async def send(self, message: MessageChain | None): + async def send(self, message: MessageChain | None) -> None: """发送消息""" raw = self.message_obj.raw_message assert isinstance(raw, dict), ( @@ -100,7 +100,7 @@ async def send(self, message: MessageChain | None): await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr) await super().send(MessageChain([])) - async def send_streaming(self, generator, use_fallback=False): + async def send_streaming(self, generator, use_fallback=False) -> None: """流式发送消息,参考webchat的send_streaming设计""" final_data = "" raw = self.message_obj.raw_message diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py index 3a982bdf7..db6aa408e 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py @@ -52,7 +52,7 @@ def get_or_create_back_queue(self, session_id: str) -> asyncio.Queue: logger.debug(f"[WecomAI] 创建输出队列: {session_id}") return self.back_queues[session_id] - def remove_queues(self, session_id: str): + def remove_queues(self, session_id: str) -> None: """移除指定会话的所有队列 Args: @@ -95,7 +95,9 @@ def has_back_queue(self, session_id: str) -> bool: """ return session_id in self.back_queues - def set_pending_response(self, session_id: str, callback_params: dict[str, str]): + def set_pending_response( + self, session_id: str, callback_params: dict[str, str] + ) -> None: """设置待处理的响应参数 Args: @@ -121,7 +123,7 @@ def get_pending_response(self, session_id: str) -> dict[str, Any] | None: """ return self.pending_responses.get(session_id) - def cleanup_expired_responses(self, max_age_seconds: int = 300): + def cleanup_expired_responses(self, max_age_seconds: int = 300) -> None: """清理过期的待处理响应 Args: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py index 5cbdd1130..80ec5179e 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -23,7 +23,7 @@ def __init__( port: int, api_client: WecomAIBotAPIClient, message_handler: Callable[[dict[str, Any], dict[str, str]], Any] | None = None, - ): + ) -> None: """初始化服务器 Args: @@ -43,7 +43,7 @@ def __init__( self.shutdown_event = asyncio.Event() - def _setup_routes(self): + def _setup_routes(self) -> None: """设置 Quart 路由""" # 使用 Quart 的 add_url_rule 方法添加路由 self.app.add_url_rule( @@ -162,7 +162,7 @@ async def handle_callback(self, request): logger.error("处理消息时发生异常: %s", e) return "内部服务器错误", 500 - async def start_server(self): + async def start_server(self) -> None: """启动服务器""" logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port) @@ -176,11 +176,11 @@ async def start_server(self): logger.error("服务器运行异常: %s", e) raise - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: """关闭触发器""" await self.shutdown_event.wait() - async def shutdown(self): + async def shutdown(self) -> None: """关闭服务器""" logger.info("企业微信智能机器人服务器正在关闭...") self.shutdown_event.set() diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index d12285d68..0e4ecff85 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -35,7 +35,7 @@ class WeixinOfficialAccountServer: - def __init__(self, event_queue: asyncio.Queue, config: dict): + def __init__(self, event_queue: asyncio.Queue, config: dict) -> None: self.server = quart.Quart(__name__) self.port = int(cast(int | str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") @@ -129,7 +129,7 @@ async def handle_callback(self, request) -> str: return "success" - async def start_polling(self): + async def start_polling(self) -> None: logger.info( f"将在 {self.callback_server_host}:{self.port} 端口启动 微信公众平台 适配器。", ) @@ -139,7 +139,7 @@ async def start_polling(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() @@ -218,7 +218,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: await super().send_by_session(session, message_chain) @override @@ -231,7 +231,7 @@ def meta(self) -> PlatformMetadata: ) @override - async def run(self): + async def run(self) -> None: # 如果启用统一 webhook 模式,则不启动独立服务器 webhook_uuid = self.config.get("webhook_uuid") if self.unified_webhook_mode and webhook_uuid: @@ -330,7 +330,7 @@ async def convert_message( logger.info(f"abm: {abm}") await self.handle_msg(abm) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage) -> None: message_event = WeixinOfficialAccountPlatformEvent( message_str=message.message_str, message_obj=message, @@ -343,7 +343,7 @@ async def handle_msg(self, message: AstrBotMessage): def get_client(self) -> WeChatClient: return self.client - async def terminate(self): + async def terminate(self) -> None: self.server.shutdown_event.set() try: await self.server.server.shutdown() diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py index c1f137a41..995b16690 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -26,7 +26,7 @@ def __init__( platform_meta: PlatformMetadata, session_id: str, client: WeChatClient, - ): + ) -> None: super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client @@ -35,7 +35,7 @@ async def send_with_client( client: WeChatClient, message: MessageChain, user_name: str, - ): + ) -> None: pass async def split_plain(self, plain: str) -> list[str]: @@ -84,7 +84,7 @@ async def split_plain(self, plain: str) -> list[str]: return result - async def send(self, message: MessageChain): + async def send(self, message: MessageChain) -> None: message_obj = self.message_obj active_send_mode = cast(dict, message_obj.raw_message).get( "active_send_mode", False diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py index d6d524698..ad8bb44f6 100644 --- a/astrbot/core/platform_message_history_mgr.py +++ b/astrbot/core/platform_message_history_mgr.py @@ -3,7 +3,7 @@ class PlatformMessageHistoryManager: - def __init__(self, db_helper: BaseDatabase): + def __init__(self, db_helper: BaseDatabase) -> None: self.db = db_helper async def insert( @@ -40,7 +40,9 @@ async def get( history.reverse() return history - async def delete(self, platform_id: str, user_id: str, offset_sec: int = 86400): + async def delete( + self, platform_id: str, user_id: str, offset_sec: int = 86400 + ) -> None: """Delete platform message history records older than the specified offset.""" await self.db.delete_platform_message_offset( platform_id=platform_id, diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index dc188f141..9d0adeb82 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -106,7 +106,7 @@ class ProviderRequest: model: str | None = None """模型名称,为 None 时使用提供商的默认模型""" - def __repr__(self): + def __repr__(self) -> str: return ( f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, " f"image_count={len(self.image_urls or [])}, " @@ -116,10 +116,10 @@ def __repr__(self): f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, " ) - def __str__(self): + def __str__(self) -> str: return self.__repr__() - def append_tool_calls_result(self, tool_calls_result: ToolCallsResult): + def append_tool_calls_result(self, tool_calls_result: ToolCallsResult) -> None: """添加工具调用结果到请求中""" if not self.tool_calls_result: self.tool_calls_result = [] @@ -241,7 +241,7 @@ def __init__( | AnthropicMessage | None = None, is_chunk: bool = False, - ): + ) -> None: """初始化 LLMResponse Args: @@ -279,7 +279,7 @@ def completion_text(self): return self._completion_text @completion_text.setter - def completion_text(self, value): + def completion_text(self, value) -> None: if self.result_chain: self.result_chain.chain = [ comp diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 7aad86bdd..106b42cc5 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -500,7 +500,7 @@ def load_mcp_config(self): logger.error(f"加载 MCP 配置失败: {e}") return DEFAULT_MCP_CONFIG - def save_mcp_config(self, config: dict): + def save_mcp_config(self, config: dict) -> bool: try: with open(self.mcp_config_path, "w", encoding="utf-8") as f: json.dump(config, f, ensure_ascii=False, indent=4) @@ -575,10 +575,10 @@ async def sync_modelscope_mcp_servers(self, access_token: str) -> None: except Exception as e: raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {e!s}") - def __str__(self): + def __str__(self) -> str: return str(self.func_list) - def __repr__(self): + def __repr__(self) -> str: return str(self.func_list) diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index be8edc282..79b3fa519 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -30,7 +30,7 @@ def __init__( acm: AstrBotConfigManager, db_helper: BaseDatabase, persona_mgr: PersonaManager, - ): + ) -> None: self.reload_lock = asyncio.Lock() self.persona_mgr = persona_mgr self.acm = acm @@ -88,7 +88,7 @@ async def set_provider( provider_id: str, provider_type: ProviderType, umo: str | None = None, - ): + ) -> None: """设置提供商。 Args: @@ -187,7 +187,7 @@ def get_using_provider( raise ValueError(f"Unknown provider type: {provider_type}") return provider - async def initialize(self): + async def initialize(self) -> None: # 逐个初始化提供商 for provider_config in self.providers_config: try: @@ -251,7 +251,7 @@ async def initialize(self): # 初始化 MCP Client 连接 asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients") - async def load_provider(self, provider_config: dict): + async def load_provider(self, provider_config: dict) -> None: if not provider_config["enable"]: logger.info(f"Provider {provider_config['id']} is disabled, skipping") return @@ -491,7 +491,7 @@ async def load_provider(self, provider_config: dict): f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}", ) - async def reload(self, provider_config: dict): + async def reload(self, provider_config: dict) -> None: async with self.reload_lock: await self.terminate_provider(provider_config["id"]) if provider_config["enable"]: @@ -536,7 +536,7 @@ async def reload(self, provider_config: dict): def get_insts(self): return self.provider_insts - async def terminate_provider(self, provider_id: str): + async def terminate_provider(self, provider_id: str) -> None: if provider_id in self.inst_map: logger.info( f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...", @@ -570,7 +570,7 @@ async def terminate_provider(self, provider_id: str): ) del self.inst_map[provider_id] - async def terminate(self): + async def terminate(self) -> None: for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): await provider_inst.terminate() # type: ignore diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 7f21a2ee1..2979a8609 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -2,7 +2,7 @@ import asyncio import os from collections.abc import AsyncGenerator -from typing import TypeAlias, Union +from typing import NoReturn, TypeAlias, Union from astrbot.core.agent.message import Message from astrbot.core.agent.tool import ToolSet @@ -32,7 +32,7 @@ def __init__(self, provider_config: dict) -> None: self.model_name = "" self.provider_config = provider_config - def set_model(self, model_name: str): + def set_model(self, model_name: str) -> None: """Set the current model name""" self.model_name = model_name @@ -54,7 +54,7 @@ def meta(self) -> ProviderMeta: ) return meta - async def test(self): + async def test(self) -> None: """test the provider is a raises: @@ -84,7 +84,7 @@ def get_keys(self) -> list[str]: return keys or [""] @abc.abstractmethod - def set_key(self, key: str): + def set_key(self, key: str) -> NoReturn: raise NotImplementedError @abc.abstractmethod @@ -155,7 +155,7 @@ async def text_chat_stream( yield None # type: ignore raise NotImplementedError() - async def pop_record(self, context: list): + async def pop_record(self, context: list) -> None: """弹出 context 第一条非系统提示词对话记录""" poped = 0 indexs_to_pop = [] @@ -186,7 +186,7 @@ def _ensure_message_to_dicts( return dicts - async def test(self, timeout: float = 45.0): + async def test(self, timeout: float = 45.0) -> None: await asyncio.wait_for( self.text_chat(prompt="REPLY `PONG` ONLY"), timeout=timeout, @@ -204,7 +204,7 @@ async def get_text(self, audio_url: str) -> str: """获取音频的文本""" raise NotImplementedError - async def test(self): + async def test(self) -> None: sample_audio_path = os.path.join( get_astrbot_path(), "samples", @@ -224,7 +224,7 @@ async def get_audio(self, text: str) -> str: """获取文本的音频,返回音频文件路径""" raise NotImplementedError - async def test(self): + async def test(self) -> None: await self.get_audio("hi") @@ -249,7 +249,7 @@ def get_dim(self) -> int: """获取向量的维度""" ... - async def test(self): + async def test(self) -> None: await self.get_embedding("astrbot") async def get_embeddings_batch( @@ -279,7 +279,7 @@ async def get_embeddings_batch( completed_count = 0 total_count = len(texts) - async def process_batch(batch_idx: int, batch_texts: list[str]): + async def process_batch(batch_idx: int, batch_texts: list[str]) -> None: nonlocal completed_count async with semaphore: for attempt in range(max_retries): @@ -336,7 +336,7 @@ async def rerank( """获取查询和文档的重排序分数""" ... - async def test(self): + async def test(self) -> None: result = await self.rerank("Apple", documents=["apple", "banana"]) if not result: raise Exception("Rerank provider test failed, no results returned") diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index bd0f06fba..bd09ebae7 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -408,5 +408,5 @@ async def get_models(self) -> list[str]: models_str.append(model.id) return models_str - def set_key(self, key: str): + def set_key(self, key: str) -> None: self.chosen_api_key = key diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py index 2ccf146ca..08180222a 100644 --- a/astrbot/core/provider/sources/azure_tts_source.py +++ b/astrbot/core/provider/sources/azure_tts_source.py @@ -21,7 +21,7 @@ class OTTSProvider: - def __init__(self, config: dict): + def __init__(self, config: dict) -> None: self.skey = config["OTTS_SKEY"] self.api_url = config["OTTS_URL"] self.auth_time_url = config["OTTS_AUTH_TIME"] @@ -48,7 +48,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._client.aclose() self._client = None - async def _sync_time(self): + async def _sync_time(self) -> None: try: response = await self.client.get(self.auth_time_url) response.raise_for_status() @@ -103,7 +103,7 @@ async def get_audio(self, text: str, voice_params: dict) -> str: class AzureNativeProvider(TTSProvider): - def __init__(self, provider_config: dict, provider_settings: dict): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) self.subscription_key = provider_config.get( "azure_tts_subscription_key", @@ -149,7 +149,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self._client.aclose() self._client = None - async def _refresh_token(self): + async def _refresh_token(self) -> None: token_url = ( f"https://{self.region}.api.cognitive.microsoft.com/sts/v1.0/issuetoken" ) @@ -195,7 +195,7 @@ async def get_audio(self, text: str) -> str: @register_provider_adapter("azure_tts", "Azure TTS", ProviderType.TEXT_TO_SPEECH) class AzureTTSProvider(TTSProvider): - def __init__(self, provider_config: dict, provider_settings: dict): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) key_value = provider_config.get("azure_tts_subscription_key", "") self.provider = self._parse_provider(key_value, provider_config) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index e2efc6aab..18eb8db9f 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -742,7 +742,7 @@ def get_current_key(self) -> str: def get_keys(self) -> list[str]: return self.api_keys - def set_key(self, key): + def set_key(self, key) -> None: self.chosen_api_key = key self._init_client() @@ -782,5 +782,5 @@ async def encode_image_bs64(self, image_url: str) -> str: image_bs64 = base64.b64encode(f.read()).decode("utf-8") return "data:image/jpeg;base64," + image_bs64 - async def terminate(self): + async def terminate(self) -> None: logger.info("Google GenAI 适配器已终止。") diff --git a/astrbot/core/provider/sources/gsv_selfhosted_source.py b/astrbot/core/provider/sources/gsv_selfhosted_source.py index 7f8d39eac..029f6af10 100644 --- a/astrbot/core/provider/sources/gsv_selfhosted_source.py +++ b/astrbot/core/provider/sources/gsv_selfhosted_source.py @@ -39,7 +39,7 @@ def __init__( self.timeout = provider_config.get("timeout", 60) self._session: aiohttp.ClientSession | None = None - async def initialize(self): + async def initialize(self) -> None: """异步初始化:在 ProviderManager 中被调用""" self._session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=self.timeout), @@ -85,7 +85,7 @@ async def _make_request( logger.error(f"[GSV TTS] 请求 {endpoint} 最终失败:{e}") raise - async def _set_model_weights(self): + async def _set_model_weights(self) -> None: """设置模型路径""" try: if self.gpt_weights_path: @@ -144,7 +144,7 @@ def build_synthesis_params(self, text: str) -> dict: # TODO: 在此处添加情绪分析,例如 params["emotion"] = detect_emotion(text) return params - async def terminate(self): + async def terminate(self) -> None: """终止释放资源:在 ProviderManager 中被调用""" if self._session and not self._session.closed: await self._session.close() diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 788b649a9..996a5ab24 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -74,7 +74,7 @@ def __init__(self, provider_config, provider_settings) -> None: self.reasoning_key = "reasoning_content" - def _maybe_inject_xai_search(self, payloads: dict, **kwargs): + def _maybe_inject_xai_search(self, payloads: dict, **kwargs) -> None: """当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。 - 仅在 provider_config.xai_native_search 为 True 时生效 @@ -602,7 +602,7 @@ def get_current_key(self) -> str: def get_keys(self) -> list[str]: return self.api_keys - def set_key(self, key): + def set_key(self, key) -> None: self.client.api_key = key async def assemble_context( diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index a41bd72fd..965b83a5a 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -37,7 +37,7 @@ def __init__( self.model = None self.is_emotion = provider_config.get("is_emotion", False) - async def initialize(self): + async def initialize(self) -> None: logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...") # 将模型加载放到线程池中执行 @@ -52,7 +52,7 @@ async def get_timestamped_path(self) -> str: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") return os.path.join("data", "temp", f"{timestamp}") - async def _is_silk_file(self, file_path): + async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" with open(file_path, "rb") as f: file_header = f.read(8) diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index fa69206ef..9532ca5b9 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -38,7 +38,7 @@ def __init__( self.set_model(provider_config["model"]) - async def _get_audio_format(self, file_path): + async def _get_audio_format(self, file_path) -> str | None: # 定义要检测的头部字节 silk_header = b"SILK" amr_header = b"#!AMR" diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index a14f93f14..d5d2dc340 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -30,7 +30,7 @@ def __init__( self.set_model(provider_config["model"]) self.model = None - async def initialize(self): + async def initialize(self) -> None: loop = asyncio.get_event_loop() logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...") self.model = await loop.run_in_executor( @@ -40,7 +40,7 @@ async def initialize(self): ) logger.info("Whisper 模型加载完成。") - async def _is_silk_file(self, file_path): + async def _is_silk_file(self, file_path) -> bool: silk_header = b"SILK" with open(file_path, "rb") as f: file_header = f.read(8) diff --git a/astrbot/core/provider/sources/xinference_rerank_source.py b/astrbot/core/provider/sources/xinference_rerank_source.py index 960408550..9c3a77c15 100644 --- a/astrbot/core/provider/sources/xinference_rerank_source.py +++ b/astrbot/core/provider/sources/xinference_rerank_source.py @@ -37,7 +37,7 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: self.model: AsyncRESTfulRerankModelHandle | None = None self.model_uid = None - async def initialize(self): + async def initialize(self) -> None: if self.api_key: logger.info("Xinference Rerank: Using API key for authentication.") self.client = Client(self.base_url, api_key=self.api_key) diff --git a/astrbot/core/provider/sources/xinference_stt_provider.py b/astrbot/core/provider/sources/xinference_stt_provider.py index 9c69a0039..344655afd 100644 --- a/astrbot/core/provider/sources/xinference_stt_provider.py +++ b/astrbot/core/provider/sources/xinference_stt_provider.py @@ -37,7 +37,7 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: self.client = None self.model_uid = None - async def initialize(self): + async def initialize(self) -> None: if self.api_key: logger.info("Xinference STT: Using API key for authentication.") self.client = Client(self.base_url, api_key=self.api_key) diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index c474962c5..2bf86872e 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -15,7 +15,7 @@ class Star(CommandParserMixin, PluginKVStoreMixin): author: str name: str - def __init__(self, context: Context, config: dict | None = None): + def __init__(self, context: Context, config: dict | None = None) -> None: StarTools.initialize(context) self.context = context @@ -55,13 +55,13 @@ async def html_render( options=options, ) - async def initialize(self): + async def initialize(self) -> None: """当插件被激活时会调用这个方法""" - async def terminate(self): + async def terminate(self) -> None: """当插件被禁用、重载插件时会调用这个方法""" - def __del__(self): + def __del__(self) -> None: """[Deprecated] 当插件被禁用、重载插件时会调用这个方法""" diff --git a/astrbot/core/star/config.py b/astrbot/core/star/config.py index a9af974c5..4068d8439 100644 --- a/astrbot/core/star/config.py +++ b/astrbot/core/star/config.py @@ -22,7 +22,7 @@ def load_config(namespace: str) -> dict | bool: return ret -def put_config(namespace: str, name: str, key: str, value, description: str): +def put_config(namespace: str, name: str, key: str, value, description: str) -> None: """将配置项写入以namespace为名字的配置文件,如果key不存在于目标配置文件中。当前 value 仅支持 str, int, float, bool, list 类型(暂不支持 dict)。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 name: str, 配置项的显示名字。 @@ -64,7 +64,7 @@ def put_config(namespace: str, name: str, key: str, value, description: str): f.flush() -def update_config(namespace: str, key: str, value): +def update_config(namespace: str, key: str, value) -> None: """更新配置文件中的配置项。 namespace: str, 配置的唯一识别符,也就是配置文件的名字。 key: str, 配置项的键。 diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 9a52ec8bc..4049ea21c 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -65,7 +65,7 @@ def __init__( persona_manager: PersonaManager, astrbot_config_mgr: AstrBotConfigManager, knowledge_base_manager: KnowledgeBaseManager, - ): + ) -> None: self._event_queue = event_queue """事件队列。消息平台通过事件队列传递消息事件。""" self._config = config @@ -398,7 +398,7 @@ def register_web_api( view_handler: Awaitable, methods: list, desc: str, - ): + ) -> None: for idx, api in enumerate(self.registered_web_apis): if api[0] == route and methods == api[2]: self.registered_web_apis[idx] = (route, view_handler, methods, desc) @@ -448,7 +448,7 @@ def get_db(self) -> BaseDatabase: """获取 AstrBot 数据库。""" return self._db - def register_provider(self, provider: Provider): + def register_provider(self, provider: Provider) -> None: """注册一个 LLM Provider(Chat_Completion 类型)。""" self.provider_manager.provider_insts.append(provider) @@ -493,7 +493,7 @@ def register_commands( awaitable: Callable[..., Awaitable[Any]], use_regex=False, ignore_prefix=False, - ): + ) -> None: """注册一个命令。 [Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。 @@ -522,6 +522,6 @@ def register_commands( ) star_handlers_registry.append(md) - def register_task(self, task: Awaitable, desc: str): + def register_task(self, task: Awaitable, desc: str) -> None: """[DEPRECATED]注册一个异步任务。""" self._register_tasks.append(task) diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index 51ad5f089..e27f9d58a 100755 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -37,7 +37,7 @@ def __init__( alias: set | None = None, handler_md: StarHandlerMetadata | None = None, parent_command_names: list[str] | None = None, - ): + ) -> None: self.command_name = command_name self.alias = alias if alias else set() self._original_command_name = command_name @@ -63,7 +63,7 @@ def print_types(self): result = "".join(parts).rstrip(",") return result - def init_handler_md(self, handle_md: StarHandlerMetadata): + def init_handler_md(self, handle_md: StarHandlerMetadata) -> None: self.handler_md = handle_md signature = inspect.signature(self.handler_md.handler) self.handler_params = {} # 参数名 -> 参数类型,如果有默认值则为默认值 @@ -81,7 +81,7 @@ def init_handler_md(self, handle_md: StarHandlerMetadata): def get_handler_md(self) -> StarHandlerMetadata: return self.handler_md - def add_custom_filter(self, custom_filter: CustomFilter): + def add_custom_filter(self, custom_filter: CustomFilter) -> None: self.custom_filter_list.append(custom_filter) def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py index 4cbd2c007..52fb6a452 100755 --- a/astrbot/core/star/filter/command_group.py +++ b/astrbot/core/star/filter/command_group.py @@ -15,7 +15,7 @@ def __init__( group_name: str, alias: set | None = None, parent_group: CommandGroupFilter | None = None, - ): + ) -> None: self.group_name = group_name self.alias = alias if alias else set() self._original_group_name = group_name @@ -29,10 +29,10 @@ def __init__( def add_sub_command_filter( self, sub_command_filter: CommandFilter | CommandGroupFilter, - ): + ) -> None: self.sub_command_filters.append(sub_command_filter) - def add_custom_filter(self, custom_filter: CustomFilter): + def add_custom_filter(self, custom_filter: CustomFilter) -> None: self.custom_filter_list.append(custom_filter) def get_complete_command_names(self) -> list[str]: diff --git a/astrbot/core/star/filter/custom_filter.py b/astrbot/core/star/filter/custom_filter.py index d57b5cac0..2e7303d06 100644 --- a/astrbot/core/star/filter/custom_filter.py +++ b/astrbot/core/star/filter/custom_filter.py @@ -19,7 +19,7 @@ def __or__(cls, other): class CustomFilter(HandlerFilter, metaclass=CustomFilterMeta): - def __init__(self, raise_error: bool = True, **kwargs): + def __init__(self, raise_error: bool = True, **kwargs) -> None: self.raise_error = raise_error @abstractmethod @@ -35,7 +35,7 @@ def __and__(self, other): class CustomFilterOr(CustomFilter): - def __init__(self, filter1: CustomFilter, filter2: CustomFilter): + def __init__(self, filter1: CustomFilter, filter2: CustomFilter) -> None: super().__init__() if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): raise ValueError( @@ -49,7 +49,7 @@ def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: class CustomFilterAnd(CustomFilter): - def __init__(self, filter1: CustomFilter, filter2: CustomFilter): + def __init__(self, filter1: CustomFilter, filter2: CustomFilter) -> None: super().__init__() if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)): raise ValueError( diff --git a/astrbot/core/star/filter/event_message_type.py b/astrbot/core/star/filter/event_message_type.py index 7f350bd38..604fc3ed3 100644 --- a/astrbot/core/star/filter/event_message_type.py +++ b/astrbot/core/star/filter/event_message_type.py @@ -22,7 +22,7 @@ class EventMessageType(enum.Flag): class EventMessageTypeFilter(HandlerFilter): - def __init__(self, event_message_type: EventMessageType): + def __init__(self, event_message_type: EventMessageType) -> None: self.event_message_type = event_message_type def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool: diff --git a/astrbot/core/star/filter/permission.py b/astrbot/core/star/filter/permission.py index 3374544c2..a70299fa9 100644 --- a/astrbot/core/star/filter/permission.py +++ b/astrbot/core/star/filter/permission.py @@ -14,7 +14,9 @@ class PermissionType(enum.Flag): class PermissionTypeFilter(HandlerFilter): - def __init__(self, permission_type: PermissionType, raise_error: bool = True): + def __init__( + self, permission_type: PermissionType, raise_error: bool = True + ) -> None: self.permission_type = permission_type self.raise_error = raise_error diff --git a/astrbot/core/star/filter/platform_adapter_type.py b/astrbot/core/star/filter/platform_adapter_type.py index 1182ff9b0..d5ff6146a 100644 --- a/astrbot/core/star/filter/platform_adapter_type.py +++ b/astrbot/core/star/filter/platform_adapter_type.py @@ -58,7 +58,7 @@ class PlatformAdapterType(enum.Flag): class PlatformAdapterTypeFilter(HandlerFilter): - def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str): + def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str) -> None: if isinstance(platform_adapter_type_or_str, str): self.platform_type = ADAPTER_NAME_2_TYPE.get(platform_adapter_type_or_str) else: diff --git a/astrbot/core/star/filter/regex.py b/astrbot/core/star/filter/regex.py index cd5bebdb4..abec5a488 100644 --- a/astrbot/core/star/filter/regex.py +++ b/astrbot/core/star/filter/regex.py @@ -10,7 +10,7 @@ class RegexFilter(HandlerFilter): """正则表达式过滤器""" - def __init__(self, regex: str): + def __init__(self, regex: str) -> None: self.regex_str = regex self.regex = re.compile(regex) diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index daf36a8f6..01b0889e2 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -250,7 +250,7 @@ class RegisteringCommandable: command: Callable[..., Callable[..., None]] = register_command custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter - def __init__(self, parent_group: CommandGroupFilter): + def __init__(self, parent_group: CommandGroupFilter) -> None: self.parent_group = parent_group @@ -492,7 +492,7 @@ def llm_tool(self, *args, **kwargs): kwargs["registering_agent"] = self return register_llm_tool(*args, **kwargs) - def __init__(self, agent: Agent[AstrAgentContext]): + def __init__(self, agent: Agent[AstrAgentContext]) -> None: self._agent = agent diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index be5b4679f..261c5234d 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -12,11 +12,11 @@ class StarHandlerRegistry(Generic[T]): - def __init__(self): + def __init__(self) -> None: self.star_handlers_map: dict[str, StarHandlerMetadata] = {} self._handlers: list[StarHandlerMetadata] = [] - def append(self, handler: StarHandlerMetadata): + def append(self, handler: StarHandlerMetadata) -> None: """添加一个 Handler,并保持按优先级有序""" if "priority" not in handler.extras_configs: handler.extras_configs["priority"] = 0 @@ -25,7 +25,7 @@ def append(self, handler: StarHandlerMetadata): self._handlers.append(handler) self._handlers.sort(key=lambda h: -h.extras_configs["priority"]) - def _print_handlers(self): + def _print_handlers(self) -> None: for handler in self._handlers: print(handler.handler_full_name) @@ -156,18 +156,18 @@ def get_handlers_by_module_name( if handler.handler_module_path == module_name ] - def clear(self): + def clear(self) -> None: self.star_handlers_map.clear() self._handlers.clear() - def remove(self, handler: StarHandlerMetadata): + def remove(self, handler: StarHandlerMetadata) -> None: self.star_handlers_map.pop(handler.handler_full_name, None) self._handlers = [h for h in self._handlers if h != handler] def __iter__(self): return iter(self._handlers) - def __len__(self): + def __len__(self) -> int: return len(self._handlers) diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 1f9f95ae5..552fb3acd 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -38,7 +38,7 @@ class PluginManager: - def __init__(self, context: Context, config: AstrBotConfig): + def __init__(self, context: Context, config: AstrBotConfig) -> None: self.updator = PluginUpdator() self.context = context @@ -66,7 +66,7 @@ def __init__(self, context: Context, config: AstrBotConfig): if os.getenv("ASTRBOT_RELOAD", "0") == "1": asyncio.create_task(self._watch_plugins_changes()) - async def _watch_plugins_changes(self): + async def _watch_plugins_changes(self) -> None: """监视插件文件变化""" try: async for changes in awatch( @@ -83,7 +83,7 @@ async def _watch_plugins_changes(self): logger.error(f"插件热重载监视任务异常: {e!s}") logger.error(traceback.format_exc()) - async def _handle_file_changes(self, changes): + async def _handle_file_changes(self, changes) -> None: """处理文件变化""" logger.info(f"检测到文件变化: {changes}") plugins_to_check = [] @@ -167,7 +167,9 @@ def _get_plugin_modules(self) -> list[dict]: plugins.extend(_p) return plugins - async def _check_plugin_dept_update(self, target_plugin: str | None = None): + async def _check_plugin_dept_update( + self, target_plugin: str | None = None + ) -> bool | None: """检查插件的依赖 如果 target_plugin 为 None,则检查所有插件的依赖 """ @@ -264,7 +266,7 @@ def _purge_modules( module_patterns: list[str] | None = None, root_dir_name: str | None = None, is_reserved: bool = False, - ): + ) -> None: """从 sys.modules 中移除指定的模块 可以基于模块名模式或插件目录名移除模块,用于清理插件相关的模块缓存 @@ -699,7 +701,7 @@ async def uninstall_plugin( plugin_name: str, delete_config: bool = False, delete_data: bool = False, - ): + ) -> None: """卸载指定的插件。 Args: @@ -788,7 +790,7 @@ async def uninstall_plugin( except Exception as e: logger.warning(f"删除插件持久化数据失败 (plugins_data): {e!s}") - async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): + async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str) -> None: """解绑并移除一个插件。 Args: @@ -839,7 +841,7 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str): is_reserved=plugin.reserved, ) - async def update_plugin(self, plugin_name: str, proxy=""): + async def update_plugin(self, plugin_name: str, proxy="") -> None: """升级一个插件""" plugin = self.context.get_registered_star(plugin_name) if not plugin: @@ -850,7 +852,7 @@ async def update_plugin(self, plugin_name: str, proxy=""): await self.updator.update(plugin, proxy=proxy) await self.reload(plugin_name) - async def turn_off_plugin(self, plugin_name: str): + async def turn_off_plugin(self, plugin_name: str) -> None: """禁用一个插件。 调用插件的 terminate() 方法, 将插件的 module_path 加入到 data/shared_preferences.json 的 inactivated_plugins 列表中。 @@ -892,7 +894,7 @@ async def turn_off_plugin(self, plugin_name: str): plugin.activated = False @staticmethod - async def _terminate_plugin(star_metadata: StarMetadata): + async def _terminate_plugin(star_metadata: StarMetadata) -> None: """终止插件,调用插件的 terminate() 和 __del__() 方法""" logger.info(f"正在终止插件 {star_metadata.name} ...") @@ -912,7 +914,7 @@ async def _terminate_plugin(star_metadata: StarMetadata): elif "terminate" in star_metadata.star_cls_type.__dict__: await star_metadata.star_cls.terminate() - async def turn_on_plugin(self, plugin_name: str): + async def turn_on_plugin(self, plugin_name: str) -> None: plugin = self.context.get_registered_star(plugin_name) if plugin is None: raise Exception(f"插件 {plugin_name} 不存在。") diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 7a66449b4..4d85131fc 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -89,7 +89,7 @@ async def send_message_by_id( id: str, message_chain: MessageChain, platform: str = "aiocqhttp", - ): + ) -> None: """根据 id(例如qq号, 群号等) 直接, 主动地发送消息 Args: diff --git a/astrbot/core/star/updator.py b/astrbot/core/star/updator.py index 8793ad505..1a0c5fc26 100644 --- a/astrbot/core/star/updator.py +++ b/astrbot/core/star/updator.py @@ -52,7 +52,7 @@ async def update(self, plugin: StarMetadata, proxy="") -> str: return plugin_path - def unzip_file(self, zip_path: str, target_dir: str): + def unzip_file(self, zip_path: str, target_dir: str) -> None: os.makedirs(target_dir, exist_ok=True) update_dir = "" logger.info(f"正在解压压缩包: {zip_path}") diff --git a/astrbot/core/umop_config_router.py b/astrbot/core/umop_config_router.py index 27f6232aa..dad17815f 100644 --- a/astrbot/core/umop_config_router.py +++ b/astrbot/core/umop_config_router.py @@ -4,14 +4,14 @@ class UmopConfigRouter: """UMOP 配置路由器""" - def __init__(self, sp: SharedPreferences): + def __init__(self, sp: SharedPreferences) -> None: self.umop_to_conf_id: dict[str, str] = {} """UMOP 到配置文件 ID 的映射""" self.sp = sp self._load_routing_table() - def _load_routing_table(self): + def _load_routing_table(self) -> None: """加载路由表""" # 从 SharedPreferences 中加载 umop_to_conf_id 映射 sp_data = self.sp.get( @@ -47,7 +47,7 @@ def get_conf_id_for_umop(self, umo: str) -> str | None: return conf_id return None - async def update_routing_data(self, new_routing: dict[str, str]): + async def update_routing_data(self, new_routing: dict[str, str]) -> None: """更新路由表 Args: @@ -67,7 +67,7 @@ async def update_routing_data(self, new_routing: dict[str, str]): self.umop_to_conf_id = new_routing await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) - async def update_route(self, umo: str, conf_id: str): + async def update_route(self, umo: str, conf_id: str) -> None: """更新一条路由 Args: @@ -86,7 +86,7 @@ async def update_route(self, umo: str, conf_id: str): self.umop_to_conf_id[umo] = conf_id await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) - async def delete_route(self, umo: str): + async def delete_route(self, umo: str) -> None: """删除一条路由 Args: diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index 0a7116a0d..9e4a0be84 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -23,7 +23,7 @@ def __init__(self, repo_mirror: str = "") -> None: self.MAIN_PATH = get_astrbot_path() self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases" - def terminate_child_processes(self): + def terminate_child_processes(self) -> None: """终止当前进程的所有子进程 使用 psutil 库获取当前进程的所有子进程,并尝试终止它们 """ @@ -44,7 +44,7 @@ def terminate_child_processes(self): except psutil.NoSuchProcess: pass - def _reboot(self, delay: int = 3): + def _reboot(self, delay: int = 3) -> None: """重启当前程序 在指定的延迟后,终止所有子进程并重新启动程序 这里只能使用 os.exec* 来重启程序 @@ -85,7 +85,7 @@ async def check_update( async def get_releases(self) -> list: return await self.fetch_release_info(self.ASTRBOT_RELEASE_API) - async def update(self, reboot=False, latest=True, version=None, proxy=""): + async def update(self, reboot=False, latest=True, version=None, proxy="") -> None: update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest) file_url = None diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index fcf5bb3c7..7aef3172d 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -19,7 +19,7 @@ logger = logging.getLogger("astrbot") -def on_error(func, path, exc_info): +def on_error(func, path, exc_info) -> None: """A callback of the rmtree function.""" import stat @@ -37,7 +37,7 @@ def remove_dir(file_path: str) -> bool: return True -def port_checker(port: int, host: str = "localhost"): +def port_checker(port: int, host: str = "localhost") -> bool | None: sk = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sk.settimeout(1) try: @@ -134,7 +134,7 @@ async def download_image_by_url( raise e -async def download_file(url: str, path: str, show_progress: bool = False): +async def download_file(url: str, path: str, show_progress: bool = False) -> None: """从指定 url 下载文件到指定路径 path""" try: ssl_context = ssl.create_default_context( diff --git a/astrbot/core/utils/log_pipe.py b/astrbot/core/utils/log_pipe.py index 2e931dd81..6f40f0942 100644 --- a/astrbot/core/utils/log_pipe.py +++ b/astrbot/core/utils/log_pipe.py @@ -10,7 +10,7 @@ def __init__( logger: Logger, identifier=None, callback=None, - ): + ) -> None: threading.Thread.__init__(self) self.daemon = True self.level = level @@ -24,7 +24,7 @@ def __init__( def fileno(self): return self.fd_write - def run(self): + def run(self) -> None: for line in iter(self.reader.readline, ""): if self.callback: self.callback(line.strip()) @@ -32,5 +32,5 @@ def run(self): self.reader.close() - def close(self): + def close(self) -> None: os.close(self.fd_write) diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index f12019e3c..06acdba4b 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -40,7 +40,7 @@ def get_installation_id(): return "null" @staticmethod - async def upload(**kwargs): + async def upload(**kwargs) -> None: """上传相关非敏感的指标以更好地了解 AstrBot 的使用情况。上传的指标不会包含任何有关消息文本、用户信息等敏感信息。 Powered by TickStats. diff --git a/astrbot/core/utils/pip_installer.py b/astrbot/core/utils/pip_installer.py index 6076a114a..689a1dc86 100644 --- a/astrbot/core/utils/pip_installer.py +++ b/astrbot/core/utils/pip_installer.py @@ -6,7 +6,7 @@ class PipInstaller: - def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None): + def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None) -> None: self.pip_install_arg = pip_install_arg self.pypi_index_url = pypi_index_url @@ -15,7 +15,7 @@ async def install( package_name: str | None = None, requirements_path: str | None = None, mirror: str | None = None, - ): + ) -> None: args = ["install"] if package_name: args.append(package_name) diff --git a/astrbot/core/utils/session_lock.py b/astrbot/core/utils/session_lock.py index 912d91e53..7810d6ce4 100644 --- a/astrbot/core/utils/session_lock.py +++ b/astrbot/core/utils/session_lock.py @@ -4,7 +4,7 @@ class SessionLockManager: - def __init__(self): + def __init__(self) -> None: self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) self._lock_count: dict[str, int] = defaultdict(int) self._access_lock = asyncio.Lock() diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index e1f2fbef7..b327a6184 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -18,7 +18,7 @@ class SessionController: """控制一个 Session 是否已经结束""" - def __init__(self): + def __init__(self) -> None: self.future = asyncio.Future() self.current_event: asyncio.Event | None = None """当前正在等待的所用的异步事件""" @@ -29,7 +29,7 @@ def __init__(self): self.history_chains: list[list[Comp.BaseMessageComponent]] = [] - def stop(self, error: Exception | None = None): + def stop(self, error: Exception | None = None) -> None: """立即结束这个会话""" if not self.future.done(): if error: @@ -37,7 +37,7 @@ def stop(self, error: Exception | None = None): else: self.future.set_result(None) - def keep(self, timeout: float = 0, reset_timeout=False): + def keep(self, timeout: float = 0, reset_timeout=False) -> None: """保持这个会话 Args: @@ -71,7 +71,7 @@ def keep(self, timeout: float = 0, reset_timeout=False): asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep - async def _holding(self, event: asyncio.Event, timeout: float): + async def _holding(self, event: asyncio.Event, timeout: float) -> None: """等待事件结束或超时""" try: await asyncio.wait_for(event.wait(), timeout) @@ -107,7 +107,7 @@ def __init__( session_filter: SessionFilter, session_id: str, record_history_chains: bool, - ): + ) -> None: self.session_id = session_id self.session_filter = session_filter self.handler: ( @@ -141,7 +141,7 @@ async def register_wait( finally: self._cleanup() - def _cleanup(self, error: Exception | None = None): + def _cleanup(self, error: Exception | None = None) -> None: """清理会话""" USER_SESSIONS.pop(self.session_id, None) try: @@ -151,7 +151,7 @@ def _cleanup(self, error: Exception | None = None): self.session_controller.stop(error) @classmethod - async def trigger(cls, session_id: str, event: AstrMessageEvent): + async def trigger(cls, session_id: str, event: AstrMessageEvent) -> None: """外部输入触发会话处理""" session = USER_SESSIONS.get(session_id) if not session or session.session_controller.future.done(): diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index ccd394ee4..a4f69100a 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -12,7 +12,7 @@ class SharedPreferences: - def __init__(self, db_helper: BaseDatabase, json_storage_path=None): + def __init__(self, db_helper: BaseDatabase, json_storage_path=None) -> None: if json_storage_path is None: json_storage_path = os.path.join( get_astrbot_data_path(), @@ -118,7 +118,7 @@ async def global_get( return await self.range_get_async("global", "global", key) return await self.get_async("global", "global", key, default) - async def put_async(self, scope: str, scope_id: str, key: str, value: Any): + async def put_async(self, scope: str, scope_id: str, key: str, value: Any) -> None: """设置指定范围和键的偏好设置""" await self.db_helper.insert_preference_or_update( scope, @@ -127,24 +127,24 @@ async def put_async(self, scope: str, scope_id: str, key: str, value: Any): {"val": value}, ) - async def session_put(self, umo: str, key: str, value: Any): + async def session_put(self, umo: str, key: str, value: Any) -> None: await self.put_async("umo", umo, key, value) - async def global_put(self, key: str, value: Any): + async def global_put(self, key: str, value: Any) -> None: await self.put_async("global", "global", key, value) - async def remove_async(self, scope: str, scope_id: str, key: str): + async def remove_async(self, scope: str, scope_id: str, key: str) -> None: """删除指定范围和键的偏好设置""" await self.db_helper.remove_preference(scope, scope_id, key) - async def session_remove(self, umo: str, key: str): + async def session_remove(self, umo: str, key: str) -> None: await self.remove_async("umo", umo, key) - async def global_remove(self, key: str): + async def global_remove(self, key: str) -> None: """删除全局偏好设置""" await self.remove_async("global", "global", key) - async def clear_async(self, scope: str, scope_id: str): + async def clear_async(self, scope: str, scope_id: str) -> None: """清空指定范围的所有偏好设置""" await self.db_helper.clear_preferences(scope, scope_id) @@ -188,21 +188,25 @@ def range_get( return result - def put(self, key, value, scope: str | None = None, scope_id: str | None = None): + def put( + self, key, value, scope: str | None = None, scope_id: str | None = None + ) -> None: """设置偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.put_async(scope or "unknown", scope_id or "unknown", key, value), self._sync_loop, ).result() - def remove(self, key, scope: str | None = None, scope_id: str | None = None): + def remove( + self, key, scope: str | None = None, scope_id: str | None = None + ) -> None: """删除偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.remove_async(scope or "unknown", scope_id or "unknown", key), self._sync_loop, ).result() - def clear(self, scope: str | None = None, scope_id: str | None = None): + def clear(self, scope: str | None = None, scope_id: str | None = None) -> None: """清空偏好设置(已弃用)""" asyncio.run_coroutine_threadsafe( self.clear_async(scope or "unknown", scope_id or "unknown"), diff --git a/astrbot/core/utils/t2i/network_strategy.py b/astrbot/core/utils/t2i/network_strategy.py index 7ebba5669..2abb22917 100644 --- a/astrbot/core/utils/t2i/network_strategy.py +++ b/astrbot/core/utils/t2i/network_strategy.py @@ -28,7 +28,7 @@ def __init__(self, base_url: str | None = None) -> None: self.endpoints = [self.BASE_RENDER_URL] self.template_manager = TemplateManager() - async def initialize(self): + async def initialize(self) -> None: if self.BASE_RENDER_URL == ASTRBOT_T2I_DEFAULT_ENDPOINT: asyncio.create_task(self.get_official_endpoints()) @@ -36,7 +36,7 @@ async def get_template(self, name: str = "base") -> str: """通过名称获取文转图 HTML 模板""" return self.template_manager.get_template(name) - async def get_official_endpoints(self): + async def get_official_endpoints(self) -> None: """获取官方的 t2i 端点列表。""" try: async with aiohttp.ClientSession() as session: diff --git a/astrbot/core/utils/t2i/renderer.py b/astrbot/core/utils/t2i/renderer.py index 2ce7a5ebf..e3118d7e8 100644 --- a/astrbot/core/utils/t2i/renderer.py +++ b/astrbot/core/utils/t2i/renderer.py @@ -7,11 +7,11 @@ class HtmlRenderer: - def __init__(self, endpoint_url: str | None = None): + def __init__(self, endpoint_url: str | None = None) -> None: self.network_strategy = NetworkRenderStrategy(endpoint_url) self.local_strategy = LocalRenderStrategy() - async def initialize(self): + async def initialize(self) -> None: await self.network_strategy.initialize() async def render_custom_template( diff --git a/astrbot/core/utils/t2i/template_manager.py b/astrbot/core/utils/t2i/template_manager.py index 6d44f735b..b3eb0c9ff 100644 --- a/astrbot/core/utils/t2i/template_manager.py +++ b/astrbot/core/utils/t2i/template_manager.py @@ -14,7 +14,7 @@ class TemplateManager: CORE_TEMPLATES = ["base.html", "astrbot_powershell.html"] - def __init__(self): + def __init__(self) -> None: self.builtin_template_dir = os.path.join( get_astrbot_path(), "astrbot", @@ -28,7 +28,7 @@ def __init__(self): os.makedirs(self.user_template_dir, exist_ok=True) self._initialize_user_templates() - def _copy_core_templates(self, overwrite: bool = False): + def _copy_core_templates(self, overwrite: bool = False) -> None: """从内置目录复制核心模板到用户目录。""" for filename in self.CORE_TEMPLATES: src = os.path.join(self.builtin_template_dir, filename) @@ -36,7 +36,7 @@ def _copy_core_templates(self, overwrite: bool = False): if os.path.exists(src) and (overwrite or not os.path.exists(dst)): shutil.copyfile(src, dst) - def _initialize_user_templates(self): + def _initialize_user_templates(self) -> None: """如果用户目录下缺少核心模板,则进行复制。""" self._copy_core_templates(overwrite=False) @@ -80,7 +80,7 @@ def get_template(self, name: str) -> str: raise FileNotFoundError("模板不存在。") - def create_template(self, name: str, content: str): + def create_template(self, name: str, content: str) -> None: """在用户目录中创建一个新的模板文件。""" path = self._get_user_template_path(name) if os.path.exists(path): @@ -88,7 +88,7 @@ def create_template(self, name: str, content: str): with open(path, "w", encoding="utf-8") as f: f.write(content) - def update_template(self, name: str, content: str): + def update_template(self, name: str, content: str) -> None: """更新一个模板。此操作始终写入用户目录。 如果更新的是一个内置模板,此操作实际上会在用户目录中创建一个修改后的副本, 从而实现对内置模板的“覆盖”。 @@ -97,7 +97,7 @@ def update_template(self, name: str, content: str): with open(path, "w", encoding="utf-8") as f: f.write(content) - def delete_template(self, name: str): + def delete_template(self, name: str) -> None: """仅删除用户目录中的模板文件。 如果删除的是一个覆盖了内置模板的用户模板,这将有效地“恢复”到内置版本。 """ @@ -106,6 +106,6 @@ def delete_template(self, name: str): raise FileNotFoundError("用户模板不存在,无法删除。") os.remove(path) - def reset_default_template(self): + def reset_default_template(self) -> None: """将核心模板从内置目录强制重置到用户目录。""" self._copy_core_templates(overwrite=True) diff --git a/astrbot/core/utils/webhook_utils.py b/astrbot/core/utils/webhook_utils.py index 0e1c3f9cd..07abc115a 100644 --- a/astrbot/core/utils/webhook_utils.py +++ b/astrbot/core/utils/webhook_utils.py @@ -20,7 +20,7 @@ def _get_dashboard_port() -> int: return 6185 -def log_webhook_info(platform_name: str, webhook_uuid: str): +def log_webhook_info(platform_name: str, webhook_uuid: str) -> None: """打印美观的 webhook 信息日志 Args: diff --git a/astrbot/core/zip_updator.py b/astrbot/core/zip_updator.py index 728dfdabb..6cea6b38d 100644 --- a/astrbot/core/zip_updator.py +++ b/astrbot/core/zip_updator.py @@ -3,6 +3,7 @@ import shutil import ssl import zipfile +from typing import NoReturn import aiohttp import certifi @@ -101,10 +102,10 @@ def github_api_release_parser(self, releases: list) -> list: ) return ret - def unzip(self): + def unzip(self) -> NoReturn: raise NotImplementedError - async def update(self): + async def update(self) -> NoReturn: raise NotImplementedError def compare_version(self, v1: str, v2: str) -> int: @@ -148,7 +149,9 @@ async def check_update( body=f"{tag_name}\n\n{sel_release_data['body']}", ) - async def download_from_repo_url(self, target_path: str, repo_url: str, proxy=""): + async def download_from_repo_url( + self, target_path: str, repo_url: str, proxy="" + ) -> None: author, repo, branch = self.parse_github_url(repo_url) logger.info(f"正在下载更新 {repo} ...") @@ -203,7 +206,7 @@ def parse_github_url(self, url: str): return author, repo, branch raise ValueError("无效的 GitHub URL") - def unzip_file(self, zip_path: str, target_dir: str): + def unzip_file(self, zip_path: str, target_dir: str) -> None: """解压缩文件, 并将压缩包内**第一个**文件夹内的文件移动到 target_dir""" os.makedirs(target_dir, exist_ok=True) update_dir = "" diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index f2439c058..38dc27f40 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -512,7 +512,7 @@ def _extract_attachment_ids(self, history_list) -> list[str]: attachment_ids.append(part["attachment_id"]) return attachment_ids - async def _delete_attachments(self, attachment_ids: list[str]): + async def _delete_attachments(self, attachment_ids: list[str]) -> None: """删除附件(包括数据库记录和磁盘文件)""" try: attachments = await self.db.get_attachments(attachment_ids) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 0edbe8377..116b587ce 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -48,7 +48,7 @@ def try_cast(value: Any, type_: str): def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]: errors = [] - def validate(data: dict, metadata: dict = schema, path=""): + def validate(data: dict, metadata: dict = schema, path="") -> None: for key, value in data.items(): if key not in metadata: continue @@ -121,7 +121,9 @@ def validate(data: dict, metadata: dict = schema, path=""): return errors, data -def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False): +def save_config( + post_config: dict, config: AstrBotConfig, is_core: bool = False +) -> None: """验证并保存配置""" errors = None logger.info(f"Saving config, is_core={is_core}") @@ -666,7 +668,7 @@ async def get_llm_tools(self): tools = tool_mgr.get_func_desc_openai_style() return Response().ok(tools).__dict__ - async def _register_platform_logo(self, platform, platform_default_tmpl): + async def _register_platform_logo(self, platform, platform_default_tmpl) -> None: """注册平台logo文件并生成访问令牌""" if not platform.logo_path: return @@ -789,7 +791,7 @@ async def _get_plugin_config(self, plugin_name: str): async def _save_astrbot_configs( self, post_configs: dict, conf_id: str | None = None - ): + ) -> None: try: if conf_id not in self.acm.confs: raise ValueError(f"配置文件 {conf_id} 不存在") @@ -805,7 +807,7 @@ async def _save_astrbot_configs( except Exception as e: raise e - async def _save_plugin_configs(self, post_configs: dict, plugin_name: str): + async def _save_plugin_configs(self, post_configs: dict, plugin_name: str) -> None: md = None for plugin_md in star_registry: if plugin_md.name == plugin_name: diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index 537a81f0b..72ffd5fd3 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -4,6 +4,7 @@ import os import traceback import uuid +from typing import Any import aiofiles from quart import request @@ -75,7 +76,7 @@ def _init_task(self, task_id: str, status: str = "pending") -> None: } def _set_task_result( - self, task_id: str, status: str, result: any = None, error: str | None = None + self, task_id: str, status: str, result: Any = None, error: str | None = None ) -> None: self.upload_tasks[task_id] = { "status": status, @@ -136,7 +137,7 @@ async def _background_upload_task( batch_size: int, tasks_limit: int, max_retries: int, - ): + ) -> None: """后台上传任务""" try: # 初始化任务状态 @@ -1214,7 +1215,7 @@ async def _background_upload_from_url_task( max_retries: int, enable_cleaning: bool, cleaning_provider_id: str | None, - ): + ) -> None: """后台上传URL任务""" try: # 初始化任务状态 diff --git a/astrbot/dashboard/routes/platform.py b/astrbot/dashboard/routes/platform.py index 4d8fdddfe..874bc19db 100644 --- a/astrbot/dashboard/routes/platform.py +++ b/astrbot/dashboard/routes/platform.py @@ -26,7 +26,7 @@ def __init__( self._register_webhook_routes() - def _register_webhook_routes(self): + def _register_webhook_routes(self) -> None: """注册 webhook 路由""" # 统一 webhook 入口,支持 GET 和 POST self.app.add_url_rule( diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index fd808c6c9..ebdfe5bf6 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -260,7 +260,7 @@ def _load_plugin_cache(self, cache_file: str): logger.warning(f"加载插件市场缓存失败: {e}") return None - def _save_plugin_cache(self, cache_file: str, data, md5: str | None = None): + def _save_plugin_cache(self, cache_file: str, data, md5: str | None = None) -> None: """保存插件市场数据到本地缓存""" try: # 确保目录存在 diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index 01ab292d4..53c623443 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -14,12 +14,12 @@ class RouteContext: class Route: routes: list | dict - def __init__(self, context: RouteContext): + def __init__(self, context: RouteContext) -> None: self.app = context.app self.config = context.config - def register_routes(self): - def _add_rule(path, method, func): + def register_routes(self) -> None: + def _add_rule(path, method, func) -> None: # 统一添加 /api 前缀 full_path = f"/api{path}" self.app.add_url_rule(full_path, view_func=func, methods=[method]) diff --git a/astrbot/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py index 3d3d0ca51..e056b6c5a 100644 --- a/astrbot/dashboard/routes/static_file.py +++ b/astrbot/dashboard/routes/static_file.py @@ -30,7 +30,7 @@ def __init__(self, context: RouteContext) -> None: self.app.add_url_rule(i, view_func=self.index) @self.app.errorhandler(404) - async def page_not_found(e): + async def page_not_found(e) -> str: return "404 Not found。如果你初次使用打开面板发现 404, 请参考文档: https://astrbot.app/faq.html。如果你正在测试回调地址可达性,显示这段文字说明测试成功了。" async def index(self): diff --git a/astrbot/dashboard/routes/t2i.py b/astrbot/dashboard/routes/t2i.py index db70a8820..8d06826be 100644 --- a/astrbot/dashboard/routes/t2i.py +++ b/astrbot/dashboard/routes/t2i.py @@ -12,7 +12,9 @@ class T2iRoute(Route): - def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle): + def __init__( + self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle + ) -> None: super().__init__(context) self.core_lifecycle = core_lifecycle self.config = core_lifecycle.astrbot_config diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 6d6530c90..af176bea8 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -169,7 +169,7 @@ def get_process_using_port(self, port: int) -> str: except Exception as e: return f"获取进程信息失败: {e!s}" - def _init_jwt_secret(self): + def _init_jwt_secret(self) -> None: if not self.config.get("dashboard", {}).get("jwt_secret", None): # 如果没有设置 JWT 密钥,则生成一个新的密钥 jwt_secret = os.urandom(32).hex() @@ -239,6 +239,6 @@ def run(self): shutdown_trigger=self.shutdown_trigger, ) - async def shutdown_trigger(self): + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() logger.info("AstrBot WebUI 已经被优雅地关闭") diff --git a/main.py b/main.py index 60879f065..339e3a728 100644 --- a/main.py +++ b/main.py @@ -25,7 +25,7 @@ """ -def check_env(): +def check_env() -> None: if not (sys.version_info.major == 3 and sys.version_info.minor >= 10): logger.error("请使用 Python3.10+ 运行本项目。") exit() diff --git a/packages/astrbot/long_term_memory.py b/packages/astrbot/long_term_memory.py index 610995db2..e08cdc515 100644 --- a/packages/astrbot/long_term_memory.py +++ b/packages/astrbot/long_term_memory.py @@ -17,7 +17,7 @@ class LongTermMemory: - def __init__(self, acm: AstrBotConfigManager, context: star.Context): + def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None: self.acm = acm self.context = context self.session_chats = defaultdict(list) @@ -111,7 +111,7 @@ async def need_active_reply(self, event: AstrMessageEvent) -> bool: return False - async def handle_message(self, event: AstrMessageEvent): + async def handle_message(self, event: AstrMessageEvent) -> None: """仅支持群聊""" if event.get_message_type() == MessageType.GROUP_MESSAGE: datetime_str = datetime.datetime.now().strftime("%H:%M:%S") @@ -148,7 +148,7 @@ async def handle_message(self, event: AstrMessageEvent): if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: self.session_chats[event.unified_msg_origin].pop(0) - async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest): + async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None: """当触发 LLM 请求前,调用此方法修改 req""" if event.unified_msg_origin not in self.session_chats: return @@ -171,7 +171,9 @@ async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest): ) req.system_prompt += chats_str - async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse): + async def after_req_llm( + self, event: AstrMessageEvent, llm_resp: LLMResponse + ) -> None: if event.unified_msg_origin not in self.session_chats: return diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 09859ab95..18a3f05cd 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -89,7 +89,9 @@ async def on_message(self, event: AstrMessageEvent): logger.error(f"主动回复失败: {e}") @filter.on_llm_request() - async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): + async def decorate_llm_req( + self, event: AstrMessageEvent, req: ProviderRequest + ) -> None: """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" await self.proc_llm_req.process_llm_request(event, req) @@ -100,7 +102,9 @@ async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): logger.error(f"ltm: {e}") @filter.on_llm_response() - async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse): + async def inject_reasoning( + self, event: AstrMessageEvent, resp: LLMResponse + ) -> None: """在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话""" umo = event.unified_msg_origin cfg = self.context.get_config(umo).get("provider_settings", {}) @@ -117,7 +121,7 @@ async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse): logger.error(f"ltm: {e}") @filter.after_message_sent() - async def after_message_sent(self, event: AstrMessageEvent): + async def after_message_sent(self, event: AstrMessageEvent) -> None: """消息发送后处理""" if self.ltm and self.ltm_enabled(event): try: diff --git a/packages/astrbot/process_llm_request.py b/packages/astrbot/process_llm_request.py index 28c41df9f..cec405926 100644 --- a/packages/astrbot/process_llm_request.py +++ b/packages/astrbot/process_llm_request.py @@ -11,7 +11,7 @@ class ProcessLLMRequest: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.ctx = context cfg = context.get_config() self.timezone = cfg.get("timezone") @@ -21,7 +21,7 @@ def __init__(self, context: star.Context): else: logger.info(f"Timezone set to: {self.timezone}") - async def _ensure_persona(self, req: ProviderRequest, cfg: dict, umo: str): + async def _ensure_persona(self, req: ProviderRequest, cfg: dict, umo: str) -> None: """确保用户人格已加载""" if not req.conversation: return @@ -77,7 +77,7 @@ async def _ensure_img_caption( req: ProviderRequest, cfg: dict, img_cap_prov_id: str, - ): + ) -> None: try: caption = await self._request_img_caption( img_cap_prov_id, @@ -115,7 +115,9 @@ async def _request_img_caption( f"Cannot get image caption because provider `{provider_id}` is not exist.", ) - async def process_llm_request(self, event: AstrMessageEvent, req: ProviderRequest): + async def process_llm_request( + self, event: AstrMessageEvent, req: ProviderRequest + ) -> None: """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" cfg: dict = self.ctx.get_config(umo=event.unified_msg_origin)[ "provider_settings" diff --git a/packages/builtin_commands/commands/admin.py b/packages/builtin_commands/commands/admin.py index 83d4b5974..a4f46b603 100644 --- a/packages/builtin_commands/commands/admin.py +++ b/packages/builtin_commands/commands/admin.py @@ -5,10 +5,10 @@ class AdminCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def op(self, event: AstrMessageEvent, admin_id: str = ""): + async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None: """授权管理员。op """ if not admin_id: event.set_result( @@ -21,7 +21,7 @@ async def op(self, event: AstrMessageEvent, admin_id: str = ""): self.context.get_config().save_config() event.set_result(MessageEventResult().message("授权成功。")) - async def deop(self, event: AstrMessageEvent, admin_id: str = ""): + async def deop(self, event: AstrMessageEvent, admin_id: str = "") -> None: """取消授权管理员。deop """ if not admin_id: event.set_result( @@ -39,7 +39,7 @@ async def deop(self, event: AstrMessageEvent, admin_id: str = ""): MessageEventResult().message("此用户 ID 不在管理员名单内。"), ) - async def wl(self, event: AstrMessageEvent, sid: str = ""): + async def wl(self, event: AstrMessageEvent, sid: str = "") -> None: """添加白名单。wl """ if not sid: event.set_result( @@ -53,7 +53,7 @@ async def wl(self, event: AstrMessageEvent, sid: str = ""): cfg.save_config() event.set_result(MessageEventResult().message("添加白名单成功。")) - async def dwl(self, event: AstrMessageEvent, sid: str = ""): + async def dwl(self, event: AstrMessageEvent, sid: str = "") -> None: """删除白名单。dwl """ if not sid: event.set_result( @@ -70,7 +70,7 @@ async def dwl(self, event: AstrMessageEvent, sid: str = ""): except ValueError: event.set_result(MessageEventResult().message("此 SID 不在白名单内。")) - async def update_dashboard(self, event: AstrMessageEvent): + async def update_dashboard(self, event: AstrMessageEvent) -> None: """更新管理面板""" await event.send(MessageChain().message("正在尝试更新管理面板...")) await download_dashboard(version=f"v{VERSION}", latest=False) diff --git a/packages/builtin_commands/commands/alter_cmd.py b/packages/builtin_commands/commands/alter_cmd.py index 50007f6c0..ba31c3326 100644 --- a/packages/builtin_commands/commands/alter_cmd.py +++ b/packages/builtin_commands/commands/alter_cmd.py @@ -11,10 +11,10 @@ class AlterCmdCommands(CommandParserMixin): - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def update_reset_permission(self, scene_key: str, perm_type: str): + async def update_reset_permission(self, scene_key: str, perm_type: str) -> None: """更新reset命令在特定场景下的权限设置""" from astrbot.api import sp @@ -26,7 +26,7 @@ async def update_reset_permission(self, scene_key: str, perm_type: str): alter_cmd_cfg["astrbot"] = plugin_cfg await sp.global_put("alter_cmd", alter_cmd_cfg) - async def alter_cmd(self, event: AstrMessageEvent): + async def alter_cmd(self, event: AstrMessageEvent) -> None: token = self.parse_commands(event.message_str) if token.len < 3: await event.send( diff --git a/packages/builtin_commands/commands/conversation.py b/packages/builtin_commands/commands/conversation.py index de3d11ac8..eb8cfdefa 100644 --- a/packages/builtin_commands/commands/conversation.py +++ b/packages/builtin_commands/commands/conversation.py @@ -16,7 +16,7 @@ class ConversationCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def _get_current_persona_id(self, session_id): @@ -33,7 +33,7 @@ async def _get_current_persona_id(self, session_id): return None return conv.persona_id - async def reset(self, message: AstrMessageEvent): + async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" umo = message.unified_msg_origin cfg = self.context.get_config(umo=message.unified_msg_origin) @@ -98,7 +98,7 @@ async def reset(self, message: AstrMessageEvent): message.set_result(MessageEventResult().message(ret)) - async def his(self, message: AstrMessageEvent, page: int = 1): + async def his(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话记录""" if not self.context.get_using_provider(message.unified_msg_origin): message.set_result( @@ -141,7 +141,7 @@ async def his(self, message: AstrMessageEvent, page: int = 1): message.set_result(MessageEventResult().message(ret).use_t2i(False)) - async def convs(self, message: AstrMessageEvent, page: int = 1): + async def convs(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话列表""" cfg = self.context.get_config(umo=message.unified_msg_origin) agent_runner_type = cfg["provider_settings"]["agent_runner_type"] @@ -216,7 +216,7 @@ async def convs(self, message: AstrMessageEvent, page: int = 1): message.set_result(MessageEventResult().message(ret).use_t2i(False)) return - async def new_conv(self, message: AstrMessageEvent): + async def new_conv(self, message: AstrMessageEvent) -> None: """创建新对话""" cfg = self.context.get_config(umo=message.unified_msg_origin) agent_runner_type = cfg["provider_settings"]["agent_runner_type"] @@ -242,7 +242,7 @@ async def new_conv(self, message: AstrMessageEvent): MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"), ) - async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""): + async def groupnew_conv(self, message: AstrMessageEvent, sid: str = "") -> None: """创建新群聊对话""" if sid: session = str( @@ -273,7 +273,7 @@ async def switch_conv( self, message: AstrMessageEvent, index: int | None = None, - ): + ) -> None: """通过 /ls 前面的序号切换对话""" if not isinstance(index, int): message.set_result( @@ -308,7 +308,7 @@ async def switch_conv( ), ) - async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""): + async def rename_conv(self, message: AstrMessageEvent, new_name: str = "") -> None: """重命名对话""" if not new_name: message.set_result(MessageEventResult().message("请输入新的对话名称。")) @@ -319,7 +319,7 @@ async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""): ) message.set_result(MessageEventResult().message("重命名对话成功。")) - async def del_conv(self, message: AstrMessageEvent): + async def del_conv(self, message: AstrMessageEvent) -> None: """删除当前对话""" cfg = self.context.get_config(umo=message.unified_msg_origin) is_unique_session = cfg["platform_settings"]["unique_session"] diff --git a/packages/builtin_commands/commands/help.py b/packages/builtin_commands/commands/help.py index 092fc59ec..bbc25b754 100644 --- a/packages/builtin_commands/commands/help.py +++ b/packages/builtin_commands/commands/help.py @@ -8,7 +8,7 @@ class HelpCommand: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context async def _query_astrbot_notice(self): @@ -62,7 +62,7 @@ def walk(items: list[dict], indent: int = 0): walk(commands) return lines - async def help(self, event: AstrMessageEvent): + async def help(self, event: AstrMessageEvent) -> None: """查看帮助""" notice = "" try: diff --git a/packages/builtin_commands/commands/llm.py b/packages/builtin_commands/commands/llm.py index 85977df40..ba9ba5c9b 100644 --- a/packages/builtin_commands/commands/llm.py +++ b/packages/builtin_commands/commands/llm.py @@ -3,10 +3,10 @@ class LLMCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def llm(self, event: AstrMessageEvent): + async def llm(self, event: AstrMessageEvent) -> None: """开启/关闭 LLM""" cfg = self.context.get_config(umo=event.unified_msg_origin) enable = cfg["provider_settings"].get("enable", True) diff --git a/packages/builtin_commands/commands/persona.py b/packages/builtin_commands/commands/persona.py index 13a57f07f..1a5ddb848 100644 --- a/packages/builtin_commands/commands/persona.py +++ b/packages/builtin_commands/commands/persona.py @@ -5,10 +5,10 @@ class PersonaCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def persona(self, message: AstrMessageEvent): + async def persona(self, message: AstrMessageEvent) -> None: l = message.message_str.split(" ") # noqa: E741 umo = message.unified_msg_origin diff --git a/packages/builtin_commands/commands/plugin.py b/packages/builtin_commands/commands/plugin.py index ab45efc11..49bee9462 100644 --- a/packages/builtin_commands/commands/plugin.py +++ b/packages/builtin_commands/commands/plugin.py @@ -8,10 +8,10 @@ class PluginCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def plugin_ls(self, event: AstrMessageEvent): + async def plugin_ls(self, event: AstrMessageEvent) -> None: """获取已经安装的插件列表。""" parts = ["已加载的插件:\n"] for plugin in self.context.get_all_stars(): @@ -30,7 +30,7 @@ async def plugin_ls(self, event: AstrMessageEvent): MessageEventResult().message(f"{plugin_list_info}").use_t2i(False), ) - async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """禁用插件""" if DEMO_MODE: event.set_result(MessageEventResult().message("演示模式下无法禁用插件。")) @@ -43,7 +43,7 @@ async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): await self.context._star_manager.turn_off_plugin(plugin_name) # type: ignore event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。")) - async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """启用插件""" if DEMO_MODE: event.set_result(MessageEventResult().message("演示模式下无法启用插件。")) @@ -56,7 +56,7 @@ async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""): await self.context._star_manager.turn_on_plugin(plugin_name) # type: ignore event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。")) - async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""): + async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None: """安装插件""" if DEMO_MODE: event.set_result(MessageEventResult().message("演示模式下无法安装插件。")) @@ -77,7 +77,7 @@ async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""): event.set_result(MessageEventResult().message(f"安装插件失败: {e}")) return - async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """获取插件帮助""" if not plugin_name: event.set_result( diff --git a/packages/builtin_commands/commands/provider.py b/packages/builtin_commands/commands/provider.py index ce8f31831..ede51a7c8 100644 --- a/packages/builtin_commands/commands/provider.py +++ b/packages/builtin_commands/commands/provider.py @@ -8,7 +8,7 @@ class ProviderCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context def _log_reachability_failure( @@ -17,7 +17,7 @@ def _log_reachability_failure( provider_capability_type: ProviderType | None, err_code: str, err_reason: str, - ): + ) -> None: """记录不可达原因到日志。""" meta = provider.meta() logger.warning( @@ -49,7 +49,7 @@ async def provider( event: AstrMessageEvent, idx: str | int | None = None, idx2: int | None = None, - ): + ) -> None: """查看或者切换 LLM Provider""" umo = event.unified_msg_origin cfg = self.context.get_config(umo).get("provider_settings", {}) @@ -226,7 +226,7 @@ async def model_ls( self, message: AstrMessageEvent, idx_or_name: int | str | None = None, - ): + ) -> None: """查看或者切换模型""" prov = self.context.get_using_provider(message.unified_msg_origin) if not prov: @@ -291,7 +291,7 @@ async def model_ls( MessageEventResult().message(f"切换模型到 {prov.get_model()}。"), ) - async def key(self, message: AstrMessageEvent, index: int | None = None): + async def key(self, message: AstrMessageEvent, index: int | None = None) -> None: prov = self.context.get_using_provider(message.unified_msg_origin) if not prov: message.set_result( diff --git a/packages/builtin_commands/commands/setunset.py b/packages/builtin_commands/commands/setunset.py index 79e5d5d1c..096698844 100644 --- a/packages/builtin_commands/commands/setunset.py +++ b/packages/builtin_commands/commands/setunset.py @@ -3,10 +3,10 @@ class SetUnsetCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def set_variable(self, event: AstrMessageEvent, key: str, value: str): + async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None: """设置会话变量""" uid = event.unified_msg_origin session_var = await sp.session_get(uid, "session_variables", {}) @@ -19,7 +19,7 @@ async def set_variable(self, event: AstrMessageEvent, key: str, value: str): ), ) - async def unset_variable(self, event: AstrMessageEvent, key: str): + async def unset_variable(self, event: AstrMessageEvent, key: str) -> None: """移除会话变量""" uid = event.unified_msg_origin session_var = await sp.session_get(uid, "session_variables", {}) diff --git a/packages/builtin_commands/commands/sid.py b/packages/builtin_commands/commands/sid.py index 4d95c5a60..e8bdbffb1 100644 --- a/packages/builtin_commands/commands/sid.py +++ b/packages/builtin_commands/commands/sid.py @@ -7,10 +7,10 @@ class SIDCommand: """会话ID命令类""" - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def sid(self, event: AstrMessageEvent): + async def sid(self, event: AstrMessageEvent) -> None: """获取消息来源信息""" sid = event.unified_msg_origin user_id = str(event.get_sender_id()) diff --git a/packages/builtin_commands/commands/t2i.py b/packages/builtin_commands/commands/t2i.py index 7766b342f..78d6b0df7 100644 --- a/packages/builtin_commands/commands/t2i.py +++ b/packages/builtin_commands/commands/t2i.py @@ -7,10 +7,10 @@ class T2ICommand: """文本转图片命令类""" - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def t2i(self, event: AstrMessageEvent): + async def t2i(self, event: AstrMessageEvent) -> None: """开关文本转图片""" config = self.context.get_config(umo=event.unified_msg_origin) if config["t2i"]: diff --git a/packages/builtin_commands/commands/tool.py b/packages/builtin_commands/commands/tool.py index 9a6c507e6..09b239b8c 100644 --- a/packages/builtin_commands/commands/tool.py +++ b/packages/builtin_commands/commands/tool.py @@ -3,28 +3,28 @@ class ToolCommands: - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def tool_ls(self, event: AstrMessageEvent): + async def tool_ls(self, event: AstrMessageEvent) -> None: """查看函数工具列表""" event.set_result( MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), ) - async def tool_on(self, event: AstrMessageEvent, tool_name: str = ""): + async def tool_on(self, event: AstrMessageEvent, tool_name: str = "") -> None: """启用一个函数工具""" event.set_result( MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), ) - async def tool_off(self, event: AstrMessageEvent, tool_name: str = ""): + async def tool_off(self, event: AstrMessageEvent, tool_name: str = "") -> None: """停用一个函数工具""" event.set_result( MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), ) - async def tool_all_off(self, event: AstrMessageEvent): + async def tool_all_off(self, event: AstrMessageEvent) -> None: """停用所有函数工具""" event.set_result( MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。"), diff --git a/packages/builtin_commands/commands/tts.py b/packages/builtin_commands/commands/tts.py index d733ba1ea..5c245ed26 100644 --- a/packages/builtin_commands/commands/tts.py +++ b/packages/builtin_commands/commands/tts.py @@ -8,10 +8,10 @@ class TTSCommand: """文本转语音命令类""" - def __init__(self, context: star.Context): + def __init__(self, context: star.Context) -> None: self.context = context - async def tts(self, event: AstrMessageEvent): + async def tts(self, event: AstrMessageEvent) -> None: """开关文本转语音(会话级别)""" umo = event.unified_msg_origin ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo) diff --git a/packages/builtin_commands/main.py b/packages/builtin_commands/main.py index 7809c4359..d19f27f15 100644 --- a/packages/builtin_commands/main.py +++ b/packages/builtin_commands/main.py @@ -37,108 +37,108 @@ def __init__(self, context: star.Context) -> None: self.sid_c = SIDCommand(self.context) @filter.command("help") - async def help(self, event: AstrMessageEvent): + async def help(self, event: AstrMessageEvent) -> None: """查看帮助""" await self.help_c.help(event) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("llm") - async def llm(self, event: AstrMessageEvent): + async def llm(self, event: AstrMessageEvent) -> None: """开启/关闭 LLM""" await self.llm_c.llm(event) @filter.command_group("tool") - def tool(self): + def tool(self) -> None: """函数工具管理""" @tool.command("ls") - async def tool_ls(self, event: AstrMessageEvent): + async def tool_ls(self, event: AstrMessageEvent) -> None: """查看函数工具列表""" await self.tool_c.tool_ls(event) @tool.command("on") - async def tool_on(self, event: AstrMessageEvent, tool_name: str): + async def tool_on(self, event: AstrMessageEvent, tool_name: str) -> None: """启用一个函数工具""" await self.tool_c.tool_on(event, tool_name) @tool.command("off") - async def tool_off(self, event: AstrMessageEvent, tool_name: str): + async def tool_off(self, event: AstrMessageEvent, tool_name: str) -> None: """停用一个函数工具""" await self.tool_c.tool_off(event, tool_name) @tool.command("off_all") - async def tool_all_off(self, event: AstrMessageEvent): + async def tool_all_off(self, event: AstrMessageEvent) -> None: """停用所有函数工具""" await self.tool_c.tool_all_off(event) @filter.command_group("plugin") - def plugin(self): + def plugin(self) -> None: """插件管理""" @plugin.command("ls") - async def plugin_ls(self, event: AstrMessageEvent): + async def plugin_ls(self, event: AstrMessageEvent) -> None: """获取已经安装的插件列表。""" await self.plugin_c.plugin_ls(event) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("off") - async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """禁用插件""" await self.plugin_c.plugin_off(event, plugin_name) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("on") - async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """启用插件""" await self.plugin_c.plugin_on(event, plugin_name) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("get") - async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""): + async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None: """安装插件""" await self.plugin_c.plugin_get(event, plugin_repo) @plugin.command("help") - async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""): + async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None: """获取插件帮助""" await self.plugin_c.plugin_help(event, plugin_name) @filter.command("t2i") - async def t2i(self, event: AstrMessageEvent): + async def t2i(self, event: AstrMessageEvent) -> None: """开关文本转图片""" await self.t2i_c.t2i(event) @filter.command("tts") - async def tts(self, event: AstrMessageEvent): + async def tts(self, event: AstrMessageEvent) -> None: """开关文本转语音(会话级别)""" await self.tts_c.tts(event) @filter.command("sid") - async def sid(self, event: AstrMessageEvent): + async def sid(self, event: AstrMessageEvent) -> None: """获取会话 ID 和 管理员 ID""" await self.sid_c.sid(event) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("op") - async def op(self, event: AstrMessageEvent, admin_id: str = ""): + async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None: """授权管理员。op """ await self.admin_c.op(event, admin_id) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("deop") - async def deop(self, event: AstrMessageEvent, admin_id: str): + async def deop(self, event: AstrMessageEvent, admin_id: str) -> None: """取消授权管理员。deop """ await self.admin_c.deop(event, admin_id) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("wl") - async def wl(self, event: AstrMessageEvent, sid: str = ""): + async def wl(self, event: AstrMessageEvent, sid: str = "") -> None: """添加白名单。wl """ await self.admin_c.wl(event, sid) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("dwl") - async def dwl(self, event: AstrMessageEvent, sid: str): + async def dwl(self, event: AstrMessageEvent, sid: str) -> None: """删除白名单。dwl """ await self.admin_c.dwl(event, sid) @@ -149,12 +149,12 @@ async def provider( event: AstrMessageEvent, idx: str | int | None = None, idx2: int | None = None, - ): + ) -> None: """查看或者切换 LLM Provider""" await self.provider_c.provider(event, idx, idx2) @filter.command("reset") - async def reset(self, message: AstrMessageEvent): + async def reset(self, message: AstrMessageEvent) -> None: """重置 LLM 会话""" await self.conversation_c.reset(message) @@ -164,74 +164,76 @@ async def model_ls( self, message: AstrMessageEvent, idx_or_name: int | str | None = None, - ): + ) -> None: """查看或者切换模型""" await self.provider_c.model_ls(message, idx_or_name) @filter.command("history") - async def his(self, message: AstrMessageEvent, page: int = 1): + async def his(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话记录""" await self.conversation_c.his(message, page) @filter.command("ls") - async def convs(self, message: AstrMessageEvent, page: int = 1): + async def convs(self, message: AstrMessageEvent, page: int = 1) -> None: """查看对话列表""" await self.conversation_c.convs(message, page) @filter.command("new") - async def new_conv(self, message: AstrMessageEvent): + async def new_conv(self, message: AstrMessageEvent) -> None: """创建新对话""" await self.conversation_c.new_conv(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("groupnew") - async def groupnew_conv(self, message: AstrMessageEvent, sid: str): + async def groupnew_conv(self, message: AstrMessageEvent, sid: str) -> None: """创建新群聊对话""" await self.conversation_c.groupnew_conv(message, sid) @filter.command("switch") - async def switch_conv(self, message: AstrMessageEvent, index: int | None = None): + async def switch_conv( + self, message: AstrMessageEvent, index: int | None = None + ) -> None: """通过 /ls 前面的序号切换对话""" await self.conversation_c.switch_conv(message, index) @filter.command("rename") - async def rename_conv(self, message: AstrMessageEvent, new_name: str): + async def rename_conv(self, message: AstrMessageEvent, new_name: str) -> None: """重命名对话""" await self.conversation_c.rename_conv(message, new_name) @filter.command("del") - async def del_conv(self, message: AstrMessageEvent): + async def del_conv(self, message: AstrMessageEvent) -> None: """删除当前对话""" await self.conversation_c.del_conv(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("key") - async def key(self, message: AstrMessageEvent, index: int | None = None): + async def key(self, message: AstrMessageEvent, index: int | None = None) -> None: """查看或者切换 Key""" await self.provider_c.key(message, index) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("persona") - async def persona(self, message: AstrMessageEvent): + async def persona(self, message: AstrMessageEvent) -> None: """查看或者切换 Persona""" await self.persona_c.persona(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("dashboard_update") - async def update_dashboard(self, event: AstrMessageEvent): + async def update_dashboard(self, event: AstrMessageEvent) -> None: """更新管理面板""" await self.admin_c.update_dashboard(event) @filter.command("set") - async def set_variable(self, event: AstrMessageEvent, key: str, value: str): + async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None: await self.setunset_c.set_variable(event, key, value) @filter.command("unset") - async def unset_variable(self, event: AstrMessageEvent, key: str): + async def unset_variable(self, event: AstrMessageEvent, key: str) -> None: await self.setunset_c.unset_variable(event, key) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("alter_cmd", alias={"alter"}) - async def alter_cmd(self, event: AstrMessageEvent): + async def alter_cmd(self, event: AstrMessageEvent) -> None: """修改命令权限""" await self.alter_cmd_c.alter_cmd(event) diff --git a/packages/python_interpreter/main.py b/packages/python_interpreter/main.py index afbef7560..f78602770 100644 --- a/packages/python_interpreter/main.py +++ b/packages/python_interpreter/main.py @@ -124,7 +124,7 @@ def __init__(self, context: star.Context) -> None: with open(PATH) as f: self.config = json.load(f) - async def initialize(self): + async def initialize(self) -> None: ok = await self.is_docker_available() if not ok: logger.info( @@ -171,7 +171,7 @@ async def get_image_name(self) -> str: return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}" return self.config["sandbox"]["image"] - def _save_config(self): + def _save_config(self) -> None: with open(PATH, "w") as f: json.dump(self.config, f) @@ -240,7 +240,9 @@ async def on_message(self, event: AstrMessageEvent): del self.user_waiting[uid] @filter.on_llm_request() - async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest): + async def on_llm_req( + self, event: AstrMessageEvent, request: ProviderRequest + ) -> None: if event.get_session_id() in self.user_file_msg_buffer: files = self.user_file_msg_buffer[event.get_session_id()] if not request.prompt: @@ -248,7 +250,7 @@ async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest): request.prompt += f"\nUser provided files: {files}" @filter.command_group("pi") - def pi(self): + def pi(self) -> None: """代码执行器配置""" @pi.command("absdir") diff --git a/packages/python_interpreter/shared/api.py b/packages/python_interpreter/shared/api.py index 287773fb0..f4cf51ff2 100644 --- a/packages/python_interpreter/shared/api.py +++ b/packages/python_interpreter/shared/api.py @@ -6,17 +6,17 @@ def _get_magic_code(): return os.getenv("MAGIC_CODE") -def send_text(text: str): +def send_text(text: str) -> None: print(f"[ASTRBOT_TEXT_OUTPUT#{_get_magic_code()}]: {text}") -def send_image(image_path: str): +def send_image(image_path: str) -> None: if not os.path.exists(image_path): raise Exception(f"Image file not found: {image_path}") print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}") -def send_file(file_path: str): +def send_file(file_path: str) -> None: if not os.path.exists(file_path): raise Exception(f"File not found: {file_path}") print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}") diff --git a/packages/reminder/main.py b/packages/reminder/main.py index 62af7ae56..34d4f2652 100644 --- a/packages/reminder/main.py +++ b/packages/reminder/main.py @@ -38,7 +38,7 @@ def __init__(self, context: star.Context) -> None: self._init_scheduler() self.scheduler.start() - def _init_scheduler(self): + def _init_scheduler(self) -> None: """Initialize the scheduler.""" for group in self.reminder_data: for reminder in self.reminder_data[group]: @@ -82,7 +82,7 @@ def check_is_outdated(self, reminder: dict): return reminder_time < datetime.datetime.now(self.timezone) return False - async def _save_data(self): + async def _save_data(self) -> None: """Save the reminder data.""" reminder_file = os.path.join(get_astrbot_data_path(), "astrbot-reminder.json") with open(reminder_file, "w", encoding="utf-8") as f: @@ -178,7 +178,7 @@ async def reminder_tool( ) @filter.command_group("reminder") - def reminder(self): + def reminder(self) -> None: """待办提醒""" async def get_upcoming_reminders(self, unified_msg_origin: str): @@ -246,7 +246,7 @@ async def reminder_rm(self, event: AstrMessageEvent, index: int): await self._save_data() yield event.plain_result("成功删除待办事项:\n" + reminder["text"]) - async def _reminder_callback(self, unified_msg_origin: str, d: dict): + async def _reminder_callback(self, unified_msg_origin: str, d: dict) -> None: """The callback function of the reminder.""" logger.info(f"Reminder Activated: {d['text']}, created by {unified_msg_origin}") await self.context.send_message( @@ -260,7 +260,7 @@ async def _reminder_callback(self, unified_msg_origin: str, d: dict): ), ) - async def terminate(self): + async def terminate(self) -> None: self.scheduler.shutdown() await self._save_data() logger.info("Reminder plugin terminated.") diff --git a/packages/session_controller/main.py b/packages/session_controller/main.py index 9ea62ea30..fdc295313 100644 --- a/packages/session_controller/main.py +++ b/packages/session_controller/main.py @@ -17,11 +17,11 @@ class Main(Star): """会话控制""" - def __init__(self, context: Context): + def __init__(self, context: Context) -> None: super().__init__(context) @filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize) - async def handle_session_control_agent(self, event: AstrMessageEvent): + async def handle_session_control_agent(self, event: AstrMessageEvent) -> None: """会话控制代理""" for session_filter in FILTERS: session_id = session_filter.filter(event) @@ -91,7 +91,7 @@ async def handle_empty_mention(self, event: AstrMessageEvent): async def empty_mention_waiter( controller: SessionController, event: AstrMessageEvent, - ): + ) -> None: event.message_obj.message.insert( 0, Comp.At(qq=event.get_self_id(), name=event.get_self_id()), diff --git a/packages/web_searcher/engines/__init__.py b/packages/web_searcher/engines/__init__.py index 699438602..2c18d9884 100644 --- a/packages/web_searcher/engines/__init__.py +++ b/packages/web_searcher/engines/__init__.py @@ -48,7 +48,7 @@ def __init__(self) -> None: def _set_selector(self, selector: str) -> str: raise NotImplementedError - def _get_next_page(self, query: str): + async def _get_next_page(self, query: str) -> str: raise NotImplementedError async def _get_html(self, url: str, data: dict | None = None) -> str: diff --git a/packages/web_searcher/main.py b/packages/web_searcher/main.py index 4745cd0c0..c51e78aa1 100644 --- a/packages/web_searcher/main.py +++ b/packages/web_searcher/main.py @@ -184,7 +184,7 @@ async def _extract_tavily(self, cfg: AstrBotConfig, payload: dict) -> list[dict] return results @filter.command("websearch") - async def websearch(self, event: AstrMessageEvent, oper: str | None = None): + async def websearch(self, event: AstrMessageEvent, oper: str | None = None) -> None: """网页搜索指令(已废弃)""" event.set_result( MessageEventResult().message( @@ -231,7 +231,7 @@ async def search_from_search_engine( return ret - async def ensure_baidu_ai_search_mcp(self, umo: str | None = None): + async def ensure_baidu_ai_search_mcp(self, umo: str | None = None) -> None: if self.baidu_initialized: return cfg = self.context.get_config(umo=umo) @@ -379,7 +379,7 @@ async def edit_web_search_tools( self, event: AstrMessageEvent, req: ProviderRequest, - ): + ) -> None: """Get the session conversation for the given event.""" cfg = self.context.get_config(umo=event.unified_msg_origin) prov_settings = cfg.get("provider_settings", {}) diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 969f0da6d..0e41d13e1 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -61,7 +61,7 @@ async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecyc @pytest.mark.asyncio -async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): +async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle) -> None: """Tests the login functionality with both wrong and correct credentials.""" test_client = app.test_client() response = await test_client.post( @@ -83,7 +83,7 @@ async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): @pytest.mark.asyncio -async def test_get_stat(app: Quart, authenticated_header: dict): +async def test_get_stat(app: Quart, authenticated_header: dict) -> None: test_client = app.test_client() response = await test_client.get("/api/stat/get") assert response.status_code == 401 @@ -94,7 +94,7 @@ async def test_get_stat(app: Quart, authenticated_header: dict): @pytest.mark.asyncio -async def test_plugins(app: Quart, authenticated_header: dict): +async def test_plugins(app: Quart, authenticated_header: dict) -> None: test_client = app.test_client() # 已经安装的插件 response = await test_client.get("/api/plugin/get", headers=authenticated_header) @@ -189,7 +189,7 @@ async def test_commands_api(app: Quart, authenticated_header: dict): @pytest.mark.asyncio -async def test_check_update(app: Quart, authenticated_header: dict): +async def test_check_update(app: Quart, authenticated_header: dict) -> None: test_client = app.test_client() response = await test_client.get("/api/update/check", headers=authenticated_header) assert response.status_code == 200 @@ -204,22 +204,22 @@ async def test_do_update( core_lifecycle_td: AstrBotCoreLifecycle, monkeypatch, tmp_path_factory, -): +) -> None: test_client = app.test_client() # Use a temporary path for the mock update to avoid side effects temp_release_dir = tmp_path_factory.mktemp("release") release_path = temp_release_dir / "astrbot" - async def mock_update(*args, **kwargs): + async def mock_update(*args, **kwargs) -> None: """Mocks the update process by creating a directory in the temp path.""" os.makedirs(release_path, exist_ok=True) - async def mock_download_dashboard(*args, **kwargs): + async def mock_download_dashboard(*args, **kwargs) -> None: """Mocks the dashboard download to prevent network access.""" return - async def mock_pip_install(*args, **kwargs): + async def mock_pip_install(*args, **kwargs) -> None: """Mocks pip install to prevent actual installation.""" return diff --git a/tests/test_main.py b/tests/test_main.py index 0453a51ee..d84cd44c9 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -12,12 +12,12 @@ class _version_info: - def __init__(self, major, minor): + def __init__(self, major, minor) -> None: self.major = major self.minor = minor -def test_check_env(monkeypatch): +def test_check_env(monkeypatch) -> None: version_info_correct = _version_info(3, 10) version_info_wrong = _version_info(3, 9) monkeypatch.setattr(sys, "version_info", version_info_correct) @@ -33,7 +33,7 @@ def test_check_env(monkeypatch): @pytest.mark.asyncio -async def test_check_dashboard_files_not_exists(monkeypatch): +async def test_check_dashboard_files_not_exists(monkeypatch) -> None: """Tests dashboard download when files do not exist.""" monkeypatch.setattr(os.path, "exists", lambda x: False) @@ -43,7 +43,7 @@ async def test_check_dashboard_files_not_exists(monkeypatch): @pytest.mark.asyncio -async def test_check_dashboard_files_exists_and_version_match(monkeypatch): +async def test_check_dashboard_files_exists_and_version_match(monkeypatch) -> None: """Tests that dashboard is not downloaded when it exists and version matches.""" # Mock os.path.exists to return True monkeypatch.setattr(os.path, "exists", lambda x: True) @@ -62,7 +62,7 @@ async def test_check_dashboard_files_exists_and_version_match(monkeypatch): @pytest.mark.asyncio -async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch): +async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch) -> None: """Tests that a warning is logged when dashboard version mismatches.""" monkeypatch.setattr(os.path, "exists", lambda x: True) @@ -77,7 +77,7 @@ async def test_check_dashboard_files_exists_but_version_mismatch(monkeypatch): @pytest.mark.asyncio -async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch): +async def test_check_dashboard_files_with_webui_dir_arg(monkeypatch) -> None: """Tests that providing a valid webui_dir skips all checks.""" valid_dir = "/tmp/my-custom-webui" monkeypatch.setattr(os.path, "exists", lambda path: path == valid_dir) diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 1e4cd866a..6e3cbc2c9 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -59,21 +59,21 @@ def plugin_manager_pm(tmp_path): return manager -def test_plugin_manager_initialization(plugin_manager_pm: PluginManager): +def test_plugin_manager_initialization(plugin_manager_pm: PluginManager) -> None: assert plugin_manager_pm is not None assert plugin_manager_pm.context is not None assert plugin_manager_pm.config is not None @pytest.mark.asyncio -async def test_plugin_manager_reload(plugin_manager_pm: PluginManager): +async def test_plugin_manager_reload(plugin_manager_pm: PluginManager) -> None: success, err_message = await plugin_manager_pm.reload() assert success is True assert err_message is None @pytest.mark.asyncio -async def test_install_plugin(plugin_manager_pm: PluginManager): +async def test_install_plugin(plugin_manager_pm: PluginManager) -> None: """Tests successful plugin installation in an isolated environment.""" test_repo = "https://github.com/Soulter/astrbot_plugin_essential" plugin_info = await plugin_manager_pm.install_plugin(test_repo) @@ -90,7 +90,7 @@ async def test_install_plugin(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager): +async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager) -> None: """Tests that installing a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.install_plugin( @@ -99,7 +99,7 @@ async def test_install_nonexistent_plugin(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_update_plugin(plugin_manager_pm: PluginManager): +async def test_update_plugin(plugin_manager_pm: PluginManager) -> None: """Tests updating an existing plugin in an isolated environment.""" # First, install the plugin test_repo = "https://github.com/Soulter/astrbot_plugin_essential" @@ -110,14 +110,14 @@ async def test_update_plugin(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager): +async def test_update_nonexistent_plugin(plugin_manager_pm: PluginManager) -> None: """Tests that updating a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.update_plugin("non_existent_plugin") @pytest.mark.asyncio -async def test_uninstall_plugin(plugin_manager_pm: PluginManager): +async def test_uninstall_plugin(plugin_manager_pm: PluginManager) -> None: """Tests successful plugin uninstallation in an isolated environment.""" # First, install the plugin test_repo = "https://github.com/Soulter/astrbot_plugin_essential" @@ -144,7 +144,7 @@ async def test_uninstall_plugin(plugin_manager_pm: PluginManager): @pytest.mark.asyncio -async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager): +async def test_uninstall_nonexistent_plugin(plugin_manager_pm: PluginManager) -> None: """Tests that uninstalling a non-existent plugin raises an exception.""" with pytest.raises(Exception): await plugin_manager_pm.uninstall_plugin("non_existent_plugin") diff --git a/tests/test_security_fixes.py b/tests/test_security_fixes.py index d4e455541..8e5cf7f6d 100644 --- a/tests/test_security_fixes.py +++ b/tests/test_security_fixes.py @@ -10,7 +10,7 @@ import pytest -def test_wecom_crypto_uses_secrets(): +def test_wecom_crypto_uses_secrets() -> None: """Test that WXBizJsonMsgCrypt uses secrets module instead of random.""" from astrbot.core.platform.sources.wecom_ai_bot.WXBizJsonMsgCrypt import Prpcrypt @@ -33,7 +33,7 @@ def test_wecom_crypto_uses_secrets(): assert 1000000000000000 <= int(decoded) <= 9999999999999999 -def test_wecomai_utils_uses_secrets(): +def test_wecomai_utils_uses_secrets() -> None: """Test that wecomai_utils uses secrets module for random string generation.""" from astrbot.core.platform.sources.wecom_ai_bot.wecomai_utils import ( generate_random_string, @@ -53,7 +53,7 @@ def test_wecomai_utils_uses_secrets(): assert len(set(random_strings)) >= 19 # Allow for 1 collision in 20 (very unlikely) -def test_azure_tts_signature_uses_secrets(): +def test_azure_tts_signature_uses_secrets() -> None: """Test that Azure TTS signature generation uses secrets module.""" import asyncio @@ -66,7 +66,7 @@ def test_azure_tts_signature_uses_secrets(): "OTTS_AUTH_TIME": "https://example.com/api/time", } - async def test_nonce_generation(): + async def test_nonce_generation() -> None: async with OTTSProvider(config) as provider: # Mock time sync to avoid actual API calls provider.time_offset = 0 @@ -94,7 +94,7 @@ async def test_nonce_generation(): asyncio.run(test_nonce_generation()) -def test_ssl_context_fallback_explicit(): +def test_ssl_context_fallback_explicit() -> None: """Test that SSL context fallback is properly configured.""" # This test verifies the SSL context configuration # We can't easily test the full io.py functions without network calls, @@ -113,7 +113,7 @@ def test_ssl_context_fallback_explicit(): # The actual code only uses this when certificate validation fails -def test_io_module_has_ssl_imports(): +def test_io_module_has_ssl_imports() -> None: """Verify that io.py properly imports ssl module.""" from astrbot.core.utils import io @@ -124,7 +124,7 @@ def test_io_module_has_ssl_imports(): assert hasattr(io.ssl, "CERT_NONE") -def test_secrets_module_randomness_quality(): +def test_secrets_module_randomness_quality() -> None: """Test that secrets module provides high-quality randomness.""" import secrets