diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 8831c7e14..88b0d80ef 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -207,6 +207,18 @@ "callback_server_host": "0.0.0.0", "port": 6195, }, + "企业微信智能机器人": { + "id": "wecom_ai_bot", + "type": "wecom_ai_bot", + "enable": True, + "wecomaibot_init_respond_text": "💭 思考中...", + "wecomaibot_friend_message_welcome_text": "", + "wecom_ai_bot_name": "", + "token": "", + "encoding_aes_key": "", + "callback_server_host": "0.0.0.0", + "port": 6198, + }, "飞书(Lark)": { "id": "lark", "type": "lark", @@ -447,10 +459,25 @@ "type": "string", "hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。", }, + "wecom_ai_bot_name": { + "description": "企业微信智能机器人的名字", + "type": "string", + "hint": "请务必填写正确,否则无法使用一些指令。", + }, + "wecomaibot_init_respond_text": { + "description": "企业微信智能机器人初始响应文本", + "type": "string", + "hint": "当机器人收到消息时,首先回复的文本内容。留空则使用默认值。", + }, + "wecomaibot_friend_message_welcome_text": { + "description": "企业微信智能机器人私聊欢迎语", + "type": "string", + "hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。", + }, "lark_bot_name": { "description": "飞书机器人的名字", "type": "string", - "hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。", + "hint": "请务必填写正确,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。", }, "discord_token": { "description": "Discord Bot Token", diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index f1c3988a6..7a38ec03f 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -74,7 +74,7 @@ async def execute(self, event: AstrMessageEvent): await self._process_stages(event) # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 - if event.get_platform_name() == "webchat": + if event.get_platform_name() in ["webchat", "wecom_ai_bot"]: await event.send(None) logger.debug("pipeline 执行完毕。") diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index f0d7c2e4a..7090c669c 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -82,6 +82,10 @@ async def load_platform(self, platform_config: dict): from .sources.wecom.wecom_adapter import ( WecomPlatformAdapter, # noqa: F401 ) + case "wecom_ai_bot": + from .sources.wecom_ai_bot.wecomai_adapter import ( + WecomAIBotAdapter, # noqa: F401 + ) case "weixin_official_account": from .sources.weixin_official_account.weixin_offacc_adapter import ( WeixinOfficialAccountPlatformAdapter, # noqa: F401 diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 43da100f4..faec122ac 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -91,7 +91,6 @@ async def convert_message(self, data: tuple) -> AstrBotMessage: abm = AstrBotMessage() abm.self_id = "webchat" - abm.tag = "webchat" abm.sender = MessageMember(username, username) abm.type = MessageType.FRIEND_MESSAGE diff --git a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py new file mode 100644 index 000000000..5332942b9 --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python +# -*- encoding:utf-8 -*- + +"""对企业微信发送给企业后台的消息加解密示例代码. +@copyright: Copyright (c) 1998-2020 Tencent Inc. + +""" +# ------------------------------------------------------------------------ + +import logging +import base64 +import random +import hashlib +import time +import struct +from Crypto.Cipher import AES +import socket +import json + +from . import ierror + +""" +关于Crypto.Cipher模块,ImportError: No module named 'Crypto'解决方案 +请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。 +下载后,按照README中的“Installation”小节的提示进行pycrypto安装。 +""" + + +class FormatException(Exception): + pass + + +def throw_exception(message, exception_class=FormatException): + """my define raise exception function""" + raise exception_class(message) + + +class SHA1: + """计算企业微信的消息签名接口""" + + def getSHA1(self, token, timestamp, nonce, encrypt): + """用SHA1算法生成安全签名 + @param token: 票据 + @param timestamp: 时间戳 + @param encrypt: 密文 + @param nonce: 随机字符串 + @return: 安全签名 + """ + try: + # 确保所有输入都是字符串类型 + if isinstance(encrypt, bytes): + encrypt = encrypt.decode("utf-8") + + sortlist = [str(token), str(timestamp), str(nonce), str(encrypt)] + sortlist.sort() + sha = hashlib.sha1() + sha.update("".join(sortlist).encode("utf-8")) + return ierror.WXBizMsgCrypt_OK, sha.hexdigest() + + except Exception as e: + print(e) + return ierror.WXBizMsgCrypt_ComputeSignature_Error, None + + +class JsonParse: + """提供提取消息格式中的密文及生成回复消息格式的接口""" + + # json消息模板 + AES_TEXT_RESPONSE_TEMPLATE = """{ + "encrypt": "%(msg_encrypt)s", + "msgsignature": "%(msg_signaturet)s", + "timestamp": "%(timestamp)s", + "nonce": "%(nonce)s" + }""" + + def extract(self, jsontext): + """提取出json数据包中的加密消息 + @param jsontext: 待提取的json字符串 + @return: 提取出的加密消息字符串 + """ + try: + json_dict = json.loads(jsontext) + return ierror.WXBizMsgCrypt_OK, json_dict["encrypt"] + except Exception as e: + print(e) + return ierror.WXBizMsgCrypt_ParseJson_Error, None + + def generate(self, encrypt, signature, timestamp, nonce): + """生成json消息 + @param encrypt: 加密后的消息密文 + @param signature: 安全签名 + @param timestamp: 时间戳 + @param nonce: 随机字符串 + @return: 生成的json字符串 + """ + resp_dict = { + "msg_encrypt": encrypt, + "msg_signaturet": signature, + "timestamp": timestamp, + "nonce": nonce, + } + resp_json = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict + return resp_json + + +class PKCS7Encoder: + """提供基于PKCS7算法的加解密接口""" + + block_size = 32 + + def encode(self, text): + """对需要加密的明文进行填充补位 + @param text: 需要进行填充补位操作的明文(bytes类型) + @return: 补齐明文字符串(bytes类型) + """ + text_length = len(text) + # 计算需要填充的位数 + amount_to_pad = self.block_size - (text_length % self.block_size) + if amount_to_pad == 0: + amount_to_pad = self.block_size + # 获得补位所用的字符 + pad = bytes([amount_to_pad]) + # 确保text是bytes类型 + if isinstance(text, str): + text = text.encode("utf-8") + return text + pad * amount_to_pad + + def decode(self, decrypted): + """删除解密后明文的补位字符 + @param decrypted: 解密后的明文 + @return: 删除补位字符后的明文 + """ + pad = ord(decrypted[-1]) + if pad < 1 or pad > 32: + pad = 0 + return decrypted[:-pad] + + +class Prpcrypt(object): + """提供接收和推送给企业微信消息的加解密接口""" + + def __init__(self, key): + # self.key = base64.b64decode(key+"=") + self.key = key + # 设置加解密模式为AES的CBC模式 + self.mode = AES.MODE_CBC + + def encrypt(self, text, receiveid): + """对明文进行加密 + @param text: 需要加密的明文 + @return: 加密得到的字符串 + """ + # 16位随机字符串添加到明文开头 + text = text.encode() + text = ( + self.get_random_str() + + struct.pack("I", socket.htonl(len(text))) + + text + + receiveid.encode() + ) + + # 使用自定义的填充方式对明文进行补位填充 + pkcs7 = PKCS7Encoder() + text = pkcs7.encode(text) + # 加密 + cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore + try: + ciphertext = cryptor.encrypt(text) + # 使用BASE64对加密后的字符串进行编码 + return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext) + except Exception as e: + logger = logging.getLogger("astrbot") + logger.error(e) + return ierror.WXBizMsgCrypt_EncryptAES_Error, None + + def decrypt(self, text, receiveid): + """对解密后的明文进行补位删除 + @param text: 密文 + @return: 删除填充补位后的明文 + """ + try: + cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore + # 使用BASE64对密文进行解码,然后AES-CBC解密 + plain_text = cryptor.decrypt(base64.b64decode(text)) + except Exception as e: + print(e) + return ierror.WXBizMsgCrypt_DecryptAES_Error, None + try: + pad = plain_text[-1] + # 去掉补位字符串 + # pkcs7 = PKCS7Encoder() + # plain_text = pkcs7.encode(plain_text) + # 去除16位随机字符串 + content = plain_text[16:-pad] + json_len = socket.ntohl(struct.unpack("I", content[:4])[0]) + json_content = content[4 : json_len + 4].decode("utf-8") + from_receiveid = content[json_len + 4 :].decode("utf-8") + except Exception as e: + print(e) + return ierror.WXBizMsgCrypt_IllegalBuffer, None + if from_receiveid != receiveid: + print("receiveid not match", receiveid, from_receiveid) + return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None + return 0, json_content + + def get_random_str(self): + """随机生成16位字符串 + @return: 16位字符串 + """ + return str(random.randint(1000000000000000, 9999999999999999)).encode() + + +class WXBizJsonMsgCrypt(object): + # 构造函数 + def __init__(self, sToken, sEncodingAESKey, sReceiveId): + try: + self.key = base64.b64decode(sEncodingAESKey + "=") + assert len(self.key) == 32 + except Exception as e: + throw_exception(f"[error]: EncodingAESKey invalid: {e}", FormatException) + # return ierror.WXBizMsgCrypt_IllegalAesKey,None + self.m_sToken = sToken + self.m_sReceiveId = sReceiveId + + # 验证URL + # @param sMsgSignature: 签名串,对应URL参数的msg_signature + # @param sTimeStamp: 时间戳,对应URL参数的timestamp + # @param sNonce: 随机串,对应URL参数的nonce + # @param sEchoStr: 随机串,对应URL参数的echostr + # @param sReplyEchoStr: 解密之后的echostr,当return返回0时有效 + # @return:成功0,失败返回对应的错误码 + + def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr): + sha1 = SHA1() + ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr) + if ret != 0: + return ret, None + if not signature == sMsgSignature: + return ierror.WXBizMsgCrypt_ValidateSignature_Error, None + pc = Prpcrypt(self.key) + ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId) + return ret, sReplyEchoStr + + def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None): + # 将企业回复用户的消息加密打包 + # @param sReplyMsg: 企业号待回复用户的消息,json格式的字符串 + # @param sTimeStamp: 时间戳,可以自己生成,也可以用URL参数的timestamp,如为None则自动用当前时间 + # @param sNonce: 随机串,可以自己生成,也可以用URL参数的nonce + # sEncryptMsg: 加密后的可以直接回复用户的密文,包括msg_signature, timestamp, nonce, encrypt的json格式的字符串, + # return:成功0,sEncryptMsg,失败返回对应的错误码None + pc = Prpcrypt(self.key) + ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId) + encrypt = encrypt.decode("utf-8") # type: ignore + if ret != 0: + return ret, None + if timestamp is None: + timestamp = str(int(time.time())) + # 生成安全签名 + sha1 = SHA1() + ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt) + if ret != 0: + return ret, None + jsonParse = JsonParse() + return ret, jsonParse.generate(encrypt, signature, timestamp, sNonce) + + def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce): + # 检验消息的真实性,并且获取解密后的明文 + # @param sMsgSignature: 签名串,对应URL参数的msg_signature + # @param sTimeStamp: 时间戳,对应URL参数的timestamp + # @param sNonce: 随机串,对应URL参数的nonce + # @param sPostData: 密文,对应POST请求的数据 + # json_content: 解密后的原文,当return返回0时有效 + # @return: 成功0,失败返回对应的错误码 + # 验证安全签名 + jsonParse = JsonParse() + ret, encrypt = jsonParse.extract(sPostData) + if ret != 0: + return ret, None + sha1 = SHA1() + ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt) + if ret != 0: + return ret, None + if not signature == sMsgSignature: + print("signature not match") + print(signature) + return ierror.WXBizMsgCrypt_ValidateSignature_Error, None + pc = Prpcrypt(self.key) + ret, json_content = pc.decrypt(encrypt, self.m_sReceiveId) + return ret, json_content diff --git a/astrbot/core/platform/sources/wecom_ai_bot/__init__.py b/astrbot/core/platform/sources/wecom_ai_bot/__init__.py new file mode 100644 index 000000000..7da900030 --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/__init__.py @@ -0,0 +1,17 @@ +""" +企业微信智能机器人平台适配器包 +""" + +from .wecomai_adapter import WecomAIBotAdapter +from .wecomai_api import WecomAIBotAPIClient +from .wecomai_event import WecomAIBotMessageEvent +from .wecomai_server import WecomAIBotServer +from .wecomai_utils import WecomAIBotConstants + +__all__ = [ + "WecomAIBotAdapter", + "WecomAIBotAPIClient", + "WecomAIBotMessageEvent", + "WecomAIBotServer", + "WecomAIBotConstants", +] diff --git a/astrbot/core/platform/sources/wecom_ai_bot/ierror.py b/astrbot/core/platform/sources/wecom_ai_bot/ierror.py new file mode 100644 index 000000000..cc1bf221e --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/ierror.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +######################################################################### +# Author: jonyqin +# Created Time: Thu 11 Sep 2014 01:53:58 PM CST +# File Name: ierror.py +# Description:定义错误码含义 +######################################################################### +WXBizMsgCrypt_OK = 0 +WXBizMsgCrypt_ValidateSignature_Error = -40001 +WXBizMsgCrypt_ParseJson_Error = -40002 +WXBizMsgCrypt_ComputeSignature_Error = -40003 +WXBizMsgCrypt_IllegalAesKey = -40004 +WXBizMsgCrypt_ValidateCorpid_Error = -40005 +WXBizMsgCrypt_EncryptAES_Error = -40006 +WXBizMsgCrypt_DecryptAES_Error = -40007 +WXBizMsgCrypt_IllegalBuffer = -40008 +WXBizMsgCrypt_EncodeBase64_Error = -40009 +WXBizMsgCrypt_DecodeBase64_Error = -40010 +WXBizMsgCrypt_GenReturnJson_Error = -40011 diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py new file mode 100644 index 000000000..830d8de58 --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -0,0 +1,445 @@ +""" +企业微信智能机器人平台适配器 +基于企业微信智能机器人 API 的消息平台适配器,支持 HTTP 回调 +参考webchat_adapter.py的队列机制,实现异步消息处理和流式响应 +""" + +import time +import asyncio +import uuid +import hashlib +import base64 +from typing import Awaitable, Any, Dict, Optional, Callable + + +from astrbot.api.platform import ( + Platform, + AstrBotMessage, + MessageMember, + MessageType, + PlatformMetadata, +) +from astrbot.api.event import MessageChain +from astrbot.api.message_components import Plain, At, Image +from astrbot.api import logger +from astrbot.core.platform.astr_message_event import MessageSesion +from ...register import register_platform_adapter + +from .wecomai_api import ( + WecomAIBotAPIClient, + WecomAIBotMessageParser, + WecomAIBotStreamMessageBuilder, +) +from .wecomai_event import WecomAIBotMessageEvent +from .wecomai_server import WecomAIBotServer +from .wecomai_queue_mgr import wecomai_queue_mgr, WecomAIQueueMgr +from .wecomai_utils import ( + WecomAIBotConstants, + format_session_id, + generate_random_string, + process_encrypted_image, +) + + +class WecomAIQueueListener: + """企业微信智能机器人队列监听器,参考webchat的QueueListener设计""" + + def __init__( + self, queue_mgr: WecomAIQueueMgr, callback: Callable[[dict], Awaitable[None]] + ) -> None: + self.queue_mgr = queue_mgr + self.callback = callback + self.running_tasks = set() + + async def listen_to_queue(self, session_id: str): + """监听特定会话的队列""" + queue = self.queue_mgr.get_or_create_queue(session_id) + while True: + try: + data = await queue.get() + await self.callback(data) + except Exception as e: + logger.error(f"处理会话 {session_id} 消息时发生错误: {e}") + break + + async def run(self): + """监控新会话队列并启动监听器""" + monitored_sessions = set() + + while True: + # 检查新会话 + current_sessions = set(self.queue_mgr.queues.keys()) + new_sessions = current_sessions - monitored_sessions + + # 为新会话启动监听器 + for session_id in new_sessions: + task = asyncio.create_task(self.listen_to_queue(session_id)) + self.running_tasks.add(task) + task.add_done_callback(self.running_tasks.discard) + monitored_sessions.add(session_id) + logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}") + + # 清理已不存在的会话 + removed_sessions = monitored_sessions - current_sessions + monitored_sessions -= removed_sessions + + # 清理过期的待处理响应 + self.queue_mgr.cleanup_expired_responses() + + await asyncio.sleep(1) # 每秒检查一次新会话 + + +@register_platform_adapter( + "wecom_ai_bot", "企业微信智能机器人适配器,支持 HTTP 回调接收消息" +) +class WecomAIBotAdapter(Platform): + """企业微信智能机器人适配器""" + + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + ) -> None: + super().__init__(event_queue) + + self.config = platform_config + self.settings = platform_settings + + # 初始化配置参数 + self.token = self.config["token"] + self.encoding_aes_key = self.config["encoding_aes_key"] + self.port = int(self.config["port"]) + self.host = self.config.get("callback_server_host", "0.0.0.0") + self.bot_name = self.config.get("wecom_ai_bot_name", "") + self.initial_respond_text = self.config.get( + "wecomaibot_init_respond_text", "💭 思考中..." + ) + self.friend_message_welcome_text = self.config.get( + "wecomaibot_friend_message_welcome_text", "" + ) + + # 平台元数据 + self.metadata = PlatformMetadata( + name="wecom_ai_bot", + description="企业微信智能机器人适配器,支持 HTTP 回调接收消息", + id=self.config.get("id", "wecom_ai_bot"), + ) + + # 初始化 API 客户端 + self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key) + + # 初始化 HTTP 服务器 + self.server = WecomAIBotServer( + host=self.host, + port=self.port, + api_client=self.api_client, + message_handler=self._process_message, + ) + + # 事件循环和关闭信号 + self.shutdown_event = asyncio.Event() + + # 队列监听器 + self.queue_listener = WecomAIQueueListener( + wecomai_queue_mgr, self._handle_queued_message + ) + + async def _handle_queued_message(self, data: dict): + """处理队列中的消息,类似webchat的callback""" + try: + abm = await self.convert_message(data) + await self.handle_msg(abm) + except Exception as e: + logger.error(f"处理队列消息时发生异常: {e}") + + async def _process_message( + self, message_data: Dict[str, Any], callback_params: Dict[str, str] + ) -> Optional[str]: + """处理接收到的消息 + + Args: + message_data: 解密后的消息数据 + callback_params: 回调参数 (nonce, timestamp) + + Returns: + 加密后的响应消息,无需响应时返回 None + """ + msgtype = message_data.get("msgtype") + if not msgtype: + logger.warning(f"消息类型未知,忽略: {message_data}") + return None + session_id = self._extract_session_id(message_data) + if msgtype in ("text", "image", "mixed"): + # user sent a text / image / mixed message + try: + # create a brand-new unique stream_id for this message session + stream_id = f"{session_id}_{generate_random_string(10)}" + await self._enqueue_message( + message_data, callback_params, stream_id, session_id + ) + wecomai_queue_mgr.set_pending_response(stream_id, callback_params) + + resp = WecomAIBotStreamMessageBuilder.make_text_stream( + stream_id, self.initial_respond_text, False + ) + return await self.api_client.encrypt_message( + resp, callback_params["nonce"], callback_params["timestamp"] + ) + except Exception as e: + logger.error("处理消息时发生异常: %s", e) + return None + elif msgtype == "stream": + # wechat server is requesting for updates of a stream + stream_id = message_data["stream"]["id"] + if not wecomai_queue_mgr.has_back_queue(stream_id): + logger.error(f"Cannot find back queue for stream_id: {stream_id}") + + # 返回结束标志,告诉微信服务器流已结束 + end_message = WecomAIBotStreamMessageBuilder.make_text_stream( + stream_id, "", True + ) + resp = await self.api_client.encrypt_message( + end_message, + callback_params["nonce"], + callback_params["timestamp"], + ) + return resp + queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + if queue.empty(): + logger.debug( + f"No new messages in back queue for stream_id: {stream_id}" + ) + return None + + # aggregate all delta chains in the back queue + latest_plain_content = "" + image_base64 = [] + finish = False + while not queue.empty(): + msg = await queue.get() + if msg["type"] == "plain": + latest_plain_content = msg["data"] or "" + elif msg["type"] == "image": + image_base64.append(msg["image_data"]) + elif msg["type"] == "end": + # stream end + finish = True + wecomai_queue_mgr.remove_queues(stream_id) + break + else: + pass + logger.debug( + f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}" + ) + if latest_plain_content or image_base64: + msg_items = [] + if finish and image_base64: + for img_b64 in image_base64: + # get md5 of image + img_data = base64.b64decode(img_b64) + img_md5 = hashlib.md5(img_data).hexdigest() + msg_items.append( + { + "msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE, + "image": {"base64": img_b64, "md5": img_md5}, + } + ) + image_base64 = [] + + plain_message = WecomAIBotStreamMessageBuilder.make_mixed_stream( + stream_id, latest_plain_content, msg_items, finish + ) + encrypted_message = await self.api_client.encrypt_message( + plain_message, + callback_params["nonce"], + callback_params["timestamp"], + ) + if encrypted_message: + logger.debug( + f"Stream message sent successfully, stream_id: {stream_id}" + ) + else: + logger.error("消息加密失败") + return encrypted_message + return None + elif msgtype == "event": + event = message_data.get("event") + if event == "enter_chat" and self.friend_message_welcome_text: + # 用户进入会话,发送欢迎消息 + try: + resp = WecomAIBotStreamMessageBuilder.make_text( + self.friend_message_welcome_text + ) + return await self.api_client.encrypt_message( + resp, + callback_params["nonce"], + callback_params["timestamp"], + ) + except Exception as e: + logger.error("处理欢迎消息时发生异常: %s", e) + return None + pass + + def _extract_session_id(self, message_data: Dict[str, Any]) -> str: + """从消息数据中提取会话ID""" + user_id = message_data.get("from", {}).get("userid", "default_user") + return format_session_id("wecomai", user_id) + + async def _enqueue_message( + self, + message_data: Dict[str, Any], + callback_params: Dict[str, str], + stream_id: str, + session_id: str, + ): + """将消息放入队列进行异步处理""" + input_queue = wecomai_queue_mgr.get_or_create_queue(stream_id) + _ = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + message_payload = { + "message_data": message_data, + "callback_params": callback_params, + "session_id": session_id, + "stream_id": stream_id, + } + await input_queue.put(message_payload) + logger.debug(f"[WecomAI] 消息已入队: {stream_id}") + + async def convert_message(self, payload: dict) -> AstrBotMessage: + """转换队列中的消息数据为AstrBotMessage,类似webchat的convert_message""" + message_data = payload["message_data"] + session_id = payload["session_id"] + # callback_params = payload["callback_params"] # 保留但暂时不使用 + + # 解析消息内容 + msgtype = message_data.get("msgtype") + content = "" + image_base64 = [] + + _img_url_to_process = [] + msg_items = [] + + if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT: + content = WecomAIBotMessageParser.parse_text_message(message_data) + elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE: + _img_url_to_process.append( + WecomAIBotMessageParser.parse_image_message(message_data) + ) + elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED: + # 提取混合消息中的文本内容 + msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data) + text_parts = [] + for item in msg_items or []: + if item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_TEXT: + text_content = item.get("text", {}).get("content", "") + if text_content: + text_parts.append(text_content) + elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE: + image_url = item.get("image", {}).get("url", "") + if image_url: + _img_url_to_process.append(image_url) + content = " ".join(text_parts) if text_parts else "" + else: + content = f"[{msgtype}消息]" + + # 并行处理图片下载和解密 + if _img_url_to_process: + tasks = [ + process_encrypted_image(url, self.encoding_aes_key) + for url in _img_url_to_process + ] + results = await asyncio.gather(*tasks) + for success, result in results: + if success: + image_base64.append(result) + else: + logger.error(f"处理加密图片失败: {result}") + + # 构建 AstrBotMessage + abm = AstrBotMessage() + abm.self_id = self.bot_name + abm.message_str = content or "[未知消息]" + abm.message_id = str(uuid.uuid4()) + abm.timestamp = int(time.time()) + abm.raw_message = payload + + # 发送者信息 + abm.sender = MessageMember( + user_id=message_data.get("from", {}).get("userid", "unknown"), + nickname=message_data.get("from", {}).get("userid", "unknown"), + ) + + # 消息类型 + abm.type = ( + MessageType.GROUP_MESSAGE + if message_data.get("chattype") == "group" + else MessageType.FRIEND_MESSAGE + ) + abm.session_id = session_id + + # 消息内容 + abm.message = [] + + # 处理 At + if self.bot_name and f"@{self.bot_name}" in abm.message_str: + abm.message_str = abm.message_str.replace(f"@{self.bot_name}", "").strip() + abm.message.append(At(qq=self.bot_name, name=self.bot_name)) + abm.message.append(Plain(abm.message_str)) + if image_base64: + for img_b64 in image_base64: + abm.message.append(Image.fromBase64(img_b64)) + + logger.debug(f"WecomAIAdapter: {abm.message}") + return abm + + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ): + """通过会话发送消息""" + # 企业微信智能机器人主要通过回调响应,这里记录日志 + logger.info("会话发送消息: %s -> %s", session.session_id, message_chain) + await super().send_by_session(session, message_chain) + + def run(self) -> Awaitable[Any]: + """运行适配器,同时启动HTTP服务器和队列监听器""" + logger.info("启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port) + + async def run_both(): + # 同时运行HTTP服务器和队列监听器 + await asyncio.gather( + self.server.start_server(), + self.queue_listener.run(), + ) + + return run_both() + + async def terminate(self): + """终止适配器""" + logger.info("企业微信智能机器人适配器正在关闭...") + self.shutdown_event.set() + await self.server.shutdown() + + def meta(self) -> PlatformMetadata: + """获取平台元数据""" + return self.metadata + + async def handle_msg(self, message: AstrBotMessage): + """处理消息,创建消息事件并提交到事件队列""" + try: + message_event = WecomAIBotMessageEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=self.meta(), + session_id=message.session_id, + api_client=self.api_client, + ) + + self.commit_event(message_event) + + except Exception as e: + logger.error("处理消息时发生异常: %s", e) + + def get_client(self) -> WecomAIBotAPIClient: + """获取 API 客户端""" + return self.api_client + + def get_server(self) -> WecomAIBotServer: + """获取 HTTP 服务器实例""" + return self.server diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py new file mode 100644 index 000000000..540bf06b6 --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py @@ -0,0 +1,378 @@ +""" +企业微信智能机器人 API 客户端 +处理消息加密解密、API 调用等 +""" + +import json +import base64 +import hashlib +from typing import Dict, Any, Optional, Tuple, Union +from Crypto.Cipher import AES +import aiohttp + +from .WXBizJsonMsgCrypt import WXBizJsonMsgCrypt +from .wecomai_utils import WecomAIBotConstants +from astrbot import logger + + +class WecomAIBotAPIClient: + """企业微信智能机器人 API 客户端""" + + def __init__(self, token: str, encoding_aes_key: str): + """初始化 API 客户端 + + Args: + token: 企业微信机器人 Token + encoding_aes_key: 消息加密密钥 + """ + self.token = token + self.encoding_aes_key = encoding_aes_key + self.wxcpt = WXBizJsonMsgCrypt(token, encoding_aes_key, "") # receiveid 为空串 + + async def decrypt_message( + self, encrypted_data: bytes, msg_signature: str, timestamp: str, nonce: str + ) -> Tuple[int, Optional[Dict[str, Any]]]: + """解密企业微信消息 + + Args: + encrypted_data: 加密的消息数据 + msg_signature: 消息签名 + timestamp: 时间戳 + nonce: 随机数 + + Returns: + (错误码, 解密后的消息数据字典) + """ + try: + ret, decrypted_msg = self.wxcpt.DecryptMsg( + encrypted_data, msg_signature, timestamp, nonce + ) + + if ret != WecomAIBotConstants.SUCCESS: + logger.error(f"消息解密失败,错误码: {ret}") + return ret, None + + # 解析 JSON + if decrypted_msg: + try: + message_data = json.loads(decrypted_msg) + logger.debug(f"解密成功,消息内容: {message_data}") + return WecomAIBotConstants.SUCCESS, message_data + except json.JSONDecodeError as e: + logger.error(f"JSON 解析失败: {e}, 原始消息: {decrypted_msg}") + return WecomAIBotConstants.PARSE_XML_ERROR, None + else: + logger.error("解密消息为空") + return WecomAIBotConstants.DECRYPT_ERROR, None + + except Exception as e: + logger.error(f"解密过程发生异常: {e}") + return WecomAIBotConstants.DECRYPT_ERROR, None + + async def encrypt_message( + self, plain_message: str, nonce: str, timestamp: str + ) -> Optional[str]: + """加密消息 + + Args: + plain_message: 明文消息 + nonce: 随机数 + timestamp: 时间戳 + + Returns: + 加密后的消息,失败时返回 None + """ + try: + ret, encrypted_msg = self.wxcpt.EncryptMsg(plain_message, nonce, timestamp) + + if ret != WecomAIBotConstants.SUCCESS: + logger.error(f"消息加密失败,错误码: {ret}") + return None + + logger.debug("消息加密成功") + return encrypted_msg + + except Exception as e: + logger.error(f"加密过程发生异常: {e}") + return None + + def verify_url( + self, msg_signature: str, timestamp: str, nonce: str, echostr: str + ) -> str: + """验证回调 URL + + Args: + msg_signature: 消息签名 + timestamp: 时间戳 + nonce: 随机数 + echostr: 验证字符串 + + Returns: + 验证结果字符串 + """ + try: + ret, echo_result = self.wxcpt.VerifyURL( + msg_signature, timestamp, nonce, echostr + ) + + if ret != WecomAIBotConstants.SUCCESS: + logger.error(f"URL 验证失败,错误码: {ret}") + return "verify fail" + + logger.info("URL 验证成功") + return echo_result if echo_result else "verify fail" + + except Exception as e: + logger.error(f"URL 验证发生异常: {e}") + return "verify fail" + + async def process_encrypted_image( + self, image_url: str, aes_key_base64: Optional[str] = None + ) -> Tuple[bool, Union[bytes, str]]: + """下载并解密加密图片 + + Args: + image_url: 加密图片的 URL + aes_key_base64: Base64 编码的 AES 密钥,如果为 None 则使用实例的密钥 + + Returns: + (是否成功, 图片数据或错误信息) + """ + try: + # 下载图片 + logger.info(f"开始下载加密图片: {image_url}") + + async with aiohttp.ClientSession() as session: + async with session.get(image_url, timeout=15) as response: + if response.status != 200: + error_msg = f"图片下载失败,状态码: {response.status}" + logger.error(error_msg) + return False, error_msg + + encrypted_data = await response.read() + logger.info(f"图片下载成功,大小: {len(encrypted_data)} 字节") + + # 准备解密密钥 + if aes_key_base64 is None: + aes_key_base64 = self.encoding_aes_key + + if not aes_key_base64: + raise ValueError("AES 密钥不能为空") + + # Base64 解码密钥 + aes_key = base64.b64decode( + aes_key_base64 + "=" * (-len(aes_key_base64) % 4) + ) + if len(aes_key) != 32: + raise ValueError("无效的 AES 密钥长度: 应为 32 字节") + + iv = aes_key[:16] # 初始向量为密钥前 16 字节 + + # 解密图片数据 + cipher = AES.new(aes_key, AES.MODE_CBC, iv) + decrypted_data = cipher.decrypt(encrypted_data) + + # 去除 PKCS#7 填充 + pad_len = decrypted_data[-1] + if pad_len > 32: # AES-256 块大小为 32 字节 + raise ValueError("无效的填充长度 (大于32字节)") + + decrypted_data = decrypted_data[:-pad_len] + logger.info(f"图片解密成功,解密后大小: {len(decrypted_data)} 字节") + + return True, decrypted_data + + except aiohttp.ClientError as e: + error_msg = f"图片下载失败: {str(e)}" + logger.error(error_msg) + return False, error_msg + + except ValueError as e: + error_msg = f"参数错误: {str(e)}" + logger.error(error_msg) + return False, error_msg + + except Exception as e: + error_msg = f"图片处理异常: {str(e)}" + logger.error(error_msg) + return False, error_msg + + +class WecomAIBotStreamMessageBuilder: + """企业微信智能机器人流消息构建器""" + + @staticmethod + def make_text_stream(stream_id: str, content: str, finish: bool = False) -> str: + """构建文本流消息 + + Args: + stream_id: 流 ID + content: 文本内容 + finish: 是否结束 + + Returns: + JSON 格式的流消息字符串 + """ + plain = { + "msgtype": WecomAIBotConstants.MSG_TYPE_STREAM, + "stream": {"id": stream_id, "finish": finish, "content": content}, + } + return json.dumps(plain, ensure_ascii=False) + + @staticmethod + def make_image_stream( + stream_id: str, image_data: bytes, finish: bool = False + ) -> str: + """构建图片流消息 + + Args: + stream_id: 流 ID + image_data: 图片二进制数据 + finish: 是否结束 + + Returns: + JSON 格式的流消息字符串 + """ + image_md5 = hashlib.md5(image_data).hexdigest() + image_base64 = base64.b64encode(image_data).decode("utf-8") + + plain = { + "msgtype": WecomAIBotConstants.MSG_TYPE_STREAM, + "stream": { + "id": stream_id, + "finish": finish, + "msg_item": [ + { + "msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE, + "image": {"base64": image_base64, "md5": image_md5}, + } + ], + }, + } + return json.dumps(plain, ensure_ascii=False) + + @staticmethod + def make_mixed_stream( + stream_id: str, content: str, msg_items: list, finish: bool = False + ) -> str: + """构建混合类型流消息 + + Args: + stream_id: 流 ID + content: 文本内容 + msg_items: 消息项列表 + finish: 是否结束 + + Returns: + JSON 格式的流消息字符串 + """ + plain = { + "msgtype": WecomAIBotConstants.MSG_TYPE_STREAM, + "stream": {"id": stream_id, "finish": finish, "msg_item": msg_items}, + } + if content: + plain["stream"]["content"] = content + return json.dumps(plain, ensure_ascii=False) + + @staticmethod + def make_text(content: str) -> str: + """构建文本消息 + + Args: + content: 文本内容 + + Returns: + JSON 格式的文本消息字符串 + """ + plain = {"msgtype": "text", "text": {"content": content}} + return json.dumps(plain, ensure_ascii=False) + + +class WecomAIBotMessageParser: + """企业微信智能机器人消息解析器""" + + @staticmethod + def parse_text_message(data: Dict[str, Any]) -> Optional[str]: + """解析文本消息 + + Args: + data: 消息数据 + + Returns: + 文本内容,解析失败返回 None + """ + try: + return data.get("text", {}).get("content") + except (KeyError, TypeError): + logger.warning("文本消息解析失败") + return None + + @staticmethod + def parse_image_message(data: Dict[str, Any]) -> Optional[str]: + """解析图片消息 + + Args: + data: 消息数据 + + Returns: + 图片 URL,解析失败返回 None + """ + try: + return data.get("image", {}).get("url") + except (KeyError, TypeError): + logger.warning("图片消息解析失败") + return None + + @staticmethod + def parse_stream_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """解析流消息 + + Args: + data: 消息数据 + + Returns: + 流消息数据,解析失败返回 None + """ + try: + stream_data = data.get("stream", {}) + return { + "id": stream_data.get("id"), + "finish": stream_data.get("finish"), + "content": stream_data.get("content"), + "msg_item": stream_data.get("msg_item", []), + } + except (KeyError, TypeError): + logger.warning("流消息解析失败") + return None + + @staticmethod + def parse_mixed_message(data: Dict[str, Any]) -> Optional[list]: + """解析混合消息 + + Args: + data: 消息数据 + + Returns: + 消息项列表,解析失败返回 None + """ + try: + return data.get("mixed", {}).get("msg_item", []) + except (KeyError, TypeError): + logger.warning("混合消息解析失败") + return None + + @staticmethod + def parse_event_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """解析事件消息 + + Args: + data: 消息数据 + + Returns: + 事件数据,解析失败返回 None + """ + try: + return data.get("event", {}) + except (KeyError, TypeError): + logger.warning("事件消息解析失败") + return None diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py new file mode 100644 index 000000000..2d7ec91ca --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -0,0 +1,149 @@ +""" +企业微信智能机器人事件处理模块,处理消息事件的发送和接收 +""" + +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.message_components import ( + Image, + Plain, +) +from astrbot.api import logger + +from .wecomai_api import WecomAIBotAPIClient +from .wecomai_queue_mgr import wecomai_queue_mgr + + +class WecomAIBotMessageEvent(AstrMessageEvent): + """企业微信智能机器人消息事件""" + + def __init__( + self, + message_str: str, + message_obj, + platform_meta, + session_id: str, + api_client: WecomAIBotAPIClient, + ): + """初始化消息事件 + + Args: + message_str: 消息字符串 + message_obj: 消息对象 + platform_meta: 平台元数据 + session_id: 会话 ID + api_client: API 客户端 + """ + super().__init__(message_str, message_obj, platform_meta, session_id) + self.api_client = api_client + + @staticmethod + async def _send( + message_chain: MessageChain, + stream_id: str, + streaming: bool = False, + ): + back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + + if not message_chain: + await back_queue.put( + { + "type": "end", + "data": "", + "streaming": False, + } + ) + return "" + + data = "" + for comp in message_chain.chain: + if isinstance(comp, Plain): + data = comp.text + await back_queue.put( + { + "type": "plain", + "data": data, + "streaming": streaming, + "session_id": stream_id, + } + ) + elif isinstance(comp, Image): + # 处理图片消息 + try: + image_base64 = await comp.convert_to_base64() + if image_base64: + await back_queue.put( + { + "type": "image", + "image_data": image_base64, + "streaming": streaming, + "session_id": stream_id, + } + ) + else: + logger.warning("图片数据为空,跳过") + except Exception as e: + logger.error("处理图片消息失败: %s", e) + else: + logger.warning(f"[WecomAI] 不支持的消息组件类型: {type(comp)}, 跳过") + + return data + + async def send(self, message: MessageChain): + """发送消息""" + raw = self.message_obj.raw_message + assert isinstance(raw, dict), ( + "wecom_ai_bot platform event raw_message should be a dict" + ) + stream_id = raw.get("stream_id", self.session_id) + await WecomAIBotMessageEvent._send(message, stream_id) + await super().send(message) + + async def send_streaming(self, generator, use_fallback=False): + """流式发送消息,参考webchat的send_streaming设计""" + final_data = "" + raw = self.message_obj.raw_message + assert isinstance(raw, dict), ( + "wecom_ai_bot platform event raw_message should be a dict" + ) + stream_id = raw.get("stream_id", self.session_id) + back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + + # 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送 + increment_plain = "" + async for chain in generator: + # 累积增量内容,并改写 Plain 段 + chain.squash_plain() + for comp in chain.chain: + if isinstance(comp, Plain): + comp.text = increment_plain + comp.text + increment_plain = comp.text + break + + if chain.type == "break" and final_data: + # 分割符 + await back_queue.put( + { + "type": "break", # break means a segment end + "data": final_data, + "streaming": True, + "session_id": self.session_id, + } + ) + final_data = "" + continue + + final_data += await WecomAIBotMessageEvent._send( + chain, + stream_id=stream_id, + streaming=True, + ) + + await back_queue.put( + { + "type": "complete", # complete means we return the final result + "data": final_data, + "streaming": True, + "session_id": self.session_id, + } + ) + await super().send_streaming(generator, use_fallback) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py new file mode 100644 index 000000000..1367301c9 --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py @@ -0,0 +1,148 @@ +""" +企业微信智能机器人队列管理器 +参考 webchat_queue_mgr.py,为企业微信智能机器人实现队列机制 +支持异步消息处理和流式响应 +""" + +import asyncio +from typing import Dict, Any, Optional +from astrbot.api import logger + + +class WecomAIQueueMgr: + """企业微信智能机器人队列管理器""" + + def __init__(self) -> None: + self.queues: Dict[str, asyncio.Queue] = {} + """StreamID 到输入队列的映射 - 用于接收用户消息""" + + self.back_queues: Dict[str, asyncio.Queue] = {} + """StreamID 到输出队列的映射 - 用于发送机器人响应""" + + self.pending_responses: Dict[str, Dict[str, Any]] = {} + """待处理的响应缓存,用于流式响应""" + + def get_or_create_queue(self, session_id: str) -> asyncio.Queue: + """获取或创建指定会话的输入队列 + + Args: + session_id: 会话ID + + Returns: + 输入队列实例 + """ + if session_id not in self.queues: + self.queues[session_id] = asyncio.Queue() + logger.debug(f"[WecomAI] 创建输入队列: {session_id}") + return self.queues[session_id] + + def get_or_create_back_queue(self, session_id: str) -> asyncio.Queue: + """获取或创建指定会话的输出队列 + + Args: + session_id: 会话ID + + Returns: + 输出队列实例 + """ + if session_id not in self.back_queues: + self.back_queues[session_id] = asyncio.Queue() + logger.debug(f"[WecomAI] 创建输出队列: {session_id}") + return self.back_queues[session_id] + + def remove_queues(self, session_id: str): + """移除指定会话的所有队列 + + Args: + session_id: 会话ID + """ + if session_id in self.queues: + del self.queues[session_id] + logger.debug(f"[WecomAI] 移除输入队列: {session_id}") + + if session_id in self.back_queues: + del self.back_queues[session_id] + logger.debug(f"[WecomAI] 移除输出队列: {session_id}") + + if session_id in self.pending_responses: + del self.pending_responses[session_id] + logger.debug(f"[WecomAI] 移除待处理响应: {session_id}") + + def has_queue(self, session_id: str) -> bool: + """检查是否存在指定会话的队列 + + Args: + session_id: 会话ID + + Returns: + 是否存在队列 + """ + return session_id in self.queues + + def has_back_queue(self, session_id: str) -> bool: + """检查是否存在指定会话的输出队列 + + Args: + session_id: 会话ID + + Returns: + 是否存在输出队列 + """ + return session_id in self.back_queues + + def set_pending_response(self, session_id: str, callback_params: Dict[str, str]): + """设置待处理的响应参数 + + Args: + session_id: 会话ID + callback_params: 回调参数(nonce, timestamp等) + """ + self.pending_responses[session_id] = { + "callback_params": callback_params, + "timestamp": asyncio.get_event_loop().time(), + } + logger.debug(f"[WecomAI] 设置待处理响应: {session_id}") + + def get_pending_response(self, session_id: str) -> Optional[Dict[str, Any]]: + """获取待处理的响应参数 + + Args: + session_id: 会话ID + + Returns: + 响应参数,如果不存在则返回None + """ + return self.pending_responses.get(session_id) + + def cleanup_expired_responses(self, max_age_seconds: int = 300): + """清理过期的待处理响应 + + Args: + max_age_seconds: 最大存活时间(秒) + """ + current_time = asyncio.get_event_loop().time() + expired_sessions = [] + + for session_id, response_data in self.pending_responses.items(): + if current_time - response_data["timestamp"] > max_age_seconds: + expired_sessions.append(session_id) + + for session_id in expired_sessions: + del self.pending_responses[session_id] + logger.debug(f"[WecomAI] 清理过期响应: {session_id}") + + def get_stats(self) -> Dict[str, int]: + """获取队列统计信息 + + Returns: + 统计信息字典 + """ + return { + "input_queues": len(self.queues), + "output_queues": len(self.back_queues), + "pending_responses": len(self.pending_responses), + } + + +# 全局队列管理器实例 +wecomai_queue_mgr = WecomAIQueueMgr() diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py new file mode 100644 index 000000000..bbb69d041 --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -0,0 +1,166 @@ +""" +企业微信智能机器人 HTTP 服务器 +处理企业微信智能机器人的 HTTP 回调请求 +""" + +import asyncio +from typing import Dict, Any, Optional, Callable + +import quart +from astrbot.api import logger + +from .wecomai_api import WecomAIBotAPIClient +from .wecomai_utils import WecomAIBotConstants + + +class WecomAIBotServer: + """企业微信智能机器人 HTTP 服务器""" + + def __init__( + self, + host: str, + port: int, + api_client: WecomAIBotAPIClient, + message_handler: Optional[ + Callable[[Dict[str, Any], Dict[str, str]], Any] + ] = None, + ): + """初始化服务器 + + Args: + host: 监听地址 + port: 监听端口 + api_client: API客户端实例 + message_handler: 消息处理回调函数 + """ + self.host = host + self.port = port + self.api_client = api_client + self.message_handler = message_handler + + self.app = quart.Quart(__name__) + self._setup_routes() + + self.shutdown_event = asyncio.Event() + + def _setup_routes(self): + """设置 Quart 路由""" + + # 使用 Quart 的 add_url_rule 方法添加路由 + self.app.add_url_rule( + "/webhook/wecom-ai-bot", + view_func=self.verify_url, + methods=["GET"], + ) + + self.app.add_url_rule( + "/webhook/wecom-ai-bot", + view_func=self.handle_message, + methods=["POST"], + ) + + async def verify_url(self): + """验证回调 URL""" + args = quart.request.args + msg_signature = args.get("msg_signature") + timestamp = args.get("timestamp") + nonce = args.get("nonce") + echostr = args.get("echostr") + + if not all([msg_signature, timestamp, nonce, echostr]): + logger.error("URL 验证参数缺失") + return "verify fail", 400 + + # 类型检查确保不为 None + assert msg_signature is not None + assert timestamp is not None + assert nonce is not None + assert echostr is not None + + logger.info("收到企业微信智能机器人 WebHook URL 验证请求。") + result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr) + return result, 200, {"Content-Type": "text/plain"} + + async def handle_message(self): + """处理消息回调""" + args = quart.request.args + msg_signature = args.get("msg_signature") + timestamp = args.get("timestamp") + nonce = args.get("nonce") + + if not all([msg_signature, timestamp, nonce]): + logger.error("消息回调参数缺失") + return "缺少必要参数", 400 + + # 类型检查确保不为 None + assert msg_signature is not None + assert timestamp is not None + assert nonce is not None + + logger.debug( + f"收到消息回调,msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}" + ) + + try: + # 获取请求体 + post_data = await quart.request.get_data() + + # 确保 post_data 是 bytes 类型 + if isinstance(post_data, str): + post_data = post_data.encode("utf-8") + + # 解密消息 + ret_code, message_data = await self.api_client.decrypt_message( + post_data, msg_signature, timestamp, nonce + ) + + if ret_code != WecomAIBotConstants.SUCCESS or not message_data: + logger.error("消息解密失败,错误码: %d", ret_code) + return "消息解密失败", 400 + + # 调用消息处理器 + response = None + if self.message_handler: + try: + response = await self.message_handler( + message_data, {"nonce": nonce, "timestamp": timestamp} + ) + except Exception as e: + logger.error("消息处理器执行异常: %s", e) + return "消息处理异常", 500 + + if response: + return response, 200, {"Content-Type": "text/plain"} + else: + return "success", 200, {"Content-Type": "text/plain"} + + except Exception as e: + logger.error("处理消息时发生异常: %s", e) + return "内部服务器错误", 500 + + async def start_server(self): + """启动服务器""" + logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port) + + try: + await self.app.run_task( + host=self.host, + port=self.port, + shutdown_trigger=self.shutdown_trigger, + ) + except Exception as e: + logger.error("服务器运行异常: %s", e) + raise + + async def shutdown_trigger(self): + """关闭触发器""" + await self.shutdown_event.wait() + + async def shutdown(self): + """关闭服务器""" + logger.info("企业微信智能机器人服务器正在关闭...") + self.shutdown_event.set() + + def get_app(self): + """获取 Quart 应用实例""" + return self.app diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py new file mode 100644 index 000000000..dccb2e260 --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py @@ -0,0 +1,199 @@ +""" +企业微信智能机器人工具模块 +提供常量定义、工具函数和辅助方法 +""" + +import string +import random +import hashlib +import base64 +import aiohttp +import asyncio +from Crypto.Cipher import AES +from typing import Any, Tuple +from astrbot.api import logger + + +# 常量定义 +class WecomAIBotConstants: + """企业微信智能机器人常量""" + + # 消息类型 + MSG_TYPE_TEXT = "text" + MSG_TYPE_IMAGE = "image" + MSG_TYPE_MIXED = "mixed" + MSG_TYPE_STREAM = "stream" + MSG_TYPE_EVENT = "event" + + # 流消息状态 + STREAM_CONTINUE = False + STREAM_FINISH = True + + # 错误码 + SUCCESS = 0 + DECRYPT_ERROR = -40001 + VALIDATE_SIGNATURE_ERROR = -40002 + PARSE_XML_ERROR = -40003 + COMPUTE_SIGNATURE_ERROR = -40004 + ILLEGAL_AES_KEY = -40005 + VALIDATE_APPID_ERROR = -40006 + ENCRYPT_AES_ERROR = -40007 + ILLEGAL_BUFFER = -40008 + + +def generate_random_string(length: int = 10) -> str: + """生成随机字符串 + + Args: + length: 字符串长度,默认为 10 + + Returns: + 随机字符串 + """ + letters = string.ascii_letters + string.digits + return "".join(random.choice(letters) for _ in range(length)) + + +def calculate_image_md5(image_data: bytes) -> str: + """计算图片数据的 MD5 值 + + Args: + image_data: 图片二进制数据 + + Returns: + MD5 哈希值(十六进制字符串) + """ + return hashlib.md5(image_data).hexdigest() + + +def encode_image_base64(image_data: bytes) -> str: + """将图片数据编码为 Base64 + + Args: + image_data: 图片二进制数据 + + Returns: + Base64 编码的字符串 + """ + return base64.b64encode(image_data).decode("utf-8") + + +def format_session_id(session_type: str, session_id: str) -> str: + """格式化会话 ID + + Args: + session_type: 会话类型 ("user", "group") + session_id: 原始会话 ID + + Returns: + 格式化后的会话 ID + """ + return f"wecom_ai_bot_{session_type}_{session_id}" + + +def parse_session_id(formatted_session_id: str) -> Tuple[str, str]: + """解析格式化的会话 ID + + Args: + formatted_session_id: 格式化的会话 ID + + Returns: + (会话类型, 原始会话ID) + """ + parts = formatted_session_id.split("_", 3) + if ( + len(parts) >= 4 + and parts[0] == "wecom" + and parts[1] == "ai" + and parts[2] == "bot" + ): + return parts[3], "_".join(parts[4:]) if len(parts) > 4 else "" + return "user", formatted_session_id + + +def safe_json_loads(json_str: str, default: Any = None) -> Any: + """安全地解析 JSON 字符串 + + Args: + json_str: JSON 字符串 + default: 解析失败时的默认值 + + Returns: + 解析结果或默认值 + """ + import json + + try: + return json.loads(json_str) + except (json.JSONDecodeError, TypeError) as e: + logger.warning(f"JSON 解析失败: {e}, 原始字符串: {json_str}") + return default + + +def format_error_response(error_code: int, error_msg: str) -> str: + """格式化错误响应 + + Args: + error_code: 错误码 + error_msg: 错误信息 + + Returns: + 格式化的错误响应字符串 + """ + return f"Error {error_code}: {error_msg}" + + +async def process_encrypted_image( + image_url: str, aes_key_base64: str +) -> Tuple[bool, str]: + """下载并解密加密图片 + + Args: + image_url: 加密图片的URL + aes_key_base64: Base64编码的AES密钥(与回调加解密相同) + + Returns: + Tuple[bool, str]: status 为 True 时 data 是解密后的图片数据的 base64 编码, + status 为 False 时 data 是错误信息 + """ + # 1. 下载加密图片 + logger.info("开始下载加密图片: %s", image_url) + try: + async with aiohttp.ClientSession() as session: + async with session.get(image_url, timeout=15) as response: + response.raise_for_status() + encrypted_data = await response.read() + logger.info("图片下载成功,大小: %d 字节", len(encrypted_data)) + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + error_msg = f"下载图片失败: {str(e)}" + logger.error(error_msg) + return False, error_msg + + # 2. 准备AES密钥和IV + if not aes_key_base64: + raise ValueError("AES密钥不能为空") + + # Base64解码密钥 (自动处理填充) + aes_key = base64.b64decode(aes_key_base64 + "=" * (-len(aes_key_base64) % 4)) + if len(aes_key) != 32: + raise ValueError("无效的AES密钥长度: 应为32字节") + + iv = aes_key[:16] # 初始向量为密钥前16字节 + + # 3. 解密图片数据 + cipher = AES.new(aes_key, AES.MODE_CBC, iv) + decrypted_data = cipher.decrypt(encrypted_data) + + # 4. 去除PKCS#7填充 (Python 3兼容写法) + pad_len = decrypted_data[-1] # 直接获取最后一个字节的整数值 + if pad_len > 32: # AES-256块大小为32字节 + raise ValueError("无效的填充长度 (大于32字节)") + + decrypted_data = decrypted_data[:-pad_len] + logger.info("图片解密成功,解密后大小: %d 字节", len(decrypted_data)) + + # 5. 转换为base64编码 + base64_data = base64.b64encode(decrypted_data).decode("utf-8") + logger.info("图片已转换为base64编码,编码后长度: %d", len(base64_data)) + + return True, base64_data diff --git a/dashboard/src/utils/platformUtils.js b/dashboard/src/utils/platformUtils.js index 2656c56c7..660ed7812 100644 --- a/dashboard/src/utils/platformUtils.js +++ b/dashboard/src/utils/platformUtils.js @@ -10,7 +10,7 @@ export function getPlatformIcon(name) { if (name === 'aiocqhttp' || name === 'qq_official' || name === 'qq_official_webhook') { return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href - } else if (name === 'wecom') { + } else if (name === 'wecom' || name === 'wecom_ai_bot') { return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href } else if (name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') { return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href