Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions astrbot/core/star/filter/command.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re
import inspect
import types
import typing
from typing import List, Any, Type, Dict
from . import HandlerFilter
from astrbot.core.platform.astr_message_event import AstrMessageEvent
Expand All @@ -14,6 +16,18 @@ class GreedyStr(str):
pass


def unwrap_optional(annotation) -> tuple:
"""去掉 Optional[T] / Union[T, None] / T|None,返回 T"""
args = typing.get_args(annotation)
non_none_args = [a for a in args if a is not type(None)]
if len(non_none_args) == 1:
return (non_none_args[0],)
elif len(non_none_args) > 1:
return tuple(non_none_args)
else:
return ()


# 标准指令受到 wake_prefix 的制约。
class CommandFilter(HandlerFilter):
"""标准指令过滤器"""
Expand All @@ -40,6 +54,8 @@ def print_types(self):
for k, v in self.handler_params.items():
if isinstance(v, type):
result += f"{k}({v.__name__}),"
elif isinstance(v, types.UnionType) or typing.get_origin(v) is typing.Union:
result += f"{k}({v}),"
else:
result += f"{k}({type(v).__name__})={v},"
result = result.rstrip(",")
Expand Down Expand Up @@ -95,7 +111,8 @@ def validate_and_convert_params(
# 没有 GreedyStr 的情况
if i >= len(params):
if (
isinstance(param_type_or_default_val, Type)
isinstance(param_type_or_default_val, (Type, types.UnionType))
or typing.get_origin(param_type_or_default_val) is typing.Union
or param_type_or_default_val is inspect.Parameter.empty
):
# 是类型
Expand Down Expand Up @@ -132,7 +149,20 @@ def validate_and_convert_params(
elif isinstance(param_type_or_default_val, float):
result[param_name] = float(params[i])
else:
result[param_name] = param_type_or_default_val(params[i])
origin = typing.get_origin(param_type_or_default_val)
if origin in (typing.Union, types.UnionType):
# 注解是联合类型
# NOTE: 目前没有处理联合类型嵌套相关的注解写法
nn_types = unwrap_optional(param_type_or_default_val)
if len(nn_types) == 1:
# 只有一个非 NoneType 类型
result[param_name] = nn_types[0](params[i])
else:
# 没有或者有多个非 NoneType 类型,这里我们暂时直接赋值为原始值。
# NOTE: 目前还没有做类型校验
result[param_name] = params[i]
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:
raise ValueError(
f"参数 {param_name} 类型错误。完整参数: {self.print_types()}"
Expand Down