From f0277420e475db64ff47ee49667fe068f2a216ab Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 28 Dec 2025 19:09:13 +0800 Subject: [PATCH 1/7] perf: support extended thinking for Anthropic, DeepSeek reasoning mode, and Gemini text part thought signatures to improve multi-turn reasoning performance. --- astrbot/builtin_stars/astrbot/main.py | 12 +- astrbot/core/agent/message.py | 24 +++- .../agent/runners/tool_loop_agent_runner.py | 31 ++++- astrbot/core/astr_agent_hooks.py | 6 + astrbot/core/config/default.py | 14 +- .../method/agent_sub_stages/internal.py | 45 ++++--- .../core/pipeline/result_decorate/stage.py | 122 ++++++++++-------- astrbot/core/provider/entities.py | 10 +- .../core/provider/sources/anthropic_source.py | 64 ++++++++- .../core/provider/sources/gemini_source.py | 43 +++++- .../core/provider/sources/openai_source.py | 42 +++--- astrbot/core/provider/sources/xai_source.py | 29 +++++ 12 files changed, 312 insertions(+), 130 deletions(-) create mode 100644 astrbot/core/provider/sources/xai_source.py diff --git a/astrbot/builtin_stars/astrbot/main.py b/astrbot/builtin_stars/astrbot/main.py index 09859ab95..b3ea355b1 100644 --- a/astrbot/builtin_stars/astrbot/main.py +++ b/astrbot/builtin_stars/astrbot/main.py @@ -100,16 +100,8 @@ async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): logger.error(f"ltm: {e}") @filter.on_llm_response() - async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse): - """在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话""" - umo = event.unified_msg_origin - cfg = self.context.get_config(umo).get("provider_settings", {}) - show_reasoning = cfg.get("display_reasoning_text", False) - if show_reasoning and resp.reasoning_content: - resp.completion_text = ( - f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}" - ) - + async def record_llm_resp_to_ltm(self, event: AstrMessageEvent, resp: LLMResponse): + """在 LLM 响应后记录对话""" if self.ltm and self.ltm_enabled(event): try: await self.ltm.after_req_llm(event, resp) diff --git a/astrbot/core/agent/message.py b/astrbot/core/agent/message.py index c51ce5008..582b1eef2 100644 --- a/astrbot/core/agent/message.py +++ b/astrbot/core/agent/message.py @@ -12,7 +12,7 @@ class ContentPart(BaseModel): __content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {} - type: str + type: Literal["text", "think", "image_url", "audio_url"] def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) @@ -63,6 +63,28 @@ class TextPart(ContentPart): text: str +class ThinkPart(ContentPart): + """ + >>> ThinkPart(think="I think I need to think about this.").model_dump() + {'type': 'think', 'think': 'I think I need to think about this.', 'encrypted': None} + """ + + type: str = "think" + think: str + encrypted: str | None = None + """Encrypted thinking content, or signature.""" + + def merge_in_place(self, other: Any) -> bool: + if not isinstance(other, ThinkPart): + return False + if self.encrypted: + return False + self.think += other.think + if other.encrypted: + self.encrypted = other.encrypted + return True + + class ImageURLPart(ContentPart): """ >>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump() diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 88e302ad7..4b0c601b4 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -13,6 +13,7 @@ ) from astrbot import logger +from astrbot.core.agent.message import TextPart, ThinkPart from astrbot.core.message.components import Json from astrbot.core.message.message_event_result import ( MessageChain, @@ -169,13 +170,20 @@ async def step(self): self.final_llm_resp = llm_resp self._transition_state(AgentState.DONE) self.stats.end_time = time.time() + # record the final assistant message - self.run_context.messages.append( - Message( - role="assistant", - content=llm_resp.completion_text or "*No response*", - ), - ) + parts = [] + if llm_resp.reasoning_content or llm_resp.reasoning_signature: + parts.append( + ThinkPart( + think=llm_resp.reasoning_content, + encrypted=llm_resp.reasoning_signature, + ) + ) + parts.append(TextPart(text=llm_resp.completion_text or "*No response*")) + self.run_context.messages.append(Message(role="assistant", content=parts)) + + # call the on_agent_done hook try: await self.agent_hooks.on_agent_done(self.run_context, llm_resp) except Exception as e: @@ -214,10 +222,19 @@ async def step(self): data=AgentResponseData(chain=result), ) # 将结果添加到上下文中 + parts = [] + if llm_resp.reasoning_content or llm_resp.reasoning_signature: + parts.append( + ThinkPart( + think=llm_resp.reasoning_content, + encrypted=llm_resp.reasoning_signature, + ) + ) + parts.append(TextPart(text=llm_resp.completion_text or "*No response*")) tool_calls_result = ToolCallsResult( tool_calls_info=AssistantMessageSegment( tool_calls=llm_resp.to_openai_to_calls_model(), - content=llm_resp.completion_text, + content=parts, ), tool_calls_result=tool_call_result_blocks, ) diff --git a/astrbot/core/astr_agent_hooks.py b/astrbot/core/astr_agent_hooks.py index f394fc947..9d85de0cc 100644 --- a/astrbot/core/astr_agent_hooks.py +++ b/astrbot/core/astr_agent_hooks.py @@ -13,6 +13,12 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): async def on_agent_done(self, run_context, llm_response): # 执行事件钩子 + if llm_response and llm_response.reasoning_content: + # we will use this in result_decorate stage to inject reasoning content to chain + run_context.context.event.set_extra( + "_llm_reasoning_content", llm_response.reasoning_content + ) + await call_event_hook( run_context.context.event, EventType.OnLLMResponseEvent, diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index f10678bf7..a13e3f432 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -905,6 +905,7 @@ class ChatProviderTemplate(TypedDict): "key": [], "api_base": "https://api.anthropic.com/v1", "timeout": 120, + "anth_thinking_config": {"budget": 0}, }, "Moonshot": { "id": "moonshot", @@ -920,7 +921,7 @@ class ChatProviderTemplate(TypedDict): "xAI": { "id": "xai", "provider": "xai", - "type": "openai_chat_completion", + "type": "xai_chat_completion", "provider_type": "chat_completion", "enable": True, "key": [], @@ -1787,6 +1788,17 @@ class ChatProviderTemplate(TypedDict): }, }, }, + "anth_thinking_config": { + "description": "Thinking Config", + "type": "object", + "items": { + "budget": { + "description": "Thinking Budget", + "type": "int", + "hint": "Anthropic thinking.budget_tokens param. Must >= 1024. See: https://platform.claude.com/docs/en/build-with-claude/extended-thinking", + }, + }, + }, "minimax-group-id": { "type": "string", "description": "用户组", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 6c7c857a7..d00b0abde 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -6,6 +6,7 @@ from collections.abc import AsyncGenerator from astrbot.core import logger +from astrbot.core.agent.message import Message from astrbot.core.agent.tool import ToolSet from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.conversation_mgr import Conversation @@ -294,6 +295,7 @@ async def _save_to_history( event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse | None, + all_messages: list[Message], ): if ( not req @@ -307,31 +309,23 @@ async def _save_to_history( logger.debug("LLM 响应为空,不保存记录。") return - if req.contexts is None: - req.contexts = [] - - # 历史上下文 - messages = copy.deepcopy(req.contexts) - # 这一轮对话请求的用户输入 - messages.append(await req.assemble_context()) - # 这一轮对话的 LLM 响应 - if req.tool_calls_result: - if not isinstance(req.tool_calls_result, list): - messages.extend(req.tool_calls_result.to_openai_messages()) - elif isinstance(req.tool_calls_result, list): - for tcr in req.tool_calls_result: - messages.extend(tcr.to_openai_messages()) - messages.append( - { - "role": "assistant", - "content": llm_response.completion_text or "*No response*", - } - ) - messages = list(filter(lambda item: "_no_save" not in item, messages)) + # using agent context messages to save to history + message_to_save = [] + for message in all_messages: + if message.role == "system": + # we do not save system messages to history + continue + if message.role in ["assistant", "user"] and getattr( + message, "_no_save", None + ): + # we do not save user and assistant messages that are marked as _no_save + continue + message_to_save.append(message.model_dump()) + await self.conv_manager.update_conversation( event.unified_msg_origin, req.conversation.cid, - history=messages, + history=message_to_save, ) def _fix_messages(self, messages: list[dict]) -> list[dict]: @@ -513,7 +507,12 @@ async def process( # 恢复备份的 contexts req.contexts = backup_contexts - await self._save_to_history(event, req, agent_runner.get_final_llm_resp()) + await self._save_to_history( + event, + req, + agent_runner.get_final_llm_resp(), + agent_runner.run_context.messages, + ) # 异步处理 WebChat 特殊情况 if event.get_platform_name() == "webchat": diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 7647ef022..38aa8bdd4 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -98,6 +98,9 @@ async def initialize(self, ctx: PipelineContext): self.content_safe_check_stage = stage_cls() await self.content_safe_check_stage.initialize(ctx) + provider_cfg = ctx.astrbot_config.get("provider_settings", {}) + self.show_reasoning = provider_cfg.get("display_reasoning_text", False) + def _split_text_by_words(self, text: str) -> list[str]: """使用分段词列表分段文本""" if not self.split_words_pattern: @@ -254,70 +257,75 @@ async def process( event.unified_msg_origin, ) - if ( - self.ctx.astrbot_config["provider_tts_settings"]["enable"] + should_tts = ( + bool(self.ctx.astrbot_config["provider_tts_settings"]["enable"]) and result.is_llm_result() and SessionServiceManager.should_process_tts_request(event) - ): - should_tts = self.tts_trigger_probability >= 1.0 or ( - self.tts_trigger_probability > 0.0 - and random.random() <= self.tts_trigger_probability + and random.random() <= self.tts_trigger_probability + and tts_provider + ) + if not tts_provider: + logger.warning( + f"会话 {event.unified_msg_origin} 未配置文本转语音模型。", ) - if not should_tts: - logger.debug("跳过 TTS:触发概率未命中。") - elif not tts_provider: - logger.warning( - f"会话 {event.unified_msg_origin} 未配置文本转语音模型。", - ) - else: - new_chain = [] - for comp in result.chain: - if isinstance(comp, Plain) and len(comp.text) > 1: - try: - logger.info(f"TTS 请求: {comp.text}") - audio_path = await tts_provider.get_audio(comp.text) - logger.info(f"TTS 结果: {audio_path}") - if not audio_path: - logger.error( - f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}", - ) - new_chain.append(comp) - continue - - use_file_service = self.ctx.astrbot_config[ - "provider_tts_settings" - ]["use_file_service"] - callback_api_base = self.ctx.astrbot_config[ - "callback_api_base" - ] - dual_output = self.ctx.astrbot_config[ - "provider_tts_settings" - ]["dual_output"] - - url = None - if use_file_service and callback_api_base: - token = await file_token_service.register_file( - audio_path, - ) - url = f"{callback_api_base}/api/file/{token}" - logger.debug(f"已注册:{url}") - - new_chain.append( - Record( - file=url or audio_path, - url=url or audio_path, - ), + if ( + not should_tts + and self.show_reasoning + and event.get_extra("_llm_reasoning_content") + ): + # inject reasoning content to chain + reasoning_content = event.get_extra("_llm_reasoning_content") + result.chain.insert(0, Plain(f"🤔 思考: {reasoning_content}\n")) + + if should_tts and tts_provider: + new_chain = [] + for comp in result.chain: + if isinstance(comp, Plain) and len(comp.text) > 1: + try: + logger.info(f"TTS 请求: {comp.text}") + audio_path = await tts_provider.get_audio(comp.text) + logger.info(f"TTS 结果: {audio_path}") + if not audio_path: + logger.error( + f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}", ) - if dual_output: - new_chain.append(comp) - except Exception: - logger.error(traceback.format_exc()) - logger.error("TTS 失败,使用文本发送。") new_chain.append(comp) - else: + continue + + use_file_service = self.ctx.astrbot_config[ + "provider_tts_settings" + ]["use_file_service"] + callback_api_base = self.ctx.astrbot_config[ + "callback_api_base" + ] + dual_output = self.ctx.astrbot_config[ + "provider_tts_settings" + ]["dual_output"] + + url = None + if use_file_service and callback_api_base: + token = await file_token_service.register_file( + audio_path, + ) + url = f"{callback_api_base}/api/file/{token}" + logger.debug(f"已注册:{url}") + + new_chain.append( + Record( + file=url or audio_path, + url=url or audio_path, + ), + ) + if dual_output: + new_chain.append(comp) + except Exception: + logger.error(traceback.format_exc()) + logger.error("TTS 失败,使用文本发送。") new_chain.append(comp) - result.chain = new_chain + else: + new_chain.append(comp) + result.chain = new_chain # 文本转图片 elif ( diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 8f1bc442e..a02874c04 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -272,6 +272,8 @@ class LLMResponse: """Tool call extra content. tool_call_id -> extra_content dict""" reasoning_content: str = "" """The reasoning content extracted from the LLM, if any.""" + reasoning_signature: str | None = None + """The signature of the reasoning content, if any.""" raw_completion: ( ChatCompletion | GenerateContentResponse | AnthropicMessage | None @@ -292,12 +294,14 @@ class LLMResponse: def __init__( self, role: str, - completion_text: str = "", + completion_text: str | None = None, result_chain: MessageChain | None = None, tools_call_args: list[dict[str, Any]] | None = None, tools_call_name: list[str] | None = None, tools_call_ids: list[str] | None = None, tools_call_extra_content: dict[str, dict[str, Any]] | None = None, + reasoning_content: str | None = None, + reasoning_signature: str | None = None, raw_completion: ChatCompletion | GenerateContentResponse | AnthropicMessage @@ -317,6 +321,8 @@ def __init__( raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None. """ + if reasoning_content is None: + reasoning_content = "" if tools_call_args is None: tools_call_args = [] if tools_call_name is None: @@ -333,6 +339,8 @@ def __init__( self.tools_call_name = tools_call_name self.tools_call_ids = tools_call_ids self.tools_call_extra_content = tools_call_extra_content + self.reasoning_content = reasoning_content + self.reasoning_signature = reasoning_signature self.raw_completion = raw_completion self.is_chunk = is_chunk diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index cdb8cd430..57cf09db9 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -48,6 +48,8 @@ def __init__( base_url=self.base_url, ) + self.thinking_config = provider_config.get("anth_thinking_config", {}) + self.set_model(provider_config.get("model", "unknown")) def _prepare_payload(self, messages: list[dict]): @@ -64,11 +66,32 @@ def _prepare_payload(self, messages: list[dict]): new_messages = [] for message in messages: if message["role"] == "system": - system_prompt = message["content"] + system_prompt = message["content"] or "" elif message["role"] == "assistant": blocks = [] - if isinstance(message["content"], str): + reasoning_content = "" + thinking_signature = "" + if isinstance(message["content"], str) and message["content"].strip(): blocks.append({"type": "text", "text": message["content"]}) + elif isinstance(message["content"], list): + for part in message["content"]: + if part.get("type") == "think": + # only pick the last think part for now + reasoning_content = part.get("think") + thinking_signature = part.get("encrypted") + else: + blocks.append(part) + + if reasoning_content and thinking_signature: + blocks.insert( + 0, + { + "type": "thinking", + "thinking": reasoning_content, + "signature": thinking_signature, + }, + ) + if "tool_calls" in message and isinstance(message["tool_calls"], list): for tool_call in message["tool_calls"]: blocks.append( # noqa: PERF401 @@ -100,7 +123,7 @@ def _prepare_payload(self, messages: list[dict]): { "type": "tool_result", "tool_use_id": message["tool_call_id"], - "content": message["content"], + "content": message["content"] or "", }, ], }, @@ -135,6 +158,11 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: if "max_tokens" not in payloads: payloads["max_tokens"] = 1024 + if self.thinking_config.get("budget"): + payloads["thinking"] = { + "budget_tokens": self.thinking_config.get("budget"), + "type": "enabled", + } completion = await self.client.messages.create( **payloads, stream=False, extra_body=extra_body @@ -153,6 +181,11 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: completion_text = str(content_block.text).strip() llm_response.completion_text = completion_text + if content_block.type == "thinking": + reasoning_content = str(content_block.thinking).strip() + llm_response.reasoning_content = reasoning_content + llm_response.reasoning_signature = content_block.signature + if content_block.type == "tool_use": llm_response.tools_call_args.append(content_block.input) llm_response.tools_call_name.append(content_block.name) @@ -184,15 +217,23 @@ async def _query_stream( id = None usage = TokenUsage() extra_body = self.provider_config.get("custom_extra_body", {}) + reasoning_content = "" + reasoning_signature = "" if "max_tokens" not in payloads: payloads["max_tokens"] = 1024 + if self.thinking_config.get("budget"): + payloads["thinking"] = { + "budget_tokens": self.thinking_config.get("budget"), + "type": "enabled", + } async with self.client.messages.stream( **payloads, extra_body=extra_body ) as stream: assert isinstance(stream, anthropic.AsyncMessageStream) async for event in stream: + print(f"event: {event}") if event.type == "message_start": # the usage contains input token usage id = event.message.id @@ -226,6 +267,21 @@ async def _query_stream( usage=usage, id=id, ) + elif event.delta.type == "thinking_delta": + # 思考增量 + reasoning = event.delta.thinking + if reasoning: + yield LLMResponse( + role="assistant", + reasoning_content=reasoning, + is_chunk=True, + usage=usage, + id=id, + reasoning_signature=reasoning_signature or None, + ) + reasoning_content += reasoning + elif event.delta.type == "signature_delta": + reasoning_signature = event.delta.signature elif event.delta.type == "input_json_delta": # 工具调用参数增量 if event.index in tool_use_buffer: @@ -282,6 +338,8 @@ async def _query_stream( is_chunk=False, usage=usage, id=id, + reasoning_content=reasoning_content, + reasoning_signature=reasoning_signature or None, ) if final_tool_calls: diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 46358ac26..e9b360212 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -321,9 +321,40 @@ def append_or_extend( append_or_extend(gemini_contents, parts, types.UserContent) elif role == "assistant": - if content: + if isinstance(content, str): parts = [types.Part.from_text(text=content)] append_or_extend(gemini_contents, parts, types.ModelContent) + elif isinstance(content, list): + parts = [] + reasoning_content = None + thinking_signature = None + text = "" + for part in content: + # for most cases, assistant content only contains two parts: think and text + if part.get("type") == "think": + reasoning_content = part.get("think") or None + thinking_signature = part.get("encrypted") or None + else: + text += part.get("text") + + if thinking_signature and isinstance(thinking_signature, str): + try: + thinking_signature = base64.b64decode(thinking_signature) + except Exception as e: + logger.warning( + f"Failed to decode google gemini thinking signature: {e}", + exc_info=True, + ) + thinking_signature = None + parts.append( + types.Part( + text=text, + thought_signature=thinking_signature, + thought=reasoning_content, + ) + ) + append_or_extend(gemini_contents, parts, types.ModelContent) + elif not native_tool_enabled and "tool_calls" in message: parts = [] for tool in message["tool_calls"]: @@ -441,7 +472,8 @@ def _process_content_parts( for part in result_parts: if part.text: chain.append(Comp.Plain(part.text)) - elif ( + + if ( part.function_call and part.function_call.name is not None and part.function_call.args is not None @@ -458,13 +490,18 @@ def _process_content_parts( llm_response.tools_call_extra_content[tool_call_id] = { "google": {"thought_signature": ts_bs64} } - elif ( + + if ( part.inline_data and part.inline_data.mime_type and part.inline_data.mime_type.startswith("image/") and part.inline_data.data ): chain.append(Comp.Image.fromBytes(part.inline_data.data)) + + if ts := part.thought_signature: + # only keep the last thinking signature + llm_response.reasoning_signature = base64.b64encode(ts).decode("utf-8") return MessageChain(chain=chain) async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 1212e8b00..b369b17a4 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -74,28 +74,6 @@ def __init__(self, provider_config, provider_settings) -> None: self.reasoning_key = "reasoning_content" - def _maybe_inject_xai_search(self, payloads: dict, **kwargs): - """当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。 - - - 仅在 provider_config.xai_native_search 为 True 时生效 - - 默认注入 {"mode": "auto"} - - 允许通过 kwargs 使用 xai_search_mode 覆盖(on/auto/off) - """ - if not bool(self.provider_config.get("xai_native_search", False)): - return - - mode = kwargs.get("xai_search_mode", "auto") - mode = str(mode).lower() - if mode not in ("auto", "on", "off"): - mode = "auto" - - # off 时不注入,保持与未开启一致 - if mode == "off": - return - - # OpenAI SDK 不识别的字段会在 _query/_query_stream 中放入 extra_body - payloads["search_parameters"] = {"mode": mode} - async def get_models(self): try: models_str = [] @@ -381,11 +359,27 @@ async def _prepare_chat_payload( payloads = {"messages": context_query, "model": model} - # xAI origin search tool inject - self._maybe_inject_xai_search(payloads, **kwargs) + self._finally_convert_payload(payloads) return payloads, context_query + def _finally_convert_payload(self, payloads: dict): + """Finally convert the payload. Such as think part conversion, tool inject.""" + for message in payloads.get("messages", []): + if message.get("role") == "assistant" and isinstance( + message.get("content"), list + ): + reasoning_content = "" + new_content = [] # not including think part + for part in message["content"]: + if part.get("type") == "think": + reasoning_content += part.get("think") + else: + new_content.append(part) + message["content"] = new_content + # reasoning key is "reasoning_content" + message["reasoning_content"] = reasoning_content + async def _handle_api_error( self, e: Exception, diff --git a/astrbot/core/provider/sources/xai_source.py b/astrbot/core/provider/sources/xai_source.py new file mode 100644 index 000000000..a050412d3 --- /dev/null +++ b/astrbot/core/provider/sources/xai_source.py @@ -0,0 +1,29 @@ +from ..register import register_provider_adapter +from .openai_source import ProviderOpenAIOfficial + + +@register_provider_adapter( + "xai_chat_completion", "xAI Chat Completion Provider Adapter" +) +class ProviderXAI(ProviderOpenAIOfficial): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + ) -> None: + super().__init__(provider_config, provider_settings) + + def _maybe_inject_xai_search(self, payloads: dict): + """当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。 + + - 仅在 provider_config.xai_native_search 为 True 时生效 + - 默认注入 {"mode": "auto"} + """ + if not bool(self.provider_config.get("xai_native_search", False)): + return + # OpenAI SDK 不识别的字段会在 _query/_query_stream 中放入 extra_body + payloads["search_parameters"] = {"mode": "auto"} + + def _finally_convert_payload(self, payloads: dict): + self._maybe_inject_xai_search(payloads) + super()._finally_convert_payload(payloads) From 797a74fb4c3eae48ba2fd5f5618d7e8dc7cbf1cb Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 28 Dec 2025 19:12:29 +0800 Subject: [PATCH 2/7] chore: remove verbose --- astrbot/core/provider/sources/anthropic_source.py | 1 - 1 file changed, 1 deletion(-) diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 57cf09db9..7ce36b0f5 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -233,7 +233,6 @@ async def _query_stream( ) as stream: assert isinstance(stream, anthropic.AsyncMessageStream) async for event in stream: - print(f"event: {event}") if event.type == "message_start": # the usage contains input token usage id = event.message.id From d1f069f0b42fc459e4e070ab7647e941cdad41b6 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 28 Dec 2025 19:14:14 +0800 Subject: [PATCH 3/7] perf --- astrbot/core/pipeline/result_decorate/stage.py | 2 +- astrbot/core/provider/sources/gemini_source.py | 2 +- astrbot/core/provider/sources/openai_source.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 38aa8bdd4..529d0b263 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -264,7 +264,7 @@ async def process( and random.random() <= self.tts_trigger_probability and tts_provider ) - if not tts_provider: + if should_tts and not tts_provider: logger.warning( f"会话 {event.unified_msg_origin} 未配置文本转语音模型。", ) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index e9b360212..65d66b721 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -335,7 +335,7 @@ def append_or_extend( reasoning_content = part.get("think") or None thinking_signature = part.get("encrypted") or None else: - text += part.get("text") + text += str(part.get("text")) if thinking_signature and isinstance(thinking_signature, str): try: diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index b369b17a4..c9a0dd347 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -373,7 +373,7 @@ def _finally_convert_payload(self, payloads: dict): new_content = [] # not including think part for part in message["content"]: if part.get("type") == "think": - reasoning_content += part.get("think") + reasoning_content += str(part.get("think")) else: new_content.append(part) message["content"] = new_content From 8ee3f4ffd3083010deeb6e18524bbd1dcd605f28 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 29 Dec 2025 12:41:07 +0800 Subject: [PATCH 4/7] refactor: remove special tools handling for deepseek-reasoner model in openai source --- astrbot/core/provider/sources/openai_source.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index c9a0dd347..29d959d05 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -112,10 +112,6 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: model = payloads.get("model", "").lower() - # 针对 deepseek 模型的特殊处理:deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat - if model == "deepseek-reasoner" and "tools" in payloads: - del payloads["tools"] - completion = await self.client.chat.completions.create( **payloads, stream=False, From 6b25475b81ec4e90e2f08f83206fd9da7adaba8f Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 29 Dec 2025 14:10:25 +0800 Subject: [PATCH 5/7] fix: improve error handling and logging in InternalAgentSubStage processing --- .../method/agent_sub_stages/internal.py | 323 +++++++++--------- 1 file changed, 167 insertions(+), 156 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index d00b0abde..6aa7de85e 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -349,179 +349,190 @@ async def process( ) -> AsyncGenerator[None, None]: req: ProviderRequest | None = None - provider = self._select_provider(event) - if provider is None: - return - if not isinstance(provider, Provider): - logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。") - return - - streaming_response = self.streaming_response - if (enable_streaming := event.get_extra("enable_streaming")) is not None: - streaming_response = bool(enable_streaming) - - logger.debug("ready to request llm provider") - async with session_lock_manager.acquire_lock(event.unified_msg_origin): - logger.debug("acquired session lock for llm request") - if event.get_extra("provider_request"): - req = event.get_extra("provider_request") - assert isinstance(req, ProviderRequest), ( - "provider_request 必须是 ProviderRequest 类型。" + try: + provider = self._select_provider(event) + if provider is None: + return + if not isinstance(provider, Provider): + logger.error( + f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。" ) + return - if req.conversation: - req.contexts = json.loads(req.conversation.history) - - else: - req = ProviderRequest() - req.prompt = "" - req.image_urls = [] - if sel_model := event.get_extra("selected_model"): - req.model = sel_model - if provider_wake_prefix and not event.message_str.startswith( - provider_wake_prefix - ): - return - - req.prompt = event.message_str[len(provider_wake_prefix) :] - # func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。 - # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() - for comp in event.message_obj.message: - if isinstance(comp, Image): - image_path = await comp.convert_to_file_path() - req.image_urls.append(image_path) - - conversation = await self._get_session_conv(event) - req.conversation = conversation - req.contexts = json.loads(conversation.history) - - event.set_extra("provider_request", req) - - # fix contexts json str - if isinstance(req.contexts, str): - req.contexts = json.loads(req.contexts) - - # apply file extract - if self.file_extract_enabled: - try: - await self._apply_file_extract(event, req) - except Exception as e: - logger.error(f"Error occurred while applying file extract: {e}") + streaming_response = self.streaming_response + if (enable_streaming := event.get_extra("enable_streaming")) is not None: + streaming_response = bool(enable_streaming) + + logger.debug("ready to request llm provider") + async with session_lock_manager.acquire_lock(event.unified_msg_origin): + logger.debug("acquired session lock for llm request") + if event.get_extra("provider_request"): + req = event.get_extra("provider_request") + assert isinstance(req, ProviderRequest), ( + "provider_request 必须是 ProviderRequest 类型。" + ) - if not req.prompt and not req.image_urls: - return + if req.conversation: + req.contexts = json.loads(req.conversation.history) - # call event hook - if await call_event_hook(event, EventType.OnLLMRequestEvent, req): - return + else: + req = ProviderRequest() + req.prompt = "" + req.image_urls = [] + if sel_model := event.get_extra("selected_model"): + req.model = sel_model + if provider_wake_prefix and not event.message_str.startswith( + provider_wake_prefix + ): + return + + req.prompt = event.message_str[len(provider_wake_prefix) :] + # func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。 + # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() + for comp in event.message_obj.message: + if isinstance(comp, Image): + image_path = await comp.convert_to_file_path() + req.image_urls.append(image_path) + + conversation = await self._get_session_conv(event) + req.conversation = conversation + req.contexts = json.loads(conversation.history) + + event.set_extra("provider_request", req) + + # fix contexts json str + if isinstance(req.contexts, str): + req.contexts = json.loads(req.contexts) + + # apply file extract + if self.file_extract_enabled: + try: + await self._apply_file_extract(event, req) + except Exception as e: + logger.error(f"Error occurred while applying file extract: {e}") + + if not req.prompt and not req.image_urls: + return - # apply knowledge base feature - await self._apply_kb(event, req) + # call event hook + if await call_event_hook(event, EventType.OnLLMRequestEvent, req): + return - # truncate contexts to fit max length - if req.contexts: - req.contexts = self._truncate_contexts(req.contexts) - self._fix_messages(req.contexts) + # apply knowledge base feature + await self._apply_kb(event, req) - # session_id - if not req.session_id: - req.session_id = event.unified_msg_origin + # truncate contexts to fit max length + if req.contexts: + req.contexts = self._truncate_contexts(req.contexts) + self._fix_messages(req.contexts) - # check provider modalities, if provider does not support image/tool_use, clear them in request. - self._modalities_fix(provider, req) + # session_id + if not req.session_id: + req.session_id = event.unified_msg_origin - # filter tools, only keep tools from this pipeline's selected plugins - self._plugin_tool_fix(event, req) + # check provider modalities, if provider does not support image/tool_use, clear them in request. + self._modalities_fix(provider, req) - stream_to_general = ( - self.unsupported_streaming_strategy == "turn_off" - and not event.platform_meta.support_streaming_message - ) - # 备份 req.contexts - backup_contexts = copy.deepcopy(req.contexts) + # filter tools, only keep tools from this pipeline's selected plugins + self._plugin_tool_fix(event, req) - # run agent - agent_runner = AgentRunner() - logger.debug( - f"handle provider[id: {provider.provider_config['id']}] request: {req}", - ) - astr_agent_ctx = AstrAgentContext( - context=self.ctx.plugin_manager.context, - event=event, - ) - await agent_runner.reset( - provider=provider, - request=req, - run_context=AgentContextWrapper( - context=astr_agent_ctx, - tool_call_timeout=self.tool_call_timeout, - ), - tool_executor=FunctionToolExecutor(), - agent_hooks=MAIN_AGENT_HOOKS, - streaming=streaming_response, - ) + stream_to_general = ( + self.unsupported_streaming_strategy == "turn_off" + and not event.platform_meta.support_streaming_message + ) + # 备份 req.contexts + backup_contexts = copy.deepcopy(req.contexts) - if streaming_response and not stream_to_general: - # 流式响应 - event.set_result( - MessageEventResult() - .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream( - run_agent( - agent_runner, - self.max_step, - self.show_tool_use, - show_reasoning=self.show_reasoning, - ), + # run agent + agent_runner = AgentRunner() + logger.debug( + f"handle provider[id: {provider.provider_config['id']}] request: {req}", + ) + astr_agent_ctx = AstrAgentContext( + context=self.ctx.plugin_manager.context, + event=event, + ) + await agent_runner.reset( + provider=provider, + request=req, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=self.tool_call_timeout, ), + tool_executor=FunctionToolExecutor(), + agent_hooks=MAIN_AGENT_HOOKS, + streaming=streaming_response, ) - yield - if agent_runner.done(): - if final_llm_resp := agent_runner.get_final_llm_resp(): - if final_llm_resp.completion_text: - chain = ( - MessageChain() - .message(final_llm_resp.completion_text) - .chain - ) - elif final_llm_resp.result_chain: - chain = final_llm_resp.result_chain.chain - else: - chain = MessageChain().chain - event.set_result( - MessageEventResult( - chain=chain, - result_content_type=ResultContentType.STREAMING_FINISH, + + if streaming_response and not stream_to_general: + # 流式响应 + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream( + run_agent( + agent_runner, + self.max_step, + self.show_tool_use, + show_reasoning=self.show_reasoning, ), - ) - else: - async for _ in run_agent( - agent_runner, - self.max_step, - self.show_tool_use, - stream_to_general, - show_reasoning=self.show_reasoning, - ): + ), + ) yield + if agent_runner.done(): + if final_llm_resp := agent_runner.get_final_llm_resp(): + if final_llm_resp.completion_text: + chain = ( + MessageChain() + .message(final_llm_resp.completion_text) + .chain + ) + elif final_llm_resp.result_chain: + chain = final_llm_resp.result_chain.chain + else: + chain = MessageChain().chain + event.set_result( + MessageEventResult( + chain=chain, + result_content_type=ResultContentType.STREAMING_FINISH, + ), + ) + else: + async for _ in run_agent( + agent_runner, + self.max_step, + self.show_tool_use, + stream_to_general, + show_reasoning=self.show_reasoning, + ): + yield + + # 恢复备份的 contexts + req.contexts = backup_contexts + + await self._save_to_history( + event, + req, + agent_runner.get_final_llm_resp(), + agent_runner.run_context.messages, + ) - # 恢复备份的 contexts - req.contexts = backup_contexts + # 异步处理 WebChat 特殊情况 + if event.get_platform_name() == "webchat": + asyncio.create_task(self._handle_webchat(event, req, provider)) - await self._save_to_history( - event, - req, - agent_runner.get_final_llm_resp(), - agent_runner.run_context.messages, + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=agent_runner.provider.get_model(), + provider_type=agent_runner.provider.meta().type, + ), ) - # 异步处理 WebChat 特殊情况 - if event.get_platform_name() == "webchat": - asyncio.create_task(self._handle_webchat(event, req, provider)) - - asyncio.create_task( - Metric.upload( - llm_tick=1, - model_name=agent_runner.provider.get_model(), - provider_type=agent_runner.provider.meta().type, - ), - ) + except Exception as e: + logger.error(f"Error occurred while processing agent: {e}") + await event.send( + MessageChain().message( + f"Error occurred while processing agent request: {e}" + ) + ) From f4bdffa9ae7bf0e2dc156385895928c74d02b692 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 29 Dec 2025 14:10:31 +0800 Subject: [PATCH 6/7] refactor: remove unused reasoning content from Gemini source processing --- astrbot/core/provider/sources/gemini_source.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 65d66b721..97c072d0e 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -326,13 +326,11 @@ def append_or_extend( append_or_extend(gemini_contents, parts, types.ModelContent) elif isinstance(content, list): parts = [] - reasoning_content = None thinking_signature = None text = "" for part in content: # for most cases, assistant content only contains two parts: think and text if part.get("type") == "think": - reasoning_content = part.get("think") or None thinking_signature = part.get("encrypted") or None else: text += str(part.get("text")) @@ -350,7 +348,6 @@ def append_or_extend( types.Part( text=text, thought_signature=thinking_signature, - thought=reasoning_content, ) ) append_or_extend(gemini_contents, parts, types.ModelContent) From 5e6300daed98571483d5fa4bddec5a69bcae3999 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 29 Dec 2025 14:19:53 +0800 Subject: [PATCH 7/7] refactor: enhance modality determination logic in useProviderSources --- .../src/composables/useProviderSources.ts | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/dashboard/src/composables/useProviderSources.ts b/dashboard/src/composables/useProviderSources.ts index 41dcc1c61..e8bf58f45 100644 --- a/dashboard/src/composables/useProviderSources.ts +++ b/dashboard/src/composables/useProviderSources.ts @@ -508,12 +508,19 @@ export function useProviderSources(options: UseProviderSourcesOptions) { const sourceId = editableProviderSource.value?.id || selectedProviderSource.value.id const newId = `${sourceId}/${modelName}` - const modalities = ['text'] - if (supportsImageInput(getModelMetadata(modelName))) { - modalities.push('image') - } - if (supportsToolCall(getModelMetadata(modelName))) { - modalities.push('tool_use') + const metadata = getModelMetadata(modelName) + let modalities: string[] + + if (!metadata) { + modalities = ['text', 'image', 'tool_use'] + } else { + modalities = ['text'] + if (supportsImageInput(metadata)) { + modalities.push('image') + } + if (supportsToolCall(metadata)) { + modalities.push('tool_use') + } } const newProvider = {