diff --git a/astrbot/core/backup/exporter.py b/astrbot/core/backup/exporter.py index 77102c080..51c4a4650 100644 --- a/astrbot/core/backup/exporter.py +++ b/astrbot/core/backup/exporter.py @@ -447,6 +447,7 @@ def _generate_manifest( "version": BACKUP_MANIFEST_VERSION, "astrbot_version": VERSION, "exported_at": datetime.now(timezone.utc).isoformat(), + "origin": "exported", # 标记备份来源:exported=本实例导出, uploaded=用户上传 "schema_version": { "main_db": "v4", "kb_db": "v1", diff --git a/astrbot/dashboard/routes/backup.py b/astrbot/dashboard/routes/backup.py index bfd7b047d..a7c29de0b 100644 --- a/astrbot/dashboard/routes/backup.py +++ b/astrbot/dashboard/routes/backup.py @@ -1,13 +1,18 @@ """备份管理 API 路由""" import asyncio +import json import os import re +import shutil +import time import traceback import uuid +import zipfile from datetime import datetime from pathlib import Path +import jwt from quart import request, send_file from astrbot.core import logger @@ -22,6 +27,10 @@ from .route import Response, Route, RouteContext +# 分片上传常量 +CHUNK_SIZE = 1024 * 1024 # 1MB +UPLOAD_EXPIRE_SECONDS = 3600 # 上传会话过期时间(1小时) + def secure_filename(filename: str) -> str: """清洗文件名,移除路径遍历字符和危险字符 @@ -54,17 +63,17 @@ def secure_filename(filename: str) -> str: def generate_unique_filename(original_filename: str) -> str: - """生成唯一的文件名,添加时间戳前缀 + """生成唯一的文件名,在原文件名后添加时间戳后缀避免重名 Args: original_filename: 原始文件名(已清洗) Returns: - 唯一的文件名 + 添加了时间戳后缀的唯一文件名,格式为 {原文件名}_{YYYYMMDD_HHMMSS}.{扩展名} """ - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") name, ext = os.path.splitext(original_filename) - return f"uploaded_{timestamp}_{name}{ext}" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + return f"{name}_{timestamp}{ext}" class BackupRoute(Route): @@ -84,21 +93,34 @@ def __init__( self.core_lifecycle = core_lifecycle self.backup_dir = get_astrbot_backups_path() self.data_dir = get_astrbot_data_path() + self.chunks_dir = os.path.join(self.backup_dir, ".chunks") # 任务状态跟踪 self.backup_tasks: dict[str, dict] = {} self.backup_progress: dict[str, dict] = {} + # 分片上传会话跟踪 + # upload_id -> {filename, total_chunks, received_chunks, last_activity, chunk_dir} + self.upload_sessions: dict[str, dict] = {} + + # 后台清理任务句柄 + self._cleanup_task: asyncio.Task | None = None + # 注册路由 self.routes = { "/backup/list": ("GET", self.list_backups), "/backup/export": ("POST", self.export_backup), - "/backup/upload": ("POST", self.upload_backup), # 上传文件 + "/backup/upload": ("POST", self.upload_backup), # 上传文件(兼容小文件) + "/backup/upload/init": ("POST", self.upload_init), # 分片上传初始化 + "/backup/upload/chunk": ("POST", self.upload_chunk), # 上传分片 + "/backup/upload/complete": ("POST", self.upload_complete), # 完成分片上传 + "/backup/upload/abort": ("POST", self.upload_abort), # 取消上传 "/backup/check": ("POST", self.check_backup), # 预检查 "/backup/import": ("POST", self.import_backup), # 确认导入 "/backup/progress": ("GET", self.get_progress), "/backup/download": ("GET", self.download_backup), "/backup/delete": ("POST", self.delete_backup), + "/backup/rename": ("POST", self.rename_backup), # 重命名备份 } self.register_routes() @@ -173,7 +195,81 @@ async def _callback(stage: str, current: int, total: int, message: str = ""): return _callback + def _ensure_cleanup_task_started(self): + """确保后台清理任务已启动(在异步上下文中延迟启动)""" + if self._cleanup_task is None or self._cleanup_task.done(): + try: + self._cleanup_task = asyncio.create_task( + self._cleanup_expired_uploads() + ) + except RuntimeError: + # 如果没有运行中的事件循环,跳过(等待下次异步调用时启动) + pass + + async def _cleanup_expired_uploads(self): + """定期清理过期的上传会话 + + 基于 last_activity 字段判断过期,避免清理活跃的上传会话。 + """ + while True: + try: + await asyncio.sleep(300) # 每5分钟检查一次 + current_time = time.time() + expired_sessions = [] + + for upload_id, session in self.upload_sessions.items(): + # 使用 last_activity 判断过期,而非 created_at + last_activity = session.get("last_activity", session["created_at"]) + if current_time - last_activity > UPLOAD_EXPIRE_SECONDS: + expired_sessions.append(upload_id) + + for upload_id in expired_sessions: + await self._cleanup_upload_session(upload_id) + logger.info(f"清理过期的上传会话: {upload_id}") + + except asyncio.CancelledError: + # 任务被取消,正常退出 + break + except Exception as e: + logger.error(f"清理过期上传会话失败: {e}") + + async def _cleanup_upload_session(self, upload_id: str): + """清理上传会话""" + if upload_id in self.upload_sessions: + session = self.upload_sessions[upload_id] + chunk_dir = session.get("chunk_dir") + if chunk_dir and os.path.exists(chunk_dir): + try: + shutil.rmtree(chunk_dir) + except Exception as e: + logger.warning(f"清理分片目录失败: {e}") + del self.upload_sessions[upload_id] + + def _get_backup_manifest(self, zip_path: str) -> dict | None: + """从备份文件读取 manifest.json + + Args: + zip_path: ZIP 文件路径 + + Returns: + dict | None: manifest 内容,如果不是有效备份则返回 None + """ + try: + with zipfile.ZipFile(zip_path, "r") as zf: + if "manifest.json" in zf.namelist(): + manifest_data = zf.read("manifest.json") + return json.loads(manifest_data.decode("utf-8")) + else: + # 没有 manifest.json,不是有效的 AstrBot 备份 + return None + except Exception as e: + logger.debug(f"读取备份 manifest 失败: {e}") + return None # 无法读取,不是有效备份 + async def list_backups(self): + # 确保后台清理任务已启动 + self._ensure_cleanup_task_started() + """获取备份列表 Query 参数: @@ -190,16 +286,34 @@ async def list_backups(self): # 获取所有备份文件 backup_files = [] for filename in os.listdir(self.backup_dir): - if filename.endswith(".zip") and filename.startswith("astrbot_backup_"): - file_path = os.path.join(self.backup_dir, filename) - stat = os.stat(file_path) - backup_files.append( - { - "filename": filename, - "size": stat.st_size, - "created_at": stat.st_mtime, - } - ) + # 只处理 .zip 文件,排除隐藏文件和目录 + if not filename.endswith(".zip") or filename.startswith("."): + continue + + file_path = os.path.join(self.backup_dir, filename) + if not os.path.isfile(file_path): + continue + + # 读取 manifest.json 获取备份信息 + # 如果返回 None,说明不是有效的 AstrBot 备份,跳过 + manifest = self._get_backup_manifest(file_path) + if manifest is None: + logger.debug(f"跳过无效备份文件: {filename}") + continue + + stat = os.stat(file_path) + backup_files.append( + { + "filename": filename, + "size": stat.st_size, + "created_at": stat.st_mtime, + "type": manifest.get( + "origin", "exported" + ), # 老版本没有 origin 默认为 exported + "astrbot_version": manifest.get("astrbot_version", "未知"), + "exported_at": manifest.get("exported_at"), + } + ) # 按创建时间倒序排序 backup_files.sort(key=lambda x: x["created_at"], reverse=True) @@ -345,6 +459,309 @@ async def upload_backup(self): logger.error(traceback.format_exc()) return Response().error(f"上传备份文件失败: {e!s}").__dict__ + async def upload_init(self): + """初始化分片上传 + + 创建一个上传会话,返回 upload_id 供后续分片上传使用。 + + JSON Body: + - filename: 原始文件名 + - total_size: 文件总大小(字节) + + 返回: + - upload_id: 上传会话 ID + - chunk_size: 分片大小(由后端决定) + - total_chunks: 分片总数(由后端根据 total_size 和 chunk_size 计算) + """ + try: + data = await request.json + filename = data.get("filename") + total_size = data.get("total_size", 0) + + if not filename: + return Response().error("缺少 filename 参数").__dict__ + + if not filename.endswith(".zip"): + return Response().error("请上传 ZIP 格式的备份文件").__dict__ + + if total_size <= 0: + return Response().error("无效的文件大小").__dict__ + + # 由后端计算分片总数,确保前后端一致 + import math + + total_chunks = math.ceil(total_size / CHUNK_SIZE) + + # 生成上传 ID + upload_id = str(uuid.uuid4()) + + # 创建分片存储目录 + chunk_dir = os.path.join(self.chunks_dir, upload_id) + Path(chunk_dir).mkdir(parents=True, exist_ok=True) + + # 清洗文件名 + safe_filename = secure_filename(filename) + unique_filename = generate_unique_filename(safe_filename) + + # 创建上传会话 + current_time = time.time() + self.upload_sessions[upload_id] = { + "filename": unique_filename, + "original_filename": filename, + "total_size": total_size, + "total_chunks": total_chunks, + "received_chunks": set(), + "created_at": current_time, + "last_activity": current_time, # 用于判断会话是否活跃 + "chunk_dir": chunk_dir, + } + + logger.info( + f"初始化分片上传: upload_id={upload_id}, " + f"filename={unique_filename}, total_chunks={total_chunks}" + ) + + return ( + Response() + .ok( + { + "upload_id": upload_id, + "chunk_size": CHUNK_SIZE, + "total_chunks": total_chunks, + "filename": unique_filename, + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"初始化分片上传失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"初始化分片上传失败: {e!s}").__dict__ + + async def upload_chunk(self): + """上传分片 + + 上传单个分片数据。 + + Form Data: + - upload_id: 上传会话 ID + - chunk_index: 分片索引(从 0 开始) + - chunk: 分片数据 + + 返回: + - received: 已接收的分片数量 + - total: 分片总数 + """ + try: + form = await request.form + files = await request.files + + upload_id = form.get("upload_id") + chunk_index_str = form.get("chunk_index") + + if not upload_id or chunk_index_str is None: + return Response().error("缺少必要参数").__dict__ + + try: + chunk_index = int(chunk_index_str) + except ValueError: + return Response().error("无效的分片索引").__dict__ + + if "chunk" not in files: + return Response().error("缺少分片数据").__dict__ + + # 验证上传会话 + if upload_id not in self.upload_sessions: + return Response().error("上传会话不存在或已过期").__dict__ + + session = self.upload_sessions[upload_id] + + # 验证分片索引 + if chunk_index < 0 or chunk_index >= session["total_chunks"]: + return Response().error("分片索引超出范围").__dict__ + + # 保存分片 + chunk_file = files["chunk"] + chunk_path = os.path.join(session["chunk_dir"], f"{chunk_index}.part") + await chunk_file.save(chunk_path) + + # 记录已接收的分片,并更新最后活动时间 + session["received_chunks"].add(chunk_index) + session["last_activity"] = time.time() # 刷新活动时间,防止活跃上传被清理 + + received_count = len(session["received_chunks"]) + total_chunks = session["total_chunks"] + + logger.debug( + f"接收分片: upload_id={upload_id}, " + f"chunk={chunk_index + 1}/{total_chunks}" + ) + + return ( + Response() + .ok( + { + "received": received_count, + "total": total_chunks, + "chunk_index": chunk_index, + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"上传分片失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"上传分片失败: {e!s}").__dict__ + + def _mark_backup_as_uploaded(self, zip_path: str) -> None: + """修改备份文件的 manifest.json,将 origin 设置为 uploaded + + 使用 zipfile 的 append 模式添加新的 manifest.json, + ZIP 规范中后添加的同名文件会覆盖先前的文件。 + + Args: + zip_path: ZIP 文件路径 + """ + try: + # 读取原有 manifest + manifest = {"origin": "uploaded", "uploaded_at": datetime.now().isoformat()} + with zipfile.ZipFile(zip_path, "r") as zf: + if "manifest.json" in zf.namelist(): + manifest_data = zf.read("manifest.json") + manifest = json.loads(manifest_data.decode("utf-8")) + manifest["origin"] = "uploaded" + manifest["uploaded_at"] = datetime.now().isoformat() + + # 使用 append 模式添加新的 manifest.json + # ZIP 规范中,后添加的同名文件会覆盖先前的 + with zipfile.ZipFile(zip_path, "a") as zf: + new_manifest = json.dumps(manifest, ensure_ascii=False, indent=2) + zf.writestr("manifest.json", new_manifest) + + logger.debug(f"已标记备份为上传来源: {zip_path}") + except Exception as e: + logger.warning(f"标记备份来源失败: {e}") + + async def upload_complete(self): + """完成分片上传 + + 合并所有分片为完整文件。 + + JSON Body: + - upload_id: 上传会话 ID + + 返回: + - filename: 合并后的文件名 + - size: 文件大小 + """ + try: + data = await request.json + upload_id = data.get("upload_id") + + if not upload_id: + return Response().error("缺少 upload_id 参数").__dict__ + + # 验证上传会话 + if upload_id not in self.upload_sessions: + return Response().error("上传会话不存在或已过期").__dict__ + + session = self.upload_sessions[upload_id] + + # 检查是否所有分片都已接收 + received = session["received_chunks"] + total = session["total_chunks"] + + if len(received) != total: + missing = set(range(total)) - received + return ( + Response() + .error(f"分片不完整,缺少: {sorted(missing)[:10]}...") + .__dict__ + ) + + # 合并分片 + chunk_dir = session["chunk_dir"] + filename = session["filename"] + + Path(self.backup_dir).mkdir(parents=True, exist_ok=True) + output_path = os.path.join(self.backup_dir, filename) + + try: + with open(output_path, "wb") as outfile: + for i in range(total): + chunk_path = os.path.join(chunk_dir, f"{i}.part") + with open(chunk_path, "rb") as chunk_file: + # 分块读取,避免内存溢出 + while True: + data_block = chunk_file.read(8192) + if not data_block: + break + outfile.write(data_block) + + file_size = os.path.getsize(output_path) + + # 标记备份为上传来源(修改 manifest.json 中的 origin 字段) + self._mark_backup_as_uploaded(output_path) + + logger.info( + f"分片上传完成: {filename}, size={file_size}, chunks={total}" + ) + + # 清理分片目录 + await self._cleanup_upload_session(upload_id) + + return ( + Response() + .ok( + { + "filename": filename, + "original_filename": session["original_filename"], + "size": file_size, + } + ) + .__dict__ + ) + except Exception as e: + # 如果合并失败,删除不完整的文件 + if os.path.exists(output_path): + os.remove(output_path) + raise e + + except Exception as e: + logger.error(f"完成分片上传失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"完成分片上传失败: {e!s}").__dict__ + + async def upload_abort(self): + """取消分片上传 + + 取消上传并清理已上传的分片。 + + JSON Body: + - upload_id: 上传会话 ID + """ + try: + data = await request.json + upload_id = data.get("upload_id") + + if not upload_id: + return Response().error("缺少 upload_id 参数").__dict__ + + if upload_id not in self.upload_sessions: + # 会话已不存在,可能已过期或已完成 + return Response().ok(message="上传已取消").__dict__ + + # 清理会话 + await self._cleanup_upload_session(upload_id) + + logger.info(f"取消分片上传: {upload_id}") + + return Response().ok(message="上传已取消").__dict__ + except Exception as e: + logger.error(f"取消上传失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"取消上传失败: {e!s}").__dict__ + async def check_backup(self): """预检查备份文件 @@ -537,12 +954,33 @@ async def download_backup(self): Query 参数: - filename: 备份文件名 (必填) + - token: JWT token (必填,用于浏览器原生下载鉴权) + + 注意: 此路由已被添加到 auth_middleware 白名单中, + 使用 URL 参数中的 token 进行鉴权,以支持浏览器原生下载。 """ try: filename = request.args.get("filename") + token = request.args.get("token") + if not filename: return Response().error("缺少参数 filename").__dict__ + if not token: + return Response().error("缺少参数 token").__dict__ + + # 验证 JWT token + try: + jwt_secret = self.config.get("dashboard", {}).get("jwt_secret") + if not jwt_secret: + return Response().error("服务器配置错误").__dict__ + + jwt.decode(token, jwt_secret, algorithms=["HS256"]) + except jwt.ExpiredSignatureError: + return Response().error("Token 已过期,请刷新页面后重试").__dict__ + except jwt.InvalidTokenError: + return Response().error("Token 无效").__dict__ + # 安全检查 - 防止路径遍历 if ".." in filename or "/" in filename or "\\" in filename: return Response().error("无效的文件名").__dict__ @@ -587,3 +1025,69 @@ async def delete_backup(self): logger.error(f"删除备份失败: {e}") logger.error(traceback.format_exc()) return Response().error(f"删除备份失败: {e!s}").__dict__ + + async def rename_backup(self): + """重命名备份文件 + + Body: + - filename: 当前文件名 (必填) + - new_name: 新文件名 (必填,不含扩展名) + """ + try: + data = await request.json + filename = data.get("filename") + new_name = data.get("new_name") + + if not filename: + return Response().error("缺少参数 filename").__dict__ + + if not new_name: + return Response().error("缺少参数 new_name").__dict__ + + # 安全检查 - 防止路径遍历 + if ".." in filename or "/" in filename or "\\" in filename: + return Response().error("无效的文件名").__dict__ + + # 清洗新文件名(移除路径和危险字符) + new_name = secure_filename(new_name) + + # 移除新文件名中的扩展名(如果有的话) + if new_name.endswith(".zip"): + new_name = new_name[:-4] + + # 验证新文件名不为空 + if not new_name or new_name.replace("_", "") == "": + return Response().error("新文件名无效").__dict__ + + # 强制使用 .zip 扩展名 + new_filename = f"{new_name}.zip" + + # 检查原文件是否存在 + old_path = os.path.join(self.backup_dir, filename) + if not os.path.exists(old_path): + return Response().error("备份文件不存在").__dict__ + + # 检查新文件名是否已存在 + new_path = os.path.join(self.backup_dir, new_filename) + if os.path.exists(new_path): + return Response().error(f"文件名 '{new_filename}' 已存在").__dict__ + + # 执行重命名 + os.rename(old_path, new_path) + + logger.info(f"备份文件重命名: {filename} -> {new_filename}") + + return ( + Response() + .ok( + { + "old_filename": filename, + "new_filename": new_filename, + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"重命名备份失败: {e}") + logger.error(traceback.format_exc()) + return Response().error(f"重命名备份失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index f778a5049..ad83c4886 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -115,6 +115,7 @@ async def auth_middleware(self): "/api/file", "/api/platform/webhook", "/api/stat/start-time", + "/api/backup/download", # 备份下载使用 URL 参数传递 token ] if any(request.path.startswith(prefix) for prefix in allowed_endpoints): return None diff --git a/dashboard/src/components/shared/BackupDialog.vue b/dashboard/src/components/shared/BackupDialog.vue index 629e4e559..eb0327e4e 100644 --- a/dashboard/src/components/shared/BackupDialog.vue +++ b/dashboard/src/components/shared/BackupDialog.vue @@ -110,9 +110,23 @@
{{ t('features.settings.backup.import.uploadWait') }}
++ {{ uploadProgress.message || t('features.settings.backup.import.uploadWait') }} +
++ {{ formatFileSize(uploadProgress.uploaded) }} / {{ formatFileSize(uploadProgress.total) }} + ({{ uploadProgress.percent }}%) +
+
+
+ {{ t('features.settings.backup.list.renameHint') }} +
+