Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@
"show_tool_use_status": False,
"streaming_segmented": False,
"max_agent_step": 30,
"error_handling": {
"retry_on_failure": 0,
"report_error_message": True,
"fallback_response": "",
"retry_delay": 0,
},
},
"provider_stt_settings": {
"enable": False,
Expand Down Expand Up @@ -2086,6 +2092,32 @@
},
},
},
"error_handling": {
"description": "错误处理",
"type": "object",
"items": {
"provider_settings.error_handling.retry_on_failure": {
"description": "LLM 请求失败时重试次数",
"type": "int",
"hint": "当请求失败时自动重试的次数。0 表示不重试。",
},
"provider_settings.error_handling.report_error_message": {
"description": "向用户报告详细错误",
"type": "bool",
"hint": "是否将详细的技术性错误信息作为消息发送给用户。如果关闭,将使用下方的备用回复。",
},
"provider_settings.error_handling.fallback_response": {
"description": "备用回复",
"type": "string",
"hint": "当“向用户报告详细错误”被关闭且所有重试都失败时,发送给用户的固定回复。如果留空,则不发送任何消息(静默失败)。",
},
"provider_settings.error_handling.retry_delay": {
"description": "LLM 请求失败时重试间隔(秒)",
"type": "float",
"hint": "当请求失败并进行重试时,每次重试之间的等待时间(秒)。设置为 0 则不等待。",
},
},
},
},
},
"platform_group": {
Expand Down
168 changes: 121 additions & 47 deletions astrbot/core/pipeline/process_stage/method/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,62 +221,94 @@ async def on_agent_done(self, run_context, llm_response):


async def run_agent(
agent_runner: AgentRunner, max_step: int = 30, show_tool_use: bool = True
agent_runner: AgentRunner,
max_step: int = 30,
show_tool_use: bool = True,
retry_on_failure: int = 0,
report_error_message: bool = True,
fallback_response: str = "",
retry_delay: float = 0,
) -> AsyncGenerator[MessageChain, None]:
step_idx = 0
astr_event = agent_runner.run_context.event
while step_idx < max_step:
step_idx += 1
success = False
# Retry loop
for attempt in range(retry_on_failure + 1):
step_idx = 0
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await astr_event.send(resp.data["chain"])
while step_idx < max_step:
step_idx += 1
async for resp in agent_runner.step():
if astr_event.is_stopped():
return
if resp.type == "tool_call_result":
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
resp.data["chain"].type = "tool_call_result"
await astr_event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if show_tool_use or astr_event.get_platform_name() == "webchat":
resp.data["chain"].type = "tool_call"
await astr_event.send(resp.data["chain"])
continue
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if show_tool_use or astr_event.get_platform_name() == "webchat":
resp.data["chain"].type = "tool_call"
await astr_event.send(resp.data["chain"])
continue

if not agent_runner.streaming:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
astr_event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
if not agent_runner.streaming:
content_typ = (
ResultContentType.LLM_RESULT
if resp.type == "llm_result"
else ResultContentType.GENERAL_RESULT
)
)
yield
astr_event.clear_result()
else:
if resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if agent_runner.done():
astr_event.set_result(
MessageEventResult(
chain=resp.data["chain"].chain,
result_content_type=content_typ,
)
)
yield
astr_event.clear_result()
else:
if resp.type == "streaming_delta":
yield resp.data["chain"] # MessageChain
if agent_runner.done():
success = True
break
if success:
break

except Exception as e:
logger.error(traceback.format_exc())
astr_event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
logger.error(
f"Attempt {attempt + 1}/{retry_on_failure + 1} failed: {traceback.format_exc()}"
)
return
if attempt >= retry_on_failure:
# All retries exhausted, handle final error
if report_error_message:
astr_event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
)
)
elif fallback_response:
astr_event.set_result(
MessageEventResult().message(fallback_response)
)
# If report_error_message is False and fallback_response is empty, do nothing (silent fail).
return

# --- 新增的重试等待逻辑 ---
if retry_delay > 0:
logger.debug(f"Waiting for {retry_delay} seconds before retrying...")
await asyncio.sleep(retry_delay)

# If retries are left, continue to the next attempt
continue
if success:
asyncio.create_task(
Metric.upload(
llm_tick=1,
Expand Down Expand Up @@ -304,6 +336,32 @@ async def initialize(self, ctx: PipelineContext) -> None:
self.max_step = 30
self.show_tool_use: bool = settings.get("show_tool_use_status", True)

# Load error handling settings
error_handling_settings = settings.get("error_handling", {})

# Validate retry_on_failure
retry_on_failure = error_handling_settings.get("retry_on_failure", 0)
if not isinstance(retry_on_failure, int) or not (0 <= retry_on_failure <= 10):
logger.warning(
f"Invalid value for retry_on_failure: {retry_on_failure}. Must be an integer between 0 and 10. Using default 0."
)
retry_on_failure = 0
self.retry_on_failure = retry_on_failure

# Validate and load retry_delay
retry_delay = error_handling_settings.get("retry_delay", 0)
if not isinstance(retry_delay, (int, float)) or not (0 <= retry_delay <= 10):
logger.warning(
f"Invalid value for retry_delay: {retry_delay}. Must be a number between 0 and 10. Using default 0."
)
retry_delay = 0
self.retry_delay = retry_delay

self.report_error_message = error_handling_settings.get(
"report_error_message", True
)
self.fallback_response = error_handling_settings.get("fallback_response", "")

for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
logger.info(
Expand Down Expand Up @@ -477,7 +535,15 @@ async def process(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_agent(agent_runner, self.max_step, self.show_tool_use)
run_agent(
agent_runner,
self.max_step,
self.show_tool_use,
self.retry_on_failure,
self.report_error_message,
self.fallback_response,
self.retry_delay,
)
)
)
yield
Expand All @@ -496,7 +562,15 @@ async def process(
)
)
else:
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
async for _ in run_agent(
agent_runner,
self.max_step,
self.show_tool_use,
self.retry_on_failure,
self.report_error_message,
self.fallback_response,
self.retry_delay,
):
yield

await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
Expand Down
Loading