Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions astrbot/core/pipeline/waking_check/stage.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable

from astrbot import logger
from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
Expand All @@ -13,6 +14,23 @@
from ..context import PipelineContext
from ..stage import Stage, register_stage

UNIQUE_SESSION_ID_BUILDERS: dict[str, Callable[[AstrMessageEvent], str | None]] = {
"aiocqhttp": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}",
"slack": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}",
"dingtalk": lambda e: e.get_sender_id(),
"qq_official": lambda e: e.get_sender_id(),
"qq_official_webhook": lambda e: e.get_sender_id(),
"lark": lambda e: f"{e.get_sender_id()}%{e.get_group_id()}",
"misskey": lambda e: f"{e.get_session_id()}_{e.get_sender_id()}",
"wechatpadpro": lambda e: f"{e.get_group_id()}#{e.get_sender_id()}",
}


def build_unique_session_id(event: AstrMessageEvent) -> str | None:
platform = event.get_platform_name()
builder = UNIQUE_SESSION_ID_BUILDERS.get(platform)
return builder(event) if builder else None


@register_stage
class WakingCheckStage(Stage):
Expand Down Expand Up @@ -53,18 +71,27 @@ async def initialize(self, ctx: PipelineContext) -> None:
self.disable_builtin_commands = self.ctx.astrbot_config.get(
"disable_builtin_commands", False
)
platform_settings = self.ctx.astrbot_config.get("platform_settings", {})
self.unique_session = platform_settings.get("unique_session", False)

async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
# apply unique session
if self.unique_session and event.message_obj.type == MessageType.GROUP_MESSAGE:
sid = build_unique_session_id(event)
if sid:
event.session_id = sid

# ignore bot self message
if (
self.ignore_bot_self_message
and event.get_self_id() == event.get_sender_id()
):
# 忽略机器人自己发送的消息
event.stop_event()
return

