diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py index 7fb7f9d3e..37f8527a1 100644 --- a/astrbot/core/platform/platform_metadata.py +++ b/astrbot/core/platform/platform_metadata.py @@ -14,3 +14,5 @@ class PlatformMetadata: """平台的默认配置模板""" adapter_display_name: str = None """显示在 WebUI 配置页中的平台名称,如空则是 name""" + logo_path: str = None + """平台适配器的 logo 文件路径(相对于插件目录)""" diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py index fa65392a8..97c33a43e 100644 --- a/astrbot/core/platform/register.py +++ b/astrbot/core/platform/register.py @@ -13,10 +13,12 @@ def register_platform_adapter( desc: str, default_config_tmpl: dict = None, adapter_display_name: str = None, + logo_path: str = None, ): """用于注册平台适配器的带参装饰器。 default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。 + logo_path 指定了平台适配器的 logo 文件路径,是相对于插件目录的路径。 """ def decorator(cls): @@ -39,6 +41,7 @@ def decorator(cls): description=desc, default_config_tmpl=default_config_tmpl, adapter_display_name=adapter_display_name, + logo_path=logo_path, ) platform_registry.append(pm) platform_cls_map[adapter_name] = cls diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 1cc2319a0..bb0b723bf 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1,6 +1,7 @@ import typing import traceback import os +import inspect from .route import Route, Response, RouteContext from astrbot.core.provider.entities import ProviderType from quart import request @@ -13,10 +14,10 @@ from astrbot.core.utils.astrbot_path import get_astrbot_path from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.platform.register import platform_registry +from astrbot.core.platform.register import platform_registry, platform_cls_map from astrbot.core.provider.register import provider_registry from astrbot.core.star.star import star_registry -from astrbot.core import logger +from astrbot.core import logger, file_token_service from astrbot.core.provider import Provider from astrbot.core.provider.provider import RerankProvider import asyncio @@ -149,6 +150,7 @@ def __init__( super().__init__(context) self.core_lifecycle = core_lifecycle self.config: AstrBotConfig = core_lifecycle.astrbot_config + self._logo_token_cache = {} # 缓存logo token,避免重复注册 self.acm = core_lifecycle.astrbot_config_mgr self.routes = { "/config/abconf/new": ("POST", self.create_abconf), @@ -655,6 +657,78 @@ async def post_delete_provider(self): return Response().error(str(e)).__dict__ return Response().ok(None, "删除成功,已经实时生效~").__dict__ + async def get_llm_tools(self): + """获取函数调用工具。包含了本地加载的以及 MCP 服务的工具""" + tool_mgr = self.core_lifecycle.provider_manager.llm_tools + tools = tool_mgr.get_func_desc_openai_style() + return Response().ok(tools).__dict__ + + async def _register_platform_logo(self, platform, platform_default_tmpl): + """注册平台logo文件并生成访问令牌""" + if not platform.logo_path: + return + + try: + # 检查缓存 + cache_key = f"{platform.name}:{platform.logo_path}" + if cache_key in self._logo_token_cache: + cached_token = self._logo_token_cache[cache_key] + # 确保platform_default_tmpl[platform.name]存在且为字典 + if platform.name not in platform_default_tmpl: + platform_default_tmpl[platform.name] = {} + elif not isinstance(platform_default_tmpl[platform.name], dict): + platform_default_tmpl[platform.name] = {} + platform_default_tmpl[platform.name]["logo_token"] = cached_token + logger.debug(f"Using cached logo token for platform {platform.name}") + return + + # 获取平台适配器类 + platform_cls = platform_cls_map.get(platform.name) + if not platform_cls: + logger.warning(f"Platform class not found for {platform.name}") + return + + # 获取插件目录路径 + module_file = inspect.getfile(platform_cls) + plugin_dir = os.path.dirname(module_file) + + # 解析logo文件路径 + logo_file_path = os.path.join(plugin_dir, platform.logo_path) + + # 检查文件是否存在并注册令牌 + if os.path.exists(logo_file_path): + logo_token = await file_token_service.register_file( + logo_file_path, timeout=3600 + ) + + # 确保platform_default_tmpl[platform.name]存在且为字典 + if platform.name not in platform_default_tmpl: + platform_default_tmpl[platform.name] = {} + elif not isinstance(platform_default_tmpl[platform.name], dict): + platform_default_tmpl[platform.name] = {} + + platform_default_tmpl[platform.name]["logo_token"] = logo_token + + # 缓存token + self._logo_token_cache[cache_key] = logo_token + + logger.debug(f"Logo token registered for platform {platform.name}") + else: + logger.warning( + f"Platform {platform.name} logo file not found: {logo_file_path}" + ) + + except (ImportError, AttributeError) as e: + logger.warning( + f"Failed to import required modules for platform {platform.name}: {e}" + ) + except (OSError, IOError) as e: + logger.warning(f"File system error for platform {platform.name} logo: {e}") + except Exception as e: + logger.warning( + f"Unexpected error registering logo for platform {platform.name}: {e}" + ) + async def _get_astrbot_config(self): config = self.config @@ -662,9 +736,21 @@ async def _get_astrbot_config(self): platform_default_tmpl = CONFIG_METADATA_2["platform_group"]["metadata"][ "platform" ]["config_template"] + + # 收集需要注册logo的平台 + logo_registration_tasks = [] for platform in platform_registry: if platform.default_config_tmpl: platform_default_tmpl[platform.name] = platform.default_config_tmpl + # 收集logo注册任务 + if platform.logo_path: + logo_registration_tasks.append( + self._register_platform_logo(platform, platform_default_tmpl) + ) + + # 并行执行logo注册 + if logo_registration_tasks: + await asyncio.gather(*logo_registration_tasks, return_exceptions=True) # 服务提供商的默认配置模板注入 provider_default_tmpl = CONFIG_METADATA_2["provider_group"]["metadata"][ diff --git a/dashboard/src/views/PlatformPage.vue b/dashboard/src/views/PlatformPage.vue index e135908e3..1d14b3ffb 100644 --- a/dashboard/src/views/PlatformPage.vue +++ b/dashboard/src/views/PlatformPage.vue @@ -114,7 +114,7 @@ {{ tm('dialog.idConflict.confirm') - }} + }} @@ -241,7 +241,15 @@ export default { methods: { // 从工具函数导入 - getPlatformIcon, + getPlatformIcon(platform_id) { + // 首先检查是否有来自插件的 logo_token + const template = this.metadata['platform_group']?.metadata?.platform?.config_template?.[platform_id]; + if (template && template.logo_token) { + // 通过文件服务访问插件提供的 logo + return `/api/file/${template.logo_token}`; + } + return getPlatformIcon(platform_id); + }, openTutorial() { const tutorialUrl = getTutorialLink(this.newSelectedPlatformConfig.type);