Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 271 additions & 26 deletions astrbot/core/platform/sources/satori/satori_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import websockets
from websockets.asyncio.client import connect
from typing import Optional
from typing import Optional, List
from aiohttp import ClientSession, ClientTimeout
from websockets.asyncio.client import ClientConnection
from astrbot.api import logger
Expand All @@ -17,7 +17,7 @@
register_platform_adapter,
)
from astrbot.core.platform.astr_message_event import MessageSession
from astrbot.api.message_components import Plain, Image, At, File, Record
from astrbot.api.message_components import Plain, Image, At, File, Record, BaseMessageComponent, Reply
from xml.etree import ElementTree as ET


Expand All @@ -38,12 +38,18 @@ def __init__(
)
self.token = self.config.get("satori_token", "")
self.endpoint = self.config.get(
"satori_endpoint", "ws://127.0.0.1:5140/satori/v1/events"
"satori_endpoint", "ws://localhost:5140/satori/v1/events"
)
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)

self.metadata = PlatformMetadata(
name="satori",
description="Satori 通用协议适配器",
id=self.config.get("id"),
)

self.ws: Optional[ClientConnection] = None
self.session: Optional[ClientSession] = None
self.sequence = 0
Expand All @@ -63,7 +69,7 @@ async def send_by_session(
await super().send_by_session(session, message_chain)

def meta(self) -> PlatformMetadata:
return PlatformMetadata(name="satori", description="Satori 通用协议适配器")
return self.metadata

def _is_websocket_closed(self, ws) -> bool:
"""检查WebSocket连接是否已关闭"""
Expand Down Expand Up @@ -312,12 +318,52 @@ async def convert_satori_message(

abm.self_id = login.get("user", {}).get("id", "")

# 消息链
abm.message = []

content = message.get("content", "")
abm.message = await self.parse_satori_elements(content)

quote = message.get("quote")
content_for_parsing = content # 副本

# 提取<quote>标签
if "<quote" in content:
try:
quote_info = await self._extract_quote_element(content)
if quote_info:
quote = quote_info["quote"]
content_for_parsing = quote_info["content_without_quote"]
except Exception as e:
logger.error(f"解析<quote>标签时发生错误: {e}, 错误内容: {content}")

if quote:
# 引用消息
quote_abm = await self._convert_quote_message(quote)
if quote_abm:
sender_id = quote_abm.sender.user_id
if isinstance(sender_id, str) and sender_id.isdigit():
sender_id = int(sender_id)
elif not isinstance(sender_id, int):
sender_id = 0 # 默认值

reply_component = Reply(
id=quote_abm.message_id,
chain=quote_abm.message,
sender_id=quote_abm.sender.user_id,
sender_nickname=quote_abm.sender.nickname,
time=quote_abm.timestamp,
message_str=quote_abm.message_str,
text=quote_abm.message_str,
qq=sender_id,
)
abm.message.append(reply_component)

# 解析消息内容
content_elements = await self.parse_satori_elements(content_for_parsing)
abm.message.extend(content_elements)

# parse message_str
abm.message_str = ""
for comp in abm.message:
for comp in content_elements:
if isinstance(comp, Plain):
abm.message_str += comp.text

Expand All @@ -333,6 +379,155 @@ async def convert_satori_message(
logger.error(f"转换 Satori 消息失败: {e}")
return None

def _extract_namespace_prefixes(self, content: str) -> set:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): 我们发现了这些问题:


解释

此函数的质量得分低于 25% 的质量阈值。
此得分是方法长度、认知复杂度和工作内存的组合。

如何解决这个问题?

重构此函数以使其更短、更具可读性可能是有益的。

  • 通过将部分功能提取到自己的函数中来减少函数长度。这是你能做的最重要的事情——理想情况下,一个函数应该少于 10 行。
  • 减少嵌套,例如通过引入守卫子句来提前返回。
  • 确保变量的作用域紧密,以便使用相关概念的代码在函数内部紧密地放在一起,而不是分散开来。
Original comment in English

issue (code-quality): We've found these issues:


Explanation

The quality score for this function is below the quality threshold of 25%.
This score is a combination of the method length, cognitive complexity and working memory.

How can you solve this?

It might be worth refactoring this function to make it shorter and more readable.

  • Reduce the function length by extracting pieces of functionality out into
    their own functions. This is the most important thing you can do - ideally a
    function should be less than 10 lines.
  • Reduce nesting, perhaps by introducing guard clauses to return early.
  • Ensure that variables are tightly scoped, so that code using related concepts
    sits together within the function rather than being scattered.

"""提取XML内容中的命名空间前缀"""
prefixes = set()

# 查找所有标签
i = 0
while i < len(content):
# 查找开始标签
if content[i] == '<' and i + 1 < len(content) and content[i + 1] != '/':
# 找到标签结束位置
tag_end = content.find('>', i)
if tag_end != -1:
# 提取标签内容
tag_content = content[i + 1:tag_end]
# 检查是否有命名空间前缀
if ':' in tag_content and 'xmlns:' not in tag_content:
# 分割标签名
parts = tag_content.split()
if parts:
tag_name = parts[0]
if ':' in tag_name:
prefix = tag_name.split(':')[0]
# 确保是有效的命名空间前缀
if prefix.isalnum() or prefix.replace('_', '').isalnum():
prefixes.add(prefix)
i = tag_end + 1
else:
i += 1
# 查找结束标签
elif content[i] == '<' and i + 1 < len(content) and content[i + 1] == '/':
# 找到标签结束位置
tag_end = content.find('>', i)
if tag_end != -1:
# 提取标签内容
tag_content = content[i + 2:tag_end]
# 检查是否有命名空间前缀
if ':' in tag_content:
prefix = tag_content.split(':')[0]
# 确保是有效的命名空间前缀
if prefix.isalnum() or prefix.replace('_', '').isalnum():
prefixes.add(prefix)
i = tag_end + 1
else:
i += 1
else:
i += 1

return prefixes

async def _extract_quote_element(self, content: str) -> Optional[dict]:
"""提取<quote>标签信息"""
try:
# 处理命名空间前缀问题
processed_content = content
if ':' in content and not content.startswith('<root'):
prefixes = self._extract_namespace_prefixes(content)

# 构建命名空间声明
ns_declarations = ' '.join([f'xmlns:{prefix}="http://temp.uri/{prefix}"' for prefix in prefixes])

# 包装内容
processed_content = f"<root {ns_declarations}>{content}</root>"
elif not content.startswith('<root'):
processed_content = f"<root>{content}</root>"
else:
processed_content = content

root = ET.fromstring(processed_content)

# 查找<quote>标签
quote_element = None
for elem in root.iter():
tag_name = elem.tag
if '}' in tag_name:
tag_name = tag_name.split('}')[1]
if tag_name.lower() == "quote":
quote_element = elem
break
Comment on lines +452 to +459
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: 命名空间剥离假设命名空间标签中始终存在 '}'。

如果 XML 格式发生变化或使用不同的分隔符,此方法可能会失效。使用 xml.etree.ElementTree 的命名空间处理将使代码更健壮。

Suggested change
quote_element = None
for elem in root.iter():
tag_name = elem.tag
if '}' in tag_name:
tag_name = tag_name.split('}')[1]
if tag_name.lower() == "quote":
quote_element = elem
break
quote_element = None
for elem in root.iter():
# Use ElementTree's QName to get the local name robustly
try:
local_name = ET.QName(elem.tag).localname
except AttributeError:
# Fallback for older Python versions or non-namespaced tags
local_name = elem.tag.split('}')[-1] if '}' in elem.tag else elem.tag
if local_name.lower() == "quote":
quote_element = elem
break
Original comment in English

suggestion: Namespace stripping assumes '}' is always present for namespaced tags.

This approach may break if the XML format changes or uses different delimiters. Using xml.etree.ElementTree's namespace handling would make the code more robust.

Suggested change
quote_element = None
for elem in root.iter():
tag_name = elem.tag
if '}' in tag_name:
tag_name = tag_name.split('}')[1]
if tag_name.lower() == "quote":
quote_element = elem
break
quote_element = None
for elem in root.iter():
# Use ElementTree's QName to get the local name robustly
try:
local_name = ET.QName(elem.tag).localname
except AttributeError:
# Fallback for older Python versions or non-namespaced tags
local_name = elem.tag.split('}')[-1] if '}' in elem.tag else elem.tag
if local_name.lower() == "quote":
quote_element = elem
break


if quote_element is not None:
# 提取quote标签的属性
quote_id = quote_element.get("id", "")

# 提取<quote>标签内部的内容
inner_content = ""
if quote_element.text:
inner_content += quote_element.text
for child in quote_element:
inner_content += ET.tostring(child, encoding='unicode', method='xml')
if child.tail:
inner_content += child.tail

# 构造移除了<quote>标签的内容
content_without_quote = content.replace(
ET.tostring(quote_element, encoding='unicode', method='xml'), "")
Comment on lines +475 to +476
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): 使用字符串替换来移除 可能不适用于所有情况。

如果 XML 格式发生变化,使用字符串替换可能会失败。请使用 XML 解析器可靠地移除 元素。

Original comment in English

issue (bug_risk): Using string replace to remove may not work for all cases.

Using string replacement may fail if XML formatting changes. Use an XML parser to reliably remove the element.


return {
"quote": {
"id": quote_id,
"content": inner_content
},
"content_without_quote": content_without_quote
}

return None
except Exception as e:
logger.error(f"提取<quote>标签时发生错误: {e}")
return None

async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
"""转换引用消息"""
try:
quote_abm = AstrBotMessage()
quote_abm.message_id = quote.get("id", "")

# 解析引用消息的发送者
quote_author = quote.get("author", {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): 使用命名表达式简化赋值和条件 (use-named-expression)

Original comment in English

issue (code-quality): Use named expression to simplify assignment and conditional (use-named-expression)

if quote_author:
quote_abm.sender = MessageMember(
user_id=quote_author.get("id", ""),
nickname=quote_author.get("nick", quote_author.get("name", "")),
)
else:
# 如果没有作者信息,使用默认值
quote_abm.sender = MessageMember(
user_id=quote.get("user_id", ""),
nickname="内容",
)

# 解析引用消息内容
quote_content = quote.get("content", "")
quote_abm.message = await self.parse_satori_elements(quote_content)

quote_abm.message_str = ""
for comp in quote_abm.message:
if isinstance(comp, Plain):
quote_abm.message_str += comp.text

quote_abm.timestamp = int(quote.get("timestamp", time.time()))

# 如果没有任何内容,使用默认文本
if not quote_abm.message_str.strip():
quote_abm.message_str = "[引用消息]"

return quote_abm
except Exception as e:
logger.error(f"转换引用消息失败: {e}")
return None

async def parse_satori_elements(self, content: str) -> list:
"""解析 Satori 消息元素"""
elements = []
Expand All @@ -341,12 +536,30 @@ async def parse_satori_elements(self, content: str) -> list:
return elements

try:
wrapped_content = f"<root>{content}</root>"
root = ET.fromstring(wrapped_content)
# 处理命名空间前缀问题
processed_content = content
if ':' in content and not content.startswith('<root'):
prefixes = self._extract_namespace_prefixes(content)

# 构建命名空间声明
ns_declarations = ' '.join([f'xmlns:{prefix}="http://temp.uri/{prefix}"' for prefix in prefixes])

# 包装内容
processed_content = f"<root {ns_declarations}>{content}</root>"
elif not content.startswith('<root'):
processed_content = f"<root>{content}</root>"
else:
processed_content = content

root = ET.fromstring(processed_content)
await self._parse_xml_node(root, elements)
except ET.ParseError as e:
raise ValueError(f"解析 Satori 元素时发生解析错误: {e}")
logger.error(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
# 如果解析失败,将整个内容当作纯文本
if content.strip():
elements.append(Plain(text=content))
except Exception as e:
logger.error(f"解析 Satori 元素时发生未知错误: {e}")
raise e

# 如果没有解析到任何元素,将整个内容当作纯文本
Expand All @@ -361,9 +574,14 @@ async def _parse_xml_node(self, node: ET.Element, elements: list) -> None:
elements.append(Plain(text=node.text))

for child in node:
tag_name = child.tag.lower()
# 获取标签名,去除命名空间前缀
tag_name = child.tag
if '}' in tag_name:
tag_name = tag_name.split('}')[1]
tag_name = tag_name.lower()

attrs = child.attrib

if tag_name == "at":
user_id = attrs.get("id") or attrs.get("name", "")
elements.append(At(qq=user_id, name=user_id))
Expand All @@ -372,31 +590,58 @@ async def _parse_xml_node(self, node: ET.Element, elements: list) -> None:
src = attrs.get("src", "")
if not src:
continue
if src.startswith("data:image/"):
src = src.split(",")[1]
elements.append(Image.fromBase64(src))
elif src.startswith("http"):
elements.append(Image.fromURL(src))
else:
logger.error(f"未知的图片 src 格式: {str(src)[:16]}")
elements.append(Image(file=src))

elif tag_name == "file":
src = attrs.get("src", "")
name = attrs.get("name", "文件")
if src:
elements.append(File(file=src, name=name))
elements.append(File(name=name, file=src))

elif tag_name in ("audio", "record"):
src = attrs.get("src", "")
if not src:
continue
if src.startswith("data:audio/"):
src = src.split(",")[1]
elements.append(Record.fromBase64(src))
elif src.startswith("http"):
elements.append(Record.fromURL(src))
elements.append(Record(file=src))

elif tag_name == "quote":
# quote标签已经被特殊处理
pass

elif tag_name == "face":
face_id = attrs.get("id", "")
face_name = attrs.get("name", "")
face_platform = attrs.get("platform", "")
face_type = attrs.get("type", "")

if face_name:
elements.append(Plain(text=f"[表情:{face_name}]"))
elif face_id and face_type:
elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]"))
elif face_id:
elements.append(Plain(text=f"[表情ID:{face_id}]"))
else:
elements.append(Plain(text="[表情]"))

elif tag_name == "ark":
# 作为纯文本添加到消息链中
data = attrs.get("data", "")
if data:
import html
decoded_data = html.unescape(data)
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
else:
elements.append(Plain(text="[ARK卡片]"))

elif tag_name == "json":
# JSON标签 视为ARK卡片消息
data = attrs.get("data", "")
if data:
import html
decoded_data = html.unescape(data)
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
else:
logger.error(f"未知的音频 src 格式: {str(src)[:16]}")
elements.append(Plain(text="[JSON卡片]"))

else:
# 未知标签,递归处理其内容
Expand Down
Loading