From e0fe798e08a23d54957ff140500635a6afe2d7f0 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 14 Oct 2025 11:28:24 +0800 Subject: [PATCH 1/4] stage --- astrbot/core/config/default.py | 9 + .../sources/wecom_ai_bot/WXBizJsonMsgCrypt.py | 289 ++++++++++++++ .../platform/sources/wecom_ai_bot/__init__.py | 17 + .../sources/wecom_ai_bot/demo/demo_server.py | 326 +++++++++++++++ .../platform/sources/wecom_ai_bot/ierror.py | 20 + .../sources/wecom_ai_bot/wecomai_adapter.py | 376 ++++++++++++++++++ .../sources/wecom_ai_bot/wecomai_api.py | 362 +++++++++++++++++ .../sources/wecom_ai_bot/wecomai_event.py | 182 +++++++++ .../sources/wecom_ai_bot/wecomai_queue_mgr.py | 148 +++++++ .../sources/wecom_ai_bot/wecomai_server.py | 172 ++++++++ .../sources/wecom_ai_bot/wecomai_utils.py | 173 ++++++++ pyproject.toml | 1 + 12 files changed, 2075 insertions(+) create mode 100644 astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py create mode 100644 astrbot/core/platform/sources/wecom_ai_bot/__init__.py create mode 100644 astrbot/core/platform/sources/wecom_ai_bot/demo/demo_server.py create mode 100644 astrbot/core/platform/sources/wecom_ai_bot/ierror.py create mode 100644 astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py create mode 100644 astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py create mode 100644 astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py create mode 100644 astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py create mode 100644 astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py create mode 100644 astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 8831c7e14..7e68127ec 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -207,6 +207,15 @@ "callback_server_host": "0.0.0.0", "port": 6195, }, + "企业微信智能机器人": { + "id": "wecom_ai_bot", + "type": "wecom_ai_bot", + "enable": True, + "token": "", + "encoding_aes_key": "", + "callback_server_host": "0.0.0.0", + "port": 6198, + }, "飞书(Lark)": { "id": "lark", "type": "lark", diff --git a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py new file mode 100644 index 000000000..88963b1a0 --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python +# -*- encoding:utf-8 -*- + +"""对企业微信发送给企业后台的消息加解密示例代码. +@copyright: Copyright (c) 1998-2020 Tencent Inc. + +""" +# ------------------------------------------------------------------------ + +import logging +import base64 +import random +import hashlib +import time +import struct +from Crypto.Cipher import AES +import socket +import json + +import ierror + +""" +关于Crypto.Cipher模块,ImportError: No module named 'Crypto'解决方案 +请到官方网站 https://www.dlitz.net/software/pycrypto/ 下载pycrypto。 +下载后,按照README中的“Installation”小节的提示进行pycrypto安装。 +""" + + +class FormatException(Exception): + pass + + +def throw_exception(message, exception_class=FormatException): + """my define raise exception function""" + raise exception_class(message) + + +class SHA1: + """计算企业微信的消息签名接口""" + + def getSHA1(self, token, timestamp, nonce, encrypt): + """用SHA1算法生成安全签名 + @param token: 票据 + @param timestamp: 时间戳 + @param encrypt: 密文 + @param nonce: 随机字符串 + @return: 安全签名 + """ + try: + # 确保所有输入都是字符串类型 + if isinstance(encrypt, bytes): + encrypt = encrypt.decode("utf-8") + + sortlist = [str(token), str(timestamp), str(nonce), str(encrypt)] + sortlist.sort() + sha = hashlib.sha1() + sha.update("".join(sortlist).encode("utf-8")) + return ierror.WXBizMsgCrypt_OK, sha.hexdigest() + + except Exception as e: + print(e) + return ierror.WXBizMsgCrypt_ComputeSignature_Error, None + + +class JsonParse: + """提供提取消息格式中的密文及生成回复消息格式的接口""" + + # json消息模板 + AES_TEXT_RESPONSE_TEMPLATE = """{ + "encrypt": "%(msg_encrypt)s", + "msgsignature": "%(msg_signaturet)s", + "timestamp": "%(timestamp)s", + "nonce": "%(nonce)s" + }""" + + def extract(self, jsontext): + """提取出json数据包中的加密消息 + @param jsontext: 待提取的json字符串 + @return: 提取出的加密消息字符串 + """ + try: + json_dict = json.loads(jsontext) + return ierror.WXBizMsgCrypt_OK, json_dict["encrypt"] + except Exception as e: + print(e) + return ierror.WXBizMsgCrypt_ParseJson_Error, None + + def generate(self, encrypt, signature, timestamp, nonce): + """生成json消息 + @param encrypt: 加密后的消息密文 + @param signature: 安全签名 + @param timestamp: 时间戳 + @param nonce: 随机字符串 + @return: 生成的json字符串 + """ + resp_dict = { + "msg_encrypt": encrypt, + "msg_signaturet": signature, + "timestamp": timestamp, + "nonce": nonce, + } + resp_json = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict + return resp_json + + +class PKCS7Encoder: + """提供基于PKCS7算法的加解密接口""" + + block_size = 32 + + def encode(self, text): + """对需要加密的明文进行填充补位 + @param text: 需要进行填充补位操作的明文(bytes类型) + @return: 补齐明文字符串(bytes类型) + """ + text_length = len(text) + # 计算需要填充的位数 + amount_to_pad = self.block_size - (text_length % self.block_size) + if amount_to_pad == 0: + amount_to_pad = self.block_size + # 获得补位所用的字符 + pad = bytes([amount_to_pad]) + # 确保text是bytes类型 + if isinstance(text, str): + text = text.encode("utf-8") + return text + pad * amount_to_pad + + def decode(self, decrypted): + """删除解密后明文的补位字符 + @param decrypted: 解密后的明文 + @return: 删除补位字符后的明文 + """ + pad = ord(decrypted[-1]) + if pad < 1 or pad > 32: + pad = 0 + return decrypted[:-pad] + + +class Prpcrypt(object): + """提供接收和推送给企业微信消息的加解密接口""" + + def __init__(self, key): + # self.key = base64.b64decode(key+"=") + self.key = key + # 设置加解密模式为AES的CBC模式 + self.mode = AES.MODE_CBC + + def encrypt(self, text, receiveid): + """对明文进行加密 + @param text: 需要加密的明文 + @return: 加密得到的字符串 + """ + # 16位随机字符串添加到明文开头 + text = text.encode() + text = ( + self.get_random_str() + + struct.pack("I", socket.htonl(len(text))) + + text + + receiveid.encode() + ) + + # 使用自定义的填充方式对明文进行补位填充 + pkcs7 = PKCS7Encoder() + text = pkcs7.encode(text) + # 加密 + cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore + try: + ciphertext = cryptor.encrypt(text) + # 使用BASE64对加密后的字符串进行编码 + return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext) + except Exception as e: + logger = logging.getLogger("astrbot") + logger.error(e) + return ierror.WXBizMsgCrypt_EncryptAES_Error, None + + def decrypt(self, text, receiveid): + """对解密后的明文进行补位删除 + @param text: 密文 + @return: 删除填充补位后的明文 + """ + try: + cryptor = AES.new(self.key, self.mode, self.key[:16]) # type: ignore + # 使用BASE64对密文进行解码,然后AES-CBC解密 + plain_text = cryptor.decrypt(base64.b64decode(text)) + except Exception as e: + print(e) + return ierror.WXBizMsgCrypt_DecryptAES_Error, None + try: + pad = plain_text[-1] + # 去掉补位字符串 + # pkcs7 = PKCS7Encoder() + # plain_text = pkcs7.encode(plain_text) + # 去除16位随机字符串 + content = plain_text[16:-pad] + json_len = socket.ntohl(struct.unpack("I", content[:4])[0]) + json_content = content[4 : json_len + 4].decode("utf-8") + from_receiveid = content[json_len + 4 :].decode("utf-8") + except Exception as e: + print(e) + return ierror.WXBizMsgCrypt_IllegalBuffer, None + if from_receiveid != receiveid: + print("receiveid not match", receiveid, from_receiveid) + return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None + return 0, json_content + + def get_random_str(self): + """随机生成16位字符串 + @return: 16位字符串 + """ + return str(random.randint(1000000000000000, 9999999999999999)).encode() + + +class WXBizJsonMsgCrypt(object): + # 构造函数 + def __init__(self, sToken, sEncodingAESKey, sReceiveId): + try: + self.key = base64.b64decode(sEncodingAESKey + "=") + assert len(self.key) == 32 + except Exception as e: + throw_exception(f"[error]: EncodingAESKey invalid: {e}", FormatException) + # return ierror.WXBizMsgCrypt_IllegalAesKey,None + self.m_sToken = sToken + self.m_sReceiveId = sReceiveId + + # 验证URL + # @param sMsgSignature: 签名串,对应URL参数的msg_signature + # @param sTimeStamp: 时间戳,对应URL参数的timestamp + # @param sNonce: 随机串,对应URL参数的nonce + # @param sEchoStr: 随机串,对应URL参数的echostr + # @param sReplyEchoStr: 解密之后的echostr,当return返回0时有效 + # @return:成功0,失败返回对应的错误码 + + def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr): + sha1 = SHA1() + ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr) + if ret != 0: + return ret, None + if not signature == sMsgSignature: + return ierror.WXBizMsgCrypt_ValidateSignature_Error, None + pc = Prpcrypt(self.key) + ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId) + return ret, sReplyEchoStr + + def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None): + # 将企业回复用户的消息加密打包 + # @param sReplyMsg: 企业号待回复用户的消息,json格式的字符串 + # @param sTimeStamp: 时间戳,可以自己生成,也可以用URL参数的timestamp,如为None则自动用当前时间 + # @param sNonce: 随机串,可以自己生成,也可以用URL参数的nonce + # sEncryptMsg: 加密后的可以直接回复用户的密文,包括msg_signature, timestamp, nonce, encrypt的json格式的字符串, + # return:成功0,sEncryptMsg,失败返回对应的错误码None + pc = Prpcrypt(self.key) + ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId) + encrypt = encrypt.decode("utf-8") # type: ignore + if ret != 0: + return ret, None + if timestamp is None: + timestamp = str(int(time.time())) + # 生成安全签名 + sha1 = SHA1() + ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt) + if ret != 0: + return ret, None + jsonParse = JsonParse() + return ret, jsonParse.generate(encrypt, signature, timestamp, sNonce) + + def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce): + # 检验消息的真实性,并且获取解密后的明文 + # @param sMsgSignature: 签名串,对应URL参数的msg_signature + # @param sTimeStamp: 时间戳,对应URL参数的timestamp + # @param sNonce: 随机串,对应URL参数的nonce + # @param sPostData: 密文,对应POST请求的数据 + # json_content: 解密后的原文,当return返回0时有效 + # @return: 成功0,失败返回对应的错误码 + # 验证安全签名 + jsonParse = JsonParse() + ret, encrypt = jsonParse.extract(sPostData) + if ret != 0: + return ret, None + sha1 = SHA1() + ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt) + if ret != 0: + return ret, None + if not signature == sMsgSignature: + print("signature not match") + print(signature) + return ierror.WXBizMsgCrypt_ValidateSignature_Error, None + pc = Prpcrypt(self.key) + ret, json_content = pc.decrypt(encrypt, self.m_sReceiveId) + return ret, json_content diff --git a/astrbot/core/platform/sources/wecom_ai_bot/__init__.py b/astrbot/core/platform/sources/wecom_ai_bot/__init__.py new file mode 100644 index 000000000..7da900030 --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/__init__.py @@ -0,0 +1,17 @@ +""" +企业微信智能机器人平台适配器包 +""" + +from .wecomai_adapter import WecomAIBotAdapter +from .wecomai_api import WecomAIBotAPIClient +from .wecomai_event import WecomAIBotMessageEvent +from .wecomai_server import WecomAIBotServer +from .wecomai_utils import WecomAIBotConstants + +__all__ = [ + "WecomAIBotAdapter", + "WecomAIBotAPIClient", + "WecomAIBotMessageEvent", + "WecomAIBotServer", + "WecomAIBotConstants", +] diff --git a/astrbot/core/platform/sources/wecom_ai_bot/demo/demo_server.py b/astrbot/core/platform/sources/wecom_ai_bot/demo/demo_server.py new file mode 100644 index 000000000..cd394290f --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/demo/demo_server.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python +# coding=utf-8 +# 文档:https://developer.work.weixin.qq.com/document/path/101039 + +from fastapi import FastAPI, Request, HTTPException +from fastapi.responses import Response +import uvicorn +import os +import logging +import json +import random +import string +import time +import base64 +import hashlib +from astrbot.core.platform.sources.wecom_ai_bot.WXBizJsonMsgCrypt import ( + WXBizJsonMsgCrypt, +) +from Crypto.Cipher import AES +import requests + +app = FastAPI() + +# 常量定义 +CACHE_DIR = "/tmp/llm_demo_cache" +MAX_STEPS = 10 + +# 配置日志 +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def _generate_random_string(length): + letters = string.ascii_letters + string.digits + return "".join(random.choice(letters) for _ in range(length)) + + +def _process_encrypted_image(image_url, aes_key_base64): + """ + 下载并解密加密图片 + + 参数: + image_url: 加密图片的URL + aes_key_base64: Base64编码的AES密钥(与回调加解密相同) + + 返回: + tuple: (status: bool, data: bytes/str) + status为True时data是解密后的图片数据, + status为False时data是错误信息 + """ + try: + # 1. 下载加密图片 + logger.info("开始下载加密图片: %s", image_url) + response = requests.get(image_url, timeout=15) + response.raise_for_status() + encrypted_data = response.content + logger.info("图片下载成功,大小: %d 字节", len(encrypted_data)) + + # 2. 准备AES密钥和IV + if not aes_key_base64: + raise ValueError("AES密钥不能为空") + + # Base64解码密钥 (自动处理填充) + aes_key = base64.b64decode(aes_key_base64 + "=" * (-len(aes_key_base64) % 4)) + if len(aes_key) != 32: + raise ValueError("无效的AES密钥长度: 应为32字节") + + iv = aes_key[:16] # 初始向量为密钥前16字节 + + # 3. 解密图片数据 + cipher = AES.new(aes_key, AES.MODE_CBC, iv) + decrypted_data = cipher.decrypt(encrypted_data) + + # 4. 去除PKCS#7填充 (Python 3兼容写法) + pad_len = decrypted_data[-1] # 直接获取最后一个字节的整数值 + if pad_len > 32: # AES-256块大小为32字节 + raise ValueError("无效的填充长度 (大于32字节)") + + decrypted_data = decrypted_data[:-pad_len] + logger.info("图片解密成功,解密后大小: %d 字节", len(decrypted_data)) + + return True, decrypted_data + + except requests.exceptions.RequestException as e: + error_msg = f"图片下载失败 : {str(e)}" + logger.error(error_msg) + return False, error_msg + + except ValueError as e: + error_msg = f"参数错误 : {str(e)}" + logger.error(error_msg) + return False, error_msg + + except Exception as e: + error_msg = f"图片处理异常 : {str(e)}" + logger.error(error_msg) + return False, error_msg + + +def MakeTextStream(stream_id, content, finish): + plain = { + "msgtype": "stream", + "stream": {"id": stream_id, "finish": finish, "content": content}, + } + return json.dumps(plain, ensure_ascii=False) + + +def MakeImageStream(stream_id, image_data, finish): + image_md5 = hashlib.md5(image_data).hexdigest() + image_base64 = base64.b64encode(image_data).decode("utf-8") + + plain = { + "msgtype": "stream", + "stream": { + "id": stream_id, + "finish": finish, + "msg_item": [ + { + "msgtype": "image", + "image": {"base64": image_base64, "md5": image_md5}, + } + ], + }, + } + return json.dumps(plain) + + +def EncryptMessage(receiveid, nonce, timestamp, stream): + logger.info( + "开始加密消息,receiveid=%s, nonce=%s, timestamp=%s", + receiveid, + nonce, + timestamp, + ) + logger.debug("发送流消息: %s", stream) + + wxcpt = WXBizJsonMsgCrypt( + os.getenv("Token", ""), os.getenv("EncodingAESKey", ""), receiveid + ) + ret, resp = wxcpt.EncryptMsg(stream, nonce, timestamp) + if ret != 0: + logger.error("加密失败,错误码: %d", ret) + return + + stream_id = json.loads(stream)["stream"]["id"] + finish = json.loads(stream)["stream"]["finish"] + logger.info( + "回调处理完成, 返回加密的流消息, stream_id=%s, finish=%s", stream_id, finish + ) + logger.debug("加密后的消息: %s", resp) + + return resp + + +# TODO 这里模拟一个大模型的行为 +class LLMDemo: + def __init__(self): + self.cache_dir = CACHE_DIR + if not os.path.exists(self.cache_dir): + os.makedirs(self.cache_dir) + + def invoke(self, question): + stream_id = _generate_random_string(10) # 生成一个随机字符串作为任务ID + # 创建任务缓存文件 + cache_file = os.path.join(self.cache_dir, "%s.json" % stream_id) + with open(cache_file, "w", encoding="utf-8") as f: + json.dump( + { + "question": question, + "created_time": time.time(), + "current_step": 0, + "max_steps": MAX_STEPS, + }, + f, + ) + return stream_id + + def get_answer(self, stream_id): + cache_file = os.path.join(self.cache_dir, "%s.json" % stream_id) + if not os.path.exists(cache_file): + return "任务不存在或已过期" + + with open(cache_file, "r", encoding="utf-8") as f: + task_data = json.load(f) + + # 更新缓存 + current_step = task_data["current_step"] + 1 + task_data["current_step"] = current_step + with open(cache_file, "w", encoding="utf-8") as f: + json.dump(task_data, f) + + response = "收到问题:%s\n" % task_data["question"] + for i in range(current_step): + response += "处理步骤 %d: 已完成\n" % (i) + + return response + + def is_task_finish(self, stream_id): + cache_file = os.path.join(self.cache_dir, "%s.json" % stream_id) + if not os.path.exists(cache_file): + return True + + with open(cache_file, "r", encoding="utf-8") as f: + task_data = json.load(f) + + return task_data["current_step"] >= task_data["max_steps"] + + +@app.get("/ai-bot/callback/demo/{botid}") +async def verify_url( + request: Request, + botid: str, + msg_signature: str, + timestamp: str, + nonce: str, + echostr: str, +): + # 企业创建的自能机器人的 VerifyUrl 请求, receiveid 是空串 + receiveid = "" + wxcpt = WXBizJsonMsgCrypt( + os.getenv("Token", ""), os.getenv("EncodingAESKey", ""), receiveid + ) + + ret, echostr = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr) + + if ret != 0: + echostr = "verify fail" + + return Response(content=echostr, media_type="text/plain") + + +@app.post("/ai-bot/callback/demo/{botid}") +async def handle_message( + request: Request, + botid: str, + msg_signature: str = None, + timestamp: str = None, + nonce: str = None, +): + query_params = dict(request.query_params) + if not all([msg_signature, timestamp, nonce]): + raise HTTPException(status_code=400, detail="缺少必要参数") + logger.info( + "收到消息,botid=%s, msg_signature=%s, timestamp=%s, nonce=%s", + botid, + msg_signature, + timestamp, + nonce, + ) + + post_data = await request.body() + + # 智能机器人的 receiveid 是空串 + receiveid = "" + wxcpt = WXBizJsonMsgCrypt( + os.getenv("Token", ""), os.getenv("EncodingAESKey", ""), receiveid + ) + + ret, msg = wxcpt.DecryptMsg(post_data, msg_signature, timestamp, nonce) + + if ret != 0: + raise HTTPException(status_code=400, detail="解密失败") + + data = json.loads(msg) + logger.debug("Decrypted data: %s", data) + if "msgtype" not in data: + logger.info("不认识的事件: %s", data) + return Response(content="success", media_type="text/plain") + + msgtype = data["msgtype"] + if msgtype == "text": + content = data["text"]["content"] + + # 询问大模型产生回复 + llm = LLMDemo() + stream_id = llm.invoke(content) + answer = llm.get_answer(stream_id) + finish = llm.is_task_finish(stream_id) + + stream = MakeTextStream(stream_id, answer, finish) + resp = EncryptMessage(receiveid, nonce, timestamp, stream) + return Response(content=resp, media_type="text/plain") + elif msgtype == "stream": # case stream + # 询问大模型最新的回复 + stream_id = data["stream"]["id"] + llm = LLMDemo() + answer = llm.get_answer(stream_id) + finish = llm.is_task_finish(stream_id) + + stream = MakeTextStream(stream_id, answer, finish) + resp = EncryptMessage(receiveid, nonce, timestamp, stream) + return Response(content=resp, media_type="text/plain") + elif msgtype == "image": + # 从环境变量获取AES密钥 + aes_key = os.getenv("EncodingAESKey", "") + + # 调用图片处理函数 + success, result = _process_encrypted_image(data["image"]["url"], aes_key) + if not success: + logger.error("图片处理失败: %s", result) + return + + # 这里简单处理直接原图回复 + decrypted_data = result + stream_id = _generate_random_string(10) + finish = True + + stream = MakeImageStream(stream_id, decrypted_data, finish) + resp = EncryptMessage(receiveid, nonce, timestamp, stream) + return Response(content=resp, media_type="text/plain") + elif msgtype == "mixed": + # TODO 处理图文混排消息 + logger.warning("需要支持mixed消息类型") + elif msgtype == "event": + # TODO 一些事件的处理 + logger.warning("需要支持event消息类型: %s", data) + return + else: + logger.warning("不支持的消息类型: %s", msgtype) + return + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=80) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/ierror.py b/astrbot/core/platform/sources/wecom_ai_bot/ierror.py new file mode 100644 index 000000000..cc1bf221e --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/ierror.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +######################################################################### +# Author: jonyqin +# Created Time: Thu 11 Sep 2014 01:53:58 PM CST +# File Name: ierror.py +# Description:定义错误码含义 +######################################################################### +WXBizMsgCrypt_OK = 0 +WXBizMsgCrypt_ValidateSignature_Error = -40001 +WXBizMsgCrypt_ParseJson_Error = -40002 +WXBizMsgCrypt_ComputeSignature_Error = -40003 +WXBizMsgCrypt_IllegalAesKey = -40004 +WXBizMsgCrypt_ValidateCorpid_Error = -40005 +WXBizMsgCrypt_EncryptAES_Error = -40006 +WXBizMsgCrypt_DecryptAES_Error = -40007 +WXBizMsgCrypt_IllegalBuffer = -40008 +WXBizMsgCrypt_EncodeBase64_Error = -40009 +WXBizMsgCrypt_DecodeBase64_Error = -40010 +WXBizMsgCrypt_GenReturnJson_Error = -40011 diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py new file mode 100644 index 000000000..49dc9caeb --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -0,0 +1,376 @@ +""" +企业微信智能机器人平台适配器 +基于企业微信智能机器人 API 的消息平台适配器,支持 HTTP 回调 +参考webchat_adapter.py的队列机制,实现异步消息处理和流式响应 +""" + +import time +import asyncio +import uuid +from typing import Awaitable, Any, Dict, Optional, Callable + + +from astrbot.api.platform import ( + Platform, + AstrBotMessage, + MessageMember, + MessageType, + PlatformMetadata, +) +from astrbot.api.event import MessageChain +from astrbot.api.message_components import Plain +from astrbot.api import logger +from astrbot.core.platform.astr_message_event import MessageSesion +from ...register import register_platform_adapter + +from .wecomai_api import ( + WecomAIBotAPIClient, + WecomAIBotMessageParser, + WecomAIBotStreamMessageBuilder, +) +from .wecomai_event import WecomAIBotMessageEvent +from .wecomai_server import WecomAIBotServer +from .wecomai_queue_mgr import wecomai_queue_mgr, WecomAIQueueMgr +from .wecomai_utils import ( + WecomAIBotConstants, + validate_config, + format_session_id, + generate_random_string, +) + + +class WecomAIQueueListener: + """企业微信智能机器人队列监听器,参考webchat的QueueListener设计""" + + def __init__( + self, queue_mgr: WecomAIQueueMgr, callback: Callable[[tuple], Awaitable[None]] + ) -> None: + self.queue_mgr = queue_mgr + self.callback = callback + self.running_tasks = set() + + async def listen_to_queue(self, session_id: str): + """监听特定会话的队列""" + queue = self.queue_mgr.get_or_create_queue(session_id) + while True: + try: + data = await queue.get() + await self.callback(data) + except Exception as e: + logger.error(f"处理会话 {session_id} 消息时发生错误: {e}") + break + + async def run(self): + """监控新会话队列并启动监听器""" + monitored_sessions = set() + + while True: + # 检查新会话 + current_sessions = set(self.queue_mgr.queues.keys()) + new_sessions = current_sessions - monitored_sessions + + # 为新会话启动监听器 + for session_id in new_sessions: + task = asyncio.create_task(self.listen_to_queue(session_id)) + self.running_tasks.add(task) + task.add_done_callback(self.running_tasks.discard) + monitored_sessions.add(session_id) + logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}") + + # 清理已不存在的会话 + removed_sessions = monitored_sessions - current_sessions + monitored_sessions -= removed_sessions + + # 清理过期的待处理响应 + self.queue_mgr.cleanup_expired_responses() + + await asyncio.sleep(1) # 每秒检查一次新会话 + + +@register_platform_adapter( + "wecom_ai_bot", "企业微信智能机器人适配器,支持 HTTP 回调接收消息" +) +class WecomAIBotAdapter(Platform): + """企业微信智能机器人适配器""" + + def __init__( + self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue + ) -> None: + super().__init__(event_queue) + + self.config = platform_config + self.settings = platform_settings + + # 验证配置 + is_valid, error_msg = validate_config(self.config) + if not is_valid: + raise ValueError(f"配置验证失败: {error_msg}") + + # 初始化配置参数 + self.token = self.config["token"] + self.encoding_aes_key = self.config["encoding_aes_key"] + self.port = int(self.config["port"]) + self.host = self.config.get("callback_server_host", "0.0.0.0") + self.bot_id = self.config.get("bot_id", "default") + + # 平台元数据 + self.metadata = PlatformMetadata( + name="wecom_ai_bot", + description="企业微信智能机器人适配器,支持 HTTP 回调接收消息", + id=self.config.get("id", "wecom_ai_bot"), + ) + + # 初始化 API 客户端 + self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key) + + # 初始化 HTTP 服务器 + self.server = WecomAIBotServer( + host=self.host, + port=self.port, + bot_id=self.bot_id, + api_client=self.api_client, + message_handler=self._process_message, + ) + + # 事件循环和关闭信号 + self.shutdown_event = asyncio.Event() + + # 队列监听器 + self.queue_listener = WecomAIQueueListener( + wecomai_queue_mgr, self._handle_queued_message + ) + + async def _handle_queued_message(self, data: tuple): + """处理队列中的消息,类似webchat的callback""" + try: + abm = await self.convert_message(data) + await self.handle_msg(abm) + except Exception as e: + logger.error(f"处理队列消息时发生异常: {e}") + + async def _process_message( + self, message_data: Dict[str, Any], callback_params: Dict[str, str] + ) -> Optional[str]: + """处理接收到的消息 + + Args: + message_data: 解密后的消息数据 + callback_params: 回调参数 (nonce, timestamp) + + Returns: + 加密后的响应消息,无需响应时返回 None + """ + msgtype = message_data.get("msgtype") + if not msgtype: + logger.warning(f"消息类型未知,忽略: {message_data}") + return None + session_id = self._extract_session_id(message_data) + if msgtype == "text": + # user sent a text message + try: + # create a brand-new unique stream_id for this message session + stream_id = f"{session_id}_{generate_random_string(10)}" + await self._enqueue_message( + message_data, callback_params, stream_id, session_id + ) + wecomai_queue_mgr.set_pending_response(stream_id, callback_params) + + resp = WecomAIBotStreamMessageBuilder.make_text_stream( + stream_id, "思考中...", False + ) + return await self.api_client.encrypt_message( + resp, callback_params["nonce"], callback_params["timestamp"] + ) + except Exception as e: + logger.error("处理消息时发生异常: %s", e) + return None + elif msgtype == "stream": + # wechat server is requesting for updates of a stream + stream_id = message_data["stream"]["id"] + if not wecomai_queue_mgr.has_back_queue(stream_id): + logger.error(f"Cannot find back queue for stream_id: {stream_id}") + return None + queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + if queue.empty(): + logger.debug( + f"No new messages in back queue for stream_id: {stream_id}" + ) + return None + + # aggregate all delta messages in the back queue + aggregated_content = "" + finish = False + while not queue.empty(): + msg = await queue.get() + if msg["type"] == "plain": + aggregated_content += msg["data"] + elif msg["type"] == "end": + finish = True + else: + pass + logger.debug(f"Aggregated content: {aggregated_content}, finish: {finish}") + if aggregated_content: + plain_message = WecomAIBotStreamMessageBuilder.make_text_stream( + stream_id, aggregated_content, finish + ) + encrypted_message = await self.api_client.encrypt_message( + plain_message, + callback_params["nonce"], + callback_params["timestamp"], + ) + if encrypted_message: + logger.debug( + f"Stream message sent successfully, stream_id: {stream_id}" + ) + else: + logger.error("消息加密失败") + return encrypted_message + return None + elif msgtype == "image": + pass + + def _extract_session_id(self, message_data: Dict[str, Any]) -> str: + """从消息数据中提取会话ID""" + user_id = message_data.get("from", {}).get("userid", "default_user") + return format_session_id("wecomai", user_id) + + async def _enqueue_message( + self, + message_data: Dict[str, Any], + callback_params: Dict[str, str], + stream_id: str, + session_id: str, + ): + """将消息放入队列进行异步处理""" + input_queue = wecomai_queue_mgr.get_or_create_queue(stream_id) + _ = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + message_payload = { + "message_data": message_data, + "callback_params": callback_params, + "session_id": session_id, + } + + await input_queue.put(("wecomai", stream_id, message_payload)) + logger.debug(f"[WecomAI] 消息已入队: {stream_id}") + + async def convert_message(self, data: tuple) -> AstrBotMessage: + """转换队列中的消息数据为AstrBotMessage,类似webchat的convert_message""" + platform, stream_id, payload = data + + message_data = payload["message_data"] + session_id = payload["session_id"] + # callback_params = payload["callback_params"] # 保留但暂时不使用 + + # 解析消息内容 + msgtype = message_data.get("msgtype") + content = "" + + if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT: + content = WecomAIBotMessageParser.parse_text_message(message_data) + elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE: + content = "[图片消息]" + elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED: + # 提取混合消息中的文本内容 + msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data) + text_parts = [] + for item in msg_items or []: + if item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_TEXT: + text_content = item.get("text", {}).get("content", "") + if text_content: + text_parts.append(text_content) + content = " ".join(text_parts) if text_parts else "[混合消息]" + else: + content = f"[{msgtype}消息]" + + # 构建AstrBotMessage + abm = AstrBotMessage() + abm.self_id = self.config.get("id", "wecom_ai_bot") + abm.message_str = content or "[未知消息]" + abm.message_id = str(uuid.uuid4()) + abm.timestamp = int(time.time()) + abm.raw_message = data + + # 发送者信息 + abm.sender = MessageMember( + user_id="wecom_user", + nickname="WeChat Work User", + ) + + # 消息类型 + abm.type = MessageType.FRIEND_MESSAGE + abm.session_id = session_id + + # 消息内容 + abm.message = [Plain(text=content or "[未知消息]")] + + logger.debug(f"WecomAIAdapter: {abm.message}") + return abm + + async def send_by_session( + self, session: MessageSesion, message_chain: MessageChain + ): + """通过会话发送消息""" + # 企业微信智能机器人主要通过回调响应,这里记录日志 + logger.info("会话发送消息: %s -> %s", session.session_id, message_chain) + await super().send_by_session(session, message_chain) + + def run(self) -> Awaitable[Any]: + """运行适配器,同时启动HTTP服务器和队列监听器""" + logger.info("启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port) + + async def run_both(): + # 同时运行HTTP服务器和队列监听器 + await asyncio.gather( + self.server.start_server(), + self.queue_listener.run(), + ) + + return run_both() + + async def terminate(self): + """终止适配器""" + logger.info("企业微信智能机器人适配器正在关闭...") + self.shutdown_event.set() + await self.server.shutdown() + + def meta(self) -> PlatformMetadata: + """获取平台元数据""" + return self.metadata + + async def handle_msg(self, message: AstrBotMessage): + """处理消息,创建消息事件并提交到事件队列""" + try: + # 从原始消息中获取回调参数 + if isinstance(message.raw_message, tuple) and len(message.raw_message) == 3: + _, _, payload = message.raw_message + callback_params = payload["callback_params"] + else: + # 如果没有有效的回调参数,使用默认值 + callback_params = { + "nonce": "default", + "timestamp": str(int(time.time())), + } + logger.warning("使用默认回调参数") + + message_event = WecomAIBotMessageEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=self.meta(), + session_id=message.session_id, + api_client=self.api_client, + callback_params=callback_params, + ) + + self.commit_event(message_event) + logger.debug("消息事件已提交: %s", message.message_str) + + except Exception as e: + logger.error("处理消息时发生异常: %s", e) + + def get_client(self) -> WecomAIBotAPIClient: + """获取 API 客户端""" + return self.api_client + + def get_server(self) -> WecomAIBotServer: + """获取 HTTP 服务器实例""" + return self.server diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py new file mode 100644 index 000000000..73898ac4a --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py @@ -0,0 +1,362 @@ +""" +企业微信智能机器人 API 客户端 +处理消息加密解密、API 调用等 +""" + +import json +import logging +import base64 +import hashlib +from typing import Dict, Any, Optional, Tuple, Union +from Crypto.Cipher import AES +import aiohttp + +from .WXBizJsonMsgCrypt import WXBizJsonMsgCrypt +from .wecomai_utils import WecomAIBotConstants + +logger = logging.getLogger(__name__) + + +class WecomAIBotAPIClient: + """企业微信智能机器人 API 客户端""" + + def __init__(self, token: str, encoding_aes_key: str): + """初始化 API 客户端 + + Args: + token: 企业微信机器人 Token + encoding_aes_key: 消息加密密钥 + """ + self.token = token + self.encoding_aes_key = encoding_aes_key + self.wxcpt = WXBizJsonMsgCrypt(token, encoding_aes_key, "") # receiveid 为空串 + + async def decrypt_message( + self, encrypted_data: bytes, msg_signature: str, timestamp: str, nonce: str + ) -> Tuple[int, Optional[Dict[str, Any]]]: + """解密企业微信消息 + + Args: + encrypted_data: 加密的消息数据 + msg_signature: 消息签名 + timestamp: 时间戳 + nonce: 随机数 + + Returns: + (错误码, 解密后的消息数据字典) + """ + try: + ret, decrypted_msg = self.wxcpt.DecryptMsg( + encrypted_data, msg_signature, timestamp, nonce + ) + + if ret != WecomAIBotConstants.SUCCESS: + logger.error(f"消息解密失败,错误码: {ret}") + return ret, None + + # 解析 JSON + if decrypted_msg: + try: + message_data = json.loads(decrypted_msg) + logger.debug("解密成功,消息内容: %s", message_data) + return WecomAIBotConstants.SUCCESS, message_data + except json.JSONDecodeError as e: + logger.error("JSON 解析失败: %s, 原始消息: %s", e, decrypted_msg) + return WecomAIBotConstants.PARSE_XML_ERROR, None + else: + logger.error("解密消息为空") + return WecomAIBotConstants.DECRYPT_ERROR, None + + except Exception as e: + logger.error(f"解密过程发生异常: {e}") + return WecomAIBotConstants.DECRYPT_ERROR, None + + async def encrypt_message( + self, plain_message: str, nonce: str, timestamp: str + ) -> Optional[str]: + """加密消息 + + Args: + plain_message: 明文消息 + nonce: 随机数 + timestamp: 时间戳 + + Returns: + 加密后的消息,失败时返回 None + """ + try: + ret, encrypted_msg = self.wxcpt.EncryptMsg(plain_message, nonce, timestamp) + + if ret != WecomAIBotConstants.SUCCESS: + logger.error(f"消息加密失败,错误码: {ret}") + return None + + logger.debug("消息加密成功") + return encrypted_msg + + except Exception as e: + logger.error(f"加密过程发生异常: {e}") + return None + + def verify_url( + self, msg_signature: str, timestamp: str, nonce: str, echostr: str + ) -> str: + """验证回调 URL + + Args: + msg_signature: 消息签名 + timestamp: 时间戳 + nonce: 随机数 + echostr: 验证字符串 + + Returns: + 验证结果字符串 + """ + try: + ret, echo_result = self.wxcpt.VerifyURL( + msg_signature, timestamp, nonce, echostr + ) + + if ret != WecomAIBotConstants.SUCCESS: + logger.error("URL 验证失败,错误码: %d", ret) + return "verify fail" + + logger.info("URL 验证成功") + return echo_result if echo_result else "verify fail" + + except Exception as e: + logger.error("URL 验证发生异常: %s", e) + return "verify fail" + + async def process_encrypted_image( + self, image_url: str, aes_key_base64: Optional[str] = None + ) -> Tuple[bool, Union[bytes, str]]: + """下载并解密加密图片 + + Args: + image_url: 加密图片的 URL + aes_key_base64: Base64 编码的 AES 密钥,如果为 None 则使用实例的密钥 + + Returns: + (是否成功, 图片数据或错误信息) + """ + try: + # 下载图片 + logger.info(f"开始下载加密图片: {image_url}") + + async with aiohttp.ClientSession() as session: + async with session.get(image_url, timeout=15) as response: + if response.status != 200: + error_msg = f"图片下载失败,状态码: {response.status}" + logger.error(error_msg) + return False, error_msg + + encrypted_data = await response.read() + logger.info(f"图片下载成功,大小: {len(encrypted_data)} 字节") + + # 准备解密密钥 + if aes_key_base64 is None: + aes_key_base64 = self.encoding_aes_key + + if not aes_key_base64: + raise ValueError("AES 密钥不能为空") + + # Base64 解码密钥 + aes_key = base64.b64decode( + aes_key_base64 + "=" * (-len(aes_key_base64) % 4) + ) + if len(aes_key) != 32: + raise ValueError("无效的 AES 密钥长度: 应为 32 字节") + + iv = aes_key[:16] # 初始向量为密钥前 16 字节 + + # 解密图片数据 + cipher = AES.new(aes_key, AES.MODE_CBC, iv) + decrypted_data = cipher.decrypt(encrypted_data) + + # 去除 PKCS#7 填充 + pad_len = decrypted_data[-1] + if pad_len > 32: # AES-256 块大小为 32 字节 + raise ValueError("无效的填充长度 (大于32字节)") + + decrypted_data = decrypted_data[:-pad_len] + logger.info(f"图片解密成功,解密后大小: {len(decrypted_data)} 字节") + + return True, decrypted_data + + except aiohttp.ClientError as e: + error_msg = f"图片下载失败: {str(e)}" + logger.error(error_msg) + return False, error_msg + + except ValueError as e: + error_msg = f"参数错误: {str(e)}" + logger.error(error_msg) + return False, error_msg + + except Exception as e: + error_msg = f"图片处理异常: {str(e)}" + logger.error(error_msg) + return False, error_msg + + +class WecomAIBotStreamMessageBuilder: + """企业微信智能机器人流消息构建器""" + + @staticmethod + def make_text_stream(stream_id: str, content: str, finish: bool = False) -> str: + """构建文本流消息 + + Args: + stream_id: 流 ID + content: 文本内容 + finish: 是否结束 + + Returns: + JSON 格式的流消息字符串 + """ + plain = { + "msgtype": WecomAIBotConstants.MSG_TYPE_STREAM, + "stream": {"id": stream_id, "finish": finish, "content": content}, + } + return json.dumps(plain, ensure_ascii=False) + + @staticmethod + def make_image_stream( + stream_id: str, image_data: bytes, finish: bool = False + ) -> str: + """构建图片流消息 + + Args: + stream_id: 流 ID + image_data: 图片二进制数据 + finish: 是否结束 + + Returns: + JSON 格式的流消息字符串 + """ + image_md5 = hashlib.md5(image_data).hexdigest() + image_base64 = base64.b64encode(image_data).decode("utf-8") + + plain = { + "msgtype": WecomAIBotConstants.MSG_TYPE_STREAM, + "stream": { + "id": stream_id, + "finish": finish, + "msg_item": [ + { + "msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE, + "image": {"base64": image_base64, "md5": image_md5}, + } + ], + }, + } + return json.dumps(plain, ensure_ascii=False) + + @staticmethod + def make_mixed_stream(stream_id: str, msg_items: list, finish: bool = False) -> str: + """构建混合类型流消息 + + Args: + stream_id: 流 ID + msg_items: 消息项列表 + finish: 是否结束 + + Returns: + JSON 格式的流消息字符串 + """ + plain = { + "msgtype": WecomAIBotConstants.MSG_TYPE_STREAM, + "stream": {"id": stream_id, "finish": finish, "msg_item": msg_items}, + } + return json.dumps(plain, ensure_ascii=False) + + +class WecomAIBotMessageParser: + """企业微信智能机器人消息解析器""" + + @staticmethod + def parse_text_message(data: Dict[str, Any]) -> Optional[str]: + """解析文本消息 + + Args: + data: 消息数据 + + Returns: + 文本内容,解析失败返回 None + """ + try: + return data.get("text", {}).get("content") + except (KeyError, TypeError): + logger.warning("文本消息解析失败") + return None + + @staticmethod + def parse_image_message(data: Dict[str, Any]) -> Optional[str]: + """解析图片消息 + + Args: + data: 消息数据 + + Returns: + 图片 URL,解析失败返回 None + """ + try: + return data.get("image", {}).get("url") + except (KeyError, TypeError): + logger.warning("图片消息解析失败") + return None + + @staticmethod + def parse_stream_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """解析流消息 + + Args: + data: 消息数据 + + Returns: + 流消息数据,解析失败返回 None + """ + try: + stream_data = data.get("stream", {}) + return { + "id": stream_data.get("id"), + "finish": stream_data.get("finish"), + "content": stream_data.get("content"), + "msg_item": stream_data.get("msg_item", []), + } + except (KeyError, TypeError): + logger.warning("流消息解析失败") + return None + + @staticmethod + def parse_mixed_message(data: Dict[str, Any]) -> Optional[list]: + """解析混合消息 + + Args: + data: 消息数据 + + Returns: + 消息项列表,解析失败返回 None + """ + try: + return data.get("mixed", {}).get("msg_item", []) + except (KeyError, TypeError): + logger.warning("混合消息解析失败") + return None + + @staticmethod + def parse_event_message(data: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """解析事件消息 + + Args: + data: 消息数据 + + Returns: + 事件数据,解析失败返回 None + """ + try: + return data.get("event", {}) + except (KeyError, TypeError): + logger.warning("事件消息解析失败") + return None diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py new file mode 100644 index 000000000..93126b84d --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -0,0 +1,182 @@ +""" +企业微信智能机器人事件处理模块,处理消息事件的发送和接收 +""" + +import uuid +from typing import AsyncGenerator, Dict, Any, Optional + +from astrbot.api.event import AstrMessageEvent, MessageChain +from astrbot.api.message_components import ( + Image, + Plain, +) +from astrbot.api.platform import Group +from astrbot.api import logger + +from .wecomai_api import WecomAIBotAPIClient +from .wecomai_queue_mgr import wecomai_queue_mgr + + +class WecomAIBotMessageEvent(AstrMessageEvent): + """企业微信智能机器人消息事件""" + + def __init__( + self, + message_str: str, + message_obj, + platform_meta, + session_id: str, + api_client: WecomAIBotAPIClient, + callback_params: Dict[str, str], + ): + """初始化消息事件 + + Args: + message_str: 消息字符串 + message_obj: 消息对象 + platform_meta: 平台元数据 + session_id: 会话 ID + api_client: API 客户端 + """ + super().__init__(message_str, message_obj, platform_meta, session_id) + self.api_client = api_client + + @staticmethod + async def _send( + message_chain: MessageChain, + session_id: str, + streaming: bool = False, + ): + session_key = session_id.split("!")[-1] if "!" in session_id else session_id + back_queue = wecomai_queue_mgr.get_or_create_back_queue(session_key) + + if not message_chain: + await back_queue.put( + { + "type": "end", + "data": "", + "streaming": False, + } + ) + return "" + + data = "" + for comp in message_chain.chain: + if isinstance(comp, Plain): + data = comp.text + await back_queue.put( + { + "type": "plain", + "data": data, + "streaming": streaming, + "session_id": session_id, + } + ) + elif isinstance(comp, Image): + # 处理图片消息 + try: + image_base64 = await comp.convert_to_base64() + if image_base64: + data = f"[IMAGE]{str(uuid.uuid4())}" + await back_queue.put( + { + "type": "image", + "data": data, + "image_data": image_base64, + "streaming": streaming, + "session_id": session_id, + } + ) + else: + logger.warning("图片数据为空,跳过") + except Exception as e: + logger.error("处理图片消息失败: %s", e) + else: + # 其他类型的组件转换为文本 + text_data = str(comp) + data += text_data + await back_queue.put( + { + "type": "component", + "data": text_data, + "streaming": streaming, + "session_id": session_id, + } + ) + + return data + + async def send(self, message: MessageChain): + """发送消息""" + await WecomAIBotMessageEvent._send(message, self.session_id) + await super().send(message) + + async def send_streaming( + self, generator: AsyncGenerator, use_fallback: bool = False + ): + """流式发送消息,参考webchat的send_streaming设计""" + final_data = "" + session_key = ( + self.session_id.split("!")[-1] + if "!" in self.session_id + else self.session_id + ) + back_queue = wecomai_queue_mgr.get_or_create_back_queue(session_key) + + async for chain in generator: + if chain.type == "break" and final_data: + # 分割符 + await back_queue.put( + { + "type": "break", # break means a segment end + "data": final_data, + "streaming": True, + "session_id": self.session_id, + } + ) + final_data = "" + continue + + final_data += await WecomAIBotMessageEvent._send( + chain, + session_id=self.session_id, + streaming=True, + ) + + await back_queue.put( + { + "type": "complete", # complete means we return the final result + "data": final_data, + "streaming": True, + "session_id": self.session_id, + } + ) + await super().send_streaming(generator, use_fallback) + + async def get_group(self, group_id=None, **kwargs) -> Optional[Group]: + """获取群组信息""" + return None + + def get_sender_id(self) -> str: + """获取发送者 ID""" + return getattr(self.message_obj, "sender", {}).get("user_id", "unknown") + + def get_sender_name(self) -> str: + """获取发送者名称""" + return getattr(self.message_obj, "sender", {}).get("nickname", "Unknown") + + def get_group_id(self) -> Optional[str]: + """获取群组 ID""" + return None + + def is_private_message(self) -> bool: + """是否为私聊消息""" + return True + + def is_group_message(self) -> bool: + """是否为群消息""" + return False + + def get_raw_message(self) -> Any: + """获取原始消息数据""" + return getattr(self.message_obj, "raw_message", {}) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py new file mode 100644 index 000000000..1367301c9 --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py @@ -0,0 +1,148 @@ +""" +企业微信智能机器人队列管理器 +参考 webchat_queue_mgr.py,为企业微信智能机器人实现队列机制 +支持异步消息处理和流式响应 +""" + +import asyncio +from typing import Dict, Any, Optional +from astrbot.api import logger + + +class WecomAIQueueMgr: + """企业微信智能机器人队列管理器""" + + def __init__(self) -> None: + self.queues: Dict[str, asyncio.Queue] = {} + """StreamID 到输入队列的映射 - 用于接收用户消息""" + + self.back_queues: Dict[str, asyncio.Queue] = {} + """StreamID 到输出队列的映射 - 用于发送机器人响应""" + + self.pending_responses: Dict[str, Dict[str, Any]] = {} + """待处理的响应缓存,用于流式响应""" + + def get_or_create_queue(self, session_id: str) -> asyncio.Queue: + """获取或创建指定会话的输入队列 + + Args: + session_id: 会话ID + + Returns: + 输入队列实例 + """ + if session_id not in self.queues: + self.queues[session_id] = asyncio.Queue() + logger.debug(f"[WecomAI] 创建输入队列: {session_id}") + return self.queues[session_id] + + def get_or_create_back_queue(self, session_id: str) -> asyncio.Queue: + """获取或创建指定会话的输出队列 + + Args: + session_id: 会话ID + + Returns: + 输出队列实例 + """ + if session_id not in self.back_queues: + self.back_queues[session_id] = asyncio.Queue() + logger.debug(f"[WecomAI] 创建输出队列: {session_id}") + return self.back_queues[session_id] + + def remove_queues(self, session_id: str): + """移除指定会话的所有队列 + + Args: + session_id: 会话ID + """ + if session_id in self.queues: + del self.queues[session_id] + logger.debug(f"[WecomAI] 移除输入队列: {session_id}") + + if session_id in self.back_queues: + del self.back_queues[session_id] + logger.debug(f"[WecomAI] 移除输出队列: {session_id}") + + if session_id in self.pending_responses: + del self.pending_responses[session_id] + logger.debug(f"[WecomAI] 移除待处理响应: {session_id}") + + def has_queue(self, session_id: str) -> bool: + """检查是否存在指定会话的队列 + + Args: + session_id: 会话ID + + Returns: + 是否存在队列 + """ + return session_id in self.queues + + def has_back_queue(self, session_id: str) -> bool: + """检查是否存在指定会话的输出队列 + + Args: + session_id: 会话ID + + Returns: + 是否存在输出队列 + """ + return session_id in self.back_queues + + def set_pending_response(self, session_id: str, callback_params: Dict[str, str]): + """设置待处理的响应参数 + + Args: + session_id: 会话ID + callback_params: 回调参数(nonce, timestamp等) + """ + self.pending_responses[session_id] = { + "callback_params": callback_params, + "timestamp": asyncio.get_event_loop().time(), + } + logger.debug(f"[WecomAI] 设置待处理响应: {session_id}") + + def get_pending_response(self, session_id: str) -> Optional[Dict[str, Any]]: + """获取待处理的响应参数 + + Args: + session_id: 会话ID + + Returns: + 响应参数,如果不存在则返回None + """ + return self.pending_responses.get(session_id) + + def cleanup_expired_responses(self, max_age_seconds: int = 300): + """清理过期的待处理响应 + + Args: + max_age_seconds: 最大存活时间(秒) + """ + current_time = asyncio.get_event_loop().time() + expired_sessions = [] + + for session_id, response_data in self.pending_responses.items(): + if current_time - response_data["timestamp"] > max_age_seconds: + expired_sessions.append(session_id) + + for session_id in expired_sessions: + del self.pending_responses[session_id] + logger.debug(f"[WecomAI] 清理过期响应: {session_id}") + + def get_stats(self) -> Dict[str, int]: + """获取队列统计信息 + + Returns: + 统计信息字典 + """ + return { + "input_queues": len(self.queues), + "output_queues": len(self.back_queues), + "pending_responses": len(self.pending_responses), + } + + +# 全局队列管理器实例 +wecomai_queue_mgr = WecomAIQueueMgr() diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py new file mode 100644 index 000000000..1ea151e0f --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -0,0 +1,172 @@ +""" +企业微信智能机器人 HTTP 服务器 +处理企业微信智能机器人的 HTTP 回调请求 +""" + +import asyncio +from typing import Dict, Any, Optional, Callable + +import quart +from astrbot.api import logger + +from .wecomai_api import WecomAIBotAPIClient +from .wecomai_utils import WecomAIBotConstants + + +class WecomAIBotServer: + """企业微信智能机器人 HTTP 服务器""" + + def __init__( + self, + host: str, + port: int, + bot_id: str, + api_client: WecomAIBotAPIClient, + message_handler: Optional[ + Callable[[Dict[str, Any], Dict[str, str]], Any] + ] = None, + ): + """初始化服务器 + + Args: + host: 监听地址 + port: 监听端口 + bot_id: 机器人ID + api_client: API客户端实例 + message_handler: 消息处理回调函数 + """ + self.host = host + self.port = port + self.bot_id = bot_id + self.api_client = api_client + self.message_handler = message_handler + + self.app = quart.Quart(__name__) + self._setup_routes() + + self.shutdown_event = asyncio.Event() + + def _setup_routes(self): + """设置 Quart 路由""" + + # 使用 Quart 的 add_url_rule 方法添加路由 + self.app.add_url_rule( + f"/ai-bot/callback/{self.bot_id}", + view_func=self.verify_url, + methods=["GET"], + ) + + self.app.add_url_rule( + f"/ai-bot/callback/{self.bot_id}", + view_func=self.handle_message, + methods=["POST"], + ) + + async def verify_url(self): + """验证回调 URL""" + args = quart.request.args + msg_signature = args.get("msg_signature") + timestamp = args.get("timestamp") + nonce = args.get("nonce") + echostr = args.get("echostr") + + if not all([msg_signature, timestamp, nonce, echostr]): + logger.error("URL 验证参数缺失") + return "verify fail", 400 + + # 类型检查确保不为 None + assert msg_signature is not None + assert timestamp is not None + assert nonce is not None + assert echostr is not None + + logger.info("收到 URL 验证请求") + result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr) + return result, 200, {"Content-Type": "text/plain"} + + async def handle_message(self): + """处理消息回调""" + args = quart.request.args + msg_signature = args.get("msg_signature") + timestamp = args.get("timestamp") + nonce = args.get("nonce") + + if not all([msg_signature, timestamp, nonce]): + logger.error("消息回调参数缺失") + return "缺少必要参数", 400 + + # 类型检查确保不为 None + assert msg_signature is not None + assert timestamp is not None + assert nonce is not None + + logger.info( + "收到消息回调,msg_signature=%s, timestamp=%s, nonce=%s", + msg_signature, + timestamp, + nonce, + ) + + try: + # 获取请求体 + post_data = await quart.request.get_data() + + # 确保 post_data 是 bytes 类型 + if isinstance(post_data, str): + post_data = post_data.encode("utf-8") + + # 解密消息 + ret_code, message_data = await self.api_client.decrypt_message( + post_data, msg_signature, timestamp, nonce + ) + + if ret_code != WecomAIBotConstants.SUCCESS or not message_data: + logger.error("消息解密失败,错误码: %d", ret_code) + return "消息解密失败", 400 + + # 调用消息处理器 + response = None + if self.message_handler: + try: + response = await self.message_handler( + message_data, {"nonce": nonce, "timestamp": timestamp} + ) + except Exception as e: + logger.error("消息处理器执行异常: %s", e) + return "消息处理异常", 500 + + if response: + return response, 200, {"Content-Type": "text/plain"} + else: + return "success", 200, {"Content-Type": "text/plain"} + + except Exception as e: + logger.error("处理消息时发生异常: %s", e) + return "内部服务器错误", 500 + + async def start_server(self): + """启动服务器""" + logger.info("启动企业微信智能机器人服务器,监听 %s:%d", self.host, self.port) + + try: + await self.app.run_task( + host=self.host, + port=self.port, + shutdown_trigger=self.shutdown_trigger, + ) + except Exception as e: + logger.error("服务器运行异常: %s", e) + raise + + async def shutdown_trigger(self): + """关闭触发器""" + await self.shutdown_event.wait() + + async def shutdown(self): + """关闭服务器""" + logger.info("企业微信智能机器人服务器正在关闭...") + self.shutdown_event.set() + + def get_app(self): + """获取 Quart 应用实例""" + return self.app diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py new file mode 100644 index 000000000..953424560 --- /dev/null +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py @@ -0,0 +1,173 @@ +""" +企业微信智能机器人工具模块 +提供常量定义、工具函数和辅助方法 +""" + +import logging +import string +import random +import hashlib +import base64 +from typing import Dict, Any, Tuple + +logger = logging.getLogger(__name__) + + +# 常量定义 +class WecomAIBotConstants: + """企业微信智能机器人常量""" + + # 消息类型 + MSG_TYPE_TEXT = "text" + MSG_TYPE_IMAGE = "image" + MSG_TYPE_MIXED = "mixed" + MSG_TYPE_STREAM = "stream" + MSG_TYPE_EVENT = "event" + + # 流消息状态 + STREAM_CONTINUE = False + STREAM_FINISH = True + + # 错误码 + SUCCESS = 0 + DECRYPT_ERROR = -40001 + VALIDATE_SIGNATURE_ERROR = -40002 + PARSE_XML_ERROR = -40003 + COMPUTE_SIGNATURE_ERROR = -40004 + ILLEGAL_AES_KEY = -40005 + VALIDATE_APPID_ERROR = -40006 + ENCRYPT_AES_ERROR = -40007 + ILLEGAL_BUFFER = -40008 + + +def generate_random_string(length: int = 10) -> str: + """生成随机字符串 + + Args: + length: 字符串长度,默认为 10 + + Returns: + 随机字符串 + """ + letters = string.ascii_letters + string.digits + return "".join(random.choice(letters) for _ in range(length)) + + +def calculate_image_md5(image_data: bytes) -> str: + """计算图片数据的 MD5 值 + + Args: + image_data: 图片二进制数据 + + Returns: + MD5 哈希值(十六进制字符串) + """ + return hashlib.md5(image_data).hexdigest() + + +def encode_image_base64(image_data: bytes) -> str: + """将图片数据编码为 Base64 + + Args: + image_data: 图片二进制数据 + + Returns: + Base64 编码的字符串 + """ + return base64.b64encode(image_data).decode("utf-8") + + +def validate_config(config: Dict[str, Any]) -> Tuple[bool, str]: + """验证配置参数 + + Args: + config: 配置字典 + + Returns: + (是否有效, 错误信息) + """ + required_fields = ["token", "encoding_aes_key", "callback_url", "port"] + + for field in required_fields: + if not config.get(field): + return False, f"缺少必要配置项: {field}" + + # 验证端口号 + try: + port = int(config.get("port", 0)) + if port <= 0 or port > 65535: + return False, "端口号必须在 1-65535 范围内" + except (ValueError, TypeError): + return False, "端口号必须是有效的数字" + + # 验证 AES 密钥长度 + encoding_aes_key = config.get("encoding_aes_key", "") + if len(encoding_aes_key) != 43: + return False, "EncodingAESKey 长度必须为 43 位" + + return True, "" + + +def format_session_id(session_type: str, session_id: str) -> str: + """格式化会话 ID + + Args: + session_type: 会话类型 ("user", "group") + session_id: 原始会话 ID + + Returns: + 格式化后的会话 ID + """ + return f"wecom_ai_bot_{session_type}_{session_id}" + + +def parse_session_id(formatted_session_id: str) -> Tuple[str, str]: + """解析格式化的会话 ID + + Args: + formatted_session_id: 格式化的会话 ID + + Returns: + (会话类型, 原始会话ID) + """ + parts = formatted_session_id.split("_", 3) + if ( + len(parts) >= 4 + and parts[0] == "wecom" + and parts[1] == "ai" + and parts[2] == "bot" + ): + return parts[3], "_".join(parts[4:]) if len(parts) > 4 else "" + return "user", formatted_session_id + + +def safe_json_loads(json_str: str, default: Any = None) -> Any: + """安全地解析 JSON 字符串 + + Args: + json_str: JSON 字符串 + default: 解析失败时的默认值 + + Returns: + 解析结果或默认值 + """ + import json + + try: + return json.loads(json_str) + except (json.JSONDecodeError, TypeError) as e: + logger.warning(f"JSON 解析失败: {e}, 原始字符串: {json_str}") + return default + + +def format_error_response(error_code: int, error_msg: str) -> str: + """格式化错误响应 + + Args: + error_code: 错误码 + error_msg: 错误信息 + + Returns: + 格式化的错误响应字符串 + """ + return f"Error {error_code}: {error_msg}" diff --git a/pyproject.toml b/pyproject.toml index 500a6e698..36320141f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "wechatpy>=1.8.18", "audioop-lts ; python_full_version >= '3.13'", "click>=8.2.1", + "fastapi>=0.119.0", ] [project.scripts] From 2222d94c7f853650b3356a837d40aee612d85073 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 14 Oct 2025 14:39:36 +0800 Subject: [PATCH 2/4] stage --- astrbot/core/config/default.py | 20 +- astrbot/core/pipeline/scheduler.py | 2 +- astrbot/core/platform/manager.py | 4 + .../sources/webchat/webchat_adapter.py | 1 - .../sources/wecom_ai_bot/WXBizJsonMsgCrypt.py | 2 +- .../sources/wecom_ai_bot/demo/demo_server.py | 326 ------------------ .../sources/wecom_ai_bot/wecomai_adapter.py | 112 +++--- .../sources/wecom_ai_bot/wecomai_api.py | 25 +- .../sources/wecom_ai_bot/wecomai_event.py | 89 ++--- .../sources/wecom_ai_bot/wecomai_server.py | 16 +- .../sources/wecom_ai_bot/wecomai_utils.py | 37 +- pyproject.toml | 1 - 12 files changed, 146 insertions(+), 489 deletions(-) delete mode 100644 astrbot/core/platform/sources/wecom_ai_bot/demo/demo_server.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 7e68127ec..88b0d80ef 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -211,6 +211,9 @@ "id": "wecom_ai_bot", "type": "wecom_ai_bot", "enable": True, + "wecomaibot_init_respond_text": "💭 思考中...", + "wecomaibot_friend_message_welcome_text": "", + "wecom_ai_bot_name": "", "token": "", "encoding_aes_key": "", "callback_server_host": "0.0.0.0", @@ -456,10 +459,25 @@ "type": "string", "hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。", }, + "wecom_ai_bot_name": { + "description": "企业微信智能机器人的名字", + "type": "string", + "hint": "请务必填写正确,否则无法使用一些指令。", + }, + "wecomaibot_init_respond_text": { + "description": "企业微信智能机器人初始响应文本", + "type": "string", + "hint": "当机器人收到消息时,首先回复的文本内容。留空则使用默认值。", + }, + "wecomaibot_friend_message_welcome_text": { + "description": "企业微信智能机器人私聊欢迎语", + "type": "string", + "hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。", + }, "lark_bot_name": { "description": "飞书机器人的名字", "type": "string", - "hint": "请务必填对,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。", + "hint": "请务必填写正确,否则 @ 机器人将无法唤醒,只能通过前缀唤醒。", }, "discord_token": { "description": "Discord Bot Token", diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index f1c3988a6..7a38ec03f 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -74,7 +74,7 @@ async def execute(self, event: AstrMessageEvent): await self._process_stages(event) # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 - if event.get_platform_name() == "webchat": + if event.get_platform_name() in ["webchat", "wecom_ai_bot"]: await event.send(None) logger.debug("pipeline 执行完毕。") diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index f0d7c2e4a..7090c669c 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -82,6 +82,10 @@ async def load_platform(self, platform_config: dict): from .sources.wecom.wecom_adapter import ( WecomPlatformAdapter, # noqa: F401 ) + case "wecom_ai_bot": + from .sources.wecom_ai_bot.wecomai_adapter import ( + WecomAIBotAdapter, # noqa: F401 + ) case "weixin_official_account": from .sources.weixin_official_account.weixin_offacc_adapter import ( WeixinOfficialAccountPlatformAdapter, # noqa: F401 diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 43da100f4..faec122ac 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -91,7 +91,6 @@ async def convert_message(self, data: tuple) -> AstrBotMessage: abm = AstrBotMessage() abm.self_id = "webchat" - abm.tag = "webchat" abm.sender = MessageMember(username, username) abm.type = MessageType.FRIEND_MESSAGE diff --git a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py index 88963b1a0..5332942b9 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py @@ -17,7 +17,7 @@ import socket import json -import ierror +from . import ierror """ 关于Crypto.Cipher模块,ImportError: No module named 'Crypto'解决方案 diff --git a/astrbot/core/platform/sources/wecom_ai_bot/demo/demo_server.py b/astrbot/core/platform/sources/wecom_ai_bot/demo/demo_server.py deleted file mode 100644 index cd394290f..000000000 --- a/astrbot/core/platform/sources/wecom_ai_bot/demo/demo_server.py +++ /dev/null @@ -1,326 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 -# 文档:https://developer.work.weixin.qq.com/document/path/101039 - -from fastapi import FastAPI, Request, HTTPException -from fastapi.responses import Response -import uvicorn -import os -import logging -import json -import random -import string -import time -import base64 -import hashlib -from astrbot.core.platform.sources.wecom_ai_bot.WXBizJsonMsgCrypt import ( - WXBizJsonMsgCrypt, -) -from Crypto.Cipher import AES -import requests - -app = FastAPI() - -# 常量定义 -CACHE_DIR = "/tmp/llm_demo_cache" -MAX_STEPS = 10 - -# 配置日志 -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - - -def _generate_random_string(length): - letters = string.ascii_letters + string.digits - return "".join(random.choice(letters) for _ in range(length)) - - -def _process_encrypted_image(image_url, aes_key_base64): - """ - 下载并解密加密图片 - - 参数: - image_url: 加密图片的URL - aes_key_base64: Base64编码的AES密钥(与回调加解密相同) - - 返回: - tuple: (status: bool, data: bytes/str) - status为True时data是解密后的图片数据, - status为False时data是错误信息 - """ - try: - # 1. 下载加密图片 - logger.info("开始下载加密图片: %s", image_url) - response = requests.get(image_url, timeout=15) - response.raise_for_status() - encrypted_data = response.content - logger.info("图片下载成功,大小: %d 字节", len(encrypted_data)) - - # 2. 准备AES密钥和IV - if not aes_key_base64: - raise ValueError("AES密钥不能为空") - - # Base64解码密钥 (自动处理填充) - aes_key = base64.b64decode(aes_key_base64 + "=" * (-len(aes_key_base64) % 4)) - if len(aes_key) != 32: - raise ValueError("无效的AES密钥长度: 应为32字节") - - iv = aes_key[:16] # 初始向量为密钥前16字节 - - # 3. 解密图片数据 - cipher = AES.new(aes_key, AES.MODE_CBC, iv) - decrypted_data = cipher.decrypt(encrypted_data) - - # 4. 去除PKCS#7填充 (Python 3兼容写法) - pad_len = decrypted_data[-1] # 直接获取最后一个字节的整数值 - if pad_len > 32: # AES-256块大小为32字节 - raise ValueError("无效的填充长度 (大于32字节)") - - decrypted_data = decrypted_data[:-pad_len] - logger.info("图片解密成功,解密后大小: %d 字节", len(decrypted_data)) - - return True, decrypted_data - - except requests.exceptions.RequestException as e: - error_msg = f"图片下载失败 : {str(e)}" - logger.error(error_msg) - return False, error_msg - - except ValueError as e: - error_msg = f"参数错误 : {str(e)}" - logger.error(error_msg) - return False, error_msg - - except Exception as e: - error_msg = f"图片处理异常 : {str(e)}" - logger.error(error_msg) - return False, error_msg - - -def MakeTextStream(stream_id, content, finish): - plain = { - "msgtype": "stream", - "stream": {"id": stream_id, "finish": finish, "content": content}, - } - return json.dumps(plain, ensure_ascii=False) - - -def MakeImageStream(stream_id, image_data, finish): - image_md5 = hashlib.md5(image_data).hexdigest() - image_base64 = base64.b64encode(image_data).decode("utf-8") - - plain = { - "msgtype": "stream", - "stream": { - "id": stream_id, - "finish": finish, - "msg_item": [ - { - "msgtype": "image", - "image": {"base64": image_base64, "md5": image_md5}, - } - ], - }, - } - return json.dumps(plain) - - -def EncryptMessage(receiveid, nonce, timestamp, stream): - logger.info( - "开始加密消息,receiveid=%s, nonce=%s, timestamp=%s", - receiveid, - nonce, - timestamp, - ) - logger.debug("发送流消息: %s", stream) - - wxcpt = WXBizJsonMsgCrypt( - os.getenv("Token", ""), os.getenv("EncodingAESKey", ""), receiveid - ) - ret, resp = wxcpt.EncryptMsg(stream, nonce, timestamp) - if ret != 0: - logger.error("加密失败,错误码: %d", ret) - return - - stream_id = json.loads(stream)["stream"]["id"] - finish = json.loads(stream)["stream"]["finish"] - logger.info( - "回调处理完成, 返回加密的流消息, stream_id=%s, finish=%s", stream_id, finish - ) - logger.debug("加密后的消息: %s", resp) - - return resp - - -# TODO 这里模拟一个大模型的行为 -class LLMDemo: - def __init__(self): - self.cache_dir = CACHE_DIR - if not os.path.exists(self.cache_dir): - os.makedirs(self.cache_dir) - - def invoke(self, question): - stream_id = _generate_random_string(10) # 生成一个随机字符串作为任务ID - # 创建任务缓存文件 - cache_file = os.path.join(self.cache_dir, "%s.json" % stream_id) - with open(cache_file, "w", encoding="utf-8") as f: - json.dump( - { - "question": question, - "created_time": time.time(), - "current_step": 0, - "max_steps": MAX_STEPS, - }, - f, - ) - return stream_id - - def get_answer(self, stream_id): - cache_file = os.path.join(self.cache_dir, "%s.json" % stream_id) - if not os.path.exists(cache_file): - return "任务不存在或已过期" - - with open(cache_file, "r", encoding="utf-8") as f: - task_data = json.load(f) - - # 更新缓存 - current_step = task_data["current_step"] + 1 - task_data["current_step"] = current_step - with open(cache_file, "w", encoding="utf-8") as f: - json.dump(task_data, f) - - response = "收到问题:%s\n" % task_data["question"] - for i in range(current_step): - response += "处理步骤 %d: 已完成\n" % (i) - - return response - - def is_task_finish(self, stream_id): - cache_file = os.path.join(self.cache_dir, "%s.json" % stream_id) - if not os.path.exists(cache_file): - return True - - with open(cache_file, "r", encoding="utf-8") as f: - task_data = json.load(f) - - return task_data["current_step"] >= task_data["max_steps"] - - -@app.get("/ai-bot/callback/demo/{botid}") -async def verify_url( - request: Request, - botid: str, - msg_signature: str, - timestamp: str, - nonce: str, - echostr: str, -): - # 企业创建的自能机器人的 VerifyUrl 请求, receiveid 是空串 - receiveid = "" - wxcpt = WXBizJsonMsgCrypt( - os.getenv("Token", ""), os.getenv("EncodingAESKey", ""), receiveid - ) - - ret, echostr = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr) - - if ret != 0: - echostr = "verify fail" - - return Response(content=echostr, media_type="text/plain") - - -@app.post("/ai-bot/callback/demo/{botid}") -async def handle_message( - request: Request, - botid: str, - msg_signature: str = None, - timestamp: str = None, - nonce: str = None, -): - query_params = dict(request.query_params) - if not all([msg_signature, timestamp, nonce]): - raise HTTPException(status_code=400, detail="缺少必要参数") - logger.info( - "收到消息,botid=%s, msg_signature=%s, timestamp=%s, nonce=%s", - botid, - msg_signature, - timestamp, - nonce, - ) - - post_data = await request.body() - - # 智能机器人的 receiveid 是空串 - receiveid = "" - wxcpt = WXBizJsonMsgCrypt( - os.getenv("Token", ""), os.getenv("EncodingAESKey", ""), receiveid - ) - - ret, msg = wxcpt.DecryptMsg(post_data, msg_signature, timestamp, nonce) - - if ret != 0: - raise HTTPException(status_code=400, detail="解密失败") - - data = json.loads(msg) - logger.debug("Decrypted data: %s", data) - if "msgtype" not in data: - logger.info("不认识的事件: %s", data) - return Response(content="success", media_type="text/plain") - - msgtype = data["msgtype"] - if msgtype == "text": - content = data["text"]["content"] - - # 询问大模型产生回复 - llm = LLMDemo() - stream_id = llm.invoke(content) - answer = llm.get_answer(stream_id) - finish = llm.is_task_finish(stream_id) - - stream = MakeTextStream(stream_id, answer, finish) - resp = EncryptMessage(receiveid, nonce, timestamp, stream) - return Response(content=resp, media_type="text/plain") - elif msgtype == "stream": # case stream - # 询问大模型最新的回复 - stream_id = data["stream"]["id"] - llm = LLMDemo() - answer = llm.get_answer(stream_id) - finish = llm.is_task_finish(stream_id) - - stream = MakeTextStream(stream_id, answer, finish) - resp = EncryptMessage(receiveid, nonce, timestamp, stream) - return Response(content=resp, media_type="text/plain") - elif msgtype == "image": - # 从环境变量获取AES密钥 - aes_key = os.getenv("EncodingAESKey", "") - - # 调用图片处理函数 - success, result = _process_encrypted_image(data["image"]["url"], aes_key) - if not success: - logger.error("图片处理失败: %s", result) - return - - # 这里简单处理直接原图回复 - decrypted_data = result - stream_id = _generate_random_string(10) - finish = True - - stream = MakeImageStream(stream_id, decrypted_data, finish) - resp = EncryptMessage(receiveid, nonce, timestamp, stream) - return Response(content=resp, media_type="text/plain") - elif msgtype == "mixed": - # TODO 处理图文混排消息 - logger.warning("需要支持mixed消息类型") - elif msgtype == "event": - # TODO 一些事件的处理 - logger.warning("需要支持event消息类型: %s", data) - return - else: - logger.warning("不支持的消息类型: %s", msgtype) - return - - -if __name__ == "__main__": - uvicorn.run(app, host="0.0.0.0", port=80) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 49dc9caeb..fa6a4f14e 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -18,7 +18,7 @@ PlatformMetadata, ) from astrbot.api.event import MessageChain -from astrbot.api.message_components import Plain +from astrbot.api.message_components import Plain, At from astrbot.api import logger from astrbot.core.platform.astr_message_event import MessageSesion from ...register import register_platform_adapter @@ -33,7 +33,6 @@ from .wecomai_queue_mgr import wecomai_queue_mgr, WecomAIQueueMgr from .wecomai_utils import ( WecomAIBotConstants, - validate_config, format_session_id, generate_random_string, ) @@ -43,7 +42,7 @@ class WecomAIQueueListener: """企业微信智能机器人队列监听器,参考webchat的QueueListener设计""" def __init__( - self, queue_mgr: WecomAIQueueMgr, callback: Callable[[tuple], Awaitable[None]] + self, queue_mgr: WecomAIQueueMgr, callback: Callable[[dict], Awaitable[None]] ) -> None: self.queue_mgr = queue_mgr self.callback = callback @@ -101,17 +100,18 @@ def __init__( self.config = platform_config self.settings = platform_settings - # 验证配置 - is_valid, error_msg = validate_config(self.config) - if not is_valid: - raise ValueError(f"配置验证失败: {error_msg}") - # 初始化配置参数 self.token = self.config["token"] self.encoding_aes_key = self.config["encoding_aes_key"] self.port = int(self.config["port"]) self.host = self.config.get("callback_server_host", "0.0.0.0") - self.bot_id = self.config.get("bot_id", "default") + self.bot_name = self.config.get("wecom_ai_bot_name", "") + self.initial_respond_text = self.config.get( + "wecomaibot_init_respond_text", "思考中..." + ) + self.friend_message_welcome_text = self.config.get( + "wecomaibot_friend_message_welcome_text", "" + ) # 平台元数据 self.metadata = PlatformMetadata( @@ -127,7 +127,6 @@ def __init__( self.server = WecomAIBotServer( host=self.host, port=self.port, - bot_id=self.bot_id, api_client=self.api_client, message_handler=self._process_message, ) @@ -140,7 +139,7 @@ def __init__( wecomai_queue_mgr, self._handle_queued_message ) - async def _handle_queued_message(self, data: tuple): + async def _handle_queued_message(self, data: dict): """处理队列中的消息,类似webchat的callback""" try: abm = await self.convert_message(data) @@ -176,7 +175,7 @@ async def _process_message( wecomai_queue_mgr.set_pending_response(stream_id, callback_params) resp = WecomAIBotStreamMessageBuilder.make_text_stream( - stream_id, "思考中...", False + stream_id, self.initial_respond_text, False ) return await self.api_client.encrypt_message( resp, callback_params["nonce"], callback_params["timestamp"] @@ -189,7 +188,17 @@ async def _process_message( stream_id = message_data["stream"]["id"] if not wecomai_queue_mgr.has_back_queue(stream_id): logger.error(f"Cannot find back queue for stream_id: {stream_id}") - return None + + # 返回结束标志,告诉微信服务器流已结束 + end_message = WecomAIBotStreamMessageBuilder.make_text_stream( + stream_id, "", True + ) + resp = await self.api_client.encrypt_message( + end_message, + callback_params["nonce"], + callback_params["timestamp"], + ) + return resp queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) if queue.empty(): logger.debug( @@ -197,21 +206,25 @@ async def _process_message( ) return None - # aggregate all delta messages in the back queue - aggregated_content = "" + # aggregate all delta chains in the back queue + latest_plain_content = "" finish = False while not queue.empty(): msg = await queue.get() if msg["type"] == "plain": - aggregated_content += msg["data"] + latest_plain_content = msg["data"] + elif msg["type"] == "image": + pass elif msg["type"] == "end": finish = True else: pass - logger.debug(f"Aggregated content: {aggregated_content}, finish: {finish}") - if aggregated_content: + logger.debug( + f"Aggregated content: {latest_plain_content}, finish: {finish}" + ) + if latest_plain_content: plain_message = WecomAIBotStreamMessageBuilder.make_text_stream( - stream_id, aggregated_content, finish + stream_id, latest_plain_content, finish ) encrypted_message = await self.api_client.encrypt_message( plain_message, @@ -228,6 +241,23 @@ async def _process_message( return None elif msgtype == "image": pass + elif msgtype == "event": + event = message_data.get("event") + if event == "enter_chat" and self.friend_message_welcome_text: + # 用户进入会话,发送欢迎消息 + try: + resp = WecomAIBotStreamMessageBuilder.make_text( + self.friend_message_welcome_text + ) + return await self.api_client.encrypt_message( + resp, + callback_params["nonce"], + callback_params["timestamp"], + ) + except Exception as e: + logger.error("处理欢迎消息时发生异常: %s", e) + return None + pass def _extract_session_id(self, message_data: Dict[str, Any]) -> str: """从消息数据中提取会话ID""" @@ -248,15 +278,13 @@ async def _enqueue_message( "message_data": message_data, "callback_params": callback_params, "session_id": session_id, + "stream_id": stream_id, } - - await input_queue.put(("wecomai", stream_id, message_payload)) + await input_queue.put(message_payload) logger.debug(f"[WecomAI] 消息已入队: {stream_id}") - async def convert_message(self, data: tuple) -> AstrBotMessage: + async def convert_message(self, payload: dict) -> AstrBotMessage: """转换队列中的消息数据为AstrBotMessage,类似webchat的convert_message""" - platform, stream_id, payload = data - message_data = payload["message_data"] session_id = payload["session_id"] # callback_params = payload["callback_params"] # 保留但暂时不使用 @@ -284,24 +312,34 @@ async def convert_message(self, data: tuple) -> AstrBotMessage: # 构建AstrBotMessage abm = AstrBotMessage() - abm.self_id = self.config.get("id", "wecom_ai_bot") + abm.self_id = self.bot_name abm.message_str = content or "[未知消息]" abm.message_id = str(uuid.uuid4()) abm.timestamp = int(time.time()) - abm.raw_message = data + abm.raw_message = payload # 发送者信息 abm.sender = MessageMember( - user_id="wecom_user", - nickname="WeChat Work User", + user_id=message_data.get("from", {}).get("userid", "unknown"), + nickname=message_data.get("from", {}).get("userid", "unknown"), ) # 消息类型 - abm.type = MessageType.FRIEND_MESSAGE + abm.type = ( + MessageType.GROUP_MESSAGE + if message_data.get("chattype") == "group" + else MessageType.FRIEND_MESSAGE + ) abm.session_id = session_id # 消息内容 - abm.message = [Plain(text=content or "[未知消息]")] + abm.message = [] + + # 处理 At + if self.bot_name and f"@{self.bot_name}" in abm.message_str: + abm.message_str = abm.message_str.replace(f"@{self.bot_name}", "").strip() + abm.message.append(At(qq=self.bot_name, name=self.bot_name)) + abm.message.append(Plain(abm.message_str)) logger.debug(f"WecomAIAdapter: {abm.message}") return abm @@ -340,29 +378,15 @@ def meta(self) -> PlatformMetadata: async def handle_msg(self, message: AstrBotMessage): """处理消息,创建消息事件并提交到事件队列""" try: - # 从原始消息中获取回调参数 - if isinstance(message.raw_message, tuple) and len(message.raw_message) == 3: - _, _, payload = message.raw_message - callback_params = payload["callback_params"] - else: - # 如果没有有效的回调参数,使用默认值 - callback_params = { - "nonce": "default", - "timestamp": str(int(time.time())), - } - logger.warning("使用默认回调参数") - message_event = WecomAIBotMessageEvent( message_str=message.message_str, message_obj=message, platform_meta=self.meta(), session_id=message.session_id, api_client=self.api_client, - callback_params=callback_params, ) self.commit_event(message_event) - logger.debug("消息事件已提交: %s", message.message_str) except Exception as e: logger.error("处理消息时发生异常: %s", e) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py index 73898ac4a..10621dc69 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py @@ -4,7 +4,6 @@ """ import json -import logging import base64 import hashlib from typing import Dict, Any, Optional, Tuple, Union @@ -13,8 +12,7 @@ from .WXBizJsonMsgCrypt import WXBizJsonMsgCrypt from .wecomai_utils import WecomAIBotConstants - -logger = logging.getLogger(__name__) +from astrbot import logger class WecomAIBotAPIClient: @@ -58,10 +56,10 @@ async def decrypt_message( if decrypted_msg: try: message_data = json.loads(decrypted_msg) - logger.debug("解密成功,消息内容: %s", message_data) + logger.debug(f"解密成功,消息内容: {message_data}") return WecomAIBotConstants.SUCCESS, message_data except json.JSONDecodeError as e: - logger.error("JSON 解析失败: %s, 原始消息: %s", e, decrypted_msg) + logger.error(f"JSON 解析失败: {e}, 原始消息: {decrypted_msg}") return WecomAIBotConstants.PARSE_XML_ERROR, None else: logger.error("解密消息为空") @@ -118,14 +116,14 @@ def verify_url( ) if ret != WecomAIBotConstants.SUCCESS: - logger.error("URL 验证失败,错误码: %d", ret) + logger.error(f"URL 验证失败,错误码: {ret}") return "verify fail" logger.info("URL 验证成功") return echo_result if echo_result else "verify fail" except Exception as e: - logger.error("URL 验证发生异常: %s", e) + logger.error(f"URL 验证发生异常: {e}") return "verify fail" async def process_encrypted_image( @@ -271,6 +269,19 @@ def make_mixed_stream(stream_id: str, msg_items: list, finish: bool = False) -> } return json.dumps(plain, ensure_ascii=False) + @staticmethod + def make_text(content: str) -> str: + """构建文本消息 + + Args: + content: 文本内容 + + Returns: + JSON 格式的文本消息字符串 + """ + plain = {"msgtype": "text", "text": {"content": content}} + return json.dumps(plain, ensure_ascii=False) + class WecomAIBotMessageParser: """企业微信智能机器人消息解析器""" diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py index 93126b84d..2d7ec91ca 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -2,15 +2,11 @@ 企业微信智能机器人事件处理模块,处理消息事件的发送和接收 """ -import uuid -from typing import AsyncGenerator, Dict, Any, Optional - from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import ( Image, Plain, ) -from astrbot.api.platform import Group from astrbot.api import logger from .wecomai_api import WecomAIBotAPIClient @@ -27,7 +23,6 @@ def __init__( platform_meta, session_id: str, api_client: WecomAIBotAPIClient, - callback_params: Dict[str, str], ): """初始化消息事件 @@ -44,11 +39,10 @@ def __init__( @staticmethod async def _send( message_chain: MessageChain, - session_id: str, + stream_id: str, streaming: bool = False, ): - session_key = session_id.split("!")[-1] if "!" in session_id else session_id - back_queue = wecomai_queue_mgr.get_or_create_back_queue(session_key) + back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) if not message_chain: await back_queue.put( @@ -69,7 +63,7 @@ async def _send( "type": "plain", "data": data, "streaming": streaming, - "session_id": session_id, + "session_id": stream_id, } ) elif isinstance(comp, Image): @@ -77,14 +71,12 @@ async def _send( try: image_base64 = await comp.convert_to_base64() if image_base64: - data = f"[IMAGE]{str(uuid.uuid4())}" await back_queue.put( { "type": "image", - "data": data, "image_data": image_base64, "streaming": streaming, - "session_id": session_id, + "session_id": stream_id, } ) else: @@ -92,38 +84,41 @@ async def _send( except Exception as e: logger.error("处理图片消息失败: %s", e) else: - # 其他类型的组件转换为文本 - text_data = str(comp) - data += text_data - await back_queue.put( - { - "type": "component", - "data": text_data, - "streaming": streaming, - "session_id": session_id, - } - ) + logger.warning(f"[WecomAI] 不支持的消息组件类型: {type(comp)}, 跳过") return data async def send(self, message: MessageChain): """发送消息""" - await WecomAIBotMessageEvent._send(message, self.session_id) + raw = self.message_obj.raw_message + assert isinstance(raw, dict), ( + "wecom_ai_bot platform event raw_message should be a dict" + ) + stream_id = raw.get("stream_id", self.session_id) + await WecomAIBotMessageEvent._send(message, stream_id) await super().send(message) - async def send_streaming( - self, generator: AsyncGenerator, use_fallback: bool = False - ): + async def send_streaming(self, generator, use_fallback=False): """流式发送消息,参考webchat的send_streaming设计""" final_data = "" - session_key = ( - self.session_id.split("!")[-1] - if "!" in self.session_id - else self.session_id + raw = self.message_obj.raw_message + assert isinstance(raw, dict), ( + "wecom_ai_bot platform event raw_message should be a dict" ) - back_queue = wecomai_queue_mgr.get_or_create_back_queue(session_key) + stream_id = raw.get("stream_id", self.session_id) + back_queue = wecomai_queue_mgr.get_or_create_back_queue(stream_id) + # 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送 + increment_plain = "" async for chain in generator: + # 累积增量内容,并改写 Plain 段 + chain.squash_plain() + for comp in chain.chain: + if isinstance(comp, Plain): + comp.text = increment_plain + comp.text + increment_plain = comp.text + break + if chain.type == "break" and final_data: # 分割符 await back_queue.put( @@ -139,7 +134,7 @@ async def send_streaming( final_data += await WecomAIBotMessageEvent._send( chain, - session_id=self.session_id, + stream_id=stream_id, streaming=True, ) @@ -152,31 +147,3 @@ async def send_streaming( } ) await super().send_streaming(generator, use_fallback) - - async def get_group(self, group_id=None, **kwargs) -> Optional[Group]: - """获取群组信息""" - return None - - def get_sender_id(self) -> str: - """获取发送者 ID""" - return getattr(self.message_obj, "sender", {}).get("user_id", "unknown") - - def get_sender_name(self) -> str: - """获取发送者名称""" - return getattr(self.message_obj, "sender", {}).get("nickname", "Unknown") - - def get_group_id(self) -> Optional[str]: - """获取群组 ID""" - return None - - def is_private_message(self) -> bool: - """是否为私聊消息""" - return True - - def is_group_message(self) -> bool: - """是否为群消息""" - return False - - def get_raw_message(self) -> Any: - """获取原始消息数据""" - return getattr(self.message_obj, "raw_message", {}) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py index 1ea151e0f..bbb69d041 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -20,7 +20,6 @@ def __init__( self, host: str, port: int, - bot_id: str, api_client: WecomAIBotAPIClient, message_handler: Optional[ Callable[[Dict[str, Any], Dict[str, str]], Any] @@ -31,13 +30,11 @@ def __init__( Args: host: 监听地址 port: 监听端口 - bot_id: 机器人ID api_client: API客户端实例 message_handler: 消息处理回调函数 """ self.host = host self.port = port - self.bot_id = bot_id self.api_client = api_client self.message_handler = message_handler @@ -51,13 +48,13 @@ def _setup_routes(self): # 使用 Quart 的 add_url_rule 方法添加路由 self.app.add_url_rule( - f"/ai-bot/callback/{self.bot_id}", + "/webhook/wecom-ai-bot", view_func=self.verify_url, methods=["GET"], ) self.app.add_url_rule( - f"/ai-bot/callback/{self.bot_id}", + "/webhook/wecom-ai-bot", view_func=self.handle_message, methods=["POST"], ) @@ -80,7 +77,7 @@ async def verify_url(self): assert nonce is not None assert echostr is not None - logger.info("收到 URL 验证请求") + logger.info("收到企业微信智能机器人 WebHook URL 验证请求。") result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr) return result, 200, {"Content-Type": "text/plain"} @@ -100,11 +97,8 @@ async def handle_message(self): assert timestamp is not None assert nonce is not None - logger.info( - "收到消息回调,msg_signature=%s, timestamp=%s, nonce=%s", - msg_signature, - timestamp, - nonce, + logger.debug( + f"收到消息回调,msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}" ) try: diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py index 953424560..f32ec990e 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py @@ -3,14 +3,12 @@ 提供常量定义、工具函数和辅助方法 """ -import logging import string import random import hashlib import base64 -from typing import Dict, Any, Tuple - -logger = logging.getLogger(__name__) +from typing import Any, Tuple +from astrbot.api import logger # 常量定义 @@ -77,37 +75,6 @@ def encode_image_base64(image_data: bytes) -> str: return base64.b64encode(image_data).decode("utf-8") -def validate_config(config: Dict[str, Any]) -> Tuple[bool, str]: - """验证配置参数 - - Args: - config: 配置字典 - - Returns: - (是否有效, 错误信息) - """ - required_fields = ["token", "encoding_aes_key", "callback_url", "port"] - - for field in required_fields: - if not config.get(field): - return False, f"缺少必要配置项: {field}" - - # 验证端口号 - try: - port = int(config.get("port", 0)) - if port <= 0 or port > 65535: - return False, "端口号必须在 1-65535 范围内" - except (ValueError, TypeError): - return False, "端口号必须是有效的数字" - - # 验证 AES 密钥长度 - encoding_aes_key = config.get("encoding_aes_key", "") - if len(encoding_aes_key) != 43: - return False, "EncodingAESKey 长度必须为 43 位" - - return True, "" - - def format_session_id(session_type: str, session_id: str) -> str: """格式化会话 ID diff --git a/pyproject.toml b/pyproject.toml index 36320141f..500a6e698 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ dependencies = [ "wechatpy>=1.8.18", "audioop-lts ; python_full_version >= '3.13'", "click>=8.2.1", - "fastapi>=0.119.0", ] [project.scripts] From 13ddff4df269326ea0559a18ac3d303e00e7a677 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 14 Oct 2025 17:38:50 +0800 Subject: [PATCH 3/4] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E5=9B=BE?= =?UTF-8?q?=E7=89=87=E6=94=B6=E5=8F=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sources/wecom_ai_bot/wecomai_adapter.py | 75 +++++++++++++++---- .../sources/wecom_ai_bot/wecomai_api.py | 7 +- .../sources/wecom_ai_bot/wecomai_utils.py | 59 +++++++++++++++ 3 files changed, 125 insertions(+), 16 deletions(-) diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index fa6a4f14e..830d8de58 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -7,6 +7,8 @@ import time import asyncio import uuid +import hashlib +import base64 from typing import Awaitable, Any, Dict, Optional, Callable @@ -18,7 +20,7 @@ PlatformMetadata, ) from astrbot.api.event import MessageChain -from astrbot.api.message_components import Plain, At +from astrbot.api.message_components import Plain, At, Image from astrbot.api import logger from astrbot.core.platform.astr_message_event import MessageSesion from ...register import register_platform_adapter @@ -35,6 +37,7 @@ WecomAIBotConstants, format_session_id, generate_random_string, + process_encrypted_image, ) @@ -107,7 +110,7 @@ def __init__( self.host = self.config.get("callback_server_host", "0.0.0.0") self.bot_name = self.config.get("wecom_ai_bot_name", "") self.initial_respond_text = self.config.get( - "wecomaibot_init_respond_text", "思考中..." + "wecomaibot_init_respond_text", "💭 思考中..." ) self.friend_message_welcome_text = self.config.get( "wecomaibot_friend_message_welcome_text", "" @@ -164,8 +167,8 @@ async def _process_message( logger.warning(f"消息类型未知,忽略: {message_data}") return None session_id = self._extract_session_id(message_data) - if msgtype == "text": - # user sent a text message + if msgtype in ("text", "image", "mixed"): + # user sent a text / image / mixed message try: # create a brand-new unique stream_id for this message session stream_id = f"{session_id}_{generate_random_string(10)}" @@ -208,23 +211,41 @@ async def _process_message( # aggregate all delta chains in the back queue latest_plain_content = "" + image_base64 = [] finish = False while not queue.empty(): msg = await queue.get() if msg["type"] == "plain": - latest_plain_content = msg["data"] + latest_plain_content = msg["data"] or "" elif msg["type"] == "image": - pass + image_base64.append(msg["image_data"]) elif msg["type"] == "end": + # stream end finish = True + wecomai_queue_mgr.remove_queues(stream_id) + break else: pass logger.debug( - f"Aggregated content: {latest_plain_content}, finish: {finish}" + f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}" ) - if latest_plain_content: - plain_message = WecomAIBotStreamMessageBuilder.make_text_stream( - stream_id, latest_plain_content, finish + if latest_plain_content or image_base64: + msg_items = [] + if finish and image_base64: + for img_b64 in image_base64: + # get md5 of image + img_data = base64.b64decode(img_b64) + img_md5 = hashlib.md5(img_data).hexdigest() + msg_items.append( + { + "msgtype": WecomAIBotConstants.MSG_TYPE_IMAGE, + "image": {"base64": img_b64, "md5": img_md5}, + } + ) + image_base64 = [] + + plain_message = WecomAIBotStreamMessageBuilder.make_mixed_stream( + stream_id, latest_plain_content, msg_items, finish ) encrypted_message = await self.api_client.encrypt_message( plain_message, @@ -239,8 +260,6 @@ async def _process_message( logger.error("消息加密失败") return encrypted_message return None - elif msgtype == "image": - pass elif msgtype == "event": event = message_data.get("event") if event == "enter_chat" and self.friend_message_welcome_text: @@ -292,11 +311,17 @@ async def convert_message(self, payload: dict) -> AstrBotMessage: # 解析消息内容 msgtype = message_data.get("msgtype") content = "" + image_base64 = [] + + _img_url_to_process = [] + msg_items = [] if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT: content = WecomAIBotMessageParser.parse_text_message(message_data) elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE: - content = "[图片消息]" + _img_url_to_process.append( + WecomAIBotMessageParser.parse_image_message(message_data) + ) elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED: # 提取混合消息中的文本内容 msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data) @@ -306,11 +331,28 @@ async def convert_message(self, payload: dict) -> AstrBotMessage: text_content = item.get("text", {}).get("content", "") if text_content: text_parts.append(text_content) - content = " ".join(text_parts) if text_parts else "[混合消息]" + elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE: + image_url = item.get("image", {}).get("url", "") + if image_url: + _img_url_to_process.append(image_url) + content = " ".join(text_parts) if text_parts else "" else: content = f"[{msgtype}消息]" - # 构建AstrBotMessage + # 并行处理图片下载和解密 + if _img_url_to_process: + tasks = [ + process_encrypted_image(url, self.encoding_aes_key) + for url in _img_url_to_process + ] + results = await asyncio.gather(*tasks) + for success, result in results: + if success: + image_base64.append(result) + else: + logger.error(f"处理加密图片失败: {result}") + + # 构建 AstrBotMessage abm = AstrBotMessage() abm.self_id = self.bot_name abm.message_str = content or "[未知消息]" @@ -340,6 +382,9 @@ async def convert_message(self, payload: dict) -> AstrBotMessage: abm.message_str = abm.message_str.replace(f"@{self.bot_name}", "").strip() abm.message.append(At(qq=self.bot_name, name=self.bot_name)) abm.message.append(Plain(abm.message_str)) + if image_base64: + for img_b64 in image_base64: + abm.message.append(Image.fromBase64(img_b64)) logger.debug(f"WecomAIAdapter: {abm.message}") return abm diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py index 10621dc69..540bf06b6 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py @@ -252,11 +252,14 @@ def make_image_stream( return json.dumps(plain, ensure_ascii=False) @staticmethod - def make_mixed_stream(stream_id: str, msg_items: list, finish: bool = False) -> str: + def make_mixed_stream( + stream_id: str, content: str, msg_items: list, finish: bool = False + ) -> str: """构建混合类型流消息 Args: stream_id: 流 ID + content: 文本内容 msg_items: 消息项列表 finish: 是否结束 @@ -267,6 +270,8 @@ def make_mixed_stream(stream_id: str, msg_items: list, finish: bool = False) -> "msgtype": WecomAIBotConstants.MSG_TYPE_STREAM, "stream": {"id": stream_id, "finish": finish, "msg_item": msg_items}, } + if content: + plain["stream"]["content"] = content return json.dumps(plain, ensure_ascii=False) @staticmethod diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py index f32ec990e..dccb2e260 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py @@ -7,6 +7,9 @@ import random import hashlib import base64 +import aiohttp +import asyncio +from Crypto.Cipher import AES from typing import Any, Tuple from astrbot.api import logger @@ -138,3 +141,59 @@ def format_error_response(error_code: int, error_msg: str) -> str: 格式化的错误响应字符串 """ return f"Error {error_code}: {error_msg}" + + +async def process_encrypted_image( + image_url: str, aes_key_base64: str +) -> Tuple[bool, str]: + """下载并解密加密图片 + + Args: + image_url: 加密图片的URL + aes_key_base64: Base64编码的AES密钥(与回调加解密相同) + + Returns: + Tuple[bool, str]: status 为 True 时 data 是解密后的图片数据的 base64 编码, + status 为 False 时 data 是错误信息 + """ + # 1. 下载加密图片 + logger.info("开始下载加密图片: %s", image_url) + try: + async with aiohttp.ClientSession() as session: + async with session.get(image_url, timeout=15) as response: + response.raise_for_status() + encrypted_data = await response.read() + logger.info("图片下载成功,大小: %d 字节", len(encrypted_data)) + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + error_msg = f"下载图片失败: {str(e)}" + logger.error(error_msg) + return False, error_msg + + # 2. 准备AES密钥和IV + if not aes_key_base64: + raise ValueError("AES密钥不能为空") + + # Base64解码密钥 (自动处理填充) + aes_key = base64.b64decode(aes_key_base64 + "=" * (-len(aes_key_base64) % 4)) + if len(aes_key) != 32: + raise ValueError("无效的AES密钥长度: 应为32字节") + + iv = aes_key[:16] # 初始向量为密钥前16字节 + + # 3. 解密图片数据 + cipher = AES.new(aes_key, AES.MODE_CBC, iv) + decrypted_data = cipher.decrypt(encrypted_data) + + # 4. 去除PKCS#7填充 (Python 3兼容写法) + pad_len = decrypted_data[-1] # 直接获取最后一个字节的整数值 + if pad_len > 32: # AES-256块大小为32字节 + raise ValueError("无效的填充长度 (大于32字节)") + + decrypted_data = decrypted_data[:-pad_len] + logger.info("图片解密成功,解密后大小: %d 字节", len(decrypted_data)) + + # 5. 转换为base64编码 + base64_data = base64.b64encode(decrypted_data).decode("utf-8") + logger.info("图片已转换为base64编码,编码后长度: %d", len(base64_data)) + + return True, base64_data From 9bf26230543faf99b196f0ac55fc8a92652b95eb Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 14 Oct 2025 17:52:25 +0800 Subject: [PATCH 4/4] feat: add support for wecom_ai_bot in getPlatformIcon function --- dashboard/src/utils/platformUtils.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dashboard/src/utils/platformUtils.js b/dashboard/src/utils/platformUtils.js index 2656c56c7..660ed7812 100644 --- a/dashboard/src/utils/platformUtils.js +++ b/dashboard/src/utils/platformUtils.js @@ -10,7 +10,7 @@ export function getPlatformIcon(name) { if (name === 'aiocqhttp' || name === 'qq_official' || name === 'qq_official_webhook') { return new URL('@/assets/images/platform_logos/qq.png', import.meta.url).href - } else if (name === 'wecom') { + } else if (name === 'wecom' || name === 'wecom_ai_bot') { return new URL('@/assets/images/platform_logos/wecom.png', import.meta.url).href } else if (name === 'wechatpadpro' || name === 'weixin_official_account' || name === 'wechat') { return new URL('@/assets/images/platform_logos/wechat.png', import.meta.url).href