diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 008ef182e..0abd3ad49 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -164,7 +164,7 @@ async def insert_platform_message_history( self, platform_id: str, user_id: str, - content: list[dict], + content: dict, sender_id: str | None = None, sender_name: str | None = None, ) -> None: @@ -287,3 +287,14 @@ async def clear_preferences(self, scope: str, scope_id: str) -> None: # async def get_llm_messages(self, cid: str) -> list[LLMMessage]: # """Get all LLM messages for a specific conversation.""" # ... + + @abc.abstractmethod + async def get_session_conversations( + self, + page: int = 1, + page_size: int = 20, + search_query: str | None = None, + platform: str | None = None, + ) -> tuple[list[dict], int]: + """Get paginated session conversations with joined conversation and persona details, support search and platform filter.""" + ... diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 88113d130..24a05f947 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -75,7 +75,9 @@ class Persona(SQLModel, table=True): __tablename__ = "personas" - id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) + id: int | None = Field( + primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None + ) persona_id: str = Field(max_length=255, nullable=False) system_prompt: str = Field(sa_type=Text, nullable=False) begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON) @@ -135,7 +137,9 @@ class PlatformMessageHistory(SQLModel, table=True): __tablename__ = "platform_message_history" - id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) + id: int | None = Field( + primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None + ) platform_id: str = Field(nullable=False) user_id: str = Field(nullable=False) # An id of group, user in platform sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform @@ -158,8 +162,8 @@ class Attachment(SQLModel, table=True): __tablename__ = "attachments" - inner_attachment_id: int = Field( - primary_key=True, sa_column_kwargs={"autoincrement": True} + inner_attachment_id: int | None = Field( + primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None ) attachment_id: str = Field( max_length=36, diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 51378b017..d8c1684a7 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -15,10 +15,8 @@ SQLModel, ) -from sqlalchemy import select, update, delete, text +from sqlmodel import select, update, delete, text, func, or_, desc, col from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.sql import func -from sqlalchemy import or_ NOT_GIVEN = T.TypeVar("NOT_GIVEN") @@ -42,10 +40,10 @@ async def initialize(self) -> None: async def insert_platform_stats( self, - platform_id: str, - platform_type: str, - count: int = 1, - timestamp: datetime = None, + platform_id, + platform_type, + count=1, + timestamp=None, ) -> None: """Insert a new platform statistic record.""" async with self.get_db() as session: @@ -76,7 +74,9 @@ async def count_platform_stats(self) -> int: async with self.get_db() as session: session: AsyncSession result = await session.execute( - select(func.count(PlatformStat.platform_id)).select_from(PlatformStat) + select(func.count(col(PlatformStat.platform_id))).select_from( + PlatformStat + ) ) count = result.scalar_one_or_none() return count if count is not None else 0 @@ -96,7 +96,7 @@ async def get_platform_stats(self, offset_sec: int = 86400) -> T.List[PlatformSt """), {"start_time": start_time}, ) - return result.scalars().all() + return list(result.scalars().all()) # ==== # Conversation Management @@ -112,7 +112,7 @@ async def get_conversations(self, user_id=None, platform_id=None): if platform_id: query = query.where(ConversationV2.platform_id == platform_id) # order by - query = query.order_by(ConversationV2.created_at.desc()) + query = query.order_by(desc(ConversationV2.created_at)) result = await session.execute(query) return result.scalars().all() @@ -130,7 +130,7 @@ async def get_all_conversations(self, page=1, page_size=20): offset = (page - 1) * page_size result = await session.execute( select(ConversationV2) - .order_by(ConversationV2.created_at.desc()) + .order_by(desc(ConversationV2.created_at)) .offset(offset) .limit(page_size) ) @@ -151,25 +151,25 @@ async def get_filtered_conversations( if platform_ids: base_query = base_query.where( - ConversationV2.platform_id.in_(platform_ids) + col(ConversationV2.platform_id).in_(platform_ids) ) if search_query: search_query = search_query.encode("unicode_escape").decode("utf-8") base_query = base_query.where( or_( - ConversationV2.title.ilike(f"%{search_query}%"), - ConversationV2.content.ilike(f"%{search_query}%"), - ConversationV2.user_id.ilike(f"%{search_query}%"), + col(ConversationV2.title).ilike(f"%{search_query}%"), + col(ConversationV2.content).ilike(f"%{search_query}%"), + col(ConversationV2.user_id).ilike(f"%{search_query}%"), ) ) if "message_types" in kwargs and len(kwargs["message_types"]) > 0: for msg_type in kwargs["message_types"]: base_query = base_query.where( - ConversationV2.user_id.ilike(f"%:{msg_type}:%") + col(ConversationV2.user_id).ilike(f"%:{msg_type}:%") ) if "platforms" in kwargs and len(kwargs["platforms"]) > 0: base_query = base_query.where( - ConversationV2.platform_id.in_(kwargs["platforms"]) + col(ConversationV2.platform_id).in_(kwargs["platforms"]) ) # Get total count matching the filters @@ -180,7 +180,7 @@ async def get_filtered_conversations( # Get paginated results offset = (page - 1) * page_size result_query = ( - base_query.order_by(ConversationV2.created_at.desc()) + base_query.order_by(desc(ConversationV2.created_at)) .offset(offset) .limit(page_size) ) @@ -226,7 +226,7 @@ async def update_conversation(self, cid, title=None, persona_id=None, content=No session: AsyncSession async with session.begin(): query = update(ConversationV2).where( - ConversationV2.conversation_id == cid + col(ConversationV2.conversation_id) == cid ) values = {} if title is not None: @@ -246,7 +246,9 @@ async def delete_conversation(self, cid): session: AsyncSession async with session.begin(): await session.execute( - delete(ConversationV2).where(ConversationV2.conversation_id == cid) + delete(ConversationV2).where( + col(ConversationV2.conversation_id) == cid + ) ) async def delete_conversations_by_user_id(self, user_id: str) -> None: @@ -254,8 +256,115 @@ async def delete_conversations_by_user_id(self, user_id: str) -> None: session: AsyncSession async with session.begin(): await session.execute( - delete(ConversationV2).where(ConversationV2.user_id == user_id) + delete(ConversationV2).where(col(ConversationV2.user_id) == user_id) + ) + + async def get_session_conversations( + self, + page=1, + page_size=20, + search_query=None, + platform=None, + ) -> tuple[list[dict], int]: + """Get paginated session conversations with joined conversation and persona details.""" + async with self.get_db() as session: + session: AsyncSession + offset = (page - 1) * page_size + + base_query = ( + select( + col(Preference.scope_id).label("session_id"), + func.json_extract(Preference.value, "$.val").label( + "conversation_id" + ), # type: ignore + col(ConversationV2.persona_id).label("persona_id"), + col(ConversationV2.title).label("title"), + col(Persona.persona_id).label("persona_name"), + ) + .select_from(Preference) + .outerjoin( + ConversationV2, + func.json_extract(Preference.value, "$.val") + == ConversationV2.conversation_id, + ) + .outerjoin( + Persona, col(ConversationV2.persona_id) == Persona.persona_id ) + .where(Preference.scope == "umo", Preference.key == "sel_conv_id") + ) + + # 搜索筛选 + if search_query: + search_pattern = f"%{search_query}%" + base_query = base_query.where( + or_( + col(Preference.scope_id).ilike(search_pattern), + col(ConversationV2.title).ilike(search_pattern), + col(Persona.persona_id).ilike(search_pattern), + ) + ) + + # 平台筛选 + if platform: + platform_pattern = f"{platform}:%" + base_query = base_query.where( + col(Preference.scope_id).like(platform_pattern) + ) + + # 排序 + base_query = base_query.order_by(Preference.scope_id) + + # 分页结果 + result_query = base_query.offset(offset).limit(page_size) + result = await session.execute(result_query) + rows = result.fetchall() + + # 查询总数(应用相同的筛选条件) + count_base_query = ( + select(func.count(col(Preference.scope_id))) + .select_from(Preference) + .outerjoin( + ConversationV2, + func.json_extract(Preference.value, "$.val") + == ConversationV2.conversation_id, + ) + .outerjoin( + Persona, col(ConversationV2.persona_id) == Persona.persona_id + ) + .where(Preference.scope == "umo", Preference.key == "sel_conv_id") + ) + + # 应用相同的搜索和平台筛选条件到计数查询 + if search_query: + search_pattern = f"%{search_query}%" + count_base_query = count_base_query.where( + or_( + col(Preference.scope_id).ilike(search_pattern), + col(ConversationV2.title).ilike(search_pattern), + col(Persona.persona_id).ilike(search_pattern), + ) + ) + + if platform: + platform_pattern = f"{platform}:%" + count_base_query = count_base_query.where( + col(Preference.scope_id).like(platform_pattern) + ) + + total_result = await session.execute(count_base_query) + total = total_result.scalar() or 0 + + sessions_data = [ + { + "session_id": row.session_id, + "conversation_id": row.conversation_id, + "persona_id": row.persona_id, + "title": row.title, + "persona_name": row.persona_name, + } + for row in rows + ] + return sessions_data, total async def insert_platform_message_history( self, @@ -290,9 +399,9 @@ async def delete_platform_message_offset( cutoff_time = now - timedelta(seconds=offset_sec) await session.execute( delete(PlatformMessageHistory).where( - PlatformMessageHistory.platform_id == platform_id, - PlatformMessageHistory.user_id == user_id, - PlatformMessageHistory.created_at < cutoff_time, + col(PlatformMessageHistory.platform_id) == platform_id, + col(PlatformMessageHistory.user_id) == user_id, + col(PlatformMessageHistory.created_at) < cutoff_time, ) ) @@ -309,7 +418,7 @@ async def get_platform_message_history( PlatformMessageHistory.platform_id == platform_id, PlatformMessageHistory.user_id == user_id, ) - .order_by(PlatformMessageHistory.created_at.desc()) + .order_by(desc(PlatformMessageHistory.created_at)) ) result = await session.execute(query.offset(offset).limit(page_size)) return result.scalars().all() @@ -331,7 +440,7 @@ async def get_attachment_by_id(self, attachment_id): """Get an attachment by its ID.""" async with self.get_db() as session: session: AsyncSession - query = select(Attachment).where(Attachment.id == attachment_id) + query = select(Attachment).where(Attachment.attachment_id == attachment_id) result = await session.execute(query) return result.scalar_one_or_none() @@ -374,7 +483,7 @@ async def update_persona( async with self.get_db() as session: session: AsyncSession async with session.begin(): - query = update(Persona).where(Persona.persona_id == persona_id) + query = update(Persona).where(col(Persona.persona_id) == persona_id) values = {} if system_prompt is not None: values["system_prompt"] = system_prompt @@ -394,7 +503,7 @@ async def delete_persona(self, persona_id): session: AsyncSession async with session.begin(): await session.execute( - delete(Persona).where(Persona.persona_id == persona_id) + delete(Persona).where(col(Persona.persona_id) == persona_id) ) async def insert_preference_or_update(self, scope, scope_id, key, value): @@ -449,9 +558,9 @@ async def remove_preference(self, scope, scope_id, key): async with session.begin(): await session.execute( delete(Preference).where( - Preference.scope == scope, - Preference.scope_id == scope_id, - Preference.key == key, + col(Preference.scope) == scope, + col(Preference.scope_id) == scope_id, + col(Preference.key) == key, ) ) await session.commit() @@ -463,7 +572,8 @@ async def clear_preferences(self, scope, scope_id): async with session.begin(): await session.execute( delete(Preference).where( - Preference.scope == scope, Preference.scope_id == scope_id + col(Preference.scope) == scope, + col(Preference.scope_id) == scope_id, ) ) await session.commit() @@ -490,7 +600,7 @@ async def _inner(): DeprecatedPlatformStat( name=data.platform_id, count=data.count, - timestamp=data.timestamp.timestamp(), + timestamp=int(data.timestamp.timestamp()), ) ) return deprecated_stats @@ -548,7 +658,7 @@ async def _inner(): DeprecatedPlatformStat( name=platform_id, count=count, - timestamp=start_time.timestamp(), + timestamp=int(start_time.timestamp()), ) ) return deprecated_stats diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py index 7a846c52b..161a07879 100644 --- a/astrbot/dashboard/routes/session_management.py +++ b/astrbot/dashboard/routes/session_management.py @@ -20,6 +20,7 @@ def __init__( core_lifecycle: AstrBotCoreLifecycle, ) -> None: super().__init__(context) + self.db_helper = db_helper self.routes = { "/session/list": ("GET", self.list_sessions), "/session/update_persona": ("POST", self.update_session_persona), @@ -39,22 +40,42 @@ def __init__( async def list_sessions(self): """获取所有会话的列表,包括 persona 和 provider 信息""" try: - preferences = await sp.session_get(umo=None, key="sel_conv_id", default=[]) - session_conversations = {} - for pref in preferences: - session_conversations[pref.scope_id] = pref.value["val"] + page = int(request.args.get("page", 1)) + page_size = int(request.args.get("page_size", 20)) + search_query = request.args.get("search", "") + platform = request.args.get("platform", "") + + # 获取活跃的会话数据(处于对话内的会话) + sessions_data, total = await self.db_helper.get_session_conversations( + page, page_size, search_query, platform + ) + provider_manager = self.core_lifecycle.provider_manager persona_mgr = self.core_lifecycle.persona_mgr personas = persona_mgr.personas_v3 sessions = [] - # 构建会话信息 - for session_id, conversation_id in session_conversations.items(): + # 循环补充非数据库信息,如 provider 和 session 状态 + for data in sessions_data: + session_id = data["session_id"] + conversation_id = data["conversation_id"] + conv_persona_id = data["persona_id"] + title = data["title"] + persona_name = data["persona_name"] + + # 处理 persona 显示 + if conv_persona_id == "[%None]": + persona_name = "无人格" + else: + default_persona = persona_mgr.selected_default_persona_v3 + if default_persona: + persona_name = default_persona["name"] + session_info = { "session_id": session_id, "conversation_id": conversation_id, - "persona_id": None, + "persona_id": persona_name, "chat_provider_id": None, "stt_provider_id": None, "tts_provider_id": None, @@ -79,31 +100,10 @@ async def list_sessions(self): "session_raw_name": session_id.split(":")[2] if session_id.count(":") >= 2 else session_id, + "title": title, } - # 获取对话信息 - conversation = await self.conv_mgr.get_conversation( - unified_msg_origin=session_id, conversation_id=conversation_id - ) - if conversation: - session_info["persona_id"] = conversation.persona_id - - # 查找 persona 名称 - if conversation.persona_id and conversation.persona_id != "[%None]": - for persona in personas: - if persona["name"] == conversation.persona_id: - session_info["persona_id"] = persona["name"] - break - elif conversation.persona_id == "[%None]": - session_info["persona_id"] = "无人格" - else: - # 使用默认人格 - default_persona = persona_mgr.selected_default_persona_v3 - if default_persona: - session_info["persona_id"] = default_persona["name"] - # 获取 provider 信息 - provider_manager = self.core_lifecycle.provider_manager chat_provider = provider_manager.get_using_provider( provider_type=ProviderType.CHAT_COMPLETION, umo=session_id ) @@ -172,6 +172,14 @@ async def list_sessions(self): "available_chat_providers": available_chat_providers, "available_stt_providers": available_stt_providers, "available_tts_providers": available_tts_providers, + "pagination": { + "page": page, + "page_size": page_size, + "total": total, + "total_pages": (total + page_size - 1) // page_size + if page_size > 0 + else 0, + }, } return Response().ok(result).__dict__ diff --git a/dashboard/src/views/SessionManagementPage.vue b/dashboard/src/views/SessionManagementPage.vue index 1704dfab9..5e0fbc8b9 100644 --- a/dashboard/src/views/SessionManagementPage.vue +++ b/dashboard/src/views/SessionManagementPage.vue @@ -4,13 +4,13 @@ {{ tm('sessions.activeSessions') }} - {{ sessions.length }} {{ tm('sessions.sessionCount') }} + {{ totalItems }} {{ tm('sessions.sessionCount') }} + hide-details clearable variant="solo-filled" flat class="me-4" density="compact" @update:model-value="handleSearchChange"> + density="compact" @update:model-value="handlePlatformChange"> @@ -22,8 +22,17 @@ - + - + @@ -357,7 +366,10 @@ export default { filterPlatform: null, // 分页相关 + currentPage: 1, itemsPerPage: 10, + totalItems: 0, + totalPages: 0, // 可用选项 availablePersonas: [], @@ -424,30 +436,6 @@ export default { ] }, - // 懒加载过滤会话 - 使用客户端分页 - filteredSessions() { - let filtered = this.sessions; - - // 搜索筛选 - if (this.searchQuery) { - const query = this.searchQuery.toLowerCase().trim(); - filtered = filtered.filter(session => - session.session_name.toLowerCase().includes(query) || - session.platform.toLowerCase().includes(query) || - session.persona_name?.toLowerCase().includes(query) || - session.chat_provider_name?.toLowerCase().includes(query) || - session.session_id.toLowerCase().includes(query) - ); - } - - // 平台筛选 - if (this.filterPlatform) { - filtered = filtered.filter(session => session.platform === this.filterPlatform); - } - - return filtered; - }, - platformOptions() { const platforms = [...new Set(this.sessions.map(s => s.platform))]; return platforms.map(p => ({ title: p, value: p })); @@ -494,7 +482,20 @@ export default { async loadSessions() { this.loading = true; try { - const response = await axios.get('/api/session/list'); + const params = { + page: this.currentPage, + page_size: this.itemsPerPage + }; + + // 添加搜索和平台筛选参数 + if (this.searchQuery) { + params.search = this.searchQuery; + } + if (this.filterPlatform) { + params.platform = this.filterPlatform; + } + + const response = await axios.get('/api/session/list', { params }); if (response.data.status === 'ok') { const data = response.data.data; this.sessions = data.sessions.map(session => ({ @@ -507,6 +508,13 @@ export default { this.availableChatProviders = data.available_chat_providers; this.availableSttProviders = data.available_stt_providers; this.availableTtsProviders = data.available_tts_providers; + + // 处理分页信息 + if (data.pagination) { + this.totalItems = data.pagination.total; + this.totalPages = data.pagination.total_pages; + this.currentPage = data.pagination.page; + } } else { this.showError(response.data.message || this.tm('messages.loadSessionsError')); } @@ -679,7 +687,7 @@ export default { let totalErrorCount = 0; let allErrorSessions = []; - const sessions = this.filteredSessions; + const sessions = this.sessions; try { // 定义批量操作任务 @@ -936,6 +944,25 @@ export default { session.deleting = false; }, + + // 处理分页更新事件 + handlePaginationUpdate(options) { + this.currentPage = options.page; + this.itemsPerPage = options.itemsPerPage; + this.loadSessions(); + }, + + // 处理搜索变化 + handleSearchChange() { + this.currentPage = 1; // 重置到第一页 + this.loadSessions(); + }, + + // 处理平台筛选变化 + handlePlatformChange() { + this.currentPage = 1; // 重置到第一页 + this.loadSessions(); + }, }, }