From fa8e202179fd7b8b3ef53d88e65837bc4849fbb4 Mon Sep 17 00:00:00 2001 From: RC-CHN <1051989940@qq.com> Date: Thu, 18 Dec 2025 11:10:12 +0800 Subject: [PATCH 01/17] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E8=BF=81=E7=A7=BB=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/backup/__init__.py | 9 + astrbot/core/backup/exporter.py | 426 +++++++++++++++ astrbot/core/backup/importer.py | 493 ++++++++++++++++++ astrbot/dashboard/routes/__init__.py | 2 + astrbot/dashboard/routes/backup.py | 453 ++++++++++++++++ astrbot/dashboard/server.py | 2 + .../src/components/shared/BackupDialog.vue | 490 +++++++++++++++++ .../i18n/locales/en-US/features/settings.json | 48 +- .../i18n/locales/zh-CN/features/settings.json | 48 +- dashboard/src/views/Settings.vue | 16 + 10 files changed, 1985 insertions(+), 2 deletions(-) create mode 100644 astrbot/core/backup/__init__.py create mode 100644 astrbot/core/backup/exporter.py create mode 100644 astrbot/core/backup/importer.py create mode 100644 astrbot/dashboard/routes/backup.py create mode 100644 dashboard/src/components/shared/BackupDialog.vue diff --git a/astrbot/core/backup/__init__.py b/astrbot/core/backup/__init__.py new file mode 100644 index 000000000..114bc0b22 --- /dev/null +++ b/astrbot/core/backup/__init__.py @@ -0,0 +1,9 @@ +"""AstrBot 备份与恢复模块 + +提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据。 +""" + +from .exporter import AstrBotExporter +from .importer import AstrBotImporter + +__all__ = ["AstrBotExporter", "AstrBotImporter"] diff --git a/astrbot/core/backup/exporter.py b/astrbot/core/backup/exporter.py new file mode 100644 index 000000000..f814bb915 --- /dev/null +++ b/astrbot/core/backup/exporter.py @@ -0,0 +1,426 @@ +"""AstrBot 数据导出器 + +负责将所有数据导出为 ZIP 备份文件。 +导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。 +""" + +import hashlib +import json +import os +import zipfile +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from sqlalchemy import select +from sqlmodel import SQLModel + +from astrbot.core import logger +from astrbot.core.config.default import VERSION +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import ( + Attachment, + CommandConfig, + CommandConflict, + ConversationV2, + Persona, + PlatformMessageHistory, + PlatformSession, + PlatformStat, + Preference, +) +from astrbot.core.knowledge_base.models import ( + KBDocument, + KBMedia, + KnowledgeBase, +) + +if TYPE_CHECKING: + from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager + + +# 主数据库模型类映射 +MAIN_DB_MODELS: dict[str, type[SQLModel]] = { + "platform_stats": PlatformStat, + "conversations": ConversationV2, + "personas": Persona, + "preferences": Preference, + "platform_message_history": PlatformMessageHistory, + "platform_sessions": PlatformSession, + "attachments": Attachment, + "command_configs": CommandConfig, + "command_conflicts": CommandConflict, +} + +# 知识库元数据模型类映射 +KB_METADATA_MODELS: dict[str, type[SQLModel]] = { + "knowledge_bases": KnowledgeBase, + "kb_documents": KBDocument, + "kb_media": KBMedia, +} + + +class AstrBotExporter: + """AstrBot 数据导出器 + + 导出内容: + - 主数据库所有表(data/data_v4.db) + - 知识库元数据(data/knowledge_base/kb.db) + - 每个知识库的向量文档数据 + - 配置文件(data/cmd_config.json) + - 附件文件 + - 知识库多媒体文件 + """ + + def __init__( + self, + main_db: BaseDatabase, + kb_manager: "KnowledgeBaseManager | None" = None, + config_path: str = "data/cmd_config.json", + attachments_dir: str = "data/attachments", + ): + self.main_db = main_db + self.kb_manager = kb_manager + self.config_path = config_path + self.attachments_dir = attachments_dir + self._checksums: dict[str, str] = {} + + async def export_all( + self, + output_dir: str = "data/backups", + progress_callback: Any | None = None, + ) -> str: + """导出所有数据到 ZIP 文件 + + Args: + output_dir: 输出目录 + progress_callback: 进度回调函数,接收参数 (stage, current, total, message) + + Returns: + str: 生成的 ZIP 文件路径 + """ + # 确保输出目录存在 + Path(output_dir).mkdir(parents=True, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + zip_filename = f"astrbot_backup_{timestamp}.zip" + zip_path = os.path.join(output_dir, zip_filename) + + logger.info(f"开始导出备份到 {zip_path}") + + try: + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf: + # 1. 导出主数据库 + if progress_callback: + await progress_callback("main_db", 0, 100, "正在导出主数据库...") + main_data = await self._export_main_database() + main_db_json = json.dumps( + main_data, ensure_ascii=False, indent=2, default=str + ) + zf.writestr("databases/main_db.json", main_db_json) + self._add_checksum("databases/main_db.json", main_db_json) + if progress_callback: + await progress_callback("main_db", 100, 100, "主数据库导出完成") + + # 2. 导出知识库数据 + kb_meta_data: dict[str, Any] = { + "knowledge_bases": [], + "kb_documents": [], + "kb_media": [], + } + if self.kb_manager: + if progress_callback: + await progress_callback( + "kb_metadata", 0, 100, "正在导出知识库元数据..." + ) + kb_meta_data = await self._export_kb_metadata() + kb_meta_json = json.dumps( + kb_meta_data, ensure_ascii=False, indent=2, default=str + ) + zf.writestr("databases/kb_metadata.json", kb_meta_json) + self._add_checksum("databases/kb_metadata.json", kb_meta_json) + if progress_callback: + await progress_callback( + "kb_metadata", 100, 100, "知识库元数据导出完成" + ) + + # 导出每个知识库的文档数据 + kb_insts = self.kb_manager.kb_insts + total_kbs = len(kb_insts) + for idx, (kb_id, kb_helper) in enumerate(kb_insts.items()): + if progress_callback: + await progress_callback( + "kb_documents", + idx, + total_kbs, + f"正在导出知识库 {kb_helper.kb.kb_name} 的文档数据...", + ) + doc_data = await self._export_kb_documents(kb_helper) + doc_json = json.dumps( + doc_data, ensure_ascii=False, indent=2, default=str + ) + doc_path = f"databases/kb_{kb_id}/documents.json" + zf.writestr(doc_path, doc_json) + self._add_checksum(doc_path, doc_json) + + # 导出 FAISS 索引文件 + await self._export_faiss_index(zf, kb_helper, kb_id) + + # 导出知识库多媒体文件 + await self._export_kb_media_files(zf, kb_helper, kb_id) + + if progress_callback: + await progress_callback( + "kb_documents", total_kbs, total_kbs, "知识库文档导出完成" + ) + + # 3. 导出配置文件 + if progress_callback: + await progress_callback("config", 0, 100, "正在导出配置文件...") + if os.path.exists(self.config_path): + with open(self.config_path, encoding="utf-8") as f: + config_content = f.read() + zf.writestr("config/cmd_config.json", config_content) + self._add_checksum("config/cmd_config.json", config_content) + if progress_callback: + await progress_callback("config", 100, 100, "配置文件导出完成") + + # 4. 导出附件文件 + if progress_callback: + await progress_callback("attachments", 0, 100, "正在导出附件...") + await self._export_attachments(zf, main_data.get("attachments", [])) + if progress_callback: + await progress_callback("attachments", 100, 100, "附件导出完成") + + # 5. 生成 manifest + if progress_callback: + await progress_callback("manifest", 0, 100, "正在生成清单...") + manifest = self._generate_manifest(main_data, kb_meta_data) + manifest_json = json.dumps(manifest, ensure_ascii=False, indent=2) + zf.writestr("manifest.json", manifest_json) + if progress_callback: + await progress_callback("manifest", 100, 100, "清单生成完成") + + logger.info(f"备份导出完成: {zip_path}") + return zip_path + + except Exception as e: + logger.error(f"备份导出失败: {e}") + # 清理失败的文件 + if os.path.exists(zip_path): + os.remove(zip_path) + raise + + async def _export_main_database(self) -> dict[str, list[dict]]: + """导出主数据库所有表""" + export_data: dict[str, list[dict]] = {} + + async with self.main_db.get_db() as session: + for table_name, model_class in MAIN_DB_MODELS.items(): + try: + result = await session.execute(select(model_class)) + records = result.scalars().all() + export_data[table_name] = [ + self._model_to_dict(record) for record in records + ] + logger.debug( + f"导出表 {table_name}: {len(export_data[table_name])} 条记录" + ) + except Exception as e: + logger.warning(f"导出表 {table_name} 失败: {e}") + export_data[table_name] = [] + + return export_data + + async def _export_kb_metadata(self) -> dict[str, list[dict]]: + """导出知识库元数据库""" + if not self.kb_manager: + return {"knowledge_bases": [], "kb_documents": [], "kb_media": []} + + export_data: dict[str, list[dict]] = {} + + async with self.kb_manager.kb_db.get_db() as session: + for table_name, model_class in KB_METADATA_MODELS.items(): + try: + result = await session.execute(select(model_class)) + records = result.scalars().all() + export_data[table_name] = [ + self._model_to_dict(record) for record in records + ] + logger.debug( + f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录" + ) + except Exception as e: + logger.warning(f"导出知识库表 {table_name} 失败: {e}") + export_data[table_name] = [] + + return export_data + + async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]: + """导出知识库的文档块数据""" + try: + from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB + + vec_db: FaissVecDB = kb_helper.vec_db + if not vec_db or not vec_db.document_storage: + return {"documents": []} + + # 获取所有文档 + docs = await vec_db.document_storage.get_documents( + metadata_filters={}, + offset=0, + limit=None, # 获取全部 + ) + + return {"documents": docs} + except Exception as e: + logger.warning(f"导出知识库文档失败: {e}") + return {"documents": []} + + async def _export_faiss_index( + self, + zf: zipfile.ZipFile, + kb_helper: Any, + kb_id: str, + ) -> None: + """导出 FAISS 索引文件""" + try: + index_path = kb_helper.kb_dir / "index.faiss" + if index_path.exists(): + archive_path = f"databases/kb_{kb_id}/index.faiss" + zf.write(str(index_path), archive_path) + logger.debug(f"导出 FAISS 索引: {archive_path}") + except Exception as e: + logger.warning(f"导出 FAISS 索引失败: {e}") + + async def _export_kb_media_files( + self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str + ) -> None: + """导出知识库的多媒体文件""" + try: + media_dir = kb_helper.kb_medias_dir + if not media_dir.exists(): + return + + for root, _, files in os.walk(media_dir): + for file in files: + file_path = Path(root) / file + # 计算相对路径 + rel_path = file_path.relative_to(kb_helper.kb_dir) + archive_path = f"files/kb_media/{kb_id}/{rel_path}" + zf.write(str(file_path), archive_path) + except Exception as e: + logger.warning(f"导出知识库媒体文件失败: {e}") + + async def _export_attachments( + self, zf: zipfile.ZipFile, attachments: list[dict] + ) -> None: + """导出附件文件""" + for attachment in attachments: + try: + file_path = attachment.get("path", "") + if file_path and os.path.exists(file_path): + # 使用 attachment_id 作为文件名 + attachment_id = attachment.get("attachment_id", "") + ext = os.path.splitext(file_path)[1] + archive_path = f"files/attachments/{attachment_id}{ext}" + zf.write(file_path, archive_path) + except Exception as e: + logger.warning(f"导出附件失败: {e}") + + def _model_to_dict(self, record: Any) -> dict: + """将 SQLModel 实例转换为字典 + + 这是数据库无关的序列化方式,支持未来迁移到其他数据库。 + """ + # 使用 SQLModel 内置的 model_dump 方法(如果可用) + if hasattr(record, "model_dump"): + data = record.model_dump(mode="python") + # 处理 datetime 类型 + for key, value in data.items(): + if isinstance(value, datetime): + data[key] = value.isoformat() + return data + + # 回退到手动提取 + data = {} + # 使用 inspect 获取表信息 + from sqlalchemy import inspect as sa_inspect + + mapper = sa_inspect(record.__class__) + for column in mapper.columns: + value = getattr(record, column.name) + # 处理 datetime 类型 - 统一转为 ISO 格式字符串 + if isinstance(value, datetime): + value = value.isoformat() + data[column.name] = value + return data + + def _add_checksum(self, path: str, content: str | bytes) -> None: + """计算并添加文件校验和""" + if isinstance(content, str): + content = content.encode("utf-8") + checksum = hashlib.sha256(content).hexdigest() + self._checksums[path] = f"sha256:{checksum}" + + def _generate_manifest( + self, main_data: dict[str, list[dict]], kb_meta_data: dict[str, list[dict]] + ) -> dict: + """生成备份清单""" + # 收集知识库 ID + kb_document_tables = {} + if self.kb_manager: + for kb_id in self.kb_manager.kb_insts.keys(): + kb_document_tables[kb_id] = "documents" + + # 收集附件文件列表 + attachment_files = [] + for attachment in main_data.get("attachments", []): + attachment_id = attachment.get("attachment_id", "") + path = attachment.get("path", "") + if attachment_id and path: + ext = os.path.splitext(path)[1] + attachment_files.append(f"{attachment_id}{ext}") + + # 收集知识库媒体文件 + kb_media_files: dict[str, list[str]] = {} + if self.kb_manager: + for kb_id, kb_helper in self.kb_manager.kb_insts.items(): + media_files: list[str] = [] + media_dir = kb_helper.kb_medias_dir + if media_dir.exists(): + for root, _, files in os.walk(media_dir): + for file in files: + media_files.append(file) + if media_files: + kb_media_files[kb_id] = media_files + + manifest = { + "version": "1.0", + "astrbot_version": VERSION, + "exported_at": datetime.now(timezone.utc).isoformat(), + "schema_version": { + "main_db": "v4", + "kb_db": "v1", + }, + "tables": { + "main_db": list(main_data.keys()), + "kb_metadata": list(kb_meta_data.keys()), + "kb_documents": kb_document_tables, + }, + "files": { + "attachments": attachment_files, + "kb_media": kb_media_files, + }, + "checksums": self._checksums, + "statistics": { + "main_db": { + table: len(records) for table, records in main_data.items() + }, + "kb_metadata": { + table: len(records) for table, records in kb_meta_data.items() + }, + }, + } + + return manifest diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py new file mode 100644 index 000000000..3f9b9958b --- /dev/null +++ b/astrbot/core/backup/importer.py @@ -0,0 +1,493 @@ +"""AstrBot 数据导入器 + +负责从 ZIP 备份文件恢复所有数据。 +导入时进行严格的版本校验,仅允许相同版本的 AstrBot 进行导入。 +""" + +import json +import os +import shutil +import zipfile +from datetime import datetime +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from sqlalchemy import delete +from sqlmodel import SQLModel + +from astrbot.core import logger +from astrbot.core.config.default import VERSION +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import ( + Attachment, + CommandConfig, + CommandConflict, + ConversationV2, + Persona, + PlatformMessageHistory, + PlatformSession, + PlatformStat, + Preference, +) +from astrbot.core.knowledge_base.models import ( + KBDocument, + KBMedia, + KnowledgeBase, +) + +if TYPE_CHECKING: + from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager + + +# 主数据库模型类映射 +MAIN_DB_MODELS: dict[str, type[SQLModel]] = { + "platform_stats": PlatformStat, + "conversations": ConversationV2, + "personas": Persona, + "preferences": Preference, + "platform_message_history": PlatformMessageHistory, + "platform_sessions": PlatformSession, + "attachments": Attachment, + "command_configs": CommandConfig, + "command_conflicts": CommandConflict, +} + +# 知识库元数据模型类映射 +KB_METADATA_MODELS: dict[str, type[SQLModel]] = { + "knowledge_bases": KnowledgeBase, + "kb_documents": KBDocument, + "kb_media": KBMedia, +} + + +class ImportResult: + """导入结果""" + + def __init__(self): + self.success = True + self.imported_tables: dict[str, int] = {} + self.imported_files: dict[str, int] = {} + self.warnings: list[str] = [] + self.errors: list[str] = [] + + def add_warning(self, msg: str) -> None: + self.warnings.append(msg) + logger.warning(msg) + + def add_error(self, msg: str) -> None: + self.errors.append(msg) + self.success = False + logger.error(msg) + + def to_dict(self) -> dict: + return { + "success": self.success, + "imported_tables": self.imported_tables, + "imported_files": self.imported_files, + "warnings": self.warnings, + "errors": self.errors, + } + + +class AstrBotImporter: + """AstrBot 数据导入器 + + 导入备份文件中的所有数据,包括: + - 主数据库所有表 + - 知识库元数据和文档 + - 配置文件 + - 附件文件 + - 知识库多媒体文件 + """ + + def __init__( + self, + main_db: BaseDatabase, + kb_manager: "KnowledgeBaseManager | None" = None, + config_path: str = "data/cmd_config.json", + attachments_dir: str = "data/attachments", + kb_root_dir: str = "data/knowledge_base", + ): + self.main_db = main_db + self.kb_manager = kb_manager + self.config_path = config_path + self.attachments_dir = attachments_dir + self.kb_root_dir = kb_root_dir + + async def import_all( + self, + zip_path: str, + mode: str = "replace", # "replace" 清空后导入 + progress_callback: Any | None = None, + ) -> ImportResult: + """从 ZIP 文件导入所有数据 + + Args: + zip_path: ZIP 备份文件路径 + mode: 导入模式,目前仅支持 "replace"(清空后导入) + progress_callback: 进度回调函数,接收参数 (stage, current, total, message) + + Returns: + ImportResult: 导入结果 + """ + result = ImportResult() + + if not os.path.exists(zip_path): + result.add_error(f"备份文件不存在: {zip_path}") + return result + + logger.info(f"开始从 {zip_path} 导入备份") + + try: + with zipfile.ZipFile(zip_path, "r") as zf: + # 1. 读取并验证 manifest + if progress_callback: + await progress_callback("validate", 0, 100, "正在验证备份文件...") + + try: + manifest_data = zf.read("manifest.json") + manifest = json.loads(manifest_data) + except KeyError: + result.add_error("备份文件缺少 manifest.json") + return result + except json.JSONDecodeError as e: + result.add_error(f"manifest.json 格式错误: {e}") + return result + + # 版本校验 + try: + self._validate_version(manifest) + except ValueError as e: + result.add_error(str(e)) + return result + + if progress_callback: + await progress_callback("validate", 100, 100, "验证完成") + + # 2. 导入主数据库 + if progress_callback: + await progress_callback("main_db", 0, 100, "正在导入主数据库...") + + try: + main_data_content = zf.read("databases/main_db.json") + main_data = json.loads(main_data_content) + + if mode == "replace": + await self._clear_main_db() + + imported = await self._import_main_database(main_data) + result.imported_tables.update(imported) + except Exception as e: + result.add_error(f"导入主数据库失败: {e}") + return result + + if progress_callback: + await progress_callback("main_db", 100, 100, "主数据库导入完成") + + # 3. 导入知识库 + if self.kb_manager and "databases/kb_metadata.json" in zf.namelist(): + if progress_callback: + await progress_callback("kb", 0, 100, "正在导入知识库...") + + try: + kb_meta_content = zf.read("databases/kb_metadata.json") + kb_meta_data = json.loads(kb_meta_content) + + if mode == "replace": + await self._clear_kb_data() + + await self._import_knowledge_bases(zf, kb_meta_data, result) + except Exception as e: + result.add_warning(f"导入知识库失败: {e}") + + if progress_callback: + await progress_callback("kb", 100, 100, "知识库导入完成") + + # 4. 导入配置文件 + if progress_callback: + await progress_callback("config", 0, 100, "正在导入配置文件...") + + if "config/cmd_config.json" in zf.namelist(): + try: + config_content = zf.read("config/cmd_config.json") + # 备份现有配置 + if os.path.exists(self.config_path): + backup_path = f"{self.config_path}.bak" + shutil.copy2(self.config_path, backup_path) + + with open(self.config_path, "wb") as f: + f.write(config_content) + result.imported_files["config"] = 1 + except Exception as e: + result.add_warning(f"导入配置文件失败: {e}") + + if progress_callback: + await progress_callback("config", 100, 100, "配置文件导入完成") + + # 5. 导入附件文件 + if progress_callback: + await progress_callback("attachments", 0, 100, "正在导入附件...") + + attachment_count = await self._import_attachments( + zf, main_data.get("attachments", []) + ) + result.imported_files["attachments"] = attachment_count + + if progress_callback: + await progress_callback("attachments", 100, 100, "附件导入完成") + + logger.info(f"备份导入完成: {result.to_dict()}") + return result + + except zipfile.BadZipFile: + result.add_error("无效的 ZIP 文件") + return result + except Exception as e: + result.add_error(f"导入失败: {e}") + return result + + def _validate_version(self, manifest: dict) -> None: + """验证版本兼容性 - 仅允许相同版本导入""" + backup_version = manifest.get("astrbot_version") + if not backup_version: + raise ValueError("备份文件缺少版本信息") + + if backup_version != VERSION: + raise ValueError( + f"版本不匹配: 备份版本 {backup_version}, 当前版本 {VERSION}。" + f"请使用相同版本的 AstrBot 进行导入。" + ) + + async def _clear_main_db(self) -> None: + """清空主数据库所有表""" + async with self.main_db.get_db() as session: + async with session.begin(): + for table_name, model_class in MAIN_DB_MODELS.items(): + try: + await session.execute(delete(model_class)) + logger.debug(f"已清空表 {table_name}") + except Exception as e: + logger.warning(f"清空表 {table_name} 失败: {e}") + + async def _clear_kb_data(self) -> None: + """清空知识库数据""" + if not self.kb_manager: + return + + # 清空知识库元数据表 + async with self.kb_manager.kb_db.get_db() as session: + async with session.begin(): + for table_name, model_class in KB_METADATA_MODELS.items(): + try: + await session.execute(delete(model_class)) + logger.debug(f"已清空知识库表 {table_name}") + except Exception as e: + logger.warning(f"清空知识库表 {table_name} 失败: {e}") + + # 删除知识库文件目录 + for kb_id in list(self.kb_manager.kb_insts.keys()): + try: + kb_helper = self.kb_manager.kb_insts[kb_id] + await kb_helper.terminate() + if kb_helper.kb_dir.exists(): + shutil.rmtree(kb_helper.kb_dir) + except Exception as e: + logger.warning(f"清理知识库 {kb_id} 失败: {e}") + + self.kb_manager.kb_insts.clear() + + async def _import_main_database( + self, data: dict[str, list[dict]] + ) -> dict[str, int]: + """导入主数据库数据""" + imported: dict[str, int] = {} + + async with self.main_db.get_db() as session: + async with session.begin(): + for table_name, rows in data.items(): + model_class = MAIN_DB_MODELS.get(table_name) + if not model_class: + logger.warning(f"未知的表: {table_name}") + continue + + count = 0 + for row in rows: + try: + # 转换 datetime 字符串为 datetime 对象 + row = self._convert_datetime_fields(row, model_class) + obj = model_class(**row) + session.add(obj) + count += 1 + except Exception as e: + logger.warning(f"导入记录到 {table_name} 失败: {e}") + + imported[table_name] = count + logger.debug(f"导入表 {table_name}: {count} 条记录") + + return imported + + async def _import_knowledge_bases( + self, + zf: zipfile.ZipFile, + kb_meta_data: dict[str, list[dict]], + result: ImportResult, + ) -> None: + """导入知识库数据""" + if not self.kb_manager: + return + + # 1. 导入知识库元数据 + async with self.kb_manager.kb_db.get_db() as session: + async with session.begin(): + for table_name, rows in kb_meta_data.items(): + model_class = KB_METADATA_MODELS.get(table_name) + if not model_class: + continue + + count = 0 + for row in rows: + try: + row = self._convert_datetime_fields(row, model_class) + obj = model_class(**row) + session.add(obj) + count += 1 + except Exception as e: + logger.warning(f"导入知识库记录到 {table_name} 失败: {e}") + + result.imported_tables[f"kb_{table_name}"] = count + + # 2. 导入每个知识库的文档和文件 + for kb_data in kb_meta_data.get("knowledge_bases", []): + kb_id = kb_data.get("kb_id") + if not kb_id: + continue + + # 创建知识库目录 + kb_dir = Path(self.kb_root_dir) / kb_id + kb_dir.mkdir(parents=True, exist_ok=True) + + # 导入文档数据 + doc_path = f"databases/kb_{kb_id}/documents.json" + if doc_path in zf.namelist(): + try: + doc_content = zf.read(doc_path) + doc_data = json.loads(doc_content) + + # 导入到文档存储数据库 + await self._import_kb_documents(kb_id, doc_data) + except Exception as e: + result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}") + + # 导入 FAISS 索引 + faiss_path = f"databases/kb_{kb_id}/index.faiss" + if faiss_path in zf.namelist(): + try: + target_path = kb_dir / "index.faiss" + with zf.open(faiss_path) as src, open(target_path, "wb") as dst: + dst.write(src.read()) + except Exception as e: + result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}") + + # 导入媒体文件 + media_prefix = f"files/kb_media/{kb_id}/" + for name in zf.namelist(): + if name.startswith(media_prefix): + try: + rel_path = name[len(media_prefix) :] + target_path = kb_dir / rel_path + target_path.parent.mkdir(parents=True, exist_ok=True) + with zf.open(name) as src, open(target_path, "wb") as dst: + dst.write(src.read()) + except Exception as e: + result.add_warning(f"导入媒体文件 {name} 失败: {e}") + + # 3. 重新加载知识库实例 + await self.kb_manager.load_kbs() + + async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None: + """导入知识库文档到向量数据库""" + from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage + + kb_dir = Path(self.kb_root_dir) / kb_id + doc_db_path = kb_dir / "doc.db" + + # 初始化文档存储 + doc_storage = DocumentStorage(str(doc_db_path)) + await doc_storage.initialize() + + try: + documents = doc_data.get("documents", []) + for doc in documents: + try: + await doc_storage.insert_document( + doc_id=doc.get("doc_id", ""), + text=doc.get("text", ""), + metadata=json.loads(doc.get("metadata", "{}")), + ) + except Exception as e: + logger.warning(f"导入文档块失败: {e}") + finally: + await doc_storage.close() + + async def _import_attachments( + self, + zf: zipfile.ZipFile, + attachments: list[dict], + ) -> int: + """导入附件文件""" + count = 0 + + # 确保附件目录存在 + Path(self.attachments_dir).mkdir(parents=True, exist_ok=True) + + attachment_prefix = "files/attachments/" + for name in zf.namelist(): + if name.startswith(attachment_prefix) and name != attachment_prefix: + try: + # 从附件记录中找到原始路径 + attachment_id = os.path.splitext(os.path.basename(name))[0] + original_path = None + for att in attachments: + if att.get("attachment_id") == attachment_id: + original_path = att.get("path") + break + + if original_path: + target_path = Path(original_path) + else: + target_path = Path(self.attachments_dir) / os.path.basename( + name + ) + + target_path.parent.mkdir(parents=True, exist_ok=True) + with zf.open(name) as src, open(target_path, "wb") as dst: + dst.write(src.read()) + count += 1 + except Exception as e: + logger.warning(f"导入附件 {name} 失败: {e}") + + return count + + def _convert_datetime_fields(self, row: dict, model_class: type) -> dict: + """转换 datetime 字符串字段为 datetime 对象""" + result = row.copy() + + # 获取模型的 datetime 字段 + from sqlalchemy import inspect as sa_inspect + + try: + mapper = sa_inspect(model_class) + for column in mapper.columns: + if column.name in result and result[column.name] is not None: + # 检查是否是 datetime 类型的列 + from sqlalchemy import DateTime + + if isinstance(column.type, DateTime): + value = result[column.name] + if isinstance(value, str): + # 解析 ISO 格式的日期时间字符串 + result[column.name] = datetime.fromisoformat(value) + except Exception: + pass + + return result diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index 951db956c..bca1a2268 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -1,4 +1,5 @@ from .auth import AuthRoute +from .backup import BackupRoute from .chat import ChatRoute from .command import CommandRoute from .config import ConfigRoute @@ -17,6 +18,7 @@ __all__ = [ "AuthRoute", + "BackupRoute", "ChatRoute", "CommandRoute", "ConfigRoute", diff --git a/astrbot/dashboard/routes/backup.py b/astrbot/dashboard/routes/backup.py new file mode 100644 index 000000000..4856cf06b --- /dev/null +++ b/astrbot/dashboard/routes/backup.py @@ -0,0 +1,453 @@ +"""备份管理 API 路由""" + +import asyncio +import os +import traceback +import uuid +from pathlib import Path + +from quart import request, send_file + +from astrbot.core import logger +from astrbot.core.backup.exporter import AstrBotExporter +from astrbot.core.backup.importer import AstrBotImporter +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db import BaseDatabase + +from .route import Response, Route, RouteContext + + +class BackupRoute(Route): + """备份管理路由 + + 提供备份导出、导入、列表等 API 接口 + """ + + def __init__( + self, + context: RouteContext, + db: BaseDatabase, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: + super().__init__(context) + self.db = db + self.core_lifecycle = core_lifecycle + self.backup_dir = "data/backups" + + # 任务状态跟踪 + self.backup_tasks: dict[str, dict] = {} + self.backup_progress: dict[str, dict] = {} + + # 注册路由 + self.routes = { + "/backup/list": ("GET", self.list_backups), + "/backup/export": ("POST", self.export_backup), + "/backup/import": ("POST", self.import_backup), + "/backup/progress": ("GET", self.get_progress), + "/backup/download": ("GET", self.download_backup), + "/backup/delete": ("POST", self.delete_backup), + } + self.register_routes() + + def _init_task(self, task_id: str, task_type: str, status: str = "pending") -> None: + """初始化任务状态""" + self.backup_tasks[task_id] = { + "type": task_type, + "status": status, + "result": None, + "error": None, + } + self.backup_progress[task_id] = { + "status": status, + "stage": "waiting", + "current": 0, + "total": 100, + "message": "", + } + + def _set_task_result( + self, + task_id: str, + status: str, + result: dict | None = None, + error: str | None = None, + ) -> None: + """设置任务结果""" + if task_id in self.backup_tasks: + self.backup_tasks[task_id]["status"] = status + self.backup_tasks[task_id]["result"] = result + self.backup_tasks[task_id]["error"] = error + if task_id in self.backup_progress: + self.backup_progress[task_id]["status"] = status + + def _update_progress( + self, + task_id: str, + *, + status: str | None = None, + stage: str | None = None, + current: int | None = None, + total: int | None = None, + message: str | None = None, + ) -> None: + """更新任务进度""" + if task_id not in self.backup_progress: + return + p = self.backup_progress[task_id] + if status is not None: + p["status"] = status + if stage is not None: + p["stage"] = stage + if current is not None: + p["current"] = current + if total is not None: + p["total"] = total + if message is not None: + p["message"] = message + + def _make_progress_callback(self, task_id: str): + """创建进度回调函数""" + + async def _callback(stage: str, current: int, total: int, message: str = ""): + self._update_progress( + task_id, + status="processing", + stage=stage, + current=current, + total=total, + message=message, + ) + + return _callback + + async def list_backups(self): + """获取备份列表 + + Query 参数: + - page: 页码 (默认 1) + - page_size: 每页数量 (默认 20) + """ + try: + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 20, type=int) + + # 确保备份目录存在 + Path(self.backup_dir).mkdir(parents=True, exist_ok=True) + + # 获取所有备份文件 + backup_files = [] + for filename in os.listdir(self.backup_dir): + if filename.endswith(".zip") and filename.startswith("astrbot_backup_"): + file_path = os.path.join(self.backup_dir, filename) + stat = os.stat(file_path) + backup_files.append( + { + "filename": filename, + "size": stat.st_size, + "created_at": stat.st_mtime, + } + ) + + # 按创建时间倒序排序 + backup_files.sort(key=lambda x: x["created_at"], reverse=True) + + # 分页 + start = (page - 1) * page_size + end = start + page_size + items = backup_files[start:end] + + return ( + Response() + .ok( + { + "items": items, + "total": len(backup_files), + "page": page, + "page_size": page_size, + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"获取备份列表失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取备份列表失败: {e!s}").__dict__ + + async def export_backup(self): + """创建备份 + + 返回: + - task_id: 任务ID,用于查询导出进度 + """ + try: + # 生成任务ID + task_id = str(uuid.uuid4()) + + # 初始化任务状态 + self._init_task(task_id, "export", "pending") + + # 启动后台导出任务 + asyncio.create_task(self._background_export_task(task_id)) + + return ( + Response() + .ok( + { + "task_id": task_id, + "message": "export task created, processing in background", + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"创建备份失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"创建备份失败: {e!s}").__dict__ + + async def _background_export_task(self, task_id: str): + """后台导出任务""" + try: + self._update_progress(task_id, status="processing", message="正在初始化...") + + # 获取知识库管理器 + kb_manager = getattr(self.core_lifecycle, "kb_manager", None) + + exporter = AstrBotExporter( + main_db=self.db, + kb_manager=kb_manager, + config_path="data/cmd_config.json", + attachments_dir="data/attachments", + ) + + # 创建进度回调 + progress_callback = self._make_progress_callback(task_id) + + # 执行导出 + zip_path = await exporter.export_all( + output_dir=self.backup_dir, + progress_callback=progress_callback, + ) + + # 设置成功结果 + self._set_task_result( + task_id, + "completed", + result={ + "filename": os.path.basename(zip_path), + "path": zip_path, + "size": os.path.getsize(zip_path), + }, + ) + except Exception as e: + logger.error(f"后台导出任务 {task_id} 失败: {e}") + logger.error(traceback.format_exc()) + self._set_task_result(task_id, "failed", error=str(e)) + + async def import_backup(self): + """导入备份 + + 支持两种方式: + 1. multipart/form-data 文件上传 + 2. JSON 指定已存在的备份文件名 + + Form Data: + - file: 备份文件 (可选) + + JSON Body: + - filename: 已存在的备份文件名 (可选) + + 返回: + - task_id: 任务ID,用于查询导入进度 + """ + try: + zip_path = None + content_type = request.content_type or "" + + if "multipart/form-data" in content_type: + # 文件上传模式 + files = await request.files + if "file" not in files: + return Response().error("缺少备份文件").__dict__ + + file = files["file"] + if not file.filename or not file.filename.endswith(".zip"): + return Response().error("请上传 ZIP 格式的备份文件").__dict__ + + # 保存上传的文件 + Path(self.backup_dir).mkdir(parents=True, exist_ok=True) + zip_path = os.path.join(self.backup_dir, file.filename) + await file.save(zip_path) + else: + # JSON 模式 - 使用已存在的文件 + data = await request.json + filename = data.get("filename") + if not filename: + return Response().error("缺少 filename 参数").__dict__ + + zip_path = os.path.join(self.backup_dir, filename) + if not os.path.exists(zip_path): + return Response().error(f"备份文件不存在: {filename}").__dict__ + + # 生成任务ID + task_id = str(uuid.uuid4()) + + # 初始化任务状态 + self._init_task(task_id, "import", "pending") + + # 启动后台导入任务 + asyncio.create_task(self._background_import_task(task_id, zip_path)) + + return ( + Response() + .ok( + { + "task_id": task_id, + "message": "import task created, processing in background", + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"导入备份失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"导入备份失败: {e!s}").__dict__ + + async def _background_import_task(self, task_id: str, zip_path: str): + """后台导入任务""" + try: + self._update_progress(task_id, status="processing", message="正在初始化...") + + # 获取知识库管理器 + kb_manager = getattr(self.core_lifecycle, "kb_manager", None) + + importer = AstrBotImporter( + main_db=self.db, + kb_manager=kb_manager, + config_path="data/cmd_config.json", + attachments_dir="data/attachments", + ) + + # 创建进度回调 + progress_callback = self._make_progress_callback(task_id) + + # 执行导入 + result = await importer.import_all( + zip_path=zip_path, + mode="replace", + progress_callback=progress_callback, + ) + + # 设置结果 + if result.success: + self._set_task_result( + task_id, + "completed", + result=result.to_dict(), + ) + else: + self._set_task_result( + task_id, + "failed", + error="; ".join(result.errors), + ) + except Exception as e: + logger.error(f"后台导入任务 {task_id} 失败: {e}") + logger.error(traceback.format_exc()) + self._set_task_result(task_id, "failed", error=str(e)) + + async def get_progress(self): + """获取任务进度 + + Query 参数: + - task_id: 任务 ID (必填) + """ + try: + task_id = request.args.get("task_id") + if not task_id: + return Response().error("缺少参数 task_id").__dict__ + + if task_id not in self.backup_tasks: + return Response().error("找不到该任务").__dict__ + + task_info = self.backup_tasks[task_id] + status = task_info["status"] + + response_data = { + "task_id": task_id, + "type": task_info["type"], + "status": status, + } + + # 如果任务正在处理,返回进度信息 + if status == "processing" and task_id in self.backup_progress: + response_data["progress"] = self.backup_progress[task_id] + + # 如果任务完成,返回结果 + if status == "completed": + response_data["result"] = task_info["result"] + + # 如果任务失败,返回错误信息 + if status == "failed": + response_data["error"] = task_info["error"] + + return Response().ok(response_data).__dict__ + except Exception as e: + logger.error(f"获取任务进度失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"获取任务进度失败: {e!s}").__dict__ + + async def download_backup(self): + """下载备份文件 + + Query 参数: + - filename: 备份文件名 (必填) + """ + try: + filename = request.args.get("filename") + if not filename: + return Response().error("缺少参数 filename").__dict__ + + # 安全检查 - 防止路径遍历 + if ".." in filename or "/" in filename or "\\" in filename: + return Response().error("无效的文件名").__dict__ + + file_path = os.path.join(self.backup_dir, filename) + if not os.path.exists(file_path): + return Response().error("备份文件不存在").__dict__ + + return await send_file( + file_path, + as_attachment=True, + attachment_filename=filename, + ) + except Exception as e: + logger.error(f"下载备份失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"下载备份失败: {e!s}").__dict__ + + async def delete_backup(self): + """删除备份文件 + + Body: + - filename: 备份文件名 (必填) + """ + try: + data = await request.json + filename = data.get("filename") + if not filename: + return Response().error("缺少参数 filename").__dict__ + + # 安全检查 - 防止路径遍历 + if ".." in filename or "/" in filename or "\\" in filename: + return Response().error("无效的文件名").__dict__ + + file_path = os.path.join(self.backup_dir, filename) + if not os.path.exists(file_path): + return Response().error("备份文件不存在").__dict__ + + os.remove(file_path) + return Response().ok(message="删除备份成功").__dict__ + except Exception as e: + logger.error(f"删除备份失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"删除备份失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 6d6530c90..ad258b824 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -19,6 +19,7 @@ from astrbot.core.utils.io import get_local_ip_addresses from .routes import * +from .routes.backup import BackupRoute from .routes.platform import PlatformRoute from .routes.route import Response, RouteContext from .routes.session_management import SessionManagementRoute @@ -85,6 +86,7 @@ def __init__( self.t2i_route = T2iRoute(self.context, core_lifecycle) self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle) self.platform_route = PlatformRoute(self.context, core_lifecycle) + self.backup_route = BackupRoute(self.context, db, core_lifecycle) self.app.add_url_rule( "/api/plug/", diff --git a/dashboard/src/components/shared/BackupDialog.vue b/dashboard/src/components/shared/BackupDialog.vue new file mode 100644 index 000000000..103b6d30b --- /dev/null +++ b/dashboard/src/components/shared/BackupDialog.vue @@ -0,0 +1,490 @@ + + + + + \ No newline at end of file diff --git a/dashboard/src/i18n/locales/en-US/features/settings.json b/dashboard/src/i18n/locales/en-US/features/settings.json index 0a494ca3e..ca35ed367 100644 --- a/dashboard/src/i18n/locales/en-US/features/settings.json +++ b/dashboard/src/i18n/locales/en-US/features/settings.json @@ -18,6 +18,11 @@ "title": "Data Migration to v4.0.0", "subtitle": "If you encounter data compatibility issues, you can manually start the database migration assistant", "button": "Start Migration Assistant" + }, + "backup": { + "title": "Backup & Restore", + "subtitle": "Export or import all AstrBot data for easy migration to a new server", + "button": "Backup Manager" } }, "sidebar": { @@ -29,5 +34,46 @@ "mainItems": "Main Modules", "moreItems": "More Features" } + }, + "backup": { + "dialog": { + "title": "Backup Manager" + }, + "tabs": { + "export": "Export Backup", + "import": "Import Backup", + "list": "Backup List" + }, + "export": { + "title": "Create Backup", + "description": "Export all data as a ZIP backup file, including database, knowledge base, config and attachments.", + "includes": "Backup includes: Main database, Knowledge bases (metadata + vector index + documents), Config files, Attachment files", + "button": "Start Export", + "processing": "Exporting...", + "wait": "Please wait, packaging data...", + "completed": "Export Completed!", + "download": "Download Backup", + "another": "Create New Backup", + "failed": "Export Failed", + "retry": "Retry" + }, + "import": { + "title": "Import Backup", + "warning": "⚠️ Import will clear and overwrite existing data! Please make sure you have backed up your current data. Only backup files from the same AstrBot version are supported.", + "selectFile": "Select backup file (.zip)", + "button": "Start Import", + "processing": "Importing...", + "wait": "Please wait, restoring data...", + "completed": "Import Completed!", + "restartRequired": "Data has been successfully imported. It is recommended to restart AstrBot immediately for all changes to take effect.", + "restartNow": "Restart Now", + "failed": "Import Failed", + "retry": "Retry" + }, + "list": { + "empty": "No backup files", + "refresh": "Refresh List", + "confirmDelete": "Are you sure you want to delete this backup file? This action cannot be undone." + } } -} \ No newline at end of file +} \ No newline at end of file diff --git a/dashboard/src/i18n/locales/zh-CN/features/settings.json b/dashboard/src/i18n/locales/zh-CN/features/settings.json index bb6700f60..9b61e9a89 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/settings.json +++ b/dashboard/src/i18n/locales/zh-CN/features/settings.json @@ -18,6 +18,11 @@ "title": "数据迁移到 v4.0.0 格式", "subtitle": "如果您遇到数据兼容性问题,可以手动启动数据库迁移助手", "button": "启动迁移助手" + }, + "backup": { + "title": "数据备份与恢复", + "subtitle": "导出或导入 AstrBot 的所有数据,方便迁移到新服务器", + "button": "备份管理" } }, "sidebar": { @@ -29,5 +34,46 @@ "mainItems": "主要模块", "moreItems": "更多功能" } + }, + "backup": { + "dialog": { + "title": "备份管理" + }, + "tabs": { + "export": "导出备份", + "import": "导入备份", + "list": "备份列表" + }, + "export": { + "title": "创建备份", + "description": "将所有数据导出为 ZIP 备份文件,包括数据库、知识库、配置和附件。", + "includes": "备份包含:主数据库、知识库(元数据+向量索引+文档)、配置文件、附件文件", + "button": "开始导出", + "processing": "正在导出...", + "wait": "请稍候,正在打包数据...", + "completed": "导出完成!", + "download": "下载备份", + "another": "创建新备份", + "failed": "导出失败", + "retry": "重试" + }, + "import": { + "title": "导入备份", + "warning": "⚠️ 导入将会清空并覆盖现有数据!请确保已备份当前数据。仅支持相同版本的 AstrBot 备份文件。", + "selectFile": "选择备份文件 (.zip)", + "button": "开始导入", + "processing": "正在导入...", + "wait": "请稍候,正在恢复数据...", + "completed": "导入完成!", + "restartRequired": "数据已成功导入。建议立即重启 AstrBot 以使所有更改生效。", + "restartNow": "立即重启", + "failed": "导入失败", + "retry": "重试" + }, + "list": { + "empty": "暂无备份文件", + "refresh": "刷新列表", + "confirmDelete": "确定要删除这个备份文件吗?此操作不可撤销。" + } } -} \ No newline at end of file +} \ No newline at end of file diff --git a/dashboard/src/views/Settings.vue b/dashboard/src/views/Settings.vue index 338d0394d..1c56119ab 100644 --- a/dashboard/src/views/Settings.vue +++ b/dashboard/src/views/Settings.vue @@ -17,6 +17,13 @@ {{ tm('system.title') }} + + + mdi-backup-restore + {{ tm('system.backup.button') }} + + + {{ tm('system.restart.button') }} @@ -30,6 +37,7 @@ + @@ -40,12 +48,14 @@ import WaitingForRestart from '@/components/shared/WaitingForRestart.vue'; import ProxySelector from '@/components/shared/ProxySelector.vue'; import MigrationDialog from '@/components/shared/MigrationDialog.vue'; import SidebarCustomizer from '@/components/shared/SidebarCustomizer.vue'; +import BackupDialog from '@/components/shared/BackupDialog.vue'; import { useModuleI18n } from '@/i18n/composables'; const { tm } = useModuleI18n('features/settings'); const wfr = ref(null); const migrationDialog = ref(null); +const backupDialog = ref(null); const restartAstrBot = () => { axios.post('/api/stat/restart-core').then(() => { @@ -65,4 +75,10 @@ const startMigration = async () => { } } } + +const openBackupDialog = () => { + if (backupDialog.value) { + backupDialog.value.open(); + } +} \ No newline at end of file From 70cda9ae473fa8576054658c5169cebfe4a51825 Mon Sep 17 00:00:00 2001 From: RC-CHN <1051989940@qq.com> Date: Thu, 18 Dec 2025 11:27:03 +0800 Subject: [PATCH 02/17] =?UTF-8?q?test:=20=E6=B7=BB=E5=8A=A0=E8=BF=81?= =?UTF-8?q?=E7=A7=BB=E7=9B=B8=E5=85=B3=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_backup.py | 427 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 427 insertions(+) create mode 100644 tests/test_backup.py diff --git a/tests/test_backup.py b/tests/test_backup.py new file mode 100644 index 000000000..52e2e89ef --- /dev/null +++ b/tests/test_backup.py @@ -0,0 +1,427 @@ +"""备份功能单元测试""" + +import json +import os +import zipfile +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from astrbot.core.backup.exporter import ( + KB_METADATA_MODELS, + MAIN_DB_MODELS, + AstrBotExporter, +) +from astrbot.core.backup.importer import AstrBotImporter, ImportResult +from astrbot.core.config.default import VERSION +from astrbot.core.db.po import ( + ConversationV2, +) + + +@pytest.fixture +def temp_backup_dir(tmp_path): + """创建临时备份目录""" + backup_dir = tmp_path / "backups" + backup_dir.mkdir() + return backup_dir + + +@pytest.fixture +def temp_data_dir(tmp_path): + """创建临时数据目录""" + data_dir = tmp_path / "data" + data_dir.mkdir() + + # 创建配置文件 + config_path = data_dir / "cmd_config.json" + config_path.write_text(json.dumps({"test": "config"})) + + # 创建附件目录 + attachments_dir = data_dir / "attachments" + attachments_dir.mkdir() + + return data_dir + + +@pytest.fixture +def mock_main_db(): + """创建模拟的主数据库""" + db = MagicMock() + + # 模拟异步上下文管理器 + session = AsyncMock() + db.get_db = MagicMock( + return_value=AsyncMock(__aenter__=AsyncMock(return_value=session)) + ) + + return db + + +@pytest.fixture +def mock_kb_manager(): + """创建模拟的知识库管理器""" + kb_manager = MagicMock() + kb_manager.kb_insts = {} + + # 模拟 kb_db + kb_db = MagicMock() + session = AsyncMock() + kb_db.get_db = MagicMock( + return_value=AsyncMock(__aenter__=AsyncMock(return_value=session)) + ) + kb_manager.kb_db = kb_db + + return kb_manager + + +class TestImportResult: + """ImportResult 类测试""" + + def test_init(self): + """测试初始化""" + result = ImportResult() + assert result.success is True + assert result.imported_tables == {} + assert result.imported_files == {} + assert result.warnings == [] + assert result.errors == [] + + def test_add_warning(self): + """测试添加警告""" + result = ImportResult() + result.add_warning("test warning") + assert "test warning" in result.warnings + assert result.success is True # 警告不影响成功状态 + + def test_add_error(self): + """测试添加错误""" + result = ImportResult() + result.add_error("test error") + assert "test error" in result.errors + assert result.success is False # 错误会导致失败 + + def test_to_dict(self): + """测试转换为字典""" + result = ImportResult() + result.imported_tables = {"test_table": 10} + result.add_warning("warning") + + d = result.to_dict() + assert d["success"] is True + assert d["imported_tables"] == {"test_table": 10} + assert "warning" in d["warnings"] + + +class TestAstrBotExporter: + """AstrBotExporter 类测试""" + + def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir): + """测试初始化""" + exporter = AstrBotExporter( + main_db=mock_main_db, + kb_manager=mock_kb_manager, + config_path=str(temp_data_dir / "cmd_config.json"), + attachments_dir=str(temp_data_dir / "attachments"), + ) + assert exporter.main_db is mock_main_db + assert exporter.kb_manager is mock_kb_manager + + def test_model_to_dict_with_model_dump(self): + """测试 _model_to_dict 使用 model_dump 方法""" + exporter = AstrBotExporter(main_db=MagicMock()) + + # 创建一个有 model_dump 方法的模拟对象 + mock_record = MagicMock() + mock_record.model_dump.return_value = {"id": 1, "name": "test"} + + result = exporter._model_to_dict(mock_record) + assert result == {"id": 1, "name": "test"} + + def test_model_to_dict_with_datetime(self): + """测试 _model_to_dict 处理 datetime 字段""" + exporter = AstrBotExporter(main_db=MagicMock()) + + now = datetime.now() + mock_record = MagicMock() + mock_record.model_dump.return_value = {"id": 1, "created_at": now} + + result = exporter._model_to_dict(mock_record) + assert result["created_at"] == now.isoformat() + + def test_add_checksum(self): + """测试添加校验和""" + exporter = AstrBotExporter(main_db=MagicMock()) + + exporter._add_checksum("test.json", '{"test": "data"}') + + assert "test.json" in exporter._checksums + assert exporter._checksums["test.json"].startswith("sha256:") + + def test_generate_manifest(self, mock_main_db, mock_kb_manager): + """测试生成清单""" + exporter = AstrBotExporter( + main_db=mock_main_db, + kb_manager=mock_kb_manager, + ) + + main_data = { + "platform_stats": [{"id": 1}], + "conversations": [], + "attachments": [], + } + kb_meta_data = { + "knowledge_bases": [], + "kb_documents": [], + } + + manifest = exporter._generate_manifest(main_data, kb_meta_data) + + assert manifest["version"] == "1.0" + assert manifest["astrbot_version"] == VERSION + assert "exported_at" in manifest + assert "tables" in manifest + assert "statistics" in manifest + assert manifest["statistics"]["main_db"]["platform_stats"] == 1 + + @pytest.mark.asyncio + async def test_export_all_creates_zip( + self, mock_main_db, temp_backup_dir, temp_data_dir + ): + """测试导出创建 ZIP 文件""" + # 设置模拟数据库返回空数据 + session = AsyncMock() + result = MagicMock() + result.scalars.return_value.all.return_value = [] + session.execute = AsyncMock(return_value=result) + + mock_main_db.get_db.return_value = AsyncMock( + __aenter__=AsyncMock(return_value=session), + __aexit__=AsyncMock(return_value=None), + ) + + exporter = AstrBotExporter( + main_db=mock_main_db, + kb_manager=None, + config_path=str(temp_data_dir / "cmd_config.json"), + attachments_dir=str(temp_data_dir / "attachments"), + ) + + zip_path = await exporter.export_all(output_dir=str(temp_backup_dir)) + + assert os.path.exists(zip_path) + assert zip_path.endswith(".zip") + assert "astrbot_backup_" in zip_path + + # 验证 ZIP 文件内容 + with zipfile.ZipFile(zip_path, "r") as zf: + namelist = zf.namelist() + assert "manifest.json" in namelist + assert "databases/main_db.json" in namelist + assert "config/cmd_config.json" in namelist + + +class TestAstrBotImporter: + """AstrBotImporter 类测试""" + + def test_init(self, mock_main_db, mock_kb_manager, temp_data_dir): + """测试初始化""" + importer = AstrBotImporter( + main_db=mock_main_db, + kb_manager=mock_kb_manager, + config_path=str(temp_data_dir / "cmd_config.json"), + attachments_dir=str(temp_data_dir / "attachments"), + ) + assert importer.main_db is mock_main_db + assert importer.kb_manager is mock_kb_manager + + def test_validate_version_match(self): + """测试版本匹配验证""" + importer = AstrBotImporter(main_db=MagicMock()) + + manifest = {"astrbot_version": VERSION} + # 不应该抛出异常 + importer._validate_version(manifest) + + def test_validate_version_mismatch(self): + """测试版本不匹配验证""" + importer = AstrBotImporter(main_db=MagicMock()) + + manifest = {"astrbot_version": "0.0.1"} + with pytest.raises(ValueError, match="版本不匹配"): + importer._validate_version(manifest) + + def test_validate_version_missing(self): + """测试缺少版本信息""" + importer = AstrBotImporter(main_db=MagicMock()) + + manifest = {} + with pytest.raises(ValueError, match="缺少版本信息"): + importer._validate_version(manifest) + + def test_convert_datetime_fields(self): + """测试 datetime 字段转换""" + importer = AstrBotImporter(main_db=MagicMock()) + + # 使用 ConversationV2 作为测试模型(它有 created_at 和 updated_at 字段) + row = { + "conversation_id": "test-123", + "platform_id": "test", + "user_id": "user1", + "created_at": "2024-01-01T12:00:00", + "updated_at": "2024-01-01T12:00:00", + } + + result = importer._convert_datetime_fields(row, ConversationV2) + + # created_at 应该被转换为 datetime 对象 + assert isinstance(result["created_at"], datetime) + assert isinstance(result["updated_at"], datetime) + + @pytest.mark.asyncio + async def test_import_file_not_exists(self, mock_main_db, tmp_path): + """测试导入不存在的文件""" + importer = AstrBotImporter(main_db=mock_main_db) + + result = await importer.import_all(str(tmp_path / "nonexistent.zip")) + + assert result.success is False + assert any("不存在" in err for err in result.errors) + + @pytest.mark.asyncio + async def test_import_invalid_zip(self, mock_main_db, tmp_path): + """测试导入无效的 ZIP 文件""" + # 创建一个无效的文件 + invalid_zip = tmp_path / "invalid.zip" + invalid_zip.write_text("not a zip file") + + importer = AstrBotImporter(main_db=mock_main_db) + result = await importer.import_all(str(invalid_zip)) + + assert result.success is False + assert any("无效" in err or "ZIP" in err for err in result.errors) + + @pytest.mark.asyncio + async def test_import_missing_manifest(self, mock_main_db, tmp_path): + """测试导入缺少 manifest 的 ZIP 文件""" + # 创建一个没有 manifest 的 ZIP 文件 + zip_path = tmp_path / "no_manifest.zip" + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("test.txt", "test content") + + importer = AstrBotImporter(main_db=mock_main_db) + result = await importer.import_all(str(zip_path)) + + assert result.success is False + assert any("manifest" in err.lower() for err in result.errors) + + @pytest.mark.asyncio + async def test_import_version_mismatch(self, mock_main_db, tmp_path): + """测试导入版本不匹配的备份""" + # 创建一个版本不匹配的备份 + zip_path = tmp_path / "old_version.zip" + manifest = { + "version": "1.0", + "astrbot_version": "0.0.1", # 错误的版本 + "tables": {"main_db": []}, + } + + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr("manifest.json", json.dumps(manifest)) + + importer = AstrBotImporter(main_db=mock_main_db) + result = await importer.import_all(str(zip_path)) + + assert result.success is False + assert any("版本不匹配" in err for err in result.errors) + + +class TestModelMappings: + """测试模型映射配置""" + + def test_main_db_models_not_empty(self): + """测试主数据库模型映射非空""" + assert len(MAIN_DB_MODELS) > 0 + + def test_main_db_models_contain_expected_tables(self): + """测试主数据库模型映射包含预期的表""" + expected_tables = [ + "platform_stats", + "conversations", + "personas", + "preferences", + "attachments", + ] + for table in expected_tables: + assert table in MAIN_DB_MODELS, f"Missing table: {table}" + + def test_kb_metadata_models_not_empty(self): + """测试知识库元数据模型映射非空""" + assert len(KB_METADATA_MODELS) > 0 + + def test_kb_metadata_models_contain_expected_tables(self): + """测试知识库元数据模型映射包含预期的表""" + expected_tables = [ + "knowledge_bases", + "kb_documents", + "kb_media", + ] + for table in expected_tables: + assert table in KB_METADATA_MODELS, f"Missing table: {table}" + + +class TestBackupIntegration: + """备份集成测试""" + + @pytest.mark.asyncio + async def test_export_import_roundtrip(self, tmp_path): + """测试导出-导入往返""" + backup_dir = tmp_path / "backups" + backup_dir.mkdir() + + data_dir = tmp_path / "data" + data_dir.mkdir() + + config_path = data_dir / "cmd_config.json" + config_path.write_text(json.dumps({"setting": "value"})) + + attachments_dir = data_dir / "attachments" + attachments_dir.mkdir() + + # 创建模拟数据库 + mock_db = MagicMock() + session = AsyncMock() + result = MagicMock() + result.scalars.return_value.all.return_value = [] + session.execute = AsyncMock(return_value=result) + + mock_db.get_db.return_value = AsyncMock( + __aenter__=AsyncMock(return_value=session), + __aexit__=AsyncMock(return_value=None), + ) + + # 导出 + exporter = AstrBotExporter( + main_db=mock_db, + kb_manager=None, + config_path=str(config_path), + attachments_dir=str(attachments_dir), + ) + + zip_path = await exporter.export_all(output_dir=str(backup_dir)) + assert os.path.exists(zip_path) + + # 验证 ZIP 内容 + with zipfile.ZipFile(zip_path, "r") as zf: + # 读取 manifest + manifest = json.loads(zf.read("manifest.json")) + assert manifest["astrbot_version"] == VERSION + + # 读取配置 + config = json.loads(zf.read("config/cmd_config.json")) + assert config["setting"] == "value" + + # 读取主数据库 + main_db = json.loads(zf.read("databases/main_db.json")) + assert "platform_stats" in main_db From 57bacab49efd4d0d6c70615c1d764d2f17b82527 Mon Sep 17 00:00:00 2001 From: RC-CHN <1051989940@qq.com> Date: Thu, 18 Dec 2025 14:34:20 +0800 Subject: [PATCH 03/17] =?UTF-8?q?feat:=20=E5=A4=87=E4=BB=BD=E6=8F=92?= =?UTF-8?q?=E4=BB=B6=E5=8F=8A=E7=9B=B8=E5=85=B3=E6=8C=81=E4=B9=85=E5=8C=96?= =?UTF-8?q?=E7=9B=AE=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/backup/exporter.py | 93 ++++++++++++++++++++++- astrbot/core/backup/importer.py | 117 +++++++++++++++++++++++++++++ astrbot/dashboard/routes/backup.py | 2 + tests/test_backup.py | 10 ++- 4 files changed, 216 insertions(+), 6 deletions(-) diff --git a/astrbot/core/backup/exporter.py b/astrbot/core/backup/exporter.py index f814bb915..42caa2c19 100644 --- a/astrbot/core/backup/exporter.py +++ b/astrbot/core/backup/exporter.py @@ -60,6 +60,17 @@ } +# 需要备份的目录列表 +BACKUP_DIRECTORIES = { + "plugins": "data/plugins", # 插件本体 + "plugin_data": "data/plugin_data", # 插件数据 + "config": "data/config", # 配置目录 + "t2i_templates": "data/t2i_templates", # T2I 模板 + "webchat": "data/webchat", # WebChat 数据 + "temp": "data/temp", # 临时文件 +} + + class AstrBotExporter: """AstrBot 数据导出器 @@ -70,6 +81,12 @@ class AstrBotExporter: - 配置文件(data/cmd_config.json) - 附件文件 - 知识库多媒体文件 + - 插件目录(data/plugins) + - 插件数据目录(data/plugin_data) + - 配置目录(data/config) + - T2I 模板目录(data/t2i_templates) + - WebChat 数据目录(data/webchat) + - 临时文件目录(data/temp) """ def __init__( @@ -78,11 +95,13 @@ def __init__( kb_manager: "KnowledgeBaseManager | None" = None, config_path: str = "data/cmd_config.json", attachments_dir: str = "data/attachments", + data_root: str = "data", ): self.main_db = main_db self.kb_manager = kb_manager self.config_path = config_path self.attachments_dir = attachments_dir + self.data_root = data_root self._checksums: dict[str, str] = {} async def export_all( @@ -192,10 +211,19 @@ async def export_all( if progress_callback: await progress_callback("attachments", 100, 100, "附件导出完成") - # 5. 生成 manifest + # 5. 导出插件和其他目录 + if progress_callback: + await progress_callback( + "directories", 0, 100, "正在导出插件和数据目录..." + ) + dir_stats = await self._export_directories(zf) + if progress_callback: + await progress_callback("directories", 100, 100, "目录导出完成") + + # 6. 生成 manifest if progress_callback: await progress_callback("manifest", 0, 100, "正在生成清单...") - manifest = self._generate_manifest(main_data, kb_meta_data) + manifest = self._generate_manifest(main_data, kb_meta_data, dir_stats) manifest_json = json.dumps(manifest, ensure_ascii=False, indent=2) zf.writestr("manifest.json", manifest_json) if progress_callback: @@ -312,6 +340,56 @@ async def _export_kb_media_files( except Exception as e: logger.warning(f"导出知识库媒体文件失败: {e}") + async def _export_directories( + self, zf: zipfile.ZipFile + ) -> dict[str, dict[str, int]]: + """导出插件和其他数据目录 + + Returns: + dict: 每个目录的统计信息 {dir_name: {"files": count, "size": bytes}} + """ + stats: dict[str, dict[str, int]] = {} + + for dir_name, dir_path in BACKUP_DIRECTORIES.items(): + full_path = Path(self.data_root).parent / dir_path + if not full_path.exists(): + logger.debug(f"目录不存在,跳过: {dir_path}") + continue + + file_count = 0 + total_size = 0 + + try: + for root, dirs, files in os.walk(full_path): + # 跳过 __pycache__ 目录 + dirs[:] = [d for d in dirs if d != "__pycache__"] + + for file in files: + # 跳过 .pyc 文件 + if file.endswith(".pyc"): + continue + + file_path = Path(root) / file + try: + # 计算相对路径 + rel_path = file_path.relative_to(full_path) + archive_path = f"directories/{dir_name}/{rel_path}" + zf.write(str(file_path), archive_path) + file_count += 1 + total_size += file_path.stat().st_size + except Exception as e: + logger.warning(f"导出文件 {file_path} 失败: {e}") + + stats[dir_name] = {"files": file_count, "size": total_size} + logger.debug( + f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节" + ) + except Exception as e: + logger.warning(f"导出目录 {dir_path} 失败: {e}") + stats[dir_name] = {"files": 0, "size": 0} + + return stats + async def _export_attachments( self, zf: zipfile.ZipFile, attachments: list[dict] ) -> None: @@ -364,9 +442,14 @@ def _add_checksum(self, path: str, content: str | bytes) -> None: self._checksums[path] = f"sha256:{checksum}" def _generate_manifest( - self, main_data: dict[str, list[dict]], kb_meta_data: dict[str, list[dict]] + self, + main_data: dict[str, list[dict]], + kb_meta_data: dict[str, list[dict]], + dir_stats: dict[str, dict[str, int]] | None = None, ) -> dict: """生成备份清单""" + if dir_stats is None: + dir_stats = {} # 收集知识库 ID kb_document_tables = {} if self.kb_manager: @@ -396,7 +479,7 @@ def _generate_manifest( kb_media_files[kb_id] = media_files manifest = { - "version": "1.0", + "version": "1.1", # 升级版本号,支持目录备份 "astrbot_version": VERSION, "exported_at": datetime.now(timezone.utc).isoformat(), "schema_version": { @@ -412,6 +495,7 @@ def _generate_manifest( "attachments": attachment_files, "kb_media": kb_media_files, }, + "directories": list(dir_stats.keys()), "checksums": self._checksums, "statistics": { "main_db": { @@ -420,6 +504,7 @@ def _generate_manifest( "kb_metadata": { table: len(records) for table, records in kb_meta_data.items() }, + "directories": dir_stats, }, } diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 3f9b9958b..793669029 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -67,6 +67,7 @@ def __init__(self): self.success = True self.imported_tables: dict[str, int] = {} self.imported_files: dict[str, int] = {} + self.imported_directories: dict[str, int] = {} self.warnings: list[str] = [] self.errors: list[str] = [] @@ -84,11 +85,23 @@ def to_dict(self) -> dict: "success": self.success, "imported_tables": self.imported_tables, "imported_files": self.imported_files, + "imported_directories": self.imported_directories, "warnings": self.warnings, "errors": self.errors, } +# 需要恢复的目录列表 +BACKUP_DIRECTORIES = { + "plugins": "data/plugins", # 插件本体 + "plugin_data": "data/plugin_data", # 插件数据 + "config": "data/config", # 配置目录 + "t2i_templates": "data/t2i_templates", # T2I 模板 + "webchat": "data/webchat", # WebChat 数据 + "temp": "data/temp", # 临时文件 +} + + class AstrBotImporter: """AstrBot 数据导入器 @@ -98,6 +111,12 @@ class AstrBotImporter: - 配置文件 - 附件文件 - 知识库多媒体文件 + - 插件目录(data/plugins) + - 插件数据目录(data/plugin_data) + - 配置目录(data/config) + - T2I 模板目录(data/t2i_templates) + - WebChat 数据目录(data/webchat) + - 临时文件目录(data/temp) """ def __init__( @@ -107,12 +126,14 @@ def __init__( config_path: str = "data/cmd_config.json", attachments_dir: str = "data/attachments", kb_root_dir: str = "data/knowledge_base", + data_root: str = "data", ): self.main_db = main_db self.kb_manager = kb_manager self.config_path = config_path self.attachments_dir = attachments_dir self.kb_root_dir = kb_root_dir + self.data_root = data_root async def import_all( self, @@ -236,6 +257,18 @@ async def import_all( if progress_callback: await progress_callback("attachments", 100, 100, "附件导入完成") + # 6. 导入插件和其他目录 + if progress_callback: + await progress_callback( + "directories", 0, 100, "正在导入插件和数据目录..." + ) + + dir_stats = await self._import_directories(zf, manifest, result) + result.imported_directories = dir_stats + + if progress_callback: + await progress_callback("directories", 100, 100, "目录导入完成") + logger.info(f"备份导入完成: {result.to_dict()}") return result @@ -468,6 +501,90 @@ async def _import_attachments( return count + async def _import_directories( + self, + zf: zipfile.ZipFile, + manifest: dict, + result: ImportResult, + ) -> dict[str, int]: + """导入插件和其他数据目录 + + Args: + zf: ZIP 文件对象 + manifest: 备份清单 + result: 导入结果对象 + + Returns: + dict: 每个目录导入的文件数量 + """ + dir_stats: dict[str, int] = {} + + # 检查备份版本是否支持目录备份 + backup_version = manifest.get("version", "1.0") + if backup_version < "1.1": + logger.info("备份版本不支持目录备份,跳过目录导入") + return dir_stats + + backed_up_dirs = manifest.get("directories", []) + + for dir_name in backed_up_dirs: + if dir_name not in BACKUP_DIRECTORIES: + result.add_warning(f"未知的目录类型: {dir_name}") + continue + + target_dir = Path(self.data_root).parent / BACKUP_DIRECTORIES[dir_name] + archive_prefix = f"directories/{dir_name}/" + + file_count = 0 + + try: + # 获取该目录下的所有文件 + dir_files = [ + name + for name in zf.namelist() + if name.startswith(archive_prefix) and name != archive_prefix + ] + + if not dir_files: + continue + + # 备份现有目录(如果存在) + if target_dir.exists(): + backup_path = Path(f"{target_dir}.bak") + if backup_path.exists(): + shutil.rmtree(backup_path) + shutil.move(str(target_dir), str(backup_path)) + logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}") + + # 创建目标目录 + target_dir.mkdir(parents=True, exist_ok=True) + + # 解压文件 + for name in dir_files: + try: + # 计算相对路径 + rel_path = name[len(archive_prefix) :] + if not rel_path: # 跳过目录条目 + continue + + target_path = target_dir / rel_path + target_path.parent.mkdir(parents=True, exist_ok=True) + + with zf.open(name) as src, open(target_path, "wb") as dst: + dst.write(src.read()) + file_count += 1 + except Exception as e: + result.add_warning(f"导入文件 {name} 失败: {e}") + + dir_stats[dir_name] = file_count + logger.debug(f"导入目录 {dir_name}: {file_count} 个文件") + + except Exception as e: + result.add_warning(f"导入目录 {dir_name} 失败: {e}") + dir_stats[dir_name] = 0 + + return dir_stats + def _convert_datetime_fields(self, row: dict, model_class: type) -> dict: """转换 datetime 字符串字段为 datetime 对象""" result = row.copy() diff --git a/astrbot/dashboard/routes/backup.py b/astrbot/dashboard/routes/backup.py index 4856cf06b..063bad32e 100644 --- a/astrbot/dashboard/routes/backup.py +++ b/astrbot/dashboard/routes/backup.py @@ -217,6 +217,7 @@ async def _background_export_task(self, task_id: str): kb_manager=kb_manager, config_path="data/cmd_config.json", attachments_dir="data/attachments", + data_root="data", ) # 创建进度回调 @@ -325,6 +326,7 @@ async def _background_import_task(self, task_id: str, zip_path: str): kb_manager=kb_manager, config_path="data/cmd_config.json", attachments_dir="data/attachments", + data_root="data", ) # 创建进度回调 diff --git a/tests/test_backup.py b/tests/test_backup.py index 52e2e89ef..54fbd568d 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -175,15 +175,21 @@ def test_generate_manifest(self, mock_main_db, mock_kb_manager): "knowledge_bases": [], "kb_documents": [], } + dir_stats = { + "plugins": {"files": 10, "size": 1024}, + "plugin_data": {"files": 5, "size": 512}, + } - manifest = exporter._generate_manifest(main_data, kb_meta_data) + manifest = exporter._generate_manifest(main_data, kb_meta_data, dir_stats) - assert manifest["version"] == "1.0" + assert manifest["version"] == "1.1" # 升级版本号,支持目录备份 assert manifest["astrbot_version"] == VERSION assert "exported_at" in manifest assert "tables" in manifest assert "statistics" in manifest + assert "directories" in manifest assert manifest["statistics"]["main_db"]["platform_stats"] == 1 + assert manifest["statistics"]["directories"] == dir_stats @pytest.mark.asyncio async def test_export_all_creates_zip( From 372628a8e708ddb675cf5f02e44a786bcc9e445c Mon Sep 17 00:00:00 2001 From: RC-CHN <1051989940@qq.com> Date: Thu, 18 Dec 2025 15:07:24 +0800 Subject: [PATCH 04/17] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E7=89=88?= =?UTF-8?q?=E6=9C=AC=E5=8F=B7=E6=AF=94=E8=BE=83=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=9B=B8=E5=85=B3=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/backup/importer.py | 49 ++++++++++++++++++++++++++++-- tests/test_backup.py | 54 ++++++++++++++++++++++++++++++++- 2 files changed, 100 insertions(+), 3 deletions(-) diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 793669029..2d8d7ab45 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -39,6 +39,51 @@ from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager +def parse_version(version_str: str) -> tuple[int, ...]: + """将版本字符串解析为数值元组用于比较 + + Args: + version_str: 版本字符串,如 "1.0", "1.10", "2.0.1" + + Returns: + 数值元组,如 (1, 0), (1, 10), (2, 0, 1) + """ + try: + parts = version_str.split(".") + return tuple(int(p) for p in parts) + except (ValueError, AttributeError): + # 解析失败时返回 (0,),确保能够比较 + return (0,) + + +def compare_versions(v1: str, v2: str) -> int: + """比较两个版本号 + + Args: + v1: 第一个版本字符串 + v2: 第二个版本字符串 + + Returns: + -1 如果 v1 < v2 + 0 如果 v1 == v2 + 1 如果 v1 > v2 + """ + t1 = parse_version(v1) + t2 = parse_version(v2) + + # 补齐长度以便比较 + max_len = max(len(t1), len(t2)) + t1 = t1 + (0,) * (max_len - len(t1)) + t2 = t2 + (0,) * (max_len - len(t2)) + + if t1 < t2: + return -1 + elif t1 > t2: + return 1 + else: + return 0 + + # 主数据库模型类映射 MAIN_DB_MODELS: dict[str, type[SQLModel]] = { "platform_stats": PlatformStat, @@ -519,9 +564,9 @@ async def _import_directories( """ dir_stats: dict[str, int] = {} - # 检查备份版本是否支持目录备份 + # 检查备份版本是否支持目录备份(需要版本 >= 1.1) backup_version = manifest.get("version", "1.0") - if backup_version < "1.1": + if compare_versions(backup_version, "1.1") < 0: logger.info("备份版本不支持目录备份,跳过目录导入") return dir_stats diff --git a/tests/test_backup.py b/tests/test_backup.py index 54fbd568d..da88e9cba 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -13,7 +13,12 @@ MAIN_DB_MODELS, AstrBotExporter, ) -from astrbot.core.backup.importer import AstrBotImporter, ImportResult +from astrbot.core.backup.importer import ( + AstrBotImporter, + ImportResult, + compare_versions, + parse_version, +) from astrbot.core.config.default import VERSION from astrbot.core.db.po import ( ConversationV2, @@ -343,6 +348,53 @@ async def test_import_version_mismatch(self, mock_main_db, tmp_path): assert any("版本不匹配" in err for err in result.errors) +class TestVersionComparison: + """版本比较函数测试""" + + def test_parse_version_simple(self): + """测试解析简单版本号""" + assert parse_version("1.0") == (1, 0) + assert parse_version("2.1") == (2, 1) + + def test_parse_version_multi_digit(self): + """测试解析多位数版本号""" + assert parse_version("1.10") == (1, 10) + assert parse_version("1.10.2") == (1, 10, 2) + assert parse_version("10.20.30") == (10, 20, 30) + + def test_parse_version_invalid(self): + """测试解析无效版本号""" + assert parse_version("invalid") == (0,) + assert parse_version("") == (0,) + assert parse_version("1.x.2") == (0,) + + def test_compare_versions_equal(self): + """测试版本相等""" + assert compare_versions("1.0", "1.0") == 0 + assert compare_versions("1.0.0", "1.0") == 0 + assert compare_versions("2.10", "2.10") == 0 + + def test_compare_versions_less_than(self): + """测试版本小于""" + assert compare_versions("1.0", "1.1") == -1 + assert compare_versions("1.9", "1.10") == -1 # 关键测试:多位数版本比较 + assert compare_versions("1.2", "1.10") == -1 + assert compare_versions("1.0", "2.0") == -1 + + def test_compare_versions_greater_than(self): + """测试版本大于""" + assert compare_versions("1.1", "1.0") == 1 + assert compare_versions("1.10", "1.9") == 1 # 关键测试:多位数版本比较 + assert compare_versions("1.10", "1.2") == 1 + assert compare_versions("2.0", "1.0") == 1 + + def test_compare_versions_different_lengths(self): + """测试不同长度版本比较""" + assert compare_versions("1.0", "1.0.0") == 0 + assert compare_versions("1.0", "1.0.1") == -1 + assert compare_versions("1.0.1", "1.0") == 1 + + class TestModelMappings: """测试模型映射配置""" From ce1af61f7cbc5ec9b5f251b7d7d326024d85eafd Mon Sep 17 00:00:00 2001 From: RC-CHN <1051989940@qq.com> Date: Thu, 18 Dec 2025 15:12:53 +0800 Subject: [PATCH 05/17] =?UTF-8?q?fix:=20=E6=B8=85=E6=B4=97=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E5=90=8D=EF=BC=8C=E6=B7=BB=E5=8A=A0=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/dashboard/routes/backup.py | 56 +++++++++++++++++++++++++++- tests/test_backup.py | 59 ++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 1 deletion(-) diff --git a/astrbot/dashboard/routes/backup.py b/astrbot/dashboard/routes/backup.py index 063bad32e..cabb92176 100644 --- a/astrbot/dashboard/routes/backup.py +++ b/astrbot/dashboard/routes/backup.py @@ -2,8 +2,10 @@ import asyncio import os +import re import traceback import uuid +from datetime import datetime from pathlib import Path from quart import request, send_file @@ -17,6 +19,50 @@ from .route import Response, Route, RouteContext +def secure_filename(filename: str) -> str: + """清洗文件名,移除路径遍历字符和危险字符 + + Args: + filename: 原始文件名 + + Returns: + 安全的文件名 + """ + # 仅保留文件名部分,移除路径 + filename = os.path.basename(filename) + + # 替换路径分隔符和危险字符 + filename = filename.replace("..", "_") + filename = filename.replace("/", "_") + filename = filename.replace("\\", "_") + + # 仅保留字母、数字、下划线、连字符、点 + filename = re.sub(r"[^\w\-.]", "_", filename) + + # 移除前导点(隐藏文件) + filename = filename.lstrip(".") + + # 如果文件名为空,生成一个默认名称 + if not filename: + filename = "backup" + + return filename + + +def generate_unique_filename(original_filename: str) -> str: + """生成唯一的文件名,添加时间戳前缀 + + Args: + original_filename: 原始文件名(已清洗) + + Returns: + 唯一的文件名 + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + name, ext = os.path.splitext(original_filename) + return f"uploaded_{timestamp}_{name}{ext}" + + class BackupRoute(Route): """备份管理路由 @@ -274,10 +320,18 @@ async def import_backup(self): if not file.filename or not file.filename.endswith(".zip"): return Response().error("请上传 ZIP 格式的备份文件").__dict__ + # 清洗文件名并生成唯一名称,防止路径遍历和覆盖 + safe_filename = secure_filename(file.filename) + unique_filename = generate_unique_filename(safe_filename) + # 保存上传的文件 Path(self.backup_dir).mkdir(parents=True, exist_ok=True) - zip_path = os.path.join(self.backup_dir, file.filename) + zip_path = os.path.join(self.backup_dir, unique_filename) await file.save(zip_path) + + logger.info( + f"上传的备份文件已保存: {unique_filename} (原始名称: {file.filename})" + ) else: # JSON 模式 - 使用已存在的文件 data = await request.json diff --git a/tests/test_backup.py b/tests/test_backup.py index da88e9cba..8b8a3bb3f 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -2,6 +2,7 @@ import json import os +import re import zipfile from datetime import datetime from unittest.mock import AsyncMock, MagicMock @@ -23,6 +24,10 @@ from astrbot.core.db.po import ( ConversationV2, ) +from astrbot.dashboard.routes.backup import ( + generate_unique_filename, + secure_filename, +) @pytest.fixture @@ -348,6 +353,60 @@ async def test_import_version_mismatch(self, mock_main_db, tmp_path): assert any("版本不匹配" in err for err in result.errors) +class TestSecureFilename: + """安全文件名函数测试""" + + def test_secure_filename_normal(self): + """测试正常文件名""" + assert secure_filename("backup.zip") == "backup.zip" + assert secure_filename("my_backup_2024.zip") == "my_backup_2024.zip" + + def test_secure_filename_path_traversal(self): + """测试路径遍历攻击""" + assert ".." not in secure_filename("../../../etc/passwd") + assert "/" not in secure_filename("/etc/passwd") + assert "\\" not in secure_filename("..\\..\\windows\\system32") + + def test_secure_filename_with_path(self): + """测试带路径的文件名""" + result = secure_filename("/path/to/backup.zip") + assert result == "backup.zip" + + result = secure_filename("C:\\Users\\test\\backup.zip") + assert result == "backup.zip" + + def test_secure_filename_special_chars(self): + """测试特殊字符""" + result = secure_filename('backup<>:"|?*.zip') + # 特殊字符应被替换为下划线 + assert "<" not in result + assert ">" not in result + assert ":" not in result + assert '"' not in result + assert "|" not in result + assert "?" not in result + assert "*" not in result + + def test_secure_filename_hidden_file(self): + """测试隐藏文件(前导点)""" + result = secure_filename(".hidden_backup.zip") + assert not result.startswith(".") + + def test_secure_filename_empty(self): + """测试空文件名""" + assert secure_filename("") == "backup" + assert secure_filename("...") == "backup" + + def test_generate_unique_filename(self): + """测试生成唯一文件名""" + result = generate_unique_filename("backup.zip") + # 应包含 uploaded_ 前缀和时间戳 + assert result.startswith("uploaded_") + assert result.endswith("_backup.zip") + # 应包含时间戳格式 YYYYMMDD_HHMMSS + assert re.search(r"uploaded_\d{8}_\d{6}_backup\.zip", result) + + class TestVersionComparison: """版本比较函数测试""" From 0d592afa8b0b277f6d016c9f3b0ad6be9ef08861 Mon Sep 17 00:00:00 2001 From: RC-CHN <1051989940@qq.com> Date: Thu, 18 Dec 2025 15:17:06 +0800 Subject: [PATCH 06/17] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=AE=89?= =?UTF-8?q?=E5=85=A8=E6=96=87=E4=BB=B6=E5=90=8D=E6=B5=8B=E8=AF=95=E7=94=A8?= =?UTF-8?q?=E4=BE=8B=E6=96=AD=E8=A8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/dashboard/routes/backup.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/astrbot/dashboard/routes/backup.py b/astrbot/dashboard/routes/backup.py index cabb92176..be87b7bf7 100644 --- a/astrbot/dashboard/routes/backup.py +++ b/astrbot/dashboard/routes/backup.py @@ -28,22 +28,22 @@ def secure_filename(filename: str) -> str: Returns: 安全的文件名 """ + # 跨平台处理:先将反斜杠替换为正斜杠,再取文件名 + filename = filename.replace("\\", "/") # 仅保留文件名部分,移除路径 filename = os.path.basename(filename) - # 替换路径分隔符和危险字符 + # 替换路径遍历字符 filename = filename.replace("..", "_") - filename = filename.replace("/", "_") - filename = filename.replace("\\", "_") # 仅保留字母、数字、下划线、连字符、点 filename = re.sub(r"[^\w\-.]", "_", filename) - # 移除前导点(隐藏文件) - filename = filename.lstrip(".") + # 移除前导点(隐藏文件)和尾部点 + filename = filename.strip(".") - # 如果文件名为空,生成一个默认名称 - if not filename: + # 如果文件名为空或只包含下划线,生成一个默认名称 + if not filename or filename.replace("_", "") == "": filename = "backup" return filename From 39af9abc176b36fc1d131278a7f3ae71a82ae095 Mon Sep 17 00:00:00 2001 From: RC-CHN <1051989940@qq.com> Date: Thu, 18 Dec 2025 15:29:39 +0800 Subject: [PATCH 07/17] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BC=8C=E4=B8=BA=E5=A4=87=E4=BB=BD=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E6=8F=90=E5=8F=96=E5=85=AC=E7=94=A8=E5=B8=B8=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/backup/__init__.py | 18 +++++++++- astrbot/core/backup/constants.py | 62 ++++++++++++++++++++++++++++++++ astrbot/core/backup/exporter.py | 57 +++++------------------------ astrbot/core/backup/importer.py | 54 ++++------------------------ tests/test_backup.py | 7 ++-- 5 files changed, 97 insertions(+), 101 deletions(-) create mode 100644 astrbot/core/backup/constants.py diff --git a/astrbot/core/backup/__init__.py b/astrbot/core/backup/__init__.py index 114bc0b22..8b1ef62d6 100644 --- a/astrbot/core/backup/__init__.py +++ b/astrbot/core/backup/__init__.py @@ -3,7 +3,23 @@ 提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据。 """ +# 从 constants 模块导入共享常量 +from .constants import ( + BACKUP_DIRECTORIES, + BACKUP_MANIFEST_VERSION, + KB_METADATA_MODELS, + MAIN_DB_MODELS, +) + +# 导入导出器和导入器 from .exporter import AstrBotExporter from .importer import AstrBotImporter -__all__ = ["AstrBotExporter", "AstrBotImporter"] +__all__ = [ + "AstrBotExporter", + "AstrBotImporter", + "MAIN_DB_MODELS", + "KB_METADATA_MODELS", + "BACKUP_DIRECTORIES", + "BACKUP_MANIFEST_VERSION", +] diff --git a/astrbot/core/backup/constants.py b/astrbot/core/backup/constants.py new file mode 100644 index 000000000..6a8160cdf --- /dev/null +++ b/astrbot/core/backup/constants.py @@ -0,0 +1,62 @@ +"""AstrBot 备份模块共享常量 + +此文件定义了导出器和导入器共享的常量,确保两端配置一致。 +""" + +from sqlmodel import SQLModel + +from astrbot.core.db.po import ( + Attachment, + CommandConfig, + CommandConflict, + ConversationV2, + Persona, + PlatformMessageHistory, + PlatformSession, + PlatformStat, + Preference, +) +from astrbot.core.knowledge_base.models import ( + KBDocument, + KBMedia, + KnowledgeBase, +) + +# ============================================================ +# 共享常量 - 确保导出和导入端配置一致 +# ============================================================ + +# 主数据库模型类映射 +MAIN_DB_MODELS: dict[str, type[SQLModel]] = { + "platform_stats": PlatformStat, + "conversations": ConversationV2, + "personas": Persona, + "preferences": Preference, + "platform_message_history": PlatformMessageHistory, + "platform_sessions": PlatformSession, + "attachments": Attachment, + "command_configs": CommandConfig, + "command_conflicts": CommandConflict, +} + +# 知识库元数据模型类映射 +KB_METADATA_MODELS: dict[str, type[SQLModel]] = { + "knowledge_bases": KnowledgeBase, + "kb_documents": KBDocument, + "kb_media": KBMedia, +} + +# 需要备份的目录列表 +# 键:备份文件中的目录名称 +# 值:相对于项目根目录的实际路径 +BACKUP_DIRECTORIES: dict[str, str] = { + "plugins": "data/plugins", # 插件本体 + "plugin_data": "data/plugin_data", # 插件数据 + "config": "data/config", # 配置目录 + "t2i_templates": "data/t2i_templates", # T2I 模板 + "webchat": "data/webchat", # WebChat 数据 + "temp": "data/temp", # 临时文件 +} + +# 备份清单版本号 +BACKUP_MANIFEST_VERSION = "1.1" diff --git a/astrbot/core/backup/exporter.py b/astrbot/core/backup/exporter.py index 42caa2c19..6f2e73c2e 100644 --- a/astrbot/core/backup/exporter.py +++ b/astrbot/core/backup/exporter.py @@ -13,64 +13,23 @@ from typing import TYPE_CHECKING, Any from sqlalchemy import select -from sqlmodel import SQLModel from astrbot.core import logger from astrbot.core.config.default import VERSION from astrbot.core.db import BaseDatabase -from astrbot.core.db.po import ( - Attachment, - CommandConfig, - CommandConflict, - ConversationV2, - Persona, - PlatformMessageHistory, - PlatformSession, - PlatformStat, - Preference, -) -from astrbot.core.knowledge_base.models import ( - KBDocument, - KBMedia, - KnowledgeBase, + +# 从共享常量模块导入 +from .constants import ( + BACKUP_DIRECTORIES, + BACKUP_MANIFEST_VERSION, + KB_METADATA_MODELS, + MAIN_DB_MODELS, ) if TYPE_CHECKING: from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager -# 主数据库模型类映射 -MAIN_DB_MODELS: dict[str, type[SQLModel]] = { - "platform_stats": PlatformStat, - "conversations": ConversationV2, - "personas": Persona, - "preferences": Preference, - "platform_message_history": PlatformMessageHistory, - "platform_sessions": PlatformSession, - "attachments": Attachment, - "command_configs": CommandConfig, - "command_conflicts": CommandConflict, -} - -# 知识库元数据模型类映射 -KB_METADATA_MODELS: dict[str, type[SQLModel]] = { - "knowledge_bases": KnowledgeBase, - "kb_documents": KBDocument, - "kb_media": KBMedia, -} - - -# 需要备份的目录列表 -BACKUP_DIRECTORIES = { - "plugins": "data/plugins", # 插件本体 - "plugin_data": "data/plugin_data", # 插件数据 - "config": "data/config", # 配置目录 - "t2i_templates": "data/t2i_templates", # T2I 模板 - "webchat": "data/webchat", # WebChat 数据 - "temp": "data/temp", # 临时文件 -} - - class AstrBotExporter: """AstrBot 数据导出器 @@ -479,7 +438,7 @@ def _generate_manifest( kb_media_files[kb_id] = media_files manifest = { - "version": "1.1", # 升级版本号,支持目录备份 + "version": BACKUP_MANIFEST_VERSION, "astrbot_version": VERSION, "exported_at": datetime.now(timezone.utc).isoformat(), "schema_version": { diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index 2d8d7ab45..fbbcc7392 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -13,26 +13,16 @@ from typing import TYPE_CHECKING, Any from sqlalchemy import delete -from sqlmodel import SQLModel from astrbot.core import logger from astrbot.core.config.default import VERSION from astrbot.core.db import BaseDatabase -from astrbot.core.db.po import ( - Attachment, - CommandConfig, - CommandConflict, - ConversationV2, - Persona, - PlatformMessageHistory, - PlatformSession, - PlatformStat, - Preference, -) -from astrbot.core.knowledge_base.models import ( - KBDocument, - KBMedia, - KnowledgeBase, + +# 从共享常量模块导入 +from .constants import ( + BACKUP_DIRECTORIES, + KB_METADATA_MODELS, + MAIN_DB_MODELS, ) if TYPE_CHECKING: @@ -84,27 +74,6 @@ def compare_versions(v1: str, v2: str) -> int: return 0 -# 主数据库模型类映射 -MAIN_DB_MODELS: dict[str, type[SQLModel]] = { - "platform_stats": PlatformStat, - "conversations": ConversationV2, - "personas": Persona, - "preferences": Preference, - "platform_message_history": PlatformMessageHistory, - "platform_sessions": PlatformSession, - "attachments": Attachment, - "command_configs": CommandConfig, - "command_conflicts": CommandConflict, -} - -# 知识库元数据模型类映射 -KB_METADATA_MODELS: dict[str, type[SQLModel]] = { - "knowledge_bases": KnowledgeBase, - "kb_documents": KBDocument, - "kb_media": KBMedia, -} - - class ImportResult: """导入结果""" @@ -136,17 +105,6 @@ def to_dict(self) -> dict: } -# 需要恢复的目录列表 -BACKUP_DIRECTORIES = { - "plugins": "data/plugins", # 插件本体 - "plugin_data": "data/plugin_data", # 插件数据 - "config": "data/config", # 配置目录 - "t2i_templates": "data/t2i_templates", # T2I 模板 - "webchat": "data/webchat", # WebChat 数据 - "temp": "data/temp", # 临时文件 -} - - class AstrBotImporter: """AstrBot 数据导入器 diff --git a/tests/test_backup.py b/tests/test_backup.py index 8b8a3bb3f..cc6ce251b 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -9,11 +9,12 @@ import pytest -from astrbot.core.backup.exporter import ( +from astrbot.core.backup import ( + BACKUP_MANIFEST_VERSION, KB_METADATA_MODELS, MAIN_DB_MODELS, - AstrBotExporter, ) +from astrbot.core.backup.exporter import AstrBotExporter from astrbot.core.backup.importer import ( AstrBotImporter, ImportResult, @@ -192,7 +193,7 @@ def test_generate_manifest(self, mock_main_db, mock_kb_manager): manifest = exporter._generate_manifest(main_data, kb_meta_data, dir_stats) - assert manifest["version"] == "1.1" # 升级版本号,支持目录备份 + assert manifest["version"] == BACKUP_MANIFEST_VERSION assert manifest["astrbot_version"] == VERSION assert "exported_at" in manifest assert "tables" in manifest From 25cb8649a0189b7852c33a257a373c679d5d450b Mon Sep 17 00:00:00 2001 From: RC-CHN <1051989940@qq.com> Date: Thu, 18 Dec 2025 15:52:28 +0800 Subject: [PATCH 08/17] =?UTF-8?q?feat:=20=E4=BF=AE=E6=94=B9=E5=A4=87?= =?UTF-8?q?=E4=BB=BD=E7=89=88=E6=9C=AC=E6=A0=A1=E9=AA=8C=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E5=85=81=E8=AE=B8=E5=BC=BA=E5=88=B6=E5=B0=8F=E7=89=88?= =?UTF-8?q?=E6=9C=AC=E9=97=B4=E5=AF=BC=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/backup/__init__.py | 3 +- astrbot/core/backup/importer.py | 220 ++++++++++++++++- astrbot/dashboard/routes/backup.py | 165 +++++++++---- .../src/components/shared/BackupDialog.vue | 149 ++++++++++-- .../i18n/locales/en-US/features/settings.json | 11 +- .../i18n/locales/zh-CN/features/settings.json | 11 +- tests/test_backup.py | 221 +++++++++++++++++- 7 files changed, 706 insertions(+), 74 deletions(-) diff --git a/astrbot/core/backup/__init__.py b/astrbot/core/backup/__init__.py index 8b1ef62d6..4b1db3f38 100644 --- a/astrbot/core/backup/__init__.py +++ b/astrbot/core/backup/__init__.py @@ -13,11 +13,12 @@ # 导入导出器和导入器 from .exporter import AstrBotExporter -from .importer import AstrBotImporter +from .importer import AstrBotImporter, ImportPreCheckResult __all__ = [ "AstrBotExporter", "AstrBotImporter", + "ImportPreCheckResult", "MAIN_DB_MODELS", "KB_METADATA_MODELS", "BACKUP_DIRECTORIES", diff --git a/astrbot/core/backup/importer.py b/astrbot/core/backup/importer.py index fbbcc7392..ec9a76017 100644 --- a/astrbot/core/backup/importer.py +++ b/astrbot/core/backup/importer.py @@ -1,13 +1,17 @@ """AstrBot 数据导入器 负责从 ZIP 备份文件恢复所有数据。 -导入时进行严格的版本校验,仅允许相同版本的 AstrBot 进行导入。 +导入时进行版本校验: +- 主版本(前两位)不同时直接拒绝导入 +- 小版本(第三位)不同时提示警告,用户可选择强制导入 +- 版本匹配时也需要用户确认 """ import json import os import shutil import zipfile +from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Any @@ -74,6 +78,50 @@ def compare_versions(v1: str, v2: str) -> int: return 0 +@dataclass +class ImportPreCheckResult: + """导入预检查结果 + + 用于在实际导入前检查备份文件的版本兼容性, + 并返回确认信息让用户决定是否继续导入。 + """ + + # 检查是否通过(文件有效且版本可导入) + valid: bool = False + # 是否可以导入(版本兼容) + can_import: bool = False + # 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝) + version_status: str = "" + # 备份文件中的 AstrBot 版本 + backup_version: str = "" + # 当前运行的 AstrBot 版本 + current_version: str = VERSION + # 备份创建时间 + backup_time: str = "" + # 确认消息(显示给用户) + confirm_message: str = "" + # 警告消息列表 + warnings: list[str] = field(default_factory=list) + # 错误消息(如果检查失败) + error: str = "" + # 备份包含的内容摘要 + backup_summary: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + return { + "valid": self.valid, + "can_import": self.can_import, + "version_status": self.version_status, + "backup_version": self.backup_version, + "current_version": self.current_version, + "backup_time": self.backup_time, + "confirm_message": self.confirm_message, + "warnings": self.warnings, + "error": self.error, + "backup_summary": self.backup_summary, + } + + class ImportResult: """导入结果""" @@ -138,6 +186,156 @@ def __init__( self.kb_root_dir = kb_root_dir self.data_root = data_root + def pre_check(self, zip_path: str) -> ImportPreCheckResult: + """预检查备份文件 + + 在实际导入前检查备份文件的有效性和版本兼容性。 + 返回检查结果供前端显示确认对话框。 + + Args: + zip_path: ZIP 备份文件路径 + + Returns: + ImportPreCheckResult: 预检查结果 + """ + result = ImportPreCheckResult() + result.current_version = VERSION + + if not os.path.exists(zip_path): + result.error = f"备份文件不存在: {zip_path}" + return result + + try: + with zipfile.ZipFile(zip_path, "r") as zf: + # 读取 manifest + try: + manifest_data = zf.read("manifest.json") + manifest = json.loads(manifest_data) + except KeyError: + result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份" + return result + except json.JSONDecodeError as e: + result.error = f"manifest.json 格式错误: {e}" + return result + + # 提取基本信息 + result.backup_version = manifest.get("astrbot_version", "未知") + result.backup_time = manifest.get("created_at", "未知") + result.valid = True + + # 构建备份摘要 + result.backup_summary = { + "tables": list(manifest.get("tables", {}).keys()), + "has_knowledge_bases": manifest.get("has_knowledge_bases", False), + "has_config": manifest.get("has_config", False), + "directories": manifest.get("directories", []), + } + + # 检查版本兼容性 + version_check = self._check_version_compatibility(result.backup_version) + result.version_status = version_check["status"] + result.can_import = version_check["can_import"] + + if version_check["status"] == "major_diff": + result.warnings.append(version_check["message"]) + result.confirm_message = ( + f"⛔ 无法导入:备份版本 {result.backup_version} 与当前版本 {VERSION} " + f"的主版本号不同,跨主版本导入可能导致数据损坏。" + ) + elif version_check["status"] == "minor_diff": + result.warnings.append(version_check["message"]) + result.confirm_message = ( + f"⚠️ 版本差异警告\n\n" + f"备份版本: {result.backup_version}\n" + f"当前版本: {VERSION}\n\n" + f"小版本差异通常是兼容的,但可能存在少量数据结构变化。\n" + f"导入将会清空并覆盖现有的所有数据!\n\n" + f"是否确认继续导入?" + ) + else: + # 版本匹配 + result.confirm_message = ( + f"✅ 版本匹配\n\n" + f"备份版本: {result.backup_version}\n" + f"备份时间: {result.backup_time}\n\n" + f"⚠️ 导入将会清空并覆盖现有的所有数据,包括:\n" + f"• 主数据库(对话记录、配置等)\n" + f"• 知识库数据\n" + f"• 插件及插件数据\n" + f"• 配置文件\n\n" + f"此操作不可撤销!是否确认继续?" + ) + + return result + + except zipfile.BadZipFile: + result.error = "无效的 ZIP 文件" + return result + except Exception as e: + result.error = f"检查备份文件失败: {e}" + return result + + def _check_version_compatibility(self, backup_version: str) -> dict: + """检查版本兼容性 + + 规则: + - 主版本(前两位,如 4.9)必须一致,否则拒绝 + - 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入 + + Returns: + dict: {status, can_import, message} + """ + if not backup_version: + return { + "status": "major_diff", + "can_import": False, + "message": "备份文件缺少版本信息", + } + + backup_parts = parse_version(backup_version) + current_parts = parse_version(VERSION) + + # 补齐到至少 2 位用于主版本比较 + backup_major = ( + backup_parts[:2] + if len(backup_parts) >= 2 + else backup_parts + (0,) * (2 - len(backup_parts)) + ) + current_major = ( + current_parts[:2] + if len(current_parts) >= 2 + else current_parts + (0,) * (2 - len(current_parts)) + ) + + if backup_major != current_major: + return { + "status": "major_diff", + "can_import": False, + "message": ( + f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}。" + f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。" + ), + } + + # 比较完整版本 + backup_full = backup_parts + (0,) * (3 - len(backup_parts)) + current_full = current_parts + (0,) * (3 - len(current_parts)) + + if backup_full != current_full: + return { + "status": "minor_diff", + "can_import": True, + "message": ( + f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}。" + ), + } + + return { + "status": "match", + "can_import": True, + "message": "版本匹配", + } + async def import_all( self, zip_path: str, @@ -283,16 +481,24 @@ async def import_all( return result def _validate_version(self, manifest: dict) -> None: - """验证版本兼容性 - 仅允许相同版本导入""" + """验证版本兼容性 - 仅允许相同主版本导入 + + 注意:此方法仅在 import_all 中调用,用于双重校验。 + 前端应先调用 pre_check 获取详细的版本信息并让用户确认。 + """ backup_version = manifest.get("astrbot_version") if not backup_version: raise ValueError("备份文件缺少版本信息") - if backup_version != VERSION: - raise ValueError( - f"版本不匹配: 备份版本 {backup_version}, 当前版本 {VERSION}。" - f"请使用相同版本的 AstrBot 进行导入。" - ) + # 使用新的版本兼容性检查 + version_check = self._check_version_compatibility(backup_version) + + if version_check["status"] == "major_diff": + raise ValueError(version_check["message"]) + + # minor_diff 和 match 都允许导入 + if version_check["status"] == "minor_diff": + logger.warning(f"版本差异警告: {version_check['message']}") async def _clear_main_db(self) -> None: """清空主数据库所有表""" diff --git a/astrbot/dashboard/routes/backup.py b/astrbot/dashboard/routes/backup.py index be87b7bf7..0445810b7 100644 --- a/astrbot/dashboard/routes/backup.py +++ b/astrbot/dashboard/routes/backup.py @@ -88,7 +88,9 @@ def __init__( self.routes = { "/backup/list": ("GET", self.list_backups), "/backup/export": ("POST", self.export_backup), - "/backup/import": ("POST", self.import_backup), + "/backup/upload": ("POST", self.upload_backup), # 上传文件 + "/backup/check": ("POST", self.check_backup), # 预检查 + "/backup/import": ("POST", self.import_backup), # 确认导入 "/backup/progress": ("GET", self.get_progress), "/backup/download": ("GET", self.download_backup), "/backup/delete": ("POST", self.delete_backup), @@ -290,58 +292,137 @@ async def _background_export_task(self, task_id: str): logger.error(traceback.format_exc()) self._set_task_result(task_id, "failed", error=str(e)) - async def import_backup(self): - """导入备份 + async def upload_backup(self): + """上传备份文件 - 支持两种方式: - 1. multipart/form-data 文件上传 - 2. JSON 指定已存在的备份文件名 + 将备份文件上传到服务器,返回保存的文件名。 + 上传后应调用 check_backup 进行预检查。 Form Data: - - file: 备份文件 (可选) + - file: 备份文件 (.zip) + + 返回: + - filename: 保存的文件名 + """ + try: + files = await request.files + if "file" not in files: + return Response().error("缺少备份文件").__dict__ + + file = files["file"] + if not file.filename or not file.filename.endswith(".zip"): + return Response().error("请上传 ZIP 格式的备份文件").__dict__ + + # 清洗文件名并生成唯一名称,防止路径遍历和覆盖 + safe_filename = secure_filename(file.filename) + unique_filename = generate_unique_filename(safe_filename) + + # 保存上传的文件 + Path(self.backup_dir).mkdir(parents=True, exist_ok=True) + zip_path = os.path.join(self.backup_dir, unique_filename) + await file.save(zip_path) + + logger.info( + f"上传的备份文件已保存: {unique_filename} (原始名称: {file.filename})" + ) + + return ( + Response() + .ok( + { + "filename": unique_filename, + "original_filename": file.filename, + "size": os.path.getsize(zip_path), + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"上传备份文件失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"上传备份文件失败: {e!s}").__dict__ + + async def check_backup(self): + """预检查备份文件 + + 检查备份文件的版本兼容性,返回确认信息。 + 用户确认后调用 import_backup 执行导入。 + + JSON Body: + - filename: 已上传的备份文件名 + + 返回: + - ImportPreCheckResult: 预检查结果 + """ + try: + data = await request.json + filename = data.get("filename") + if not filename: + return Response().error("缺少 filename 参数").__dict__ + + # 安全检查 - 防止路径遍历 + if ".." in filename or "/" in filename or "\\" in filename: + return Response().error("无效的文件名").__dict__ + + zip_path = os.path.join(self.backup_dir, filename) + if not os.path.exists(zip_path): + return Response().error(f"备份文件不存在: {filename}").__dict__ + + # 获取知识库管理器(用于构造 importer) + kb_manager = getattr(self.core_lifecycle, "kb_manager", None) + + importer = AstrBotImporter( + main_db=self.db, + kb_manager=kb_manager, + config_path="data/cmd_config.json", + attachments_dir="data/attachments", + data_root="data", + ) + + # 执行预检查 + check_result = importer.pre_check(zip_path) + + return Response().ok(check_result.to_dict()).__dict__ + except Exception as e: + logger.error(f"预检查备份文件失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"预检查备份文件失败: {e!s}").__dict__ + + async def import_backup(self): + """执行备份导入 + + 在用户确认后执行实际的导入操作。 + 需要先调用 upload_backup 上传文件,再调用 check_backup 预检查。 JSON Body: - - filename: 已存在的备份文件名 (可选) + - filename: 已上传的备份文件名(必填) + - confirmed: 用户已确认(必填,必须为 true) 返回: - task_id: 任务ID,用于查询导入进度 """ try: - zip_path = None - content_type = request.content_type or "" - - if "multipart/form-data" in content_type: - # 文件上传模式 - files = await request.files - if "file" not in files: - return Response().error("缺少备份文件").__dict__ - - file = files["file"] - if not file.filename or not file.filename.endswith(".zip"): - return Response().error("请上传 ZIP 格式的备份文件").__dict__ - - # 清洗文件名并生成唯一名称,防止路径遍历和覆盖 - safe_filename = secure_filename(file.filename) - unique_filename = generate_unique_filename(safe_filename) - - # 保存上传的文件 - Path(self.backup_dir).mkdir(parents=True, exist_ok=True) - zip_path = os.path.join(self.backup_dir, unique_filename) - await file.save(zip_path) - - logger.info( - f"上传的备份文件已保存: {unique_filename} (原始名称: {file.filename})" + data = await request.json + filename = data.get("filename") + confirmed = data.get("confirmed", False) + + if not filename: + return Response().error("缺少 filename 参数").__dict__ + + if not confirmed: + return ( + Response() + .error("请先确认导入。导入将会清空并覆盖现有数据,此操作不可撤销。") + .__dict__ ) - else: - # JSON 模式 - 使用已存在的文件 - data = await request.json - filename = data.get("filename") - if not filename: - return Response().error("缺少 filename 参数").__dict__ - - zip_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(zip_path): - return Response().error(f"备份文件不存在: {filename}").__dict__ + + # 安全检查 - 防止路径遍历 + if ".." in filename or "/" in filename or "\\" in filename: + return Response().error("无效的文件名").__dict__ + + zip_path = os.path.join(self.backup_dir, filename) + if not os.path.exists(zip_path): + return Response().error(f"备份文件不存在: {filename}").__dict__ # 生成任务ID task_id = str(uuid.uuid4()) diff --git a/dashboard/src/components/shared/BackupDialog.vue b/dashboard/src/components/shared/BackupDialog.vue index 103b6d30b..4ad2207fa 100644 --- a/dashboard/src/components/shared/BackupDialog.vue +++ b/dashboard/src/components/shared/BackupDialog.vue @@ -76,6 +76,7 @@ +