# 设置 sender 身份
event.message_str = event.message_str.strip()
for admin_id in self.ctx.astrbot_config["admins_id"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(
super().__init__(platform_config, event_queue)

self.settings = platform_settings
self.unique_session = platform_settings["unique_session"]
self.host = platform_config["ws_reverse_host"]
self.port = platform_config["ws_reverse_port"]

Expand Down Expand Up @@ -136,14 +135,11 @@ async def _convert_handle_request_event(self, event: Event) -> AstrBotMessage:
abm.group_id = str(event.group_id)
else:
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
else:
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.message_str = ""
abm.message = []
abm.timestamp = int(time.time())
Expand All @@ -164,16 +160,11 @@ async def _convert_handle_notice_event(self, event: Event) -> AstrBotMessage:
abm.type = MessageType.GROUP_MESSAGE
else:
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = (
str(abm.sender.user_id) + "_" + str(event.group_id)
) # 也保留群组 id
else:
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.message_str = ""
abm.message = []
abm.raw_message = event
Expand Down Expand Up @@ -210,16 +201,11 @@ async def _convert_handle_message_event(
abm.group.group_name = event.get("group_name", "N/A")
elif event["message_type"] == "private":
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = (
abm.sender.user_id + "_" + str(event.group_id)
) # 也保留群组 id
else:
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)

abm.message_id = str(event.message_id)
abm.message = []
Expand Down
7 changes: 1 addition & 6 deletions astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def __init__(
) -> None:
super().__init__(platform_config, event_queue)

self.unique_session = platform_settings["unique_session"]

self.client_id = platform_config["client_id"]
self.client_secret = platform_config["client_secret"]

Expand Down Expand Up @@ -129,10 +127,7 @@ async def convert_msg(
if id := self._id_to_sid(user.dingtalk_id):
abm.message.append(At(qq=id))
abm.group_id = message.conversation_id
if self.unique_session:
abm.session_id = abm.sender.user_id
else:
abm.session_id = abm.group_id
abm.session_id = abm.group_id
else:
abm.session_id = abm.sender.user_id

Expand Down
12 changes: 2 additions & 10 deletions astrbot/core/platform/sources/lark/lark_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ def __init__(
) -> None:
super().__init__(platform_config, event_queue)

self.unique_session = platform_settings["unique_session"]

self.appid = platform_config["app_id"]
self.appsecret = platform_config["app_secret"]
self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN)
Expand Down Expand Up @@ -317,14 +315,8 @@ async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
user_id=event.event.sender.sender_id.open_id,
nickname=event.event.sender.sender_id.open_id[:8],
)
# 独立会话
if not self.unique_session:
if abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.group_id
else:
abm.session_id = abm.sender.user_id
elif abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id
if abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.group_id
else:
abm.session_id = abm.sender.user_id

Expand Down
5 changes: 0 additions & 5 deletions astrbot/core/platform/sources/misskey/misskey_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ def __init__(
except Exception:
self.max_download_bytes = None

self.unique_session = platform_settings["unique_session"]

self.api: MisskeyAPI | None = None
self._running = False
self.client_self_id = ""
Expand Down Expand Up @@ -641,7 +639,6 @@ async def convert_message(self, raw_data: dict[str, Any]) -> AstrBotMessage:
sender_info,
self.client_self_id,
is_chat=False,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache,
Expand Down Expand Up @@ -690,7 +687,6 @@ async def convert_chat_message(self, raw_data: dict[str, Any]) -> AstrBotMessage
sender_info,
self.client_self_id,
is_chat=True,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache,
Expand Down Expand Up @@ -720,7 +716,6 @@ async def convert_room_message(self, raw_data: dict[str, Any]) -> AstrBotMessage
self.client_self_id,
is_chat=False,
room_id=room_id,
unique_session=self.unique_session,
)

cache_user_info(
Expand Down
3 changes: 0 additions & 3 deletions astrbot/core/platform/sources/misskey/misskey_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ def create_base_message(
client_self_id: str,
is_chat: bool = False,
room_id: str | None = None,
unique_session: bool = False,
) -> AstrBotMessage:
"""创建基础消息对象"""
message = AstrBotMessage()
Expand All @@ -353,8 +352,6 @@ def create_base_message(
if room_id:
session_prefix = "room"
session_id = f"{session_prefix}%{room_id}"
if unique_session:
session_id += f"_{sender_info['sender_id']}"
message.type = MessageType.GROUP_MESSAGE
message.group_id = room_id
elif is_chat:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,8 @@ async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
message,
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
)
abm.group_id = cast(str, message.group_openid)
abm.session_id = abm.group_id
self._commit(abm)

# 收到频道消息
Expand All @@ -57,9 +54,8 @@ async def on_at_message_create(self, message: botpy.message.Message):
message,
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.channel_id
)
abm.group_id = message.channel_id
abm.session_id = abm.group_id
self._commit(abm)

# 收到私聊消息
Expand Down Expand Up @@ -104,7 +100,6 @@ def __init__(

self.appid = platform_config["appid"]
self.secret = platform_config["secret"]
self.unique_session: bool = platform_settings["unique_session"]
qq_group = platform_config["enable_group_c2c"]
guild_dm = platform_config["enable_guild_direct_message"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,8 @@ async def on_group_at_message_create(self, message: botpy.message.GroupMessage):
message,
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
)
abm.group_id = cast(str, message.group_openid)
abm.session_id = abm.group_id
self._commit(abm)

# 收到频道消息
Expand All @@ -48,9 +45,8 @@ async def on_at_message_create(self, message: botpy.message.Message):
message,
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.channel_id
)
abm.group_id = message.channel_id
abm.session_id = abm.group_id
self._commit(abm)

# 收到私聊消息
Expand Down Expand Up @@ -95,7 +91,6 @@ def __init__(

self.appid = platform_config["appid"]
self.secret = platform_config["secret"]
self.unique_session = platform_settings["unique_session"]
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)

intents = botpy.Intents(
Expand Down
9 changes: 3 additions & 6 deletions astrbot/core/platform/sources/slack/slack_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(
) -> None:
super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.unique_session = platform_settings.get("unique_session", False)

self.bot_token = platform_config.get("bot_token")
self.app_token = platform_config.get("app_token")
Expand Down Expand Up @@ -147,12 +146,10 @@ async def convert_message(self, event: dict) -> AstrBotMessage:
abm.group_id = channel_id

# 设置会话ID
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = f"{user_id}_{channel_id}"
if abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.group_id
else:
abm.session_id = (
channel_id if abm.type == MessageType.GROUP_MESSAGE else user_id
)
abm.session_id = user_id

abm.message_id = event.get("client_msg_id", uuid.uuid4().hex)
abm.timestamp = int(float(event.get("ts", time.time())))
Expand Down
1 change: 0 additions & 1 deletion astrbot/core/platform/sources/webchat/webchat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(
super().__init__(platform_config, event_queue)

self.settings = platform_settings
self.unique_session = platform_settings["unique_session"]
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
os.makedirs(self.imgs_dir, exist_ok=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(
self._shutdown_event = None
self.wxnewpass = None
self.settings = platform_settings
self.unique_session = platform_settings.get("unique_session", False)

self.metadata = PlatformMetadata(
name="wechatpadpro",
Expand Down Expand Up @@ -509,11 +508,10 @@ async def _process_chat_type(
if accurate_nickname:
abm.sender.nickname = accurate_nickname

# 对于群聊,session_id 可以是群聊 ID 或发送者 ID + 群聊 ID (如果 unique_session 为 True)
if self.unique_session:
abm.session_id = f"{from_user_name}#{abm.sender.user_id}"
if abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.group_id
else:
abm.session_id = from_user_name
abm.session_id = abm.sender.user_id

msg_source = raw_message.get("msg_source", "")
if self.wxid in msg_source:
Expand Down