Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions astrbot/core/provider/sources/anthropic_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _prepare_payload(self, messages: list[dict]):
blocks = []
if isinstance(message["content"], str):
blocks.append({"type": "text", "text": message["content"]})
if "tool_calls" in message:
if "tool_calls" in message and isinstance(message["tool_calls"], list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): 仅对 tool_calls 使用列表类型判断,可能会静默丢弃有效的可迭代值。

之前任何为真值的 tool_calls 可迭代对象都会被遍历;现在非列表的可迭代对象(例如元组、生成器)会被静默忽略。为避免这种行为变化悄然发生,建议要么通过 message["tool_calls"] = list(message["tool_calls"]) 进行规范化,要么在 tool_calls 存在但不是列表时进行抛错/记录日志,让类型问题被暴露出来而不是被跳过。

Original comment in English

issue (bug_risk): Guarding on list type for tool_calls may silently drop valid iterable values.

Previously any truthy tool_calls iterable would be iterated; now non-list iterables (e.g., tuples, generators) are silently ignored. To avoid this behavior change going unnoticed, either normalize with message["tool_calls"] = list(message["tool_calls"]) or raise/log when tool_calls is present but not a list so type issues are surfaced rather than skipped.

for tool_call in message["tool_calls"]:
blocks.append( # noqa: PERF401
{
Expand Down Expand Up @@ -132,6 +132,9 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:

extra_body = self.provider_config.get("custom_extra_body", {})

if "max_tokens" not in payloads:
payloads["max_tokens"] = 1024

completion = await self.client.messages.create(
**payloads, stream=False, extra_body=extra_body
)
Expand Down Expand Up @@ -181,6 +184,9 @@ async def _query_stream(
usage = TokenUsage()
extra_body = self.provider_config.get("custom_extra_body", {})

if "max_tokens" not in payloads:
payloads["max_tokens"] = 1024

async with self.client.messages.stream(
**payloads, extra_body=extra_body
) as stream:
Expand Down Expand Up @@ -342,11 +348,11 @@ async def text_chat(

async def text_chat_stream(
self,
prompt,
prompt=None,
session_id=None,
image_urls=...,
image_urls=None,
func_tool=None,
contexts=...,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
Expand Down