Skip to content

Commit 09d1f96

Browse files
SoulterDt8333
andcommitted
fix: 修复 /alter_cmd 指令无法控制指令组、子指令组和子指令组下子指令的问题 (#2873)
* fix: revert changes in command_group.py at 782c036 to fix command group permission check * fix: 不传递 GroupCommand handler * perf: alter_cmd 指令支持对子指令、指令组进行配置 * chore: remove test commands and subcommands from test_group * chore: add cache for complete command names list in CommandFilter and CommandGroupFilter --------- Co-authored-by: Dt8333 <25431943+Dt8333@users.noreply.github.com> Co-authored-by: Soulter <905617992@qq.com>
1 parent 26aa18d commit 09d1f96

File tree

5 files changed

+77
-41
lines changed

5 files changed

+77
-41
lines changed

astrbot/core/pipeline/waking_check/stage.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
66
from astrbot.core.platform.astr_message_event import AstrMessageEvent
77
from astrbot.core.star.filter.permission import PermissionTypeFilter
8+
from astrbot.core.star.filter.command_group import CommandGroupFilter
89
from astrbot.core.star.session_plugin_manager import SessionPluginManager
910
from astrbot.core.star.star import star_map
1011
from astrbot.core.star.star_handler import EventType, star_handlers_registry
@@ -170,11 +171,15 @@ async def process(
170171
is_wake = True
171172
event.is_wake = True
172173

173-
activated_handlers.append(handler)
174-
if "parsed_params" in event.get_extra():
175-
handlers_parsed_params[handler.handler_full_name] = event.get_extra(
176-
"parsed_params"
177-
)
174+
is_group_cmd_handler = any(
175+
isinstance(f, CommandGroupFilter) for f in handler.event_filters
176+
)
177+
if not is_group_cmd_handler:
178+
activated_handlers.append(handler)
179+
if "parsed_params" in event.get_extra(default={}):
180+
handlers_parsed_params[handler.handler_full_name] = (
181+
event.get_extra("parsed_params")
182+
)
178183

179184
event._extras.pop("parsed_params", None)
180185

astrbot/core/platform/astr_message_event.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import hashlib
55
import uuid
66

7-
from typing import List, Union, Optional, AsyncGenerator
7+
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
88

99
from astrbot import logger
1010
from astrbot.core.db.po import Conversation
@@ -26,6 +26,8 @@
2626
from .platform_metadata import PlatformMetadata
2727
from .message_session import MessageSession, MessageSesion # noqa
2828

29+
_VT = TypeVar("_VT")
30+
2931

3032
class AstrMessageEvent(abc.ABC):
3133
def __init__(
@@ -49,15 +51,15 @@ def __init__(
4951
"""是否唤醒(是否通过 WakingStage)"""
5052
self.is_at_or_wake_command = False
5153
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
52-
self._extras = {}
54+
self._extras: dict[str, Any] = {}
5355
self.session = MessageSesion(
5456
platform_name=platform_meta.id,
5557
message_type=message_obj.type,
5658
session_id=session_id,
5759
)
5860
self.unified_msg_origin = str(self.session)
5961
"""统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
60-
self._result: MessageEventResult = None
62+
self._result: MessageEventResult | None = None
6163
"""消息事件的结果"""
6264

6365
self._has_send_oper = False
@@ -173,13 +175,15 @@ def set_extra(self, key, value):
173175
"""
174176
self._extras[key] = value
175177

176-
def get_extra(self, key=None):
178+
def get_extra(
179+
self, key: str | None = None, default: _VT = None
180+
) -> dict[str, Any] | _VT:
177181
"""
178182
获取额外的信息。
179183
"""
180184
if key is None:
181185
return self._extras
182-
return self._extras.get(key, None)
186+
return self._extras.get(key, default)
183187

184188
def clear_extra(self):
185189
"""

astrbot/core/star/filter/command.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ def __init__(
3232
self.init_handler_md(handler_md)
3333
self.custom_filter_list: List[CustomFilter] = []
3434

35+
# Cache for complete command names list
36+
self._cmpl_cmd_names: list | None = None
37+
3538
def print_types(self):
3639
result = ""
3740
for k, v in self.handler_params.items():
@@ -136,6 +139,28 @@ def validate_and_convert_params(
136139
)
137140
return result
138141

142+
def get_complete_command_names(self):
143+
if self._cmpl_cmd_names is not None:
144+
return self._cmpl_cmd_names
145+
self._cmpl_cmd_names = [
146+
f"{parent} {cmd}" if parent else cmd
147+
for cmd in [self.command_name] + list(self.alias)
148+
for parent in self.parent_command_names or [""]
149+
]
150+
return self._cmpl_cmd_names
151+
152+
def startswith(self, message_str: str) -> bool:
153+
for full_cmd in self.get_complete_command_names():
154+
if message_str.startswith(f"{full_cmd} ") or message_str == full_cmd:
155+
return True
156+
return False
157+
158+
def equals(self, message_str: str) -> bool:
159+
for full_cmd in self.get_complete_command_names():
160+
if message_str == full_cmd:
161+
return True
162+
return False
163+
139164
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
140165
if not event.is_at_or_wake_command:
141166
return False
@@ -145,19 +170,7 @@ def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
145170

146171
# 检查是否以指令开头
147172
message_str = re.sub(r"\s+", " ", event.get_message_str().strip())
148-
candidates = [self.command_name] + list(self.alias)
149-
ok = False
150-
for candidate in candidates:
151-
for parent_command_name in self.parent_command_names:
152-
if parent_command_name:
153-
_full = f"{parent_command_name} {candidate}"
154-
else:
155-
_full = candidate
156-
if message_str.startswith(f"{_full} ") or message_str == _full:
157-
message_str = message_str[len(_full) :].strip()
158-
ok = True
159-
break
160-
if not ok:
173+
if not self.startswith(message_str):
161174
return False
162175

163176
# 分割为列表

astrbot/core/star/filter/command_group.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ def __init__(
2222
self.custom_filter_list: List[CustomFilter] = []
2323
self.parent_group = parent_group
2424

25+
# Cache for complete command names list
26+
self._cmpl_cmd_names: list | None = None
27+
2528
def add_sub_command_filter(
2629
self, sub_command_filter: Union[CommandFilter, CommandGroupFilter]
2730
):
@@ -34,6 +37,9 @@ def get_complete_command_names(self) -> List[str]:
3437
"""遍历父节点获取完整的指令名。
3538
3639
新版本 v3.4.29 采用预编译指令,不再从指令组递归遍历子指令,因此这个方法是返回包括别名在内的整个指令名列表。"""
40+
if self._cmpl_cmd_names is not None:
41+
return self._cmpl_cmd_names
42+
3743
parent_cmd_names = (
3844
self.parent_group.get_complete_command_names() if self.parent_group else []
3945
)
@@ -47,6 +53,7 @@ def get_complete_command_names(self) -> List[str]:
4753
for parent_cmd_name in parent_cmd_names:
4854
for candidate in candidates:
4955
result.append(parent_cmd_name + " " + candidate)
56+
self._cmpl_cmd_names = result
5057
return result
5158

5259
# 以树的形式打印出来
@@ -97,6 +104,12 @@ def custom_filter_ok(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
97104
return False
98105
return True
99106

107+
def startswith(self, message_str: str) -> bool:
108+
return message_str.startswith(tuple(self.get_complete_command_names()))
109+
110+
def equals(self, message_str: str) -> bool:
111+
return message_str in self.get_complete_command_names()
112+
100113
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
101114
if not event.is_at_or_wake_command:
102115
return False
@@ -105,8 +118,7 @@ def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
105118
if not self.custom_filter_ok(event, cfg):
106119
return False
107120

108-
complete_command_names = self.get_complete_command_names()
109-
if event.message_str.strip() in complete_command_names:
121+
if self.equals(event.message_str.strip()):
110122
tree = (
111123
self.group_name
112124
+ "\n"
@@ -116,6 +128,4 @@ def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
116128
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
117129
)
118130

119-
# complete_command_names = [name + " " for name in complete_command_names]
120-
# return event.message_str.startswith(tuple(complete_command_names))
121-
return False
131+
return self.startswith(event.message_str)

packages/astrbot/main.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,22 +1348,22 @@ async def after_llm_req(self, event: AstrMessageEvent):
13481348
logger.error(f"ltm: {e}")
13491349

13501350
@filter.permission_type(filter.PermissionType.ADMIN)
1351-
@filter.command("alter_cmd")
1351+
@filter.command("alter_cmd", alias={"alter"})
13521352
async def alter_cmd(self, event: AstrMessageEvent):
1353-
# token = event.message_str.split(" ")
13541353
token = self.parse_commands(event.message_str)
1355-
if token.len < 2:
1354+
if token.len < 3:
13561355
yield event.plain_result(
1357-
"可设置所有其他指令是否需要管理员权限。\n格式: /alter_cmd <cmd_name> <admin/member>\n 例如: /alter_cmd provider admin 将 provider 设置为管理员指令\n /alter_cmd reset config 打开reset权限配置"
1356+
"该指令用于设置指令或指令组的权限。\n"
1357+
"格式: /alter_cmd <cmd_name> <admin/member>\n"
1358+
"例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n"
1359+
"例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n"
1360+
"/alter_cmd reset config 打开 reset 权限配置"
13581361
)
13591362
return
13601363

1361-
cmd_name = token.get(1)
1362-
cmd_type = token.get(2)
1364+
cmd_name = " ".join(token.tokens[1:-1])
1365+
cmd_type = token.get(-1)
13631366

1364-
# ============================
1365-
# 对reset权限进行特殊处理
1366-
# ============================
13671367
if cmd_name == "reset" and cmd_type == "config":
13681368
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
13691369
plugin_ = alter_cmd_cfg.get("astrbot", {})
@@ -1413,16 +1413,18 @@ async def alter_cmd(self, event: AstrMessageEvent):
14131413

14141414
# 查找指令
14151415
found_command = None
1416+
cmd_group = False
14161417
for handler in star_handlers_registry:
14171418
assert isinstance(handler, StarHandlerMetadata)
14181419
for filter_ in handler.event_filters:
14191420
if isinstance(filter_, CommandFilter):
1420-
if filter_.command_name == cmd_name:
1421+
if filter_.equals(cmd_name):
14211422
found_command = handler
14221423
break
14231424
elif isinstance(filter_, CommandGroupFilter):
1424-
if cmd_name == filter_.group_name:
1425+
if filter_.equals(cmd_name):
14251426
found_command = handler
1427+
cmd_group = True
14261428
break
14271429

14281430
if not found_command:
@@ -1459,8 +1461,10 @@ async def alter_cmd(self, event: AstrMessageEvent):
14591461
else filter.PermissionType.MEMBER
14601462
),
14611463
)
1462-
1463-
yield event.plain_result(f"已将 {cmd_name} 设置为 {cmd_type} 指令")
1464+
cmd_group_str = "指令组" if cmd_group else "指令"
1465+
yield event.plain_result(
1466+
f"已将「{cmd_name}{cmd_group_str} 的权限级别调整为 {cmd_type}。"
1467+
)
14641468

14651469
async def update_reset_permission(self, scene_key: str, perm_type: str):
14661470
"""更新reset命令在特定场景下的权限设置

0 commit comments

Comments
 (0)