Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions astrbot/core/platform/astr_message_event.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import asyncio
import copy
import hashlib
import re
import uuid
Expand Down Expand Up @@ -29,6 +30,9 @@


class AstrMessageEvent(abc.ABC):
# extras 中可安全清理的瞬态字段清单;子类可按需扩展
TRANSIENT_EXTRA_KEYS: set[str] = set()

def __init__(
self,
message_str: str,
Expand Down Expand Up @@ -71,6 +75,8 @@ def __init__(

# back_compability
self.platform = platform_meta
# 可选的绕过标记,避免被 SessionWaiter 再次截获
self._bypass_session_waiter = False

def get_platform_name(self):
"""获取这个事件所属的平台的类型(如 aiocqhttp, slack, discord 等)。
Expand Down Expand Up @@ -278,6 +284,22 @@ def clear_result(self):
"""清除消息事件的结果。"""
self._result = None

def clone_for_llm(self) -> "AstrMessageEvent":
"""浅拷贝并重置状态,以便重新走默认 LLM 流程。"""
new_event: AstrMessageEvent = copy.copy(self)
new_event.clear_result()
# 保留非瞬态 extras,避免跨管线上下文丢失
new_event._extras = self._extras.copy()
for key in self.TRANSIENT_EXTRA_KEYS:
new_event._extras.pop(key, None)
new_event._has_send_oper = False
new_event.call_llm = False
new_event.is_wake = False
new_event.is_at_or_wake_command = False
new_event.plugins_name = None
new_event._bypass_session_waiter = False
return new_event

"""消息链相关"""

def make_result(self) -> MessageEventResult:
Expand Down
40 changes: 37 additions & 3 deletions astrbot/core/utils/session_waiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any

import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core.platform import AstrMessageEvent

USER_SESSIONS: dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例
Expand All @@ -29,6 +30,33 @@ def __init__(self):

self.history_chains: list[list[Comp.BaseMessageComponent]] = []

def fallback_to_llm(
self,
event_queue: asyncio.Queue,
event: AstrMessageEvent,
*,
stop_session: bool = True,
) -> AstrMessageEvent:
"""将当前事件重新入队,由默认 LLM 流程处理,适用于非预期输入的兜底。

Args:
event_queue: 事件队列
event: 当前事件
stop_session: 是否结束当前 SessionWaiter。False 时仅兜底当前输入,继续等待后续输入。
"""
if not stop_session:
logger.warning(
"fallback_to_llm(stop_session=False) 会保留当前会话,默认会话拦截可能导致兜底无效,"
"建议谨慎使用或在后续输入中自行终止会话。",
)
new_event = event.clone_for_llm()
new_event._bypass_session_waiter = not stop_session
event_queue.put_nowait(new_event)
event.stop_event()
if stop_session:
self.stop()
return new_event

def stop(self, error: Exception = None):
"""立即结束这个会话"""
if not self.future.done():
Expand Down Expand Up @@ -147,11 +175,15 @@ def _cleanup(self, error: Exception = None):
self.session_controller.stop(error)

@classmethod
async def trigger(cls, session_id: str, event: AstrMessageEvent):
"""外部输入触发会话处理"""
async def trigger(cls, session_id: str, event: AstrMessageEvent) -> bool:
"""外部输入触发会话处理

Returns:
bool: 是否成功触发处理。False 表示会话不存在或已结束。
"""
session = USER_SESSIONS.get(session_id)
if not session or session.session_controller.future.done():
return
return False

async with session._lock:
if not session.session_controller.future.done():
Expand All @@ -164,6 +196,8 @@ async def trigger(cls, session_id: str, event: AstrMessageEvent):
await session.handler(session.session_controller, event)
except Exception as e:
session.session_controller.stop(e)
return True
return False


def session_waiter(timeout: int = 30, record_history_chains: bool = False):
Expand Down
17 changes: 8 additions & 9 deletions packages/session_controller/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
from sys import maxsize

import astrbot.api.message_components as Comp
Expand All @@ -7,7 +6,6 @@
from astrbot.api.star import Context, Star
from astrbot.core.utils.session_waiter import (
FILTERS,
USER_SESSIONS,
SessionController,
SessionWaiter,
session_waiter,
Expand All @@ -23,10 +21,12 @@ def __init__(self, context: Context):
@filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize)
async def handle_session_control_agent(self, event: AstrMessageEvent):
"""会话控制代理"""
if getattr(event, "_bypass_session_waiter", False):
return
for session_filter in FILTERS:
session_id = session_filter.filter(event)
if session_id in USER_SESSIONS:
await SessionWaiter.trigger(session_id, event)
handled = await SessionWaiter.trigger(session_id, event)
if handled:
event.stop_event()

@filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize - 1)
Expand Down Expand Up @@ -96,11 +96,10 @@ async def empty_mention_waiter(
0,
Comp.At(qq=event.get_self_id(), name=event.get_self_id()),
)
new_event = copy.copy(event)
# 重新推入事件队列
self.context.get_event_queue().put_nowait(new_event)
event.stop_event()
controller.stop()
controller.fallback_to_llm(
self.context.get_event_queue(),
event,
)

try:
await empty_mention_waiter(event)
Expand Down