From 561a56dcd9341280be9ad6c7226af50e74656110 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Sun, 21 Sep 2025 18:23:17 +0800 Subject: [PATCH 1/3] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E7=A9=BAkey?= =?UTF-8?q?=E5=AF=BC=E8=87=B4=E7=9A=84=E6=97=A0=E6=B3=95=E5=88=9B=E5=BB=BA?= =?UTF-8?q?Provider=E5=AF=B9=E8=B1=A1=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/provider/provider.py | 3 +- .../core/provider/sources/anthropic_source.py | 26 ++++++++----- .../core/provider/sources/gemini_source.py | 38 ++++++++++--------- .../core/provider/sources/openai_source.py | 2 +- 4 files changed, 40 insertions(+), 29 deletions(-) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 01618767c..f03948b33 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -68,7 +68,8 @@ def get_current_key(self) -> str: def get_keys(self) -> List[str]: """获得提供商 Key""" - return self.provider_config.get("key", []) + keys = self.provider_config.get("key", [""]) + return keys if keys else [""] @abc.abstractmethod def set_key(self, key: str): diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index aaff177e5..23f39374b 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -33,7 +33,7 @@ def __init__( ) self.chosen_api_key: str = "" - self.api_keys: List = provider_config.get("key", []) + self.api_keys: List = super().get_keys() self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else "" self.base_url = provider_config.get("api_base", "https://api.anthropic.com") self.timeout = provider_config.get("timeout", 120) @@ -70,9 +70,13 @@ def _prepare_payload(self, messages: list[dict]): { "type": "tool_use", "name": tool_call["function"]["name"], - "input": json.loads(tool_call["function"]["arguments"]) - if isinstance(tool_call["function"]["arguments"], str) - else tool_call["function"]["arguments"], + "input": ( + json.loads(tool_call["function"]["arguments"]) + if isinstance( + tool_call["function"]["arguments"], str + ) + else tool_call["function"]["arguments"] + ), "id": tool_call["id"], } ) @@ -175,9 +179,9 @@ async def _query_stream( # 累积 JSON 输入 if "input_json" not in tool_use_buffer[event.index]: tool_use_buffer[event.index]["input_json"] = "" - tool_use_buffer[event.index]["input_json"] += ( - event.delta.partial_json - ) + tool_use_buffer[event.index][ + "input_json" + ] += event.delta.partial_json elif event.type == "content_block_stop": # 内容块结束 @@ -355,9 +359,11 @@ async def assemble_context(self, text: str, image_urls: List[str] = None): "source": { "type": "base64", "media_type": mime_type, - "data": image_data.split("base64,")[1] - if "base64," in image_data - else image_data, + "data": ( + image_data.split("base64,")[1] + if "base64," in image_data + else image_data + ), }, } ) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index cc4475b6b..b14a9bdcb 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -3,7 +3,7 @@ import json import logging import random -from typing import Optional +from typing import Optional, List from collections.abc import AsyncGenerator from google import genai @@ -60,7 +60,7 @@ def __init__( provider_settings, default_persona, ) - self.api_keys: list = provider_config.get("key", []) + self.api_keys: List = super().get_keys() self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else "" self.timeout: int = int(provider_config.get("timeout", 180)) @@ -218,19 +218,21 @@ async def _prepare_query_config( response_modalities=modalities, tools=tool_list, safety_settings=self.safety_settings if self.safety_settings else None, - thinking_config=types.ThinkingConfig( - thinking_budget=min( - int( - self.provider_config.get("gm_thinking_config", {}).get( - "budget", 0 - ) + thinking_config=( + types.ThinkingConfig( + thinking_budget=min( + int( + self.provider_config.get("gm_thinking_config", {}).get( + "budget", 0 + ) + ), + 24576, ), - 24576, - ), - ) - if "gemini-2.5-flash" in self.get_model() - and hasattr(types.ThinkingConfig, "thinking_budget") - else None, + ) + if "gemini-2.5-flash" in self.get_model() + and hasattr(types.ThinkingConfig, "thinking_budget") + else None + ), automatic_function_calling=types.AutomaticFunctionCallingConfig( disable=True ), @@ -274,9 +276,11 @@ def append_or_extend( if role == "user": if isinstance(content, list): parts = [ - types.Part.from_text(text=item["text"] or " ") - if item["type"] == "text" - else process_image_url(item["image_url"]) + ( + types.Part.from_text(text=item["text"] or " ") + if item["type"] == "text" + else process_image_url(item["image_url"]) + ) for item in content ] else: diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 5b199ed96..81342ad53 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -38,7 +38,7 @@ def __init__( default_persona, ) self.chosen_api_key = None - self.api_keys: List = provider_config.get("key", []) + self.api_keys: List = super().get_keys() self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): From 868a534d02dfb2ecf9cc262f10f96bc31b239b18 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Sun, 21 Sep 2025 18:27:30 +0800 Subject: [PATCH 2/3] style: format code --- astrbot/core/provider/sources/anthropic_source.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 23f39374b..57bffdc81 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -179,9 +179,9 @@ async def _query_stream( # 累积 JSON 输入 if "input_json" not in tool_use_buffer[event.index]: tool_use_buffer[event.index]["input_json"] = "" - tool_use_buffer[event.index][ - "input_json" - ] += event.delta.partial_json + tool_use_buffer[event.index]["input_json"] += ( + event.delta.partial_json + ) elif event.type == "content_block_stop": # 内容块结束 From 3589a409ad244d396d980eee83632e403be93368 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Sun, 21 Sep 2025 18:29:16 +0800 Subject: [PATCH 3/3] Update astrbot/core/provider/provider.py Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- astrbot/core/provider/provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index f03948b33..66e253715 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -69,7 +69,7 @@ def get_current_key(self) -> str: def get_keys(self) -> List[str]: """获得提供商 Key""" keys = self.provider_config.get("key", [""]) - return keys if keys else [""] + return keys or [""] @abc.abstractmethod def set_key(self, key: str):