diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 01618767c..fdbd9daec 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -75,7 +75,7 @@ def set_key(self, key: str): raise NotImplementedError() @abc.abstractmethod - def get_models(self) -> List[str]: + async def get_models(self) -> List[str]: """获得支持的模型列表""" raise NotImplementedError() diff --git a/astrbot/core/utils/command_parser.py b/astrbot/core/utils/command_parser.py index 7829140f5..557793f0a 100644 --- a/astrbot/core/utils/command_parser.py +++ b/astrbot/core/utils/command_parser.py @@ -6,7 +6,7 @@ def __init__(self) -> None: self.tokens = [] self.len = 0 - def get(self, idx: int): + def get(self, idx: int) -> str | None: if idx >= self.len: return None return self.tokens[idx].strip() diff --git a/packages/astrbot/commands/__init__.py b/packages/astrbot/commands/__init__.py new file mode 100644 index 000000000..995022a14 --- /dev/null +++ b/packages/astrbot/commands/__init__.py @@ -0,0 +1,31 @@ +# Commands module + +from .help import HelpCommand +from .llm import LLMCommands +from .tool import ToolCommands +from .plugin import PluginCommands +from .admin import AdminCommands +from .conversation import ConversationCommands +from .provider import ProviderCommands +from .persona import PersonaCommands +from .alter_cmd import AlterCmdCommands +from .setunset import SetUnsetCommands +from .t2i import T2ICommand +from .tts import TTSCommand +from .sid import SIDCommand + +__all__ = [ + "HelpCommand", + "LLMCommands", + "ToolCommands", + "PluginCommands", + "AdminCommands", + "ConversationCommands", + "ProviderCommands", + "PersonaCommands", + "AlterCmdCommands", + "SetUnsetCommands", + "T2ICommand", + "TTSCommand", + "SIDCommand", +] diff --git a/packages/astrbot/commands/admin.py b/packages/astrbot/commands/admin.py new file mode 100644 index 000000000..5da546a3b --- /dev/null +++ b/packages/astrbot/commands/admin.py @@ -0,0 +1,76 @@ +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent, MessageEventResult, MessageChain +from astrbot.core.utils.io import download_dashboard +from astrbot.core.config.default import VERSION + + +class AdminCommands: + def __init__(self, context: star.Context): + self.context = context + + async def op(self, event: AstrMessageEvent, admin_id: str = ""): + """授权管理员。op """ + if admin_id == "": + event.set_result( + MessageEventResult().message( + "使用方法: /op 授权管理员;/deop 取消管理员。可通过 /sid 获取 ID。" + ) + ) + return + self.context.get_config()["admins_id"].append(str(admin_id)) + self.context.get_config().save_config() + event.set_result(MessageEventResult().message("授权成功。")) + + async def deop(self, event: AstrMessageEvent, admin_id: str = ""): + """取消授权管理员。deop """ + if admin_id == "": + event.set_result( + MessageEventResult().message( + "使用方法: /deop 取消管理员。可通过 /sid 获取 ID。" + ) + ) + return + try: + self.context.get_config()["admins_id"].remove(str(admin_id)) + self.context.get_config().save_config() + event.set_result(MessageEventResult().message("取消授权成功。")) + except ValueError: + event.set_result( + MessageEventResult().message("此用户 ID 不在管理员名单内。") + ) + + async def wl(self, event: AstrMessageEvent, sid: str = ""): + """添加白名单。wl """ + if sid == "": + event.set_result( + MessageEventResult().message( + "使用方法: /wl 添加白名单;/dwl 删除白名单。可通过 /sid 获取 ID。" + ) + ) + return + cfg = self.context.get_config(umo=event.unified_msg_origin) + cfg["platform_settings"]["id_whitelist"].append(str(sid)) + cfg.save_config() + event.set_result(MessageEventResult().message("添加白名单成功。")) + + async def dwl(self, event: AstrMessageEvent, sid: str = ""): + """删除白名单。dwl """ + if sid == "": + event.set_result( + MessageEventResult().message( + "使用方法: /dwl 删除白名单。可通过 /sid 获取 ID。" + ) + ) + return + try: + cfg = self.context.get_config(umo=event.unified_msg_origin) + cfg["platform_settings"]["id_whitelist"].remove(str(sid)) + cfg.save_config() + event.set_result(MessageEventResult().message("删除白名单成功。")) + except ValueError: + event.set_result(MessageEventResult().message("此 SID 不在白名单内。")) + + async def update_dashboard(self, event: AstrMessageEvent): + await event.send(MessageChain().message("正在尝试更新管理面板...")) + await download_dashboard(version=f"v{VERSION}", latest=False) + await event.send(MessageChain().message("管理面板更新完成。")) diff --git a/packages/astrbot/commands/alter_cmd.py b/packages/astrbot/commands/alter_cmd.py new file mode 100644 index 000000000..bad5072cf --- /dev/null +++ b/packages/astrbot/commands/alter_cmd.py @@ -0,0 +1,188 @@ +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.core.utils.command_parser import CommandParserMixin +from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata +from astrbot.core.star.star import star_map +from astrbot.core.star.filter.command import CommandFilter +from astrbot.core.star.filter.command_group import CommandGroupFilter +from astrbot.core.star.filter.permission import PermissionTypeFilter +from enum import Enum + + +class RstScene(Enum): + GROUP_UNIQUE_ON = ("group_unique_on", "群聊+会话隔离开启") + GROUP_UNIQUE_OFF = ("group_unique_off", "群聊+会话隔离关闭") + PRIVATE = ("private", "私聊") + + @property + def key(self) -> str: + return self.value[0] + + @property + def name(self) -> str: + return self.value[1] + + @classmethod + def from_index(cls, index: int) -> "RstScene": + mapping = {1: cls.GROUP_UNIQUE_ON, 2: cls.GROUP_UNIQUE_OFF, 3: cls.PRIVATE} + return mapping[index] + + +class AlterCmdCommands(CommandParserMixin): + def __init__(self, context: star.Context): + self.context = context + + async def update_reset_permission(self, scene_key: str, perm_type: str): + """更新reset命令在特定场景下的权限设置""" + from astrbot.api import sp + + alter_cmd_cfg = await sp.global_get("alter_cmd", {}) + plugin_cfg = alter_cmd_cfg.get("astrbot", {}) + reset_cfg = plugin_cfg.get("reset", {}) + reset_cfg[scene_key] = perm_type + plugin_cfg["reset"] = reset_cfg + alter_cmd_cfg["astrbot"] = plugin_cfg + await sp.global_put("alter_cmd", alter_cmd_cfg) + + async def alter_cmd(self, event: AstrMessageEvent): + token = self.parse_commands(event.message_str) + if token.len < 3: + await event.send( + MessageChain().message( + "该指令用于设置指令或指令组的权限。\n" + "格式: /alter_cmd \n" + "例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n" + "例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n" + "/alter_cmd reset config 打开 reset 权限配置" + ) + ) + return + + cmd_name = " ".join(token.tokens[1:-1]) + cmd_type = token.get(-1) + + if cmd_name == "reset" and cmd_type == "config": + from astrbot.api import sp + + alter_cmd_cfg = await sp.global_get("alter_cmd", {}) + plugin_ = alter_cmd_cfg.get("astrbot", {}) + reset_cfg = plugin_.get("reset", {}) + + group_unique_on = reset_cfg.get("group_unique_on", "admin") + group_unique_off = reset_cfg.get("group_unique_off", "admin") + private = reset_cfg.get("private", "member") + + config_menu = f"""reset命令权限细粒度配置 + 当前配置: + 1. 群聊+会话隔离开: {group_unique_on} + 2. 群聊+会话隔离关: {group_unique_off} + 3. 私聊: {private} + 修改指令格式: + /alter_cmd reset scene <场景编号> + 例如: /alter_cmd reset scene 2 member""" + await event.send(MessageChain().message(config_menu)) + return + + if cmd_name == "reset" and cmd_type == "scene" and token.len >= 4: + scene_num = token.get(3) + perm_type = token.get(4) + + if scene_num is None or perm_type is None: + await event.send(MessageChain().message("场景编号和权限类型不能为空")) + return + + if not scene_num.isdigit() or int(scene_num) < 1 or int(scene_num) > 3: + await event.send( + MessageChain().message("场景编号必须是 1-3 之间的数字") + ) + return + + if perm_type not in ["admin", "member"]: + await event.send( + MessageChain().message("权限类型错误,只能是 admin 或 member") + ) + return + + scene_num = int(scene_num) + scene = RstScene.from_index(scene_num) + scene_key = scene.key + + await self.update_reset_permission(scene_key, perm_type) + + await event.send( + MessageChain().message( + f"已将 reset 命令在{scene.name}场景下的权限设为{perm_type}" + ) + ) + return + + if cmd_type not in ["admin", "member"]: + await event.send( + MessageChain().message("指令类型错误,可选类型有 admin, member") + ) + return + + # 查找指令 + found_command = None + cmd_group = False + for handler in star_handlers_registry: + assert isinstance(handler, StarHandlerMetadata) + for filter_ in handler.event_filters: + if isinstance(filter_, CommandFilter): + if filter_.equals(cmd_name): + found_command = handler + break + elif isinstance(filter_, CommandGroupFilter): + if filter_.equals(cmd_name): + found_command = handler + cmd_group = True + break + + if not found_command: + await event.send(MessageChain().message("未找到该指令")) + return + + found_plugin = star_map[found_command.handler_module_path] + + from astrbot.api import sp + + alter_cmd_cfg = await sp.global_get("alter_cmd", {}) + plugin_ = alter_cmd_cfg.get(found_plugin.name, {}) + cfg = plugin_.get(found_command.handler_name, {}) + cfg["permission"] = cmd_type + plugin_[found_command.handler_name] = cfg + alter_cmd_cfg[found_plugin.name] = plugin_ + + await sp.global_put("alter_cmd", alter_cmd_cfg) + + # 注入权限过滤器 + found_permission_filter = False + for filter_ in found_command.event_filters: + if isinstance(filter_, PermissionTypeFilter): + if cmd_type == "admin": + import astrbot.api.event.filter as filter + + filter_.permission_type = filter.PermissionType.ADMIN + else: + import astrbot.api.event.filter as filter + + filter_.permission_type = filter.PermissionType.MEMBER + found_permission_filter = True + break + if not found_permission_filter: + import astrbot.api.event.filter as filter + + found_command.event_filters.insert( + 0, + PermissionTypeFilter( + filter.PermissionType.ADMIN + if cmd_type == "admin" + else filter.PermissionType.MEMBER + ), + ) + cmd_group_str = "指令组" if cmd_group else "指令" + await event.send( + MessageChain().message( + f"已将「{cmd_name}」{cmd_group_str} 的权限级别调整为 {cmd_type}。" + ) + ) diff --git a/packages/astrbot/commands/conversation.py b/packages/astrbot/commands/conversation.py new file mode 100644 index 000000000..2d5317644 --- /dev/null +++ b/packages/astrbot/commands/conversation.py @@ -0,0 +1,440 @@ +import datetime +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.platform.message_type import MessageType +from astrbot.core.provider.sources.dify_source import ProviderDify +from astrbot.core.provider.sources.coze_source import ProviderCoze +from astrbot.api import sp, logger +from typing import Union +from enum import Enum + + +class RstScene(Enum): + GROUP_UNIQUE_ON = ("group_unique_on", "群聊+会话隔离开启") + GROUP_UNIQUE_OFF = ("group_unique_off", "群聊+会话隔离关闭") + PRIVATE = ("private", "私聊") + + @property + def key(self) -> str: + return self.value[0] + + @property + def name(self) -> str: + return self.value[1] + + @classmethod + def from_index(cls, index: int) -> "RstScene": + mapping = {1: cls.GROUP_UNIQUE_ON, 2: cls.GROUP_UNIQUE_OFF, 3: cls.PRIVATE} + return mapping[index] + + @classmethod + def get_scene(cls, is_group: bool, is_unique_session: bool) -> "RstScene": + if is_group: + return cls.GROUP_UNIQUE_ON if is_unique_session else cls.GROUP_UNIQUE_OFF + return cls.PRIVATE + + +class ConversationCommands: + def __init__(self, context: star.Context, ltm=None): + self.context = context + self.ltm = ltm + + def ltm_enabled(self, event: AstrMessageEvent): + if not self.ltm: + return False + ltmse = self.context.get_config(umo=event.unified_msg_origin)[ + "provider_ltm_settings" + ] + return ltmse["group_icl_enable"] or ltmse["active_reply"]["enable"] + + async def reset(self, message: AstrMessageEvent): + """重置 LLM 会话""" + + is_unique_session = self.context.get_config()["platform_settings"][ + "unique_session" + ] + is_group = bool(message.get_group_id()) + + scene = RstScene.get_scene(is_group, is_unique_session) + + alter_cmd_cfg = await sp.get_async("global", "global", "alter_cmd", {}) + plugin_config = alter_cmd_cfg.get("astrbot", {}) + reset_cfg = plugin_config.get("reset", {}) + + required_perm = reset_cfg.get( + scene.key, "admin" if is_group and not is_unique_session else "member" + ) + + if required_perm == "admin" and message.role != "admin": + message.set_result( + MessageEventResult().message( + f"在{scene.name}场景下,reset命令需要管理员权限," + f"您 (ID {message.get_sender_id()}) 不是管理员,无法执行此操作。" + ) + ) + return + + if not self.context.get_using_provider(message.unified_msg_origin): + message.set_result( + MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") + ) + return + + provider = self.context.get_using_provider(message.unified_msg_origin) + if provider and provider.meta().type in ["dify", "coze"]: + assert isinstance(provider, (ProviderDify, ProviderCoze)), ( + "provider type is not dify or coze" + ) + await provider.forget(message.unified_msg_origin) + message.set_result( + MessageEventResult().message( + "已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。" + ) + ) + return + + cid = await self.context.conversation_manager.get_curr_conversation_id( + message.unified_msg_origin + ) + + if not cid: + message.set_result( + MessageEventResult().message( + "当前未处于对话状态,请 /switch 切换或者 /new 创建。" + ) + ) + return + + await self.context.conversation_manager.update_conversation( + message.unified_msg_origin, cid, [] + ) + + ret = "清除会话 LLM 聊天历史成功。" + if self.ltm and self.ltm_enabled(message): + cnt = await self.ltm.remove_session(event=message) + ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。" + + message.set_result(MessageEventResult().message(ret)) + + async def his(self, message: AstrMessageEvent, page: int = 1): + """查看对话记录""" + if not self.context.get_using_provider(message.unified_msg_origin): + message.set_result( + MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") + ) + return + + size_per_page = 6 + + conv_mgr = self.context.conversation_manager + umo = message.unified_msg_origin + session_curr_cid = await conv_mgr.get_curr_conversation_id(umo) + + if not session_curr_cid: + session_curr_cid = await conv_mgr.new_conversation( + umo, message.get_platform_id() + ) + + contexts, total_pages = await conv_mgr.get_human_readable_context( + umo, session_curr_cid, page, size_per_page + ) + + history = "" + for context in contexts: + if len(context) > 150: + context = context[:150] + "..." + history += f"{context}\n" + + ret = ( + f"当前对话历史记录:" + f"{history or '无历史记录'}\n\n" + f"第 {page} 页 | 共 {total_pages} 页\n" + f"*输入 /history 2 跳转到第 2 页" + ) + + message.set_result(MessageEventResult().message(ret).use_t2i(False)) + + async def convs(self, message: AstrMessageEvent, page: int = 1): + """查看对话列表""" + + provider = self.context.get_using_provider(message.unified_msg_origin) + if provider and provider.meta().type == "dify": + """原有的Dify处理逻辑保持不变""" + ret = "Dify 对话列表:\n" + assert isinstance(provider, ProviderDify) + data = await provider.api_client.get_chat_convs(message.unified_msg_origin) + idx = 1 + for conv in data["data"]: + ts_h = datetime.datetime.fromtimestamp(conv["updated_at"]).strftime( + "%m-%d %H:%M" + ) + ret += f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n" + idx += 1 + if idx == 1: + ret += "没有找到任何对话。" + dify_cid = provider.conversation_ids.get(message.unified_msg_origin, None) + ret += f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。" + message.set_result(MessageEventResult().message(ret)) + return + + size_per_page = 6 + """获取所有对话列表""" + conversations_all = await self.context.conversation_manager.get_conversations( + message.unified_msg_origin + ) + """计算总页数""" + total_pages = (len(conversations_all) + size_per_page - 1) // size_per_page + """确保页码有效""" + page = max(1, min(page, total_pages)) + """分页处理""" + start_idx = (page - 1) * size_per_page + end_idx = start_idx + size_per_page + conversations_paged = conversations_all[start_idx:end_idx] + + ret = "对话列表:\n---\n" + """全局序号从当前页的第一个开始""" + global_index = start_idx + 1 + + """生成所有对话的标题字典""" + _titles = {} + for conv in conversations_all: + title = conv.title if conv.title else "新对话" + _titles[conv.cid] = title + + """遍历分页后的对话生成列表显示""" + for conv in conversations_paged: + persona_id = conv.persona_id + if not persona_id or persona_id == "[%None]": + persona = await self.context.persona_manager.get_default_persona_v3( + umo=message.unified_msg_origin + ) + persona_id = persona["name"] + title = _titles.get(conv.cid, "新对话") + ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n" + global_index += 1 + + ret += "---\n" + curr_cid = await self.context.conversation_manager.get_curr_conversation_id( + message.unified_msg_origin + ) + if curr_cid: + """从所有对话的标题字典中获取标题""" + title = _titles.get(curr_cid, "新对话") + ret += f"\n当前对话: {title}({curr_cid[:4]})" + else: + ret += "\n当前对话: 无" + + unique_session = self.context.get_config()["platform_settings"][ + "unique_session" + ] + if unique_session: + ret += "\n会话隔离粒度: 个人" + else: + ret += "\n会话隔离粒度: 群聊" + + ret += f"\n第 {page} 页 | 共 {total_pages} 页" + ret += "\n*输入 /ls 2 跳转到第 2 页" + + message.set_result(MessageEventResult().message(ret).use_t2i(False)) + return + + async def new_conv(self, message: AstrMessageEvent): + """ + 创建新对话 + """ + provider = self.context.get_using_provider(message.unified_msg_origin) + if provider and provider.meta().type in ["dify", "coze"]: + assert isinstance(provider, (ProviderDify, ProviderCoze)), ( + "provider type is not dify or coze" + ) + await provider.forget(message.unified_msg_origin) + message.set_result( + MessageEventResult().message("成功,下次聊天将是新对话。") + ) + return + + cid = await self.context.conversation_manager.new_conversation( + message.unified_msg_origin, message.get_platform_id() + ) + + # 长期记忆 + if self.ltm and self.ltm_enabled(message): + try: + await self.ltm.remove_session(event=message) + except Exception as e: + logger.error(f"清理聊天增强记录失败: {e}") + + message.set_result( + MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。") + ) + + async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""): + """创建新群聊对话""" + provider = self.context.get_using_provider(message.unified_msg_origin) + if provider and provider.meta().type in ["dify", "coze"]: + assert isinstance(provider, (ProviderDify, ProviderCoze)), ( + "provider type is not dify or coze" + ) + await provider.forget(message.unified_msg_origin) + message.set_result( + MessageEventResult().message("成功,下次聊天将是新对话。") + ) + return + if sid: + session = str( + MessageSesion( + platform_name=message.platform_meta.id, + message_type=MessageType("GroupMessage"), + session_id=sid, + ) + ) + cid = await self.context.conversation_manager.new_conversation( + session, message.get_platform_id() + ) + message.set_result( + MessageEventResult().message( + f"群聊 {session} 已切换到新对话: 新对话({cid[:4]})。" + ) + ) + else: + message.set_result( + MessageEventResult().message("请输入群聊 ID。/groupnew 群聊ID。") + ) + + async def switch_conv( + self, message: AstrMessageEvent, index: Union[int, None] = None + ): + """通过 /ls 前面的序号切换对话""" + + if not isinstance(index, int): + message.set_result( + MessageEventResult().message("类型错误,请输入数字对话序号。") + ) + return + + provider = self.context.get_using_provider(message.unified_msg_origin) + if provider and provider.meta().type == "dify": + assert isinstance(provider, ProviderDify), "provider type is not dify" + data = await provider.api_client.get_chat_convs(message.unified_msg_origin) + if not data["data"]: + message.set_result(MessageEventResult().message("未找到任何对话。")) + return + selected_conv = None + if index is not None: + try: + selected_conv = data["data"][index - 1] + except IndexError: + message.set_result( + MessageEventResult().message("对话序号错误,请使用 /ls 查看") + ) + return + else: + selected_conv = data["data"][0] + ret = ( + f"Dify 切换到对话: {selected_conv['name']}({selected_conv['id'][:4]})。" + ) + provider.conversation_ids[message.unified_msg_origin] = selected_conv["id"] + message.set_result(MessageEventResult().message(ret)) + return + + if index is None: + message.set_result( + MessageEventResult().message( + "请输入对话序号。/switch 对话序号。/ls 查看对话 /new 新建对话" + ) + ) + return + conversations = await self.context.conversation_manager.get_conversations( + message.unified_msg_origin + ) + if index > len(conversations) or index < 1: + message.set_result( + MessageEventResult().message("对话序号错误,请使用 /ls 查看") + ) + else: + conversation = conversations[index - 1] + title = conversation.title if conversation.title else "新对话" + await self.context.conversation_manager.switch_conversation( + message.unified_msg_origin, conversation.cid + ) + message.set_result( + MessageEventResult().message( + f"切换到对话: {title}({conversation.cid[:4]})。" + ) + ) + + async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""): + """重命名对话""" + if not new_name: + message.set_result(MessageEventResult().message("请输入新的对话名称。")) + return + + provider = self.context.get_using_provider(message.unified_msg_origin) + + if provider and provider.meta().type == "dify": + assert isinstance(provider, ProviderDify) + cid = provider.conversation_ids.get(message.unified_msg_origin, None) + if not cid: + message.set_result(MessageEventResult().message("未找到当前对话。")) + return + await provider.api_client.rename(cid, new_name, message.unified_msg_origin) + message.set_result(MessageEventResult().message("重命名对话成功。")) + return + + await self.context.conversation_manager.update_conversation_title( + message.unified_msg_origin, new_name + ) + message.set_result(MessageEventResult().message("重命名对话成功。")) + + async def del_conv(self, message: AstrMessageEvent): + """删除当前对话""" + is_unique_session = self.context.get_config()["platform_settings"][ + "unique_session" + ] + if message.get_group_id() and not is_unique_session and message.role != "admin": + # 群聊,没开独立会话,发送人不是管理员 + message.set_result( + MessageEventResult().message( + f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。" + ) + ) + return + + provider = self.context.get_using_provider(message.unified_msg_origin) + if provider and provider.meta().type == "dify": + assert isinstance(provider, ProviderDify) + dify_cid = provider.conversation_ids.pop(message.unified_msg_origin, None) + if dify_cid: + await provider.api_client.delete_chat_conv( + message.unified_msg_origin, dify_cid + ) + message.set_result( + MessageEventResult().message( + "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。" + ) + ) + return + + session_curr_cid = ( + await self.context.conversation_manager.get_curr_conversation_id( + message.unified_msg_origin + ) + ) + + if not session_curr_cid: + message.set_result( + MessageEventResult().message( + "当前未处于对话状态,请 /switch 序号 切换或 /new 创建。" + ) + ) + return + + await self.context.conversation_manager.delete_conversation( + message.unified_msg_origin, session_curr_cid + ) + message.set_result( + MessageEventResult().message( + "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。" + ) + ) diff --git a/packages/astrbot/commands/help.py b/packages/astrbot/commands/help.py new file mode 100644 index 000000000..de192ce3d --- /dev/null +++ b/packages/astrbot/commands/help.py @@ -0,0 +1,61 @@ +import aiohttp +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.core.config.default import VERSION +from astrbot.core.utils.io import get_dashboard_version + + +class HelpCommand: + def __init__(self, context: star.Context): + self.context = context + + async def _query_astrbot_notice(self): + try: + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get( + "https://astrbot.app/notice.json", timeout=2 + ) as resp: + return (await resp.json())["notice"] + except BaseException: + return "" + + async def help(self, event: AstrMessageEvent): + """查看帮助""" + notice = "" + try: + notice = await self._query_astrbot_notice() + except BaseException: + pass + + dashboard_version = await get_dashboard_version() + + msg = f"""AstrBot v{VERSION}(WebUI: {dashboard_version}) +内置指令: +[System] +/plugin: 查看插件、插件帮助 +/t2i: 开关文本转图片 +/tts: 开关文本转语音 +/sid: 获取会话 ID +/op: 管理员 +/wl: 白名单 +/dashboard_update: 更新管理面板(op) +/alter_cmd: 设置指令权限(op) + +[大模型] +/llm: 开启/关闭 LLM +/provider: 大模型提供商 +/model: 模型列表 +/ls: 对话列表 +/new: 创建新对话 +/groupnew 群号: 为群聊创建新对话(op) +/switch 序号: 切换对话 +/rename 新名字: 重命名当前对话 +/del: 删除当前会话对话(op) +/reset: 重置 LLM 会话 +/history: 当前对话的对话记录 +/persona: 人格情景(op) +/key: API Key(op) +/websearch: 网页搜索 +{notice}""" + + event.set_result(MessageEventResult().message(msg).use_t2i(False)) diff --git a/packages/astrbot/commands/llm.py b/packages/astrbot/commands/llm.py new file mode 100644 index 000000000..51f8d9923 --- /dev/null +++ b/packages/astrbot/commands/llm.py @@ -0,0 +1,20 @@ +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent, MessageChain + + +class LLMCommands: + def __init__(self, context: star.Context): + self.context = context + + async def llm(self, event: AstrMessageEvent): + """开启/关闭 LLM""" + cfg = self.context.get_config(umo=event.unified_msg_origin) + enable = cfg["provider_settings"].get("enable", True) + if enable: + cfg["provider_settings"]["enable"] = False + status = "关闭" + else: + cfg["provider_settings"]["enable"] = True + status = "开启" + cfg.save_config() + await event.send(MessageChain().message(f"{status} LLM 聊天功能。")) diff --git a/packages/astrbot/commands/persona.py b/packages/astrbot/commands/persona.py new file mode 100644 index 000000000..9971df6f0 --- /dev/null +++ b/packages/astrbot/commands/persona.py @@ -0,0 +1,122 @@ +import builtins +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent, MessageEventResult + + +class PersonaCommands: + def __init__(self, context: star.Context): + self.context = context + + async def persona(self, message: AstrMessageEvent): + l = message.message_str.split(" ") # noqa: E741 + umo = message.unified_msg_origin + + curr_persona_name = "无" + cid = await self.context.conversation_manager.get_curr_conversation_id(umo) + default_persona = await self.context.persona_manager.get_default_persona_v3( + umo=umo + ) + curr_cid_title = "无" + if cid: + conv = await self.context.conversation_manager.get_conversation( + unified_msg_origin=umo, + conversation_id=cid, + create_if_not_exists=True, + ) + if conv is None: + message.set_result( + MessageEventResult().message( + "当前对话不存在,请先使用 /new 新建一个对话。" + ) + ) + return + if not conv.persona_id and conv.persona_id != "[%None]": + curr_persona_name = default_persona["name"] + else: + curr_persona_name = conv.persona_id + + curr_cid_title = conv.title if conv.title else "新对话" + curr_cid_title += f"({cid[:4]})" + + if len(l) == 1: + message.set_result( + MessageEventResult() + .message( + f"""[Persona] + +- 人格情景列表: `/persona list` +- 设置人格情景: `/persona 人格` +- 人格情景详细信息: `/persona view 人格` +- 取消人格: `/persona unset` + +默认人格情景: {default_persona["name"]} +当前对话 {curr_cid_title} 的人格情景: {curr_persona_name} + +配置人格情景请前往管理面板-配置页 +""" + ) + .use_t2i(False) + ) + elif l[1] == "list": + msg = "人格列表:\n" + for persona in self.context.provider_manager.personas: + msg += f"- {persona['name']}\n" + msg += "\n\n*输入 `/persona view 人格名` 查看人格详细信息" + message.set_result(MessageEventResult().message(msg)) + elif l[1] == "view": + if len(l) == 2: + message.set_result(MessageEventResult().message("请输入人格情景名")) + return + ps = l[2].strip() + if persona := next( + builtins.filter( + lambda persona: persona["name"] == ps, + self.context.provider_manager.personas, + ), + None, + ): + msg = f"人格{ps}的详细信息:\n" + msg += f"{persona['prompt']}\n" + else: + msg = f"人格{ps}不存在" + message.set_result(MessageEventResult().message(msg)) + elif l[1] == "unset": + if not cid: + message.set_result( + MessageEventResult().message("当前没有对话,无法取消人格。") + ) + return + await self.context.conversation_manager.update_conversation_persona_id( + message.unified_msg_origin, "[%None]" + ) + message.set_result(MessageEventResult().message("取消人格成功。")) + else: + ps = "".join(l[1:]).strip() + if not cid: + message.set_result( + MessageEventResult().message( + "当前没有对话,请先开始对话或使用 /new 创建一个对话。" + ) + ) + return + if persona := next( + builtins.filter( + lambda persona: persona["name"] == ps, + self.context.provider_manager.personas, + ), + None, + ): + await self.context.conversation_manager.update_conversation_persona_id( + message.unified_msg_origin, ps + ) + message.set_result( + MessageEventResult().message( + "设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。" + ) + ) + else: + message.set_result( + MessageEventResult().message( + "不存在该人格情景。使用 /persona list 查看所有。" + ) + ) diff --git a/packages/astrbot/commands/plugin.py b/packages/astrbot/commands/plugin.py new file mode 100644 index 000000000..8f705b417 --- /dev/null +++ b/packages/astrbot/commands/plugin.py @@ -0,0 +1,117 @@ +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata +from astrbot.core.star.filter.command import CommandFilter +from astrbot.core.star.filter.command_group import CommandGroupFilter +from astrbot.core.star.star_manager import PluginManager +from astrbot.core import DEMO_MODE, logger + + +class PluginCommands: + def __init__(self, context: star.Context): + self.context = context + + async def plugin_ls(self, event: AstrMessageEvent): + """获取已经安装的插件列表。""" + plugin_list_info = "已加载的插件:\n" + for plugin in self.context.get_all_stars(): + plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}" + if not plugin.activated: + plugin_list_info += " (未启用)" + plugin_list_info += "\n" + if plugin_list_info.strip() == "": + plugin_list_info = "没有加载任何插件。" + + plugin_list_info += "\n使用 /plugin help <插件名> 查看插件帮助和加载的指令。\n使用 /plugin on/off <插件名> 启用或者禁用插件。" + event.set_result( + MessageEventResult().message(f"{plugin_list_info}").use_t2i(False) + ) + + async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""): + """禁用插件""" + if DEMO_MODE: + event.set_result(MessageEventResult().message("演示模式下无法禁用插件。")) + return + if not plugin_name: + event.set_result( + MessageEventResult().message("/plugin off <插件名> 禁用插件。") + ) + return + await self.context._star_manager.turn_off_plugin(plugin_name) # type: ignore + event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。")) + + async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""): + """启用插件""" + if DEMO_MODE: + event.set_result(MessageEventResult().message("演示模式下无法启用插件。")) + return + if not plugin_name: + event.set_result( + MessageEventResult().message("/plugin on <插件名> 启用插件。") + ) + return + await self.context._star_manager.turn_on_plugin(plugin_name) # type: ignore + event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。")) + + async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""): + """安装插件""" + if DEMO_MODE: + event.set_result(MessageEventResult().message("演示模式下无法安装插件。")) + return + if not plugin_repo: + event.set_result( + MessageEventResult().message("/plugin get <插件仓库地址> 安装插件") + ) + return + logger.info(f"准备从 {plugin_repo} 安装插件。") + if self.context._star_manager: + star_mgr: PluginManager = self.context._star_manager + try: + await star_mgr.install_plugin(plugin_repo) # type: ignore + event.set_result(MessageEventResult().message("安装插件成功。")) + except Exception as e: + logger.error(f"安装插件失败: {e}") + event.set_result(MessageEventResult().message(f"安装插件失败: {e}")) + return + + async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""): + """获取插件帮助""" + if not plugin_name: + event.set_result( + MessageEventResult().message("/plugin help <插件名> 查看插件信息。") + ) + return + plugin = self.context.get_registered_star(plugin_name) + if plugin is None: + event.set_result(MessageEventResult().message("未找到此插件。")) + return + help_msg = "" + help_msg += f"\n\n✨ 作者: {plugin.author}\n✨ 版本: {plugin.version}" + command_handlers = [] + command_names = [] + for handler in star_handlers_registry: + assert isinstance(handler, StarHandlerMetadata) + if handler.handler_module_path != plugin.module_path: + continue + for filter_ in handler.event_filters: + if isinstance(filter_, CommandFilter): + command_handlers.append(handler) + command_names.append(filter_.command_name) + break + elif isinstance(filter_, CommandGroupFilter): + command_handlers.append(handler) + command_names.append(filter_.group_name) + + if len(command_handlers) > 0: + help_msg += "\n\n🔧 指令列表:\n" + for i in range(len(command_handlers)): + help_msg += f"- {command_names[i]}" + if command_handlers[i].desc: + help_msg += f": {command_handlers[i].desc}" + help_msg += "\n" + + help_msg += "\nTip: 指令的触发需要添加唤醒前缀,默认为 /。" + + ret = f"🧩 插件 {plugin_name} 帮助信息:\n" + help_msg + ret += "更多帮助信息请查看插件仓库 README。" + event.set_result(MessageEventResult().message(ret).use_t2i(False)) diff --git a/packages/astrbot/commands/provider.py b/packages/astrbot/commands/provider.py new file mode 100644 index 000000000..3a184475b --- /dev/null +++ b/packages/astrbot/commands/provider.py @@ -0,0 +1,201 @@ +import re +from typing import Union +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.core.provider.entities import ProviderType + + +class ProviderCommands: + def __init__(self, context: star.Context): + self.context = context + + async def provider( + self, + event: AstrMessageEvent, + idx: Union[str, int, None] = None, + idx2: Union[int, None] = None, + ): + """查看或者切换 LLM Provider""" + umo = event.unified_msg_origin + + if idx is None: + ret = "## 载入的 LLM 提供商\n" + for idx, llm in enumerate(self.context.get_all_providers()): + id_ = llm.meta().id + ret += f"{idx + 1}. {id_} ({llm.meta().model})" + provider_using = self.context.get_using_provider(umo=umo) + if provider_using and provider_using.meta().id == id_: + ret += " (当前使用)" + ret += "\n" + + tts_providers = self.context.get_all_tts_providers() + if tts_providers: + ret += "\n## 载入的 TTS 提供商\n" + for idx, tts in enumerate(tts_providers): + id_ = tts.meta().id + ret += f"{idx + 1}. {id_}" + tts_using = self.context.get_using_tts_provider(umo=umo) + if tts_using and tts_using.meta().id == id_: + ret += " (当前使用)" + ret += "\n" + + stt_providers = self.context.get_all_stt_providers() + if stt_providers: + ret += "\n## 载入的 STT 提供商\n" + for idx, stt in enumerate(stt_providers): + id_ = stt.meta().id + ret += f"{idx + 1}. {id_}" + stt_using = self.context.get_using_stt_provider(umo=umo) + if stt_using and stt_using.meta().id == id_: + ret += " (当前使用)" + ret += "\n" + + ret += "\n使用 /provider <序号> 切换 LLM 提供商。" + + if tts_providers: + ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。" + if stt_providers: + ret += "\n使用 /provider stt <切换> STT 提供商。" + + event.set_result(MessageEventResult().message(ret)) + elif idx == "tts": + if idx2 is None: + event.set_result(MessageEventResult().message("请输入序号。")) + return + else: + if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1: + event.set_result(MessageEventResult().message("无效的序号。")) + provider = self.context.get_all_tts_providers()[idx2 - 1] + id_ = provider.meta().id + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.TEXT_TO_SPEECH, + umo=umo, + ) + event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + elif idx == "stt": + if idx2 is None: + event.set_result(MessageEventResult().message("请输入序号。")) + return + else: + if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1: + event.set_result(MessageEventResult().message("无效的序号。")) + provider = self.context.get_all_stt_providers()[idx2 - 1] + id_ = provider.meta().id + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.SPEECH_TO_TEXT, + umo=umo, + ) + event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + elif isinstance(idx, int): + if idx > len(self.context.get_all_providers()) or idx < 1: + event.set_result(MessageEventResult().message("无效的序号。")) + + provider = self.context.get_all_providers()[idx - 1] + id_ = provider.meta().id + await self.context.provider_manager.set_provider( + provider_id=id_, + provider_type=ProviderType.CHAT_COMPLETION, + umo=umo, + ) + event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) + else: + event.set_result(MessageEventResult().message("无效的参数。")) + + async def model_ls( + self, message: AstrMessageEvent, idx_or_name: Union[int, str, None] = None + ): + """查看或者切换模型""" + prov = self.context.get_using_provider(message.unified_msg_origin) + if not prov: + message.set_result( + MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") + ) + return + # 定义正则表达式匹配 API 密钥 + api_key_pattern = re.compile(r"key=[^&'\" ]+") + + if idx_or_name is None: + models = [] + try: + models = await prov.get_models() + except BaseException as e: + err_msg = api_key_pattern.sub("key=***", str(e)) + message.set_result( + MessageEventResult() + .message("获取模型列表失败: " + err_msg) + .use_t2i(False) + ) + return + i = 1 + ret = "下面列出了此服务提供商可用模型:" + for model in models: + ret += f"\n{i}. {model}" + i += 1 + + curr_model = prov.get_model() or "无" + ret += f"\n当前模型: [{curr_model}]" + + ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" + message.set_result(MessageEventResult().message(ret).use_t2i(False)) + else: + if isinstance(idx_or_name, int): + models = [] + try: + models = await prov.get_models() + except BaseException as e: + message.set_result( + MessageEventResult().message("获取模型列表失败: " + str(e)) + ) + return + if idx_or_name > len(models) or idx_or_name < 1: + message.set_result(MessageEventResult().message("模型序号错误。")) + else: + try: + new_model = models[idx_or_name - 1] + prov.set_model(new_model) + except BaseException as e: + message.set_result( + MessageEventResult().message("切换模型未知错误: " + str(e)) + ) + message.set_result(MessageEventResult().message("切换模型成功。")) + else: + prov.set_model(idx_or_name) + message.set_result( + MessageEventResult().message(f"切换模型到 {prov.get_model()}。") + ) + + async def key(self, message: AstrMessageEvent, index: Union[int, None] = None): + prov = self.context.get_using_provider(message.unified_msg_origin) + if not prov: + message.set_result( + MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") + ) + return + + if index is None: + keys_data = prov.get_keys() + curr_key = prov.get_current_key() + ret = "Key:" + for i, k in enumerate(keys_data): + ret += f"\n{i + 1}. {k[:8]}" + + ret += f"\n当前 Key: {curr_key[:8]}" + ret += "\n当前模型: " + prov.get_model() + ret += "\n使用 /key 切换 Key。" + + message.set_result(MessageEventResult().message(ret).use_t2i(False)) + else: + keys_data = prov.get_keys() + if index > len(keys_data) or index < 1: + message.set_result(MessageEventResult().message("Key 序号错误。")) + else: + try: + new_key = keys_data[index - 1] + prov.set_key(new_key) + except BaseException as e: + message.set_result( + MessageEventResult().message(f"切换 Key 未知错误: {str(e)}") + ) + message.set_result(MessageEventResult().message("切换 Key 成功。")) diff --git a/packages/astrbot/commands/setunset.py b/packages/astrbot/commands/setunset.py new file mode 100644 index 000000000..a82fcdca3 --- /dev/null +++ b/packages/astrbot/commands/setunset.py @@ -0,0 +1,37 @@ +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.api import sp + + +class SetUnsetCommands: + def __init__(self, context: star.Context): + self.context = context + + async def set_variable(self, event: AstrMessageEvent, key: str, value: str): + """设置会话变量""" + uid = event.unified_msg_origin + session_var = await sp.session_get(uid, "session_variables", {}) + session_var[key] = value + await sp.session_put(uid, "session_variables", session_var) + + event.set_result( + MessageEventResult().message( + f"会话 {uid} 变量 {key} 存储成功。使用 /unset 移除。" + ) + ) + + async def unset_variable(self, event: AstrMessageEvent, key: str): + """移除会话变量""" + uid = event.unified_msg_origin + session_var = await sp.session_get(uid, "session_variables", {}) + + if key not in session_var: + event.set_result( + MessageEventResult().message("没有那个变量名。格式 /unset 变量名。") + ) + else: + del session_var[key] + await sp.session_put(uid, "session_variables", session_var) + event.set_result( + MessageEventResult().message(f"会话 {uid} 变量 {key} 移除成功。") + ) diff --git a/packages/astrbot/commands/sid.py b/packages/astrbot/commands/sid.py new file mode 100644 index 000000000..165683e43 --- /dev/null +++ b/packages/astrbot/commands/sid.py @@ -0,0 +1,29 @@ +"""会话ID命令""" + +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent, MessageEventResult + + +class SIDCommand: + """会话ID命令类""" + + def __init__(self, context: star.Context): + self.context = context + + async def sid(self, event: AstrMessageEvent): + """获取会话 ID 和 管理员 ID""" + sid = event.unified_msg_origin + user_id = str(event.get_sender_id()) + ret = f"""SID: {sid} 此 ID 可用于设置会话白名单。 +/wl 添加白名单, /dwl 删除白名单。 + +UID: {user_id} 此 ID 可用于设置管理员。 +/op 授权管理员, /deop 取消管理员。""" + + if ( + self.context.get_config()["platform_settings"]["unique_session"] + and event.get_group_id() + ): + ret += f"\n\n当前处于独立会话模式, 此群 ID: {event.get_group_id()}, 也可将此 ID 加入白名单来放行整个群聊。" + + event.set_result(MessageEventResult().message(ret).use_t2i(False)) diff --git a/packages/astrbot/commands/t2i.py b/packages/astrbot/commands/t2i.py new file mode 100644 index 000000000..28c1d4eb6 --- /dev/null +++ b/packages/astrbot/commands/t2i.py @@ -0,0 +1,23 @@ +"""文本转图片命令""" + +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent, MessageEventResult + + +class T2ICommand: + """文本转图片命令类""" + + def __init__(self, context: star.Context): + self.context = context + + async def t2i(self, event: AstrMessageEvent): + """开关文本转图片""" + config = self.context.get_config(umo=event.unified_msg_origin) + if config["t2i"]: + config["t2i"] = False + config.save_config() + event.set_result(MessageEventResult().message("已关闭文本转图片模式。")) + return + config["t2i"] = True + config.save_config() + event.set_result(MessageEventResult().message("已开启文本转图片模式。")) diff --git a/packages/astrbot/commands/tool.py b/packages/astrbot/commands/tool.py new file mode 100644 index 000000000..335ed5580 --- /dev/null +++ b/packages/astrbot/commands/tool.py @@ -0,0 +1,31 @@ +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent, MessageEventResult + + +class ToolCommands: + def __init__(self, context: star.Context): + self.context = context + + async def tool_ls(self, event: AstrMessageEvent): + """查看函数工具列表""" + event.set_result( + MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。") + ) + + async def tool_on(self, event: AstrMessageEvent, tool_name: str = ""): + """启用一个函数工具""" + event.set_result( + MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。") + ) + + async def tool_off(self, event: AstrMessageEvent, tool_name: str = ""): + """停用一个函数工具""" + event.set_result( + MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。") + ) + + async def tool_all_off(self, event: AstrMessageEvent): + """停用所有函数工具""" + event.set_result( + MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。") + ) diff --git a/packages/astrbot/commands/tts.py b/packages/astrbot/commands/tts.py new file mode 100644 index 000000000..a0102fb76 --- /dev/null +++ b/packages/astrbot/commands/tts.py @@ -0,0 +1,36 @@ +"""文本转语音命令""" + +import astrbot.api.star as star +from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.core.star.session_llm_manager import SessionServiceManager + + +class TTSCommand: + """文本转语音命令类""" + + def __init__(self, context: star.Context): + self.context = context + + async def tts(self, event: AstrMessageEvent): + """开关文本转语音(会话级别)""" + umo = event.unified_msg_origin + ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo) + cfg = self.context.get_config(umo=umo) + tts_enable = cfg["provider_tts_settings"]["enable"] + + # 切换状态 + new_status = not ses_tts + SessionServiceManager.set_tts_status_for_session(umo, new_status) + + status_text = "已开启" if new_status else "已关闭" + + if new_status and not tts_enable: + event.set_result( + MessageEventResult().message( + f"{status_text}当前会话的文本转语音。但 TTS 功能在配置中未启用,请前往 WebUI 开启。" + ) + ) + else: + event.set_result( + MessageEventResult().message(f"{status_text}当前会话的文本转语音。") + ) diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 327dd8f58..08bbcba6e 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -1,86 +1,55 @@ -import aiohttp -import datetime -import builtins import traceback -import re -import zoneinfo import astrbot.api.star as star import astrbot.api.event.filter as filter -from astrbot.api.event import AstrMessageEvent, MessageEventResult -from astrbot.api import sp +from astrbot.api.event import AstrMessageEvent from astrbot.api.provider import ProviderRequest -from astrbot.core import DEMO_MODE -from astrbot.core.platform.astr_message_event import MessageSesion -from astrbot.core.platform.message_type import MessageType -from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.sources.dify_source import ProviderDify -from astrbot.core.utils.io import download_dashboard, get_dashboard_version -from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata -from astrbot.core.star.star import star_map -from astrbot.core.star.star_manager import PluginManager -from astrbot.core.star.filter.command import CommandFilter -from astrbot.core.star.filter.command_group import CommandGroupFilter -from astrbot.core.star.filter.permission import PermissionTypeFilter -from astrbot.core.config.default import VERSION from .long_term_memory import LongTermMemory from astrbot.core import logger -from astrbot.api.message_components import Plain, Image, Reply -from astrbot.core.star.session_llm_manager import SessionServiceManager -from astrbot.core.provider.func_tool_manager import ToolSet +from astrbot.api.message_components import Plain, Image from typing import Union -from enum import Enum - -class RstScene(Enum): - GROUP_UNIQUE_ON = ("group_unique_on", "群聊+会话隔离开启") - GROUP_UNIQUE_OFF = ("group_unique_off", "群聊+会话隔离关闭") - PRIVATE = ("private", "私聊") - - @property - def key(self) -> str: - return self.value[0] - - @property - def name(self) -> str: - return self.value[1] - - @classmethod - def from_index(cls, index: int) -> "RstScene": - mapping = {1: cls.GROUP_UNIQUE_ON, 2: cls.GROUP_UNIQUE_OFF, 3: cls.PRIVATE} - return mapping[index] - - @classmethod - def get_scene(cls, is_group: bool, is_unique_session: bool) -> "RstScene": - if is_group: - return cls.GROUP_UNIQUE_ON if is_unique_session else cls.GROUP_UNIQUE_OFF - return cls.PRIVATE +from .commands import ( + HelpCommand, + LLMCommands, + ToolCommands, + PluginCommands, + AdminCommands, + ConversationCommands, + ProviderCommands, + PersonaCommands, + AlterCmdCommands, + SetUnsetCommands, + T2ICommand, + TTSCommand, + SIDCommand, +) +from .process_llm_request import ProcessLLMRequest class Main(star.Star): def __init__(self, context: star.Context) -> None: self.context = context - cfg = context.get_config() - self.timezone = cfg.get("timezone") - if not self.timezone: - # 系统默认时区 - self.timezone = None - else: - logger.info(f"Timezone set to: {self.timezone}") self.ltm = None try: self.ltm = LongTermMemory(self.context.astrbot_config_mgr, self.context) except BaseException as e: logger.error(f"聊天增强 err: {e}") - async def _query_astrbot_notice(self): - try: - async with aiohttp.ClientSession(trust_env=True) as session: - async with session.get( - "https://astrbot.app/notice.json", timeout=2 - ) as resp: - return (await resp.json())["notice"] - except BaseException: - return "" + self.help_c = HelpCommand(self.context) + self.llm_c = LLMCommands(self.context) + self.tool_c = ToolCommands(self.context) + self.plugin_c = PluginCommands(self.context) + self.admin_c = AdminCommands(self.context) + self.conversation_c = ConversationCommands(self.context) + self.provider_c = ProviderCommands(self.context) + self.persona_c = PersonaCommands(self.context) + self.alter_cmd_c = AlterCmdCommands(self.context) + self.setunset_c = SetUnsetCommands(self.context) + self.t2i_c = T2ICommand(self.context) + self.tts_c = TTSCommand(self.context) + self.sid_c = SIDCommand(self.context) + self.proc_llm_req = ProcessLLMRequest(self.context) def ltm_enabled(self, event: AstrMessageEvent): ltmse = self.context.get_config(umo=event.unified_msg_origin)[ @@ -91,59 +60,13 @@ def ltm_enabled(self, event: AstrMessageEvent): @filter.command("help") async def help(self, event: AstrMessageEvent): """查看帮助""" - notice = "" - try: - notice = await self._query_astrbot_notice() - except BaseException: - pass - - dashboard_version = await get_dashboard_version() - - msg = f"""AstrBot v{VERSION}(WebUI: {dashboard_version}) -内置指令: -[System] -/plugin: 查看插件、插件帮助 -/t2i: 开关文本转图片 -/tts: 开关文本转语音 -/sid: 获取会话 ID -/op: 管理员 -/wl: 白名单 -/dashboard_update: 更新管理面板(op) -/alter_cmd: 设置指令权限(op) - -[大模型] -/llm: 开启/关闭 LLM -/provider: 大模型提供商 -/model: 模型列表 -/ls: 对话列表 -/new: 创建新对话 -/groupnew 群号: 为群聊创建新对话(op) -/switch 序号: 切换对话 -/rename 新名字: 重命名当前对话 -/del: 删除当前会话对话(op) -/reset: 重置 LLM 会话 -/history: 当前对话的对话记录 -/persona: 人格情景(op) -/key: API Key(op) -/websearch: 网页搜索 -{notice}""" - - event.set_result(MessageEventResult().message(msg).use_t2i(False)) + await self.help_c.help(event) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("llm") async def llm(self, event: AstrMessageEvent): """开启/关闭 LLM""" - cfg = self.context.get_config(umo=event.unified_msg_origin) - enable = cfg["provider_settings"]["enable"] - if enable: - cfg["provider_settings"]["enable"] = False - status = "关闭" - else: - cfg["provider_settings"]["enable"] = True - status = "开启" - cfg.save_config() - yield event.plain_result(f"{status} LLM 聊天功能。") + await self.llm_c.llm(event) @filter.command_group("tool") def tool(self): @@ -152,30 +75,22 @@ def tool(self): @tool.command("ls") async def tool_ls(self, event: AstrMessageEvent): """查看函数工具列表""" - event.set_result( - MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。") - ) + await self.tool_c.tool_ls(event) @tool.command("on") async def tool_on(self, event: AstrMessageEvent, tool_name: str): """启用一个函数工具""" - event.set_result( - MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。") - ) + await self.tool_c.tool_on(event, tool_name) @tool.command("off") async def tool_off(self, event: AstrMessageEvent, tool_name: str): """停用一个函数工具""" - event.set_result( - MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。") - ) + await self.tool_c.tool_off(event, tool_name) @tool.command("off_all") async def tool_all_off(self, event: AstrMessageEvent): """停用所有函数工具""" - event.set_result( - MessageEventResult().message("tool 指令在 AstrBot v4.0.0 已经被移除。") - ) + await self.tool_c.tool_all_off(event) @filter.command_group("plugin") def plugin(self): @@ -184,227 +99,69 @@ def plugin(self): @plugin.command("ls") async def plugin_ls(self, event: AstrMessageEvent): """获取已经安装的插件列表。""" - plugin_list_info = "已加载的插件:\n" - for plugin in self.context.get_all_stars(): - plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}" - if not plugin.activated: - plugin_list_info += " (未启用)" - plugin_list_info += "\n" - if plugin_list_info.strip() == "": - plugin_list_info = "没有加载任何插件。" - - plugin_list_info += "\n使用 /plugin help <插件名> 查看插件帮助和加载的指令。\n使用 /plugin on/off <插件名> 启用或者禁用插件。" - event.set_result( - MessageEventResult().message(f"{plugin_list_info}").use_t2i(False) - ) + await self.plugin_c.plugin_ls(event) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("off") async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = None): """禁用插件""" - if DEMO_MODE: - event.set_result(MessageEventResult().message("演示模式下无法禁用插件。")) - return - if not plugin_name: - event.set_result( - MessageEventResult().message("/plugin off <插件名> 禁用插件。") - ) - return - await self.context._star_manager.turn_off_plugin(plugin_name) - event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。")) + await self.plugin_c.plugin_off(event, plugin_name) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("on") async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = None): """启用插件""" - if DEMO_MODE: - event.set_result(MessageEventResult().message("演示模式下无法启用插件。")) - return - if not plugin_name: - event.set_result( - MessageEventResult().message("/plugin on <插件名> 启用插件。") - ) - return - await self.context._star_manager.turn_on_plugin(plugin_name) - event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。")) + await self.plugin_c.plugin_on(event, plugin_name) @filter.permission_type(filter.PermissionType.ADMIN) @plugin.command("get") async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = None): """安装插件""" - if DEMO_MODE: - event.set_result(MessageEventResult().message("演示模式下无法安装插件。")) - return - if not plugin_repo: - event.set_result( - MessageEventResult().message("/plugin get <插件仓库地址> 安装插件") - ) - return - logger.info(f"准备从 {plugin_repo} 安装插件。") - if self.context._star_manager: - star_mgr: PluginManager = self.context._star_manager - try: - await star_mgr.install_plugin(plugin_repo) - event.set_result(MessageEventResult().message("安装插件成功。")) - except Exception as e: - logger.error(f"安装插件失败: {e}") - event.set_result(MessageEventResult().message(f"安装插件失败: {e}")) - return + await self.plugin_c.plugin_get(event, plugin_repo) @plugin.command("help") async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = None): """获取插件帮助""" - if not plugin_name: - event.set_result( - MessageEventResult().message("/plugin help <插件名> 查看插件信息。") - ) - return - plugin = self.context.get_registered_star(plugin_name) - if plugin is None: - event.set_result(MessageEventResult().message("未找到此插件。")) - return - help_msg = "" - help_msg += f"\n\n✨ 作者: {plugin.author}\n✨ 版本: {plugin.version}" - command_handlers = [] - command_names = [] - for handler in star_handlers_registry: - assert isinstance(handler, StarHandlerMetadata) - if handler.handler_module_path != plugin.module_path: - continue - for filter_ in handler.event_filters: - if isinstance(filter_, CommandFilter): - command_handlers.append(handler) - command_names.append(filter_.command_name) - break - elif isinstance(filter_, CommandGroupFilter): - command_handlers.append(handler) - command_names.append(filter_.group_name) - - if len(command_handlers) > 0: - help_msg += "\n\n🔧 指令列表:\n" - for i in range(len(command_handlers)): - help_msg += f"- {command_names[i]}" - if command_handlers[i].desc: - help_msg += f": {command_handlers[i].desc}" - help_msg += "\n" - - help_msg += "\nTip: 指令的触发需要添加唤醒前缀,默认为 /。" - - ret = f"🧩 插件 {plugin_name} 帮助信息:\n" + help_msg - ret += "更多帮助信息请查看插件仓库 README。" - event.set_result(MessageEventResult().message(ret).use_t2i(False)) + await self.plugin_c.plugin_help(event, plugin_name) @filter.command("t2i") async def t2i(self, event: AstrMessageEvent): """开关文本转图片""" - config = self.context.get_config(umo=event.unified_msg_origin) - if config["t2i"]: - config["t2i"] = False - config.save_config() - event.set_result(MessageEventResult().message("已关闭文本转图片模式。")) - return - config["t2i"] = True - config.save_config() - event.set_result(MessageEventResult().message("已开启文本转图片模式。")) + await self.t2i_c.t2i(event) @filter.command("tts") async def tts(self, event: AstrMessageEvent): """开关文本转语音(会话级别)""" - umo = event.unified_msg_origin - ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo) - cfg = self.context.get_config(umo=umo) - tts_enable = cfg["provider_tts_settings"]["enable"] - - # 切换状态 - new_status = not ses_tts - SessionServiceManager.set_tts_status_for_session(umo, new_status) - - status_text = "已开启" if new_status else "已关闭" - - if new_status and not tts_enable: - event.set_result( - MessageEventResult().message( - f"{status_text}当前会话的文本转语音。但 TTS 功能在配置中未启用,请前往 WebUI 开启。" - ) - ) - else: - event.set_result( - MessageEventResult().message(f"{status_text}当前会话的文本转语音。") - ) + await self.tts_c.tts(event) @filter.command("sid") async def sid(self, event: AstrMessageEvent): """获取会话 ID 和 管理员 ID""" - sid = event.unified_msg_origin - user_id = str(event.get_sender_id()) - ret = f"""SID: {sid} 此 ID 可用于设置会话白名单。 -/wl 添加白名单, /dwl 删除白名单。 - -UID: {user_id} 此 ID 可用于设置管理员。 -/op 授权管理员, /deop 取消管理员。""" - - if ( - self.context.get_config()["platform_settings"]["unique_session"] - and event.get_group_id() - ): - ret += f"\n\n当前处于独立会话模式, 此群 ID: {event.get_group_id()}, 也可将此 ID 加入白名单来放行整个群聊。" - - event.set_result(MessageEventResult().message(ret).use_t2i(False)) + await self.sid_c.sid(event) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("op") async def op(self, event: AstrMessageEvent, admin_id: str = None): """授权管理员。op """ - if admin_id is None: - event.set_result( - MessageEventResult().message( - "使用方法: /op 授权管理员;/deop 取消管理员。可通过 /sid 获取 ID。" - ) - ) - return - self.context.get_config()["admins_id"].append(str(admin_id)) - self.context.get_config().save_config() - event.set_result(MessageEventResult().message("授权成功。")) + await self.admin_c.op(event, admin_id) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("deop") async def deop(self, event: AstrMessageEvent, admin_id: str): """取消授权管理员。deop """ - try: - self.context.get_config()["admins_id"].remove(str(admin_id)) - self.context.get_config().save_config() - event.set_result(MessageEventResult().message("取消授权成功。")) - except ValueError: - event.set_result( - MessageEventResult().message("此用户 ID 不在管理员名单内。") - ) + await self.admin_c.deop(event, admin_id) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("wl") async def wl(self, event: AstrMessageEvent, sid: str = None): """添加白名单。wl """ - if sid is None: - event.set_result( - MessageEventResult().message( - "使用方法: /wl 添加白名单;/dwl 删除白名单。可通过 /sid 获取 ID。" - ) - ) - cfg = self.context.get_config(umo=event.unified_msg_origin) - cfg["platform_settings"]["id_whitelist"].append(str(sid)) - cfg.save_config() - event.set_result(MessageEventResult().message("添加白名单成功。")) + await self.admin_c.wl(event, sid) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("dwl") async def dwl(self, event: AstrMessageEvent, sid: str): """删除白名单。dwl """ - try: - cfg = self.context.get_config(umo=event.unified_msg_origin) - cfg["platform_settings"]["id_whitelist"].remove(str(sid)) - cfg.save_config() - event.set_result(MessageEventResult().message("删除白名单成功。")) - except ValueError: - event.set_result(MessageEventResult().message("此 SID 不在白名单内。")) + await self.admin_c.dwl(event, sid) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("provider") @@ -412,162 +169,12 @@ async def provider( self, event: AstrMessageEvent, idx: Union[str, int] = None, idx2: int = None ): """查看或者切换 LLM Provider""" - umo = event.unified_msg_origin - - if idx is None: - ret = "## 载入的 LLM 提供商\n" - for idx, llm in enumerate(self.context.get_all_providers()): - id_ = llm.meta().id - ret += f"{idx + 1}. {id_} ({llm.meta().model})" - provider_using = self.context.get_using_provider(umo=umo) - if provider_using and provider_using.meta().id == id_: - ret += " (当前使用)" - ret += "\n" - - tts_providers = self.context.get_all_tts_providers() - if tts_providers: - ret += "\n## 载入的 TTS 提供商\n" - for idx, tts in enumerate(tts_providers): - id_ = tts.meta().id - ret += f"{idx + 1}. {id_}" - tts_using = self.context.get_using_tts_provider(umo=umo) - if tts_using and tts_using.meta().id == id_: - ret += " (当前使用)" - ret += "\n" - - stt_providers = self.context.get_all_stt_providers() - if stt_providers: - ret += "\n## 载入的 STT 提供商\n" - for idx, stt in enumerate(stt_providers): - id_ = stt.meta().id - ret += f"{idx + 1}. {id_}" - stt_using = self.context.get_using_stt_provider(umo=umo) - if stt_using and stt_using.meta().id == id_: - ret += " (当前使用)" - ret += "\n" - - ret += "\n使用 /provider <序号> 切换 LLM 提供商。" - - if tts_providers: - ret += "\n使用 /provider tts <序号> 切换 TTS 提供商。" - if stt_providers: - ret += "\n使用 /provider stt <切换> STT 提供商。" - - event.set_result(MessageEventResult().message(ret)) - elif idx == "tts": - if idx2 is None: - event.set_result(MessageEventResult().message("请输入序号。")) - return - else: - if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1: - event.set_result(MessageEventResult().message("无效的序号。")) - provider = self.context.get_all_tts_providers()[idx2 - 1] - id_ = provider.meta().id - await self.context.provider_manager.set_provider( - provider_id=id_, - provider_type=ProviderType.TEXT_TO_SPEECH, - umo=umo, - ) - event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) - elif idx == "stt": - if idx2 is None: - event.set_result(MessageEventResult().message("请输入序号。")) - return - else: - if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1: - event.set_result(MessageEventResult().message("无效的序号。")) - provider = self.context.get_all_stt_providers()[idx2 - 1] - id_ = provider.meta().id - await self.context.provider_manager.set_provider( - provider_id=id_, - provider_type=ProviderType.SPEECH_TO_TEXT, - umo=umo, - ) - event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) - elif isinstance(idx, int): - if idx > len(self.context.get_all_providers()) or idx < 1: - event.set_result(MessageEventResult().message("无效的序号。")) - - provider = self.context.get_all_providers()[idx - 1] - id_ = provider.meta().id - await self.context.provider_manager.set_provider( - provider_id=id_, - provider_type=ProviderType.CHAT_COMPLETION, - umo=umo, - ) - event.set_result(MessageEventResult().message(f"成功切换到 {id_}。")) - else: - event.set_result(MessageEventResult().message("无效的参数。")) + await self.provider_c.provider(event, idx, idx2) @filter.command("reset") async def reset(self, message: AstrMessageEvent): """重置 LLM 会话""" - - # ============================== - # 读取当前情况和配置 - # ============================== - is_unique_session = self.context.get_config()["platform_settings"][ - "unique_session" - ] - is_group = bool(message.get_group_id()) - - scene = RstScene.get_scene(is_group, is_unique_session) - - alter_cmd_cfg = await sp.get_async("global", "global", "alter_cmd", {}) - plugin_config = alter_cmd_cfg.get("astrbot", {}) - reset_cfg = plugin_config.get("reset", {}) - - required_perm = reset_cfg.get( - scene.key, "admin" if is_group and not is_unique_session else "member" - ) - - if required_perm == "admin" and message.role != "admin": - message.set_result( - MessageEventResult().message( - f"在{scene.name}场景下,reset命令需要管理员权限," - f"您 (ID {message.get_sender_id()}) 不是管理员,无法执行此操作。" - ) - ) - return - - if not self.context.get_using_provider(message.unified_msg_origin): - message.set_result( - MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") - ) - return - - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type in ["dify", "coze"]: - await provider.forget(message.unified_msg_origin) - message.set_result( - MessageEventResult().message( - "已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。" - ) - ) - return - - cid = await self.context.conversation_manager.get_curr_conversation_id( - message.unified_msg_origin - ) - - if not cid: - message.set_result( - MessageEventResult().message( - "当前未处于对话状态,请 /switch 切换或者 /new 创建。" - ) - ) - return - - await self.context.conversation_manager.update_conversation( - message.unified_msg_origin, cid, [] - ) - - ret = "清除会话 LLM 聊天历史成功。" - if self.ltm and self.ltm_enabled(message): - cnt = await self.ltm.remove_session(event=message) - ret += f"\n聊天增强: 已清除 {cnt} 条聊天记录。" - - message.set_result(MessageEventResult().message(ret)) + await self.conversation_c.reset(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("model") @@ -575,557 +182,70 @@ async def model_ls( self, message: AstrMessageEvent, idx_or_name: Union[int, str] = None ): """查看或者切换模型""" - prov = self.context.get_using_provider(message.unified_msg_origin) - if not prov: - message.set_result( - MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") - ) - return - # 定义正则表达式匹配 API 密钥 - api_key_pattern = re.compile(r"key=[^&'\" ]+") - - if idx_or_name is None: - models = [] - try: - models = await prov.get_models() - except BaseException as e: - err_msg = api_key_pattern.sub("key=***", str(e)) - message.set_result( - MessageEventResult() - .message("获取模型列表失败: " + err_msg) - .use_t2i(False) - ) - return - i = 1 - ret = "下面列出了此服务提供商可用模型:" - for model in models: - ret += f"\n{i}. {model}" - i += 1 - - curr_model = prov.get_model() or "无" - ret += f"\n当前模型: [{curr_model}]" - - ret += "\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。" - message.set_result(MessageEventResult().message(ret).use_t2i(False)) - else: - if isinstance(idx_or_name, int): - models = [] - try: - models = await prov.get_models() - except BaseException as e: - message.set_result( - MessageEventResult().message("获取模型列表失败: " + str(e)) - ) - return - if idx_or_name > len(models) or idx_or_name < 1: - message.set_result(MessageEventResult().message("模型序号错误。")) - else: - try: - new_model = models[idx_or_name - 1] - prov.set_model(new_model) - except BaseException as e: - message.set_result( - MessageEventResult().message("切换模型未知错误: " + str(e)) - ) - message.set_result(MessageEventResult().message("切换模型成功。")) - else: - prov.set_model(idx_or_name) - message.set_result( - MessageEventResult().message(f"切换模型到 {prov.get_model()}。") - ) + await self.provider_c.model_ls(message, idx_or_name) @filter.command("history") async def his(self, message: AstrMessageEvent, page: int = 1): """查看对话记录""" - if not self.context.get_using_provider(message.unified_msg_origin): - message.set_result( - MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") - ) - return - - size_per_page = 6 - - conv_mgr = self.context.conversation_manager - umo = message.unified_msg_origin - session_curr_cid = await conv_mgr.get_curr_conversation_id(umo) - - if not session_curr_cid: - session_curr_cid = await conv_mgr.new_conversation( - umo, message.get_platform_id() - ) - - contexts, total_pages = await conv_mgr.get_human_readable_context( - umo, session_curr_cid, page, size_per_page - ) - - history = "" - for context in contexts: - if len(context) > 150: - context = context[:150] + "..." - history += f"{context}\n" - - ret = ( - f"当前对话历史记录:" - f"{history if history else '无历史记录'}\n\n" - f"第 {page} 页 | 共 {total_pages} 页\n" - f"*输入 /history 2 跳转到第 2 页" - ) - - message.set_result(MessageEventResult().message(ret).use_t2i(False)) + await self.conversation_c.his(message, page) @filter.command("ls") async def convs(self, message: AstrMessageEvent, page: int = 1): """查看对话列表""" - - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type == "dify": - """原有的Dify处理逻辑保持不变""" - ret = "Dify 对话列表:\n" - assert isinstance(provider, ProviderDify) - data = await provider.api_client.get_chat_convs(message.unified_msg_origin) - idx = 1 - for conv in data["data"]: - ts_h = datetime.datetime.fromtimestamp(conv["updated_at"]).strftime( - "%m-%d %H:%M" - ) - ret += f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n" - idx += 1 - if idx == 1: - ret += "没有找到任何对话。" - dify_cid = provider.conversation_ids.get(message.unified_msg_origin, None) - ret += f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。" - message.set_result(MessageEventResult().message(ret)) - return - - size_per_page = 6 - """获取所有对话列表""" - conversations_all = await self.context.conversation_manager.get_conversations( - message.unified_msg_origin - ) - """计算总页数""" - total_pages = (len(conversations_all) + size_per_page - 1) // size_per_page - """确保页码有效""" - page = max(1, min(page, total_pages)) - """分页处理""" - start_idx = (page - 1) * size_per_page - end_idx = start_idx + size_per_page - conversations_paged = conversations_all[start_idx:end_idx] - - ret = "对话列表:\n---\n" - """全局序号从当前页的第一个开始""" - global_index = start_idx + 1 - - """生成所有对话的标题字典""" - _titles = {} - for conv in conversations_all: - title = conv.title if conv.title else "新对话" - _titles[conv.cid] = title - - """遍历分页后的对话生成列表显示""" - for conv in conversations_paged: - persona_id = conv.persona_id - if not persona_id or persona_id == "[%None]": - persona = await self.context.persona_manager.get_default_persona_v3( - umo=message.unified_msg_origin - ) - persona_id = persona["name"] - title = _titles.get(conv.cid, "新对话") - ret += f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n" - global_index += 1 - - ret += "---\n" - curr_cid = await self.context.conversation_manager.get_curr_conversation_id( - message.unified_msg_origin - ) - if curr_cid: - """从所有对话的标题字典中获取标题""" - title = _titles.get(curr_cid, "新对话") - ret += f"\n当前对话: {title}({curr_cid[:4]})" - else: - ret += "\n当前对话: 无" - - unique_session = self.context.get_config()["platform_settings"][ - "unique_session" - ] - if unique_session: - ret += "\n会话隔离粒度: 个人" - else: - ret += "\n会话隔离粒度: 群聊" - - ret += f"\n第 {page} 页 | 共 {total_pages} 页" - ret += "\n*输入 /ls 2 跳转到第 2 页" - - message.set_result(MessageEventResult().message(ret).use_t2i(False)) - return + await self.conversation_c.convs(message, page) @filter.command("new") async def new_conv(self, message: AstrMessageEvent): """ 创建新对话 """ - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type in ["dify", "coze"]: - await provider.forget(message.unified_msg_origin) - message.set_result( - MessageEventResult().message("成功,下次聊天将是新对话。") - ) - return - - cid = await self.context.conversation_manager.new_conversation( - message.unified_msg_origin, message.get_platform_id() - ) - - # 长期记忆 - if self.ltm and self.ltm_enabled(message): - try: - await self.ltm.remove_session(event=message) - except Exception as e: - logger.error(f"清理聊天增强记录失败: {e}") - - message.set_result( - MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。") - ) + await self.conversation_c.new_conv(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("groupnew") async def groupnew_conv(self, message: AstrMessageEvent, sid: str): """创建新群聊对话""" - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type in ["dify", "coze"]: - await provider.forget(message.unified_msg_origin) - message.set_result( - MessageEventResult().message("成功,下次聊天将是新对话。") - ) - return - if sid: - session = str( - MessageSesion( - platform_name=message.platform_meta.id, - message_type=MessageType("GroupMessage"), - session_id=sid, - ) - ) - cid = await self.context.conversation_manager.new_conversation( - session, message.get_platform_id() - ) - message.set_result( - MessageEventResult().message( - f"群聊 {session} 已切换到新对话: 新对话({cid[:4]})。" - ) - ) - else: - message.set_result( - MessageEventResult().message("请输入群聊 ID。/groupnew 群聊ID。") - ) + await self.conversation_c.groupnew_conv(message, sid) @filter.command("switch") async def switch_conv(self, message: AstrMessageEvent, index: int = None): """通过 /ls 前面的序号切换对话""" - - if not isinstance(index, int): - message.set_result( - MessageEventResult().message("类型错误,请输入数字对话序号。") - ) - return - - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type == "dify": - data = await provider.api_client.get_chat_convs(message.unified_msg_origin) - if not data["data"]: - message.set_result(MessageEventResult().message("未找到任何对话。")) - return - selected_conv = None - if index is not None: - try: - selected_conv = data["data"][index - 1] - except IndexError: - message.set_result( - MessageEventResult().message("对话序号错误,请使用 /ls 查看") - ) - return - else: - selected_conv = data["data"][0] - ret = ( - f"Dify 切换到对话: {selected_conv['name']}({selected_conv['id'][:4]})。" - ) - provider.conversation_ids[message.unified_msg_origin] = selected_conv["id"] - message.set_result(MessageEventResult().message(ret)) - return - - if index is None: - message.set_result( - MessageEventResult().message( - "请输入对话序号。/switch 对话序号。/ls 查看对话 /new 新建对话" - ) - ) - return - conversations = await self.context.conversation_manager.get_conversations( - message.unified_msg_origin - ) - if index > len(conversations) or index < 1: - message.set_result( - MessageEventResult().message("对话序号错误,请使用 /ls 查看") - ) - else: - conversation = conversations[index - 1] - title = conversation.title if conversation.title else "新对话" - await self.context.conversation_manager.switch_conversation( - message.unified_msg_origin, conversation.cid - ) - message.set_result( - MessageEventResult().message( - f"切换到对话: {title}({conversation.cid[:4]})。" - ) - ) + await self.conversation_c.switch_conv(message, index) @filter.command("rename") async def rename_conv(self, message: AstrMessageEvent, new_name: str): """重命名对话""" - provider = self.context.get_using_provider(message.unified_msg_origin) - - if provider and provider.meta().type == "dify": - assert isinstance(provider, ProviderDify) - cid = provider.conversation_ids.get(message.unified_msg_origin, None) - if not cid: - message.set_result(MessageEventResult().message("未找到当前对话。")) - return - await provider.api_client.rename(cid, new_name, message.unified_msg_origin) - message.set_result(MessageEventResult().message("重命名对话成功。")) - return - - await self.context.conversation_manager.update_conversation_title( - message.unified_msg_origin, new_name - ) - message.set_result(MessageEventResult().message("重命名对话成功。")) + await self.conversation_c.rename_conv(message, new_name) @filter.command("del") async def del_conv(self, message: AstrMessageEvent): """删除当前对话""" - is_unique_session = self.context.get_config()["platform_settings"][ - "unique_session" - ] - if message.get_group_id() and not is_unique_session and message.role != "admin": - # 群聊,没开独立会话,发送人不是管理员 - message.set_result( - MessageEventResult().message( - f"会话处于群聊,并且未开启独立会话,并且您 (ID {message.get_sender_id()}) 不是管理员,因此没有权限删除当前对话。" - ) - ) - return - - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type == "dify": - assert isinstance(provider, ProviderDify) - dify_cid = provider.conversation_ids.pop(message.unified_msg_origin, None) - if dify_cid: - await provider.api_client.delete_chat_conv( - message.unified_msg_origin, dify_cid - ) - message.set_result( - MessageEventResult().message( - "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。" - ) - ) - return - - session_curr_cid = ( - await self.context.conversation_manager.get_curr_conversation_id( - message.unified_msg_origin - ) - ) - - if not session_curr_cid: - message.set_result( - MessageEventResult().message( - "当前未处于对话状态,请 /switch 序号 切换或 /new 创建。" - ) - ) - return - - await self.context.conversation_manager.delete_conversation( - message.unified_msg_origin, session_curr_cid - ) - message.set_result( - MessageEventResult().message( - "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。" - ) - ) + await self.conversation_c.del_conv(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("key") async def key(self, message: AstrMessageEvent, index: int = None): - prov = self.context.get_using_provider(message.unified_msg_origin) - if not prov: - message.set_result( - MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") - ) - return - - if index is None: - keys_data = prov.get_keys() - curr_key = prov.get_current_key() - ret = "Key:" - for i, k in enumerate(keys_data): - ret += f"\n{i + 1}. {k[:8]}" - - ret += f"\n当前 Key: {curr_key[:8]}" - ret += "\n当前模型: " + prov.get_model() - ret += "\n使用 /key 切换 Key。" - - message.set_result(MessageEventResult().message(ret).use_t2i(False)) - else: - keys_data = prov.get_keys() - if index > len(keys_data) or index < 1: - message.set_result(MessageEventResult().message("Key 序号错误。")) - else: - try: - new_key = keys_data[index - 1] - prov.set_key(new_key) - except BaseException as e: - message.set_result( - MessageEventResult().message("切换 Key 未知错误: " + str(e)) - ) - message.set_result(MessageEventResult().message("切换 Key 成功。")) + """查看或者切换 Key""" + await self.provider_c.key(message, index) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("persona") async def persona(self, message: AstrMessageEvent): - l = message.message_str.split(" ") # noqa: E741 - umo = message.unified_msg_origin - - curr_persona_name = "无" - cid = await self.context.conversation_manager.get_curr_conversation_id(umo) - default_persona = await self.context.persona_manager.get_default_persona_v3( - umo=umo - ) - curr_cid_title = "无" - if cid: - conversation = await self.context.conversation_manager.get_conversation( - unified_msg_origin=umo, - conversation_id=cid, - create_if_not_exists=True, - ) - if not conversation.persona_id and not conversation.persona_id == "[%None]": - curr_persona_name = default_persona["name"] - else: - curr_persona_name = conversation.persona_id - - curr_cid_title = conversation.title if conversation.title else "新对话" - curr_cid_title += f"({cid[:4]})" - - if len(l) == 1: - message.set_result( - MessageEventResult() - .message( - f"""[Persona] - -- 人格情景列表: `/persona list` -- 设置人格情景: `/persona 人格` -- 人格情景详细信息: `/persona view 人格` -- 取消人格: `/persona unset` - -默认人格情景: {default_persona["name"]} -当前对话 {curr_cid_title} 的人格情景: {curr_persona_name} - -配置人格情景请前往管理面板-配置页 -""" - ) - .use_t2i(False) - ) - elif l[1] == "list": - msg = "人格列表:\n" - for persona in self.context.provider_manager.personas: - msg += f"- {persona['name']}\n" - msg += "\n\n*输入 `/persona view 人格名` 查看人格详细信息" - message.set_result(MessageEventResult().message(msg)) - elif l[1] == "view": - if len(l) == 2: - message.set_result(MessageEventResult().message("请输入人格情景名")) - return - ps = l[2].strip() - if persona := next( - builtins.filter( - lambda persona: persona["name"] == ps, - self.context.provider_manager.personas, - ), - None, - ): - msg = f"人格{ps}的详细信息:\n" - msg += f"{persona['prompt']}\n" - else: - msg = f"人格{ps}不存在" - message.set_result(MessageEventResult().message(msg)) - elif l[1] == "unset": - if not cid: - message.set_result( - MessageEventResult().message("当前没有对话,无法取消人格。") - ) - return - await self.context.conversation_manager.update_conversation_persona_id( - message.unified_msg_origin, "[%None]" - ) - message.set_result(MessageEventResult().message("取消人格成功。")) - else: - ps = "".join(l[1:]).strip() - if not cid: - message.set_result( - MessageEventResult().message( - "当前没有对话,请先开始对话或使用 /new 创建一个对话。" - ) - ) - return - if persona := next( - builtins.filter( - lambda persona: persona["name"] == ps, - self.context.provider_manager.personas, - ), - None, - ): - await self.context.conversation_manager.update_conversation_persona_id( - message.unified_msg_origin, ps - ) - message.set_result( - MessageEventResult().message( - "设置成功。如果您正在切换到不同的人格,请注意使用 /reset 来清空上下文,防止原人格对话影响现人格。" - ) - ) - else: - message.set_result( - MessageEventResult().message( - "不存在该人格情景。使用 /persona list 查看所有。" - ) - ) + """查看或者切换 Persona""" + await self.persona_c.persona(message) @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("dashboard_update") async def update_dashboard(self, event: AstrMessageEvent): - yield event.plain_result("正在尝试更新管理面板...") - await download_dashboard(version=f"v{VERSION}", latest=False) - yield event.plain_result("管理面板更新完成。") + await self.admin_c.update_dashboard(event) @filter.command("set") async def set_variable(self, event: AstrMessageEvent, key: str, value: str): - # session_id = event.get_session_id() - uid = event.unified_msg_origin - session_var = await sp.session_get(uid, "session_variables", {}) - session_var[key] = value - await sp.session_put(uid, "session_variables", session_var) - - yield event.plain_result(f"会话 {uid} 变量 {key} 存储成功。使用 /unset 移除。") + await self.setunset_c.set_variable(event, key, value) @filter.command("unset") async def unset_variable(self, event: AstrMessageEvent, key: str): - uid = event.unified_msg_origin - session_var = await sp.session_get( - umo="uid", key="session_variables", default={} - ) - - if key not in session_var: - yield event.plain_result("没有那个变量名。格式 /unset 变量名。") - else: - del session_var[key] - await sp.session_put(uid, "session_variables", session_var) - yield event.plain_result(f"会话 {uid} 变量 {key} 移除成功。") + await self.setunset_c.unset_variable(event, key) @filter.platform_adapter_type(filter.PlatformAdapterType.ALL) async def on_message(self, event: AstrMessageEvent): @@ -1203,140 +323,7 @@ async def decorate_result(self, event: AstrMessageEvent): @filter.on_llm_request() async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" - cfg = self.context.get_config(umo=event.unified_msg_origin)["provider_settings"] - if prefix := cfg.get("prompt_prefix"): - req.prompt = prefix + req.prompt - - # 解析引用内容 - quote = None - for comp in event.message_obj.message: - if isinstance(comp, Reply): - quote = comp - break - - if cfg.get("identifier"): - user_id = event.message_obj.sender.user_id - user_nickname = event.message_obj.sender.nickname - user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n" - req.prompt = user_info + req.prompt - - if cfg.get("group_name_display") and event.message_obj.group_id: - group_name = event.message_obj.group.group_name - - if group_name: - req.system_prompt += f"\nGroup name: {group_name}\n" - - # 启用附加时间戳 - if cfg.get("datetime_system_prompt"): - current_time = None - if self.timezone: - # 启用时区 - try: - now = datetime.datetime.now(zoneinfo.ZoneInfo(self.timezone)) - current_time = now.strftime("%Y-%m-%d %H:%M (%Z)") - except Exception as e: - logger.error(f"时区设置错误: {e}, 使用本地时区") - if not current_time: - current_time = ( - datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)") - ) - req.system_prompt += f"\nCurrent datetime: {current_time}\n" - - img_cap_prov_id = cfg.get("default_image_caption_provider_id") - if req.conversation: - # persona inject - persona_id = req.conversation.persona_id or cfg.get("default_personality") - if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格 - default_persona = ( - self.context.persona_manager.selected_default_persona_v3 - ) - if default_persona: - persona_id = default_persona["name"] - persona = next( - builtins.filter( - lambda persona: persona["name"] == persona_id, - self.context.persona_manager.personas_v3, - ), - None, - ) - if persona: - if prompt := persona["prompt"]: - req.system_prompt += prompt - if begin_dialogs := persona["_begin_dialogs_processed"]: - req.contexts[:0] = begin_dialogs - - # tools select - tmgr = self.context.get_llm_tool_manager() - if (persona and persona.get("tools") is None) or not persona: - # select all - toolset = tmgr.get_full_tool_set() - for tool in toolset: - if not tool.active: - toolset.remove_tool(tool.name) - else: - toolset = ToolSet() - for tool_name in persona["tools"]: - tool = tmgr.get_func(tool_name) - if tool and tool.active: - toolset.add_tool(tool) - req.func_tool = toolset - logger.debug(f"Tool set for persona {persona_id}: {toolset.names()}") - - # image caption - if img_cap_prov_id and req.image_urls: - img_cap_prompt = cfg.get( - "image_caption_prompt", "Please describe the image." - ) - try: - if prov := self.context.get_provider_by_id(img_cap_prov_id): - logger.debug( - f"Processing image caption with provider: {img_cap_prov_id}" - ) - llm_resp = await prov.text_chat( - prompt=img_cap_prompt, - image_urls=req.image_urls, - ) - if llm_resp.completion_text: - req.prompt = f"(Image Caption: {llm_resp.completion_text})\n\n{req.prompt}" - req.image_urls = [] - except Exception as e: - logger.error(f"处理图片描述失败: {e}") - - if quote: - sender_info = "" - if quote.sender_nickname: - sender_info = f"(Sent by {quote.sender_nickname})" - message_str = quote.message_str or "[Empty Text]" - req.system_prompt += ( - f"\nUser is quoting a message{sender_info}.\n" - f"Here are the information of the quoted message: Text Content: {message_str}.\n" - ) - image_seg = None - if quote.chain: - for comp in quote.chain: - if isinstance(comp, Image): - image_seg = comp - break - if image_seg: - try: - prov = None - if img_cap_prov_id: - prov = self.context.get_provider_by_id(img_cap_prov_id) - if prov is None: - prov = self.context.get_using_provider(event.unified_msg_origin) - if prov: - llm_resp = await prov.text_chat( - prompt="Please describe the image content.", - image_urls=[await image_seg.convert_to_file_path()], - ) - if llm_resp.completion_text: - req.system_prompt += ( - f"Image Caption: {llm_resp.completion_text}\n" - ) - else: - logger.warning("No provider found for image captioning.") - except BaseException as e: - logger.error(f"处理引用图片失败: {e}") + await self.proc_llm_req.process_llm_request(event, req) if self.ltm and self.ltm_enabled(event): try: @@ -1356,133 +343,5 @@ async def after_llm_req(self, event: AstrMessageEvent): @filter.permission_type(filter.PermissionType.ADMIN) @filter.command("alter_cmd", alias={"alter"}) async def alter_cmd(self, event: AstrMessageEvent): - token = self.parse_commands(event.message_str) - if token.len < 3: - yield event.plain_result( - "该指令用于设置指令或指令组的权限。\n" - "格式: /alter_cmd \n" - "例1: /alter_cmd c1 admin 将 c1 设为管理员指令\n" - "例2: /alter_cmd g1 c1 admin 将 g1 指令组的 c1 子指令设为管理员指令\n" - "/alter_cmd reset config 打开 reset 权限配置" - ) - return - - cmd_name = " ".join(token.tokens[1:-1]) - cmd_type = token.get(-1) - - if cmd_name == "reset" and cmd_type == "config": - alter_cmd_cfg = await sp.global_get("alter_cmd", {}) - plugin_ = alter_cmd_cfg.get("astrbot", {}) - reset_cfg = plugin_.get("reset", {}) - - group_unique_on = reset_cfg.get("group_unique_on", "admin") - group_unique_off = reset_cfg.get("group_unique_off", "admin") - private = reset_cfg.get("private", "member") - - config_menu = f"""reset命令权限细粒度配置 - 当前配置: - 1. 群聊+会话隔离开: {group_unique_on} - 2. 群聊+会话隔离关: {group_unique_off} - 3. 私聊: {private} - 修改指令格式: - /alter_cmd reset scene <场景编号> - 例如: /alter_cmd reset scene 2 member""" - yield event.plain_result(config_menu) - return - - if cmd_name == "reset" and cmd_type == "scene" and token.len >= 4: - scene_num = token.get(3) - perm_type = token.get(4) - - if not scene_num.isdigit() or int(scene_num) < 1 or int(scene_num) > 3: - yield event.plain_result("场景编号必须是1-3之间的数字") - return - - if perm_type not in ["admin", "member"]: - yield event.plain_result("权限类型错误,只能是admin或member") - return - - scene_num = int(scene_num) - scene = RstScene.from_index(scene_num) - scene_key = scene.key - - await self.update_reset_permission(scene_key, perm_type) - - yield event.plain_result( - f"已将 reset 命令在{scene.name}场景下的权限设为{perm_type}" - ) - return - - if cmd_type not in ["admin", "member"]: - yield event.plain_result("指令类型错误,可选类型有 admin, member") - return - - # 查找指令 - found_command = None - cmd_group = False - for handler in star_handlers_registry: - assert isinstance(handler, StarHandlerMetadata) - for filter_ in handler.event_filters: - if isinstance(filter_, CommandFilter): - if filter_.equals(cmd_name): - found_command = handler - break - elif isinstance(filter_, CommandGroupFilter): - if filter_.equals(cmd_name): - found_command = handler - cmd_group = True - break - - if not found_command: - yield event.plain_result("未找到该指令") - return - - found_plugin = star_map[found_command.handler_module_path] - - alter_cmd_cfg = await sp.global_get("alter_cmd", {}) - plugin_ = alter_cmd_cfg.get(found_plugin.name, {}) - cfg = plugin_.get(found_command.handler_name, {}) - cfg["permission"] = cmd_type - plugin_[found_command.handler_name] = cfg - alter_cmd_cfg[found_plugin.name] = plugin_ - - await sp.global_put("alter_cmd", alter_cmd_cfg) - - # 注入权限过滤器 - found_permission_filter = False - for filter_ in found_command.event_filters: - if isinstance(filter_, PermissionTypeFilter): - if cmd_type == "admin": - filter_.permission_type = filter.PermissionType.ADMIN - else: - filter_.permission_type = filter.PermissionType.MEMBER - found_permission_filter = True - break - if not found_permission_filter: - found_command.event_filters.insert( - 0, - PermissionTypeFilter( - filter.PermissionType.ADMIN - if cmd_type == "admin" - else filter.PermissionType.MEMBER - ), - ) - cmd_group_str = "指令组" if cmd_group else "指令" - yield event.plain_result( - f"已将「{cmd_name}」{cmd_group_str} 的权限级别调整为 {cmd_type}。" - ) - - async def update_reset_permission(self, scene_key: str, perm_type: str): - """更新reset命令在特定场景下的权限设置 - - Args: - scene_key (str): 场景编号,1-3 - perm_type (str): 权限类型,admin或member - """ - alter_cmd_cfg = await sp.global_get("alter_cmd", {}) - plugin_cfg = alter_cmd_cfg.get("astrbot", {}) - reset_cfg = plugin_cfg.get("reset", {}) - reset_cfg[scene_key] = perm_type - plugin_cfg["reset"] = reset_cfg - alter_cmd_cfg["astrbot"] = plugin_cfg - await sp.global_put("alter_cmd", alter_cmd_cfg) + """修改命令权限""" + await self.alter_cmd_c.alter_cmd(event) diff --git a/packages/astrbot/process_llm_request.py b/packages/astrbot/process_llm_request.py new file mode 100644 index 000000000..e1d7ab42b --- /dev/null +++ b/packages/astrbot/process_llm_request.py @@ -0,0 +1,191 @@ +import astrbot.api.star as star +import builtins +import datetime +import zoneinfo +from astrbot.api import logger +from astrbot.api.event import AstrMessageEvent +from astrbot.api.provider import Provider +from astrbot.api.provider import ProviderRequest +from astrbot.core.provider.func_tool_manager import ToolSet +from astrbot.api.message_components import Image, Reply + + +class ProcessLLMRequest: + def __init__(self, context: star.Context): + self.ctx = context + cfg = context.get_config() + self.timezone = cfg.get("timezone") + if not self.timezone: + # 系统默认时区 + self.timezone = None + else: + logger.info(f"Timezone set to: {self.timezone}") + + def _ensure_persona(self, req: ProviderRequest, cfg: dict): + """确保用户人格已加载""" + if not req.conversation: + return + # persona inject + persona_id = req.conversation.persona_id or cfg.get("default_personality") + if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格 + default_persona = self.ctx.persona_manager.selected_default_persona_v3 + if default_persona: + persona_id = default_persona["name"] + persona = next( + builtins.filter( + lambda persona: persona["name"] == persona_id, + self.ctx.persona_manager.personas_v3, + ), + None, + ) + if persona: + if prompt := persona["prompt"]: + req.system_prompt += prompt + if begin_dialogs := persona["_begin_dialogs_processed"]: + req.contexts[:0] = begin_dialogs + + # tools select + tmgr = self.ctx.get_llm_tool_manager() + if (persona and persona.get("tools") is None) or not persona: + # select all + toolset = tmgr.get_full_tool_set() + for tool in toolset: + if not tool.active: + toolset.remove_tool(tool.name) + else: + toolset = ToolSet() + if persona["tools"]: + for tool_name in persona["tools"]: + tool = tmgr.get_func(tool_name) + if tool and tool.active: + toolset.add_tool(tool) + req.func_tool = toolset + logger.debug(f"Tool set for persona {persona_id}: {toolset.names()}") + + async def _ensure_img_caption( + self, req: ProviderRequest, cfg: dict, img_cap_prov_id: str + ): + try: + caption = await self._request_img_caption( + img_cap_prov_id, cfg, req.image_urls + ) + if caption: + req.prompt = f"(Image Caption: {caption})\n\n{req.prompt}" + req.image_urls = [] + except Exception as e: + logger.error(f"处理图片描述失败: {e}") + + async def _request_img_caption( + self, provider_id: str, cfg: dict, image_urls: list[str] + ) -> str: + if prov := self.ctx.get_provider_by_id(provider_id): + if isinstance(prov, Provider): + img_cap_prompt = cfg.get( + "image_caption_prompt", "Please describe the image." + ) + logger.debug(f"Processing image caption with provider: {provider_id}") + llm_resp = await prov.text_chat( + prompt=img_cap_prompt, + image_urls=image_urls, + ) + return llm_resp.completion_text + else: + raise ValueError( + f"Cannot get image caption because provider `{provider_id}` is not a valid Provider, it is {type(prov)}." + ) + else: + raise ValueError( + f"Cannot get image caption because provider `{provider_id}` is not exist." + ) + + async def process_llm_request(self, event: AstrMessageEvent, req: ProviderRequest): + """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt""" + cfg: dict = self.ctx.get_config(umo=event.unified_msg_origin)[ + "provider_settings" + ] + + # prompt prefix + if prefix := cfg.get("prompt_prefix"): + req.prompt = prefix + req.prompt + + # user identifier + if cfg.get("identifier"): + user_id = event.message_obj.sender.user_id + user_nickname = event.message_obj.sender.nickname + req.prompt = ( + f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n{req.prompt}" + ) + + # group name identifier + if cfg.get("group_name_display") and event.message_obj.group_id: + group_name = event.message_obj.group.group_name + if group_name: + req.system_prompt += f"\nGroup name: {group_name}\n" + + # time info + if cfg.get("datetime_system_prompt"): + current_time = None + if self.timezone: + # 启用时区 + try: + now = datetime.datetime.now(zoneinfo.ZoneInfo(self.timezone)) + current_time = now.strftime("%Y-%m-%d %H:%M (%Z)") + except Exception as e: + logger.error(f"时区设置错误: {e}, 使用本地时区") + if not current_time: + current_time = ( + datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)") + ) + req.system_prompt += f"\nCurrent datetime: {current_time}\n" + + img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or "" + if req.conversation: + # inject persona for this request + self._ensure_persona(req, cfg) + + # image caption + if img_cap_prov_id and req.image_urls: + await self._ensure_img_caption(req, cfg, img_cap_prov_id) + + # quote message processing + # 解析引用内容 + quote = None + for comp in event.message_obj.message: + if isinstance(comp, Reply): + quote = comp + break + if quote: + sender_info = "" + if quote.sender_nickname: + sender_info = f"(Sent by {quote.sender_nickname})" + message_str = quote.message_str or "[Empty Text]" + req.system_prompt += ( + f"\nUser is quoting a message{sender_info}.\n" + f"Here are the information of the quoted message: Text Content: {message_str}.\n" + ) + image_seg = None + if quote.chain: + for comp in quote.chain: + if isinstance(comp, Image): + image_seg = comp + break + if image_seg: + try: + prov = None + if img_cap_prov_id: + prov = self.ctx.get_provider_by_id(img_cap_prov_id) + if prov is None: + prov = self.ctx.get_using_provider(event.unified_msg_origin) + if prov and isinstance(prov, Provider): + llm_resp = await prov.text_chat( + prompt="Please describe the image content.", + image_urls=[await image_seg.convert_to_file_path()], + ) + if llm_resp.completion_text: + req.system_prompt += ( + f"Image Caption: {llm_resp.completion_text}\n" + ) + else: + logger.warning("No provider found for image captioning.") + except BaseException as e: + logger.error(f"处理引用图片失败: {e}")