Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions astrbot/core/platform/platform_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ class PlatformMetadata:
"""平台的默认配置模板"""
adapter_display_name: str = None
"""显示在 WebUI 配置页中的平台名称,如空则是 name"""
logo_path: str = None
"""平台适配器的 logo 文件路径(相对于插件目录)"""
3 changes: 3 additions & 0 deletions astrbot/core/platform/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ def register_platform_adapter(
desc: str,
default_config_tmpl: dict = None,
adapter_display_name: str = None,
logo_path: str = None,
):
"""用于注册平台适配器的带参装饰器。

default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
logo_path 指定了平台适配器的 logo 文件路径,是相对于插件目录的路径。
"""

def decorator(cls):
Expand All @@ -39,6 +41,7 @@ def decorator(cls):
description=desc,
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name,
logo_path=logo_path,
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls
Expand Down
90 changes: 88 additions & 2 deletions astrbot/dashboard/routes/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing
import traceback
import os
import inspect
from .route import Route, Response, RouteContext
from astrbot.core.provider.entities import ProviderType
from quart import request
Expand All @@ -13,10 +14,10 @@
from astrbot.core.utils.astrbot_path import get_astrbot_path
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_registry
from astrbot.core.platform.register import platform_registry, platform_cls_map
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
from astrbot.core import logger
from astrbot.core import logger, file_token_service
from astrbot.core.provider import Provider
from astrbot.core.provider.provider import RerankProvider
import asyncio
Expand Down Expand Up @@ -149,6 +150,7 @@ def __init__(
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.config: AstrBotConfig = core_lifecycle.astrbot_config
self._logo_token_cache = {} # 缓存logo token,避免重复注册
self.acm = core_lifecycle.astrbot_config_mgr
self.routes = {
"/config/abconf/new": ("POST", self.create_abconf),
Expand Down Expand Up @@ -655,16 +657,100 @@ async def post_delete_provider(self):
return Response().error(str(e)).__dict__
return Response().ok(None, "删除成功,已经实时生效~").__dict__

async def get_llm_tools(self):
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
tools = tool_mgr.get_func_desc_openai_style()
return Response().ok(tools).__dict__

async def _register_platform_logo(self, platform, platform_default_tmpl):
"""注册平台logo文件并生成访问令牌"""
if not platform.logo_path:
return

try:
# 检查缓存
cache_key = f"{platform.name}:{platform.logo_path}"
if cache_key in self._logo_token_cache:
cached_token = self._logo_token_cache[cache_key]
# 确保platform_default_tmpl[platform.name]存在且为字典
if platform.name not in platform_default_tmpl:
platform_default_tmpl[platform.name] = {}
elif not isinstance(platform_default_tmpl[platform.name], dict):
platform_default_tmpl[platform.name] = {}
platform_default_tmpl[platform.name]["logo_token"] = cached_token
logger.debug(f"Using cached logo token for platform {platform.name}")
return

# 获取平台适配器类
platform_cls = platform_cls_map.get(platform.name)
if not platform_cls:
logger.warning(f"Platform class not found for {platform.name}")
return

# 获取插件目录路径
module_file = inspect.getfile(platform_cls)
plugin_dir = os.path.dirname(module_file)

# 解析logo文件路径
logo_file_path = os.path.join(plugin_dir, platform.logo_path)

# 检查文件是否存在并注册令牌
if os.path.exists(logo_file_path):
logo_token = await file_token_service.register_file(
logo_file_path, timeout=3600
)

# 确保platform_default_tmpl[platform.name]存在且为字典
if platform.name not in platform_default_tmpl:
platform_default_tmpl[platform.name] = {}
elif not isinstance(platform_default_tmpl[platform.name], dict):
platform_default_tmpl[platform.name] = {}

platform_default_tmpl[platform.name]["logo_token"] = logo_token

# 缓存token
self._logo_token_cache[cache_key] = logo_token

logger.debug(f"Logo token registered for platform {platform.name}")
else:
logger.warning(
f"Platform {platform.name} logo file not found: {logo_file_path}"
)

except (ImportError, AttributeError) as e:
logger.warning(
f"Failed to import required modules for platform {platform.name}: {e}"
)
except (OSError, IOError) as e:
logger.warning(f"File system error for platform {platform.name} logo: {e}")
except Exception as e:
logger.warning(
f"Unexpected error registering logo for platform {platform.name}: {e}"
)

async def _get_astrbot_config(self):
config = self.config

# 平台适配器的默认配置模板注入
platform_default_tmpl = CONFIG_METADATA_2["platform_group"]["metadata"][
"platform"
]["config_template"]

# 收集需要注册logo的平台
logo_registration_tasks = []
for platform in platform_registry:
if platform.default_config_tmpl:
platform_default_tmpl[platform.name] = platform.default_config_tmpl
# 收集logo注册任务
if platform.logo_path:
logo_registration_tasks.append(
self._register_platform_logo(platform, platform_default_tmpl)
)

# 并行执行logo注册
if logo_registration_tasks:
await asyncio.gather(*logo_registration_tasks, return_exceptions=True)

# 服务提供商的默认配置模板注入
provider_default_tmpl = CONFIG_METADATA_2["provider_group"]["metadata"][
Expand Down
12 changes: 10 additions & 2 deletions dashboard/src/views/PlatformPage.vue
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="grey" variant="text" @click="handleIdConflictConfirm(false)">{{ tm('dialog.idConflict.confirm')
}}</v-btn>
}}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
Expand Down Expand Up @@ -241,7 +241,15 @@ export default {

methods: {
// 从工具函数导入
getPlatformIcon,
getPlatformIcon(platform_id) {
// 首先检查是否有来自插件的 logo_token
const template = this.metadata['platform_group']?.metadata?.platform?.config_template?.[platform_id];
if (template && template.logo_token) {
// 通过文件服务访问插件提供的 logo
return `/api/file/${template.logo_token}`;
}
return getPlatformIcon(platform_id);
},

openTutorial() {
const tutorialUrl = getTutorialLink(this.newSelectedPlatformConfig.type);
Expand Down