From f0fcd2ab42186b641e2dad8b002be453ca0543cf Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Thu, 11 Sep 2025 18:25:54 +0800 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0=E4=BE=9B?= =?UTF-8?q?=E6=8F=92=E4=BB=B6=E4=BD=BF=E7=94=A8=E7=9A=84=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=AD=98=E5=8F=96=E6=96=B9=E6=B3=95=E5=8F=8A=E7=9B=91=E5=90=AC?= =?UTF-8?q?=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/star/star_tools.py | 276 +++++++++++++++++++++++++++++++- 1 file changed, 274 insertions(+), 2 deletions(-) diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 42ed168ff..170280308 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -21,11 +21,13 @@ import inspect import os import uuid +import asyncio from pathlib import Path -from typing import Union, Awaitable, List, Optional, ClassVar +from typing import Union, Awaitable, List, Optional, ClassVar, Dict, Any, Callable +from astrbot.core import logger from astrbot.core.message.components import BaseMessageComponent from astrbot.core.message.message_event_result import MessageChain -from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType +from astrbot.core.platform import MessageMember, AstrBotMessage, MessageType from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.star.context import Context from astrbot.core.star.star import star_map @@ -40,6 +42,9 @@ class StarTools: """ _context: ClassVar[Optional[Context]] = None + _shared_data: ClassVar[Dict[str, Any]] = {} + _data_listeners: ClassVar[Dict[str, List[Callable[[str, Any], Awaitable[None]]]]] = {} + @classmethod def initialize(cls, context: Context) -> None: @@ -287,3 +292,270 @@ def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path: raise RuntimeError(f"无法创建目录 {data_dir}:{e!s}") from e return data_dir.resolve() + + @classmethod + def set_shared_data(cls, key: str, value: Any, plugin_name: Optional[str] = None) -> None: + """ + 设置插件间共享数据 + + Args: + key (str): 数据键名 + value (Any): 要存储的数据,支持任意数据类型 + plugin_name (Optional[str]): 插件名称,如果为None则自动检测 + + Example: + # 设置工作状态 + StarTools.set_shared_data("worker_status", True) + + # 设置复杂数据 + StarTools.set_shared_data("task_progress", { + "current": 5, + "total": 10, + "status": "processing" + }) + """ + if not plugin_name: + frame = inspect.currentframe() + module = None + if frame: + frame = frame.f_back + module = inspect.getmodule(frame) + + if not module: + raise RuntimeError("无法获取调用者模块信息") + + metadata = star_map.get(module.__name__, None) + if not metadata: + raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") + + plugin_name = metadata.name + + full_key = f"{plugin_name}:{key}" + cls._shared_data[full_key] = value + asyncio.create_task(cls._notify_listeners(full_key, value)) + + @classmethod + def get_shared_data(cls, key: str, plugin_name: Optional[str] = None, default: Any = None) -> Any: + """ + 获取插件间共享数据 + + Args: + key (str): 数据键名 + plugin_name (Optional[str]): 插件名称,如果为None则自动检测 + default (Any): 当数据不存在时返回的默认值 + + Returns: + Any: 存储的数据,如果不存在则返回default + + Example: + # 获取其他插件的工作状态 + status = StarTools.get_shared_data("worker_status", "other_plugin") + + # 获取当前插件的数据 + my_data = StarTools.get_shared_data("my_key") + """ + if not plugin_name: + frame = inspect.currentframe() + module = None + if frame: + frame = frame.f_back + module = inspect.getmodule(frame) + + if not module: + raise RuntimeError("无法获取调用者模块信息") + + metadata = star_map.get(module.__name__, None) + if not metadata: + raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") + + plugin_name = metadata.name + + full_key = f"{plugin_name}:{key}" + return cls._shared_data.get(full_key, default) + + @classmethod + def remove_shared_data(cls, key: str, plugin_name: Optional[str] = None) -> bool: + """ + 删除插件间共享数据 + + Args: + key (str): 数据键名 + plugin_name (Optional[str]): 插件名称,如果为None则自动检测 + + Returns: + bool: 是否成功删除(True表示数据存在并被删除,False表示数据不存在) + """ + if not plugin_name: + frame = inspect.currentframe() + module = None + if frame: + frame = frame.f_back + module = inspect.getmodule(frame) + + if not module: + raise RuntimeError("无法获取调用者模块信息") + + metadata = star_map.get(module.__name__, None) + if not metadata: + raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") + + plugin_name = metadata.name + + full_key = f"{plugin_name}:{key}" + if full_key in cls._shared_data: + del cls._shared_data[full_key] + return True + return False + + @classmethod + def list_shared_data(cls, plugin_name: Optional[str] = None) -> Dict[str, Any]: + """ + 列出指定插件的所有共享数据 + + Args: + plugin_name (Optional[str]): 插件名称,如果为None则返回所有数据 + + Returns: + Dict[str, Any]: 数据字典,键为原始键名(不包含插件前缀) + + Example: + # 获取当前插件的所有数据 + my_data = StarTools.list_shared_data() + + # 获取所有插件的数据 + all_data = StarTools.list_shared_data("") + """ + if plugin_name is None: + frame = inspect.currentframe() + module = None + if frame: + frame = frame.f_back + module = inspect.getmodule(frame) + + if not module: + raise RuntimeError("无法获取调用者模块信息") + + metadata = star_map.get(module.__name__, None) + if not metadata: + raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") + + plugin_name = metadata.name + + if plugin_name == "": + return dict(cls._shared_data) + + prefix = f"{plugin_name}:" + result = {} + for full_key, value in cls._shared_data.items(): + if full_key.startswith(prefix): + original_key = full_key[len(prefix):] + result[original_key] = value + return result + + @classmethod + def add_data_listener( + cls, + key: str, + callback: Callable[[str, Any], Awaitable[None]], + plugin_name: Optional[str] = None + ) -> None: + """ + 添加数据变化监听器 + + Args: + key (str): 要监听的数据键名 + callback (Callable): 回调函数,接受参数(key, new_value) + plugin_name (Optional[str]): 插件名称,如果为None则自动检测 + + Example: + async def on_worker_status_change(key: str, value: Any): + if value: + logger.INFO("哈哈我的工作完成啦!") + + StarTools.add_data_listener("worker_status", on_worker_status_change, "other_plugin") + """ + if not plugin_name: + frame = inspect.currentframe() + module = None + if frame: + frame = frame.f_back + module = inspect.getmodule(frame) + + if not module: + raise RuntimeError("无法获取调用者模块信息") + + metadata = star_map.get(module.__name__, None) + if not metadata: + raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") + + plugin_name = metadata.name + + full_key = f"{plugin_name}:{key}" + if full_key not in cls._data_listeners: + cls._data_listeners[full_key] = [] + cls._data_listeners[full_key].append(callback) + + @classmethod + def remove_data_listener( + cls, + key: str, + callback: Callable[[str, Any], Awaitable[None]], + plugin_name: Optional[str] = None + ) -> bool: + """ + 移除数据变化监听器 + + Args: + key (str): 数据键名 + callback (Callable): 要移除的回调函数 + plugin_name (Optional[str]): 插件名称,如果为None则自动检测 + + Returns: + bool: 是否成功移除 + """ + if not plugin_name: + frame = inspect.currentframe() + module = None + if frame: + frame = frame.f_back + module = inspect.getmodule(frame) + + if not module: + raise RuntimeError("无法获取调用者模块信息") + + metadata = star_map.get(module.__name__, None) + if not metadata: + raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") + + plugin_name = metadata.name + + full_key = f"{plugin_name}:{key}" + if full_key in cls._data_listeners and callback in cls._data_listeners[full_key]: + cls._data_listeners[full_key].remove(callback) + if not cls._data_listeners[full_key]: + del cls._data_listeners[full_key] + return True + return False + + @classmethod + async def _notify_listeners(cls, full_key: str, value: Any) -> None: + """ + 通知所有监听指定数据的回调函数 + + Args: + full_key (str): 完整的数据键名(包含插件前缀) + value (Any): 新的数据值 + """ + if full_key in cls._data_listeners: + tasks = [] + for callback in cls._data_listeners[full_key]: + try: + task = callback(full_key, value) + if asyncio.iscoroutine(task): + tasks.append(task) + except Exception as e: + logger.Error(f"数据监听器错误:{full_key}: {e}") + + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + From 94515deca97561d405df955113ca81749fc6fd2e Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Thu, 11 Sep 2025 18:43:10 +0800 Subject: [PATCH 2/3] =?UTF-8?q?refactor:=201.=E9=87=8D=E6=9E=84=E6=8A=BD?= =?UTF-8?q?=E5=8F=96=E9=80=9A=E7=94=A8=E9=80=BB=E8=BE=91=202.=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E8=AF=BB=E5=86=99=E9=94=81=203.=20=E5=8F=8A=E6=97=B6?= =?UTF-8?q?=E6=B8=85=E9=99=A4frame?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/star/star_tools.py | 250 +++++++++++++------------------- 1 file changed, 101 insertions(+), 149 deletions(-) diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 170280308..bd80efdff 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -44,6 +44,7 @@ class StarTools: _context: ClassVar[Optional[Context]] = None _shared_data: ClassVar[Dict[str, Any]] = {} _data_listeners: ClassVar[Dict[str, List[Callable[[str, Any], Awaitable[None]]]]] = {} + _data_lock: ClassVar[asyncio.Lock] = asyncio.Lock() @classmethod @@ -262,27 +263,12 @@ def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path: - 无法获取模块的元数据信息 - 创建目录失败(权限不足或其他IO错误) """ - if not plugin_name: - frame = inspect.currentframe() - module = None - if frame: - frame = frame.f_back - module = inspect.getmodule(frame) - - if not module: - raise RuntimeError("无法获取调用者模块信息") - - metadata = star_map.get(module.__name__, None) - - if not metadata: - raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") - - plugin_name = metadata.name + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) - if not plugin_name: + if not resolved_plugin_name: raise ValueError("无法获取插件名称") - data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)) + data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", resolved_plugin_name)) try: data_dir.mkdir(parents=True, exist_ok=True) @@ -293,8 +279,46 @@ def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path: return data_dir.resolve() + return data_dir.resolve() + + @classmethod + def _get_caller_plugin_name(cls, plugin_name: Optional[str]) -> str: + """ + 通过调用栈获取插件名称 + + Returns: + str: 插件名称 + + Raises: + RuntimeError: 当无法获取调用者模块信息或元数据信息时抛出 + """ + if plugin_name is not None: + return plugin_name + + frame = inspect.currentframe() + try: + if frame: + frame = frame.f_back + if frame: + frame = frame.f_back + + if not frame: + raise RuntimeError("无法获取调用者帧信息") + + module = inspect.getmodule(frame) + if not module: + raise RuntimeError("无法获取调用者模块信息") + + metadata = star_map.get(module.__name__, None) + if not metadata: + raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") + + return metadata.name + finally: + del frame + @classmethod - def set_shared_data(cls, key: str, value: Any, plugin_name: Optional[str] = None) -> None: + async def set_shared_data(cls, key: str, value: Any, plugin_name: Optional[str] = None) -> None: """ 设置插件间共享数据 @@ -314,28 +338,16 @@ def set_shared_data(cls, key: str, value: Any, plugin_name: Optional[str] = None "status": "processing" }) """ - if not plugin_name: - frame = inspect.currentframe() - module = None - if frame: - frame = frame.f_back - module = inspect.getmodule(frame) + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) + full_key = f"{resolved_plugin_name}:{key}" - if not module: - raise RuntimeError("无法获取调用者模块信息") - - metadata = star_map.get(module.__name__, None) - if not metadata: - raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") + async with cls._data_lock: + cls._shared_data[full_key] = value - plugin_name = metadata.name - - full_key = f"{plugin_name}:{key}" - cls._shared_data[full_key] = value - asyncio.create_task(cls._notify_listeners(full_key, value)) + await cls._notify_listeners(full_key, value) @classmethod - def get_shared_data(cls, key: str, plugin_name: Optional[str] = None, default: Any = None) -> Any: + async def get_shared_data(cls, key: str, plugin_name: Optional[str] = None, default: Any = None) -> Any: """ 获取插件间共享数据 @@ -354,27 +366,14 @@ def get_shared_data(cls, key: str, plugin_name: Optional[str] = None, default: A # 获取当前插件的数据 my_data = StarTools.get_shared_data("my_key") """ - if not plugin_name: - frame = inspect.currentframe() - module = None - if frame: - frame = frame.f_back - module = inspect.getmodule(frame) - - if not module: - raise RuntimeError("无法获取调用者模块信息") + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) + full_key = f"{resolved_plugin_name}:{key}" - metadata = star_map.get(module.__name__, None) - if not metadata: - raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") - - plugin_name = metadata.name - - full_key = f"{plugin_name}:{key}" - return cls._shared_data.get(full_key, default) + async with cls._data_lock: + return cls._shared_data.get(full_key, default) @classmethod - def remove_shared_data(cls, key: str, plugin_name: Optional[str] = None) -> bool: + async def remove_shared_data(cls, key: str, plugin_name: Optional[str] = None) -> bool: """ 删除插件间共享数据 @@ -385,35 +384,22 @@ def remove_shared_data(cls, key: str, plugin_name: Optional[str] = None) -> bool Returns: bool: 是否成功删除(True表示数据存在并被删除,False表示数据不存在) """ - if not plugin_name: - frame = inspect.currentframe() - module = None - if frame: - frame = frame.f_back - module = inspect.getmodule(frame) - - if not module: - raise RuntimeError("无法获取调用者模块信息") - - metadata = star_map.get(module.__name__, None) - if not metadata: - raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") - - plugin_name = metadata.name + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) + full_key = f"{resolved_plugin_name}:{key}" - full_key = f"{plugin_name}:{key}" - if full_key in cls._shared_data: - del cls._shared_data[full_key] - return True - return False + async with cls._data_lock: + if full_key in cls._shared_data: + del cls._shared_data[full_key] + return True + return False @classmethod - def list_shared_data(cls, plugin_name: Optional[str] = None) -> Dict[str, Any]: + async def list_shared_data(cls, plugin_name: Optional[str] = None) -> Dict[str, Any]: """ 列出指定插件的所有共享数据 Args: - plugin_name (Optional[str]): 插件名称,如果为None则返回所有数据 + plugin_name (Optional[str]): 插件名称,如果为None则返回当前插件数据,如果为空字符串则返回所有数据 Returns: Dict[str, Any]: 数据字典,键为原始键名(不包含插件前缀) @@ -425,35 +411,21 @@ def list_shared_data(cls, plugin_name: Optional[str] = None) -> Dict[str, Any]: # 获取所有插件的数据 all_data = StarTools.list_shared_data("") """ - if plugin_name is None: - frame = inspect.currentframe() - module = None - if frame: - frame = frame.f_back - module = inspect.getmodule(frame) - - if not module: - raise RuntimeError("无法获取调用者模块信息") - - metadata = star_map.get(module.__name__, None) - if not metadata: - raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") + async with cls._data_lock: + if plugin_name == "": + return dict(cls._shared_data) - plugin_name = metadata.name - - if plugin_name == "": - return dict(cls._shared_data) - - prefix = f"{plugin_name}:" - result = {} - for full_key, value in cls._shared_data.items(): - if full_key.startswith(prefix): - original_key = full_key[len(prefix):] - result[original_key] = value - return result + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) + prefix = f"{resolved_plugin_name}:" + result = {} + for full_key, value in cls._shared_data.items(): + if full_key.startswith(prefix): + original_key = full_key[len(prefix):] + result[original_key] = value + return result @classmethod - def add_data_listener( + async def add_data_listener( cls, key: str, callback: Callable[[str, Any], Awaitable[None]], @@ -470,33 +442,20 @@ def add_data_listener( Example: async def on_worker_status_change(key: str, value: Any): if value: - logger.INFO("哈哈我的工作完成啦!") + logger.info("哈哈我的工作完成啦!") StarTools.add_data_listener("worker_status", on_worker_status_change, "other_plugin") """ - if not plugin_name: - frame = inspect.currentframe() - module = None - if frame: - frame = frame.f_back - module = inspect.getmodule(frame) - - if not module: - raise RuntimeError("无法获取调用者模块信息") - - metadata = star_map.get(module.__name__, None) - if not metadata: - raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") - - plugin_name = metadata.name + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) + full_key = f"{resolved_plugin_name}:{key}" - full_key = f"{plugin_name}:{key}" - if full_key not in cls._data_listeners: - cls._data_listeners[full_key] = [] - cls._data_listeners[full_key].append(callback) + async with cls._data_lock: + if full_key not in cls._data_listeners: + cls._data_listeners[full_key] = [] + cls._data_listeners[full_key].append(callback) @classmethod - def remove_data_listener( + async def remove_data_listener( cls, key: str, callback: Callable[[str, Any], Awaitable[None]], @@ -513,29 +472,16 @@ def remove_data_listener( Returns: bool: 是否成功移除 """ - if not plugin_name: - frame = inspect.currentframe() - module = None - if frame: - frame = frame.f_back - module = inspect.getmodule(frame) - - if not module: - raise RuntimeError("无法获取调用者模块信息") - - metadata = star_map.get(module.__name__, None) - if not metadata: - raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) + full_key = f"{resolved_plugin_name}:{key}" - plugin_name = metadata.name - - full_key = f"{plugin_name}:{key}" - if full_key in cls._data_listeners and callback in cls._data_listeners[full_key]: - cls._data_listeners[full_key].remove(callback) - if not cls._data_listeners[full_key]: - del cls._data_listeners[full_key] - return True - return False + async with cls._data_lock: + if full_key in cls._data_listeners and callback in cls._data_listeners[full_key]: + cls._data_listeners[full_key].remove(callback) + if not cls._data_listeners[full_key]: + del cls._data_listeners[full_key] + return True + return False @classmethod async def _notify_listeners(cls, full_key: str, value: Any) -> None: @@ -546,16 +492,22 @@ async def _notify_listeners(cls, full_key: str, value: Any) -> None: full_key (str): 完整的数据键名(包含插件前缀) value (Any): 新的数据值 """ - if full_key in cls._data_listeners: + listeners = [] + async with cls._data_lock: + if full_key in cls._data_listeners: + listeners = cls._data_listeners[full_key].copy() + + if listeners: tasks = [] - for callback in cls._data_listeners[full_key]: + for callback in listeners: try: task = callback(full_key, value) if asyncio.iscoroutine(task): tasks.append(task) except Exception as e: - logger.Error(f"数据监听器错误:{full_key}: {e}") + logger.error(f"数据监听器错误:{full_key}: {e}") if tasks: await asyncio.gather(*tasks, return_exceptions=True) + From 13faab2a917051e13f5631cceca94a0ff27499f2 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Thu, 11 Sep 2025 20:22:49 +0800 Subject: [PATCH 3/3] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E4=B8=80=E5=A4=84?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/star/star_tools.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index bd80efdff..877241606 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -279,8 +279,6 @@ def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path: return data_dir.resolve() - return data_dir.resolve() - @classmethod def _get_caller_plugin_name(cls, plugin_name: Optional[str]) -> str: """