diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py index d8a1eb22e..3d67cb750 100755 --- a/astrbot/core/star/filter/command.py +++ b/astrbot/core/star/filter/command.py @@ -1,5 +1,7 @@ import re import inspect +import types +import typing from typing import List, Any, Type, Dict from . import HandlerFilter from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -14,6 +16,18 @@ class GreedyStr(str): pass +def unwrap_optional(annotation) -> tuple: + """去掉 Optional[T] / Union[T, None] / T|None,返回 T""" + args = typing.get_args(annotation) + non_none_args = [a for a in args if a is not type(None)] + if len(non_none_args) == 1: + return (non_none_args[0],) + elif len(non_none_args) > 1: + return tuple(non_none_args) + else: + return () + + # 标准指令受到 wake_prefix 的制约。 class CommandFilter(HandlerFilter): """标准指令过滤器""" @@ -40,6 +54,8 @@ def print_types(self): for k, v in self.handler_params.items(): if isinstance(v, type): result += f"{k}({v.__name__})," + elif isinstance(v, types.UnionType) or typing.get_origin(v) is typing.Union: + result += f"{k}({v})," else: result += f"{k}({type(v).__name__})={v}," result = result.rstrip(",") @@ -95,7 +111,8 @@ def validate_and_convert_params( # 没有 GreedyStr 的情况 if i >= len(params): if ( - isinstance(param_type_or_default_val, Type) + isinstance(param_type_or_default_val, (Type, types.UnionType)) + or typing.get_origin(param_type_or_default_val) is typing.Union or param_type_or_default_val is inspect.Parameter.empty ): # 是类型 @@ -132,7 +149,20 @@ def validate_and_convert_params( elif isinstance(param_type_or_default_val, float): result[param_name] = float(params[i]) else: - result[param_name] = param_type_or_default_val(params[i]) + origin = typing.get_origin(param_type_or_default_val) + if origin in (typing.Union, types.UnionType): + # 注解是联合类型 + # NOTE: 目前没有处理联合类型嵌套相关的注解写法 + nn_types = unwrap_optional(param_type_or_default_val) + if len(nn_types) == 1: + # 只有一个非 NoneType 类型 + result[param_name] = nn_types[0](params[i]) + else: + # 没有或者有多个非 NoneType 类型,这里我们暂时直接赋值为原始值。 + # NOTE: 目前还没有做类型校验 + result[param_name] = params[i] + else: + result[param_name] = param_type_or_default_val(params[i]) except ValueError: raise ValueError( f"参数 {param_name} 类型错误。完整参数: {self.print_types()}"