diff --git a/backend/shared/domain-common/src/main/java/com/datamate/common/setting/domain/entity/ModelConfig.java b/backend/shared/domain-common/src/main/java/com/datamate/common/setting/domain/entity/ModelConfig.java index 0b153efe9..f0e97e46b 100644 --- a/backend/shared/domain-common/src/main/java/com/datamate/common/setting/domain/entity/ModelConfig.java +++ b/backend/shared/domain-common/src/main/java/com/datamate/common/setting/domain/entity/ModelConfig.java @@ -12,7 +12,7 @@ */ @Getter @Setter -@TableName("t_model_config") +@TableName("t_models") @Builder @ToString @NoArgsConstructor diff --git a/runtime/datamate-python/app/db/models/model_config.py b/runtime/datamate-python/app/db/models/models.py similarity index 71% rename from runtime/datamate-python/app/db/models/model_config.py rename to runtime/datamate-python/app/db/models/models.py index bbea7bed1..9e32cdaf3 100644 --- a/runtime/datamate-python/app/db/models/model_config.py +++ b/runtime/datamate-python/app/db/models/models.py @@ -1,18 +1,12 @@ -from sqlalchemy import Column, String, Integer, TIMESTAMP, select +from sqlalchemy import Boolean, Column, String, TIMESTAMP from app.db.models.base_entity import BaseEntity -async def get_model_by_id(db_session, model_id: str): - """根据 ID 获取单个模型配置。""" - result =await db_session.execute(select(ModelConfig).where(ModelConfig.id == model_id)) - model_config = result.scalar_one_or_none() - return model_config +class Models(BaseEntity): + """模型配置表,对应表 t_models -class ModelConfig(BaseEntity): - """模型配置表,对应表 t_model_config - - CREATE TABLE IF NOT EXISTS t_model_config ( + CREATE TABLE IF NOT EXISTS t_models ( id VARCHAR(36) PRIMARY KEY COMMENT '主键ID', model_name VARCHAR(100) NOT NULL COMMENT '模型名称(如 qwen2)', provider VARCHAR(50) NOT NULL COMMENT '模型提供商(如 Ollama、OpenAI、DeepSeek)', @@ -29,7 +23,7 @@ class ModelConfig(BaseEntity): ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COMMENT ='模型配置表'; """ - __tablename__ = "t_model_config" + __tablename__ = "t_models" id = Column(String(36), primary_key=True, index=True, comment="主键ID") model_name = Column(String(100), nullable=False, comment="模型名称(如 qwen2)") @@ -38,9 +32,9 @@ class ModelConfig(BaseEntity): api_key = Column(String(512), nullable=False, default="", comment="API 密钥(无密钥则为空)") type = Column(String(50), nullable=False, comment="模型类型(如 chat、embedding)") - # 使用 Integer 存储 TINYINT,后续可在业务层将 0/1 转为 bool - is_enabled = Column(Integer, nullable=False, default=1, comment="是否启用:1-启用,0-禁用") - is_default = Column(Integer, nullable=False, default=0, comment="是否默认:1-默认,0-非默认") + is_enabled = Column(Boolean, nullable=False, default=True, comment="是否启用") + is_default = Column(Boolean, nullable=False, default=False, comment="是否默认") + is_deleted = Column(Boolean, nullable=False, default=False, comment="是否删除") __table_args__ = ( # 与 DDL 中的 uk_model_provider 保持一致 diff --git a/runtime/datamate-python/app/module/evaluation/interface/evaluation.py b/runtime/datamate-python/app/module/evaluation/interface/evaluation.py index a32a90142..a718b4194 100644 --- a/runtime/datamate-python/app/module/evaluation/interface/evaluation.py +++ b/runtime/datamate-python/app/module/evaluation/interface/evaluation.py @@ -80,8 +80,8 @@ async def create_evaluation_task( if existing_task.scalar_one_or_none(): raise HTTPException(status_code=400, detail=f"Evaluation task with name '{request.name}' already exists") - model_config = await get_model_by_id(db, request.eval_config.model_id) - if not model_config: + models = await get_model_by_id(db, request.eval_config.model_id) + if not models: raise HTTPException(status_code=400, detail=f"Model with id '{request.eval_config.model_id}' not found") # 创建评估任务 @@ -96,7 +96,7 @@ async def create_evaluation_task( eval_prompt=request.eval_prompt, eval_config=json.dumps({ "modelId": request.eval_config.model_id, - "modelName": model_config.model_name, + "modelName": models.model_name, "dimensions": request.eval_config.dimensions, }), status=TaskStatus.PENDING.value, diff --git a/runtime/datamate-python/app/module/evaluation/service/evaluation.py b/runtime/datamate-python/app/module/evaluation/service/evaluation.py index 7994c0cf0..02d617a62 100644 --- a/runtime/datamate-python/app/module/evaluation/service/evaluation.py +++ b/runtime/datamate-python/app/module/evaluation/service/evaluation.py @@ -43,7 +43,7 @@ def get_eval_prompt(self, item: EvaluationItem) -> str: async def execute(self): eval_config = json.loads(self.task.eval_config) - model_config = await get_model_by_id(self.db, eval_config.get("modelId")) + models = await get_model_by_id(self.db, eval_config.get("modelId")) semaphore = asyncio.Semaphore(10) files = (await self.db.execute( select(EvaluationFile).where(EvaluationFile.task_id == self.task.id) @@ -55,7 +55,7 @@ async def execute(self): for file in files: items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all() tasks = [ - self.evaluate_item(model_config, item, semaphore) + self.evaluate_item(models, item, semaphore) for item in items ] await asyncio.gather(*tasks, return_exceptions=True) @@ -64,13 +64,13 @@ async def execute(self): self.task.eval_process = evaluated_count / total await self.db.commit() - async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore): + async def evaluate_item(self, models, item: EvaluationItem, semaphore: asyncio.Semaphore): async with semaphore: max_try = 3 while max_try > 0: prompt_text = self.get_eval_prompt(item) resp_text = await asyncio.to_thread( - call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name, + call_openai_style_model, models.base_url, models.api_key, models.model_name, prompt_text, ) resp_text = extract_json_substring(resp_text) diff --git a/runtime/datamate-python/app/module/generation/service/generation_service.py b/runtime/datamate-python/app/module/generation/service/generation_service.py index 2a085fd37..f5d0ad985 100644 --- a/runtime/datamate-python/app/module/generation/service/generation_service.py +++ b/runtime/datamate-python/app/module/generation/service/generation_service.py @@ -24,7 +24,8 @@ from app.module.shared.common.document_loaders import load_documents from app.module.shared.common.text_split import DocumentSplitter from app.module.shared.util.model_chat import extract_json_substring -from app.module.system.service.common_service import chat, get_model_by_id, get_chat_client +from app.module.shared.llm import LLMFactory +from app.module.system.service.common_service import get_model_by_id def _filter_docs(split_docs, chunk_size): @@ -171,8 +172,12 @@ async def _process_single_file( # 为本文件构建模型 client question_model = await get_model_by_id(self.db, question_cfg.model_id) answer_model = await get_model_by_id(self.db, answer_cfg.model_id) - question_chat = get_chat_client(question_model) - answer_chat = get_chat_client(answer_model) + question_chat = LLMFactory.create_chat( + question_model.model_name, question_model.base_url, question_model.api_key + ) + answer_chat = LLMFactory.create_chat( + answer_model.model_name, answer_model.base_url, answer_model.api_key + ) # 分批次从 DB 读取并处理 chunk batch_size = 100 @@ -356,7 +361,7 @@ async def _generate_questions_for_one_chunk( loop = asyncio.get_running_loop() raw_answer = await loop.run_in_executor( None, - chat, + LLMFactory.invoke_sync, question_chat, prompt, ) @@ -400,7 +405,7 @@ async def process_single_question(question: str): loop = asyncio.get_running_loop() answer = await loop.run_in_executor( None, - chat, + LLMFactory.invoke_sync, answer_chat, prompt_local, ) diff --git a/runtime/datamate-python/app/module/rag/interface/rag_interface.py b/runtime/datamate-python/app/module/rag/interface/rag_interface.py index d7bd907b5..b66af6463 100644 --- a/runtime/datamate-python/app/module/rag/interface/rag_interface.py +++ b/runtime/datamate-python/app/module/rag/interface/rag_interface.py @@ -9,12 +9,12 @@ router = APIRouter(prefix="/rag", tags=["rag"]) @router.post("/process/{knowledge_base_id}") -async def process_knowledge_base(knowledge_base_id: str, db: AsyncSession = Depends(get_db)): +async def process_knowledge_base(knowledge_base_id: str, rag_service: RAGService = Depends()): """ Process all unprocessed files in a knowledge base. """ try: - await RAGService(db).init_graph_rag(knowledge_base_id) + await rag_service.init_graph_rag(knowledge_base_id) return StandardResponse( code=200, message="Processing started for knowledge base.", @@ -24,12 +24,11 @@ async def process_knowledge_base(knowledge_base_id: str, db: AsyncSession = Depe raise HTTPException(status_code=500, detail=str(e)) @router.post("/query") -async def query_knowledge_graph(payload: QueryRequest, db: AsyncSession = Depends(get_db)): +async def query_knowledge_graph(payload: QueryRequest, rag_service: RAGService = Depends()): """ Query the knowledge graph with the given query text and knowledge base ID. """ try: - rag_service = RAGService(db) result = await rag_service.query_rag(payload.query, payload.knowledge_base_id) return StandardResponse(code=200, message="success", data=result) except HTTPException: diff --git a/runtime/datamate-python/app/module/rag/service/rag_service.py b/runtime/datamate-python/app/module/rag/service/rag_service.py index 67cdfc15d..1af9e49ba 100644 --- a/runtime/datamate-python/app/module/rag/service/rag_service.py +++ b/runtime/datamate-python/app/module/rag/service/rag_service.py @@ -2,14 +2,13 @@ import asyncio from typing import Optional, Sequence -from fastapi import BackgroundTasks, Depends +from fastapi import Depends from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger from app.db.models.dataset_management import DatasetFiles from app.db.models.knowledge_gen import RagFile, RagKnowledgeBase -from app.db.models.model_config import ModelConfig from app.db.session import get_db, AsyncSessionLocal from app.module.shared.common.document_loaders import load_documents from .graph_rag import ( @@ -18,7 +17,8 @@ build_llm_model_func, initialize_rag, ) -from ...system.service.common_service import get_embedding_dimension, get_openai_client +from app.module.shared.llm import LLMFactory +from ...system.service.common_service import get_model_by_id logger = get_logger(__name__) @@ -27,10 +27,10 @@ class RAGService: def __init__( self, db: AsyncSession = Depends(get_db), - background_tasks: BackgroundTasks | None = None, + ): self.db = db - self.background_tasks = background_tasks + self.background_tasks = None self.rag = None async def get_unprocessed_files(self, knowledge_base_id: str) -> Sequence[RagFile]: @@ -44,8 +44,8 @@ async def get_unprocessed_files(self, knowledge_base_id: str) -> Sequence[RagFil async def init_graph_rag(self, knowledge_base_id: str): kb = await self._get_knowledge_base(knowledge_base_id) - embedding_model = await self._get_model_config(kb.embedding_model) - chat_model = await self._get_model_config(kb.chat_model) + embedding_model = await self._get_models(kb.embedding_model) + chat_model = await self._get_models(kb.chat_model) llm_callable = await build_llm_model_func( chat_model.model_name, chat_model.base_url, chat_model.api_key @@ -54,7 +54,9 @@ async def init_graph_rag(self, knowledge_base_id: str): embedding_model.model_name, embedding_model.base_url, embedding_model.api_key, - embedding_dim=get_embedding_dimension(get_openai_client(embedding_model)), + embedding_dim=LLMFactory.get_embedding_dimension( + embedding_model.model_name, embedding_model.base_url, embedding_model.api_key + ), ) kb_working_dir = os.path.join(DEFAULT_WORKING_DIR, kb.name) @@ -124,14 +126,13 @@ async def _get_knowledge_base(self, knowledge_base_id: str): raise ValueError(f"Knowledge base with ID {knowledge_base_id} not found.") return knowledge_base - async def _get_model_config(self, model_id: Optional[str]): + async def _get_models(self, model_id: Optional[str]): if not model_id: raise ValueError("Model ID is required for initializing RAG.") - result = await self.db.execute(select(ModelConfig).where(ModelConfig.id == model_id)) - model = result.scalars().first() - if not model: - raise ValueError(f"Model config with ID {model_id} not found.") - return model + models = await get_model_by_id(self.db, model_id) + if not models: + raise ValueError(f"Models with ID {model_id} not found.") + return models async def query_rag(self, query: str, knowledge_base_id: str) -> str: if not self.rag: diff --git a/runtime/datamate-python/app/module/shared/llm/__init__.py b/runtime/datamate-python/app/module/shared/llm/__init__.py new file mode 100644 index 000000000..8222271e5 --- /dev/null +++ b/runtime/datamate-python/app/module/shared/llm/__init__.py @@ -0,0 +1,7 @@ +# app/core/llm/__init__.py +""" +LangChain 模型工厂:统一创建 Chat、Embedding 及健康检查,便于各模块复用。 +""" +from .factory import LLMFactory + +__all__ = ["LLMFactory"] diff --git a/runtime/datamate-python/app/module/shared/llm/factory.py b/runtime/datamate-python/app/module/shared/llm/factory.py new file mode 100644 index 000000000..4713600df --- /dev/null +++ b/runtime/datamate-python/app/module/shared/llm/factory.py @@ -0,0 +1,71 @@ +# app/core/llm/factory.py +""" +LangChain 模型工厂:基于 OpenAI 兼容接口封装 Chat / Embedding 的创建、健康检查与同步调用。 +便于模型配置、RAG、生成、评估等模块统一使用,避免分散的 get_chat_client / get_openai_client。 +""" +from typing import Literal + +from langchain_core.language_models import BaseChatModel +from langchain_core.embeddings import Embeddings +from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from pydantic import SecretStr + + +class LLMFactory: + """基于 LangChain 的 Chat / Embedding 工厂,面向 OpenAI 兼容 API。""" + + @staticmethod + def create_chat( + model_name: str, + base_url: str, + api_key: str | None = None, + ) -> BaseChatModel: + """创建对话模型,兼容 OpenAI 及任意 base_url 的 OpenAI 兼容服务。""" + return ChatOpenAI( + model=model_name, + base_url=base_url or None, + api_key=SecretStr(api_key or ""), + ) + + @staticmethod + def create_embedding( + model_name: str, + base_url: str, + api_key: str | None = None, + ) -> Embeddings: + """创建嵌入模型,兼容 OpenAI 及任意 base_url 的 OpenAI 兼容服务。""" + return OpenAIEmbeddings( + model=model_name, + base_url=base_url or None, + api_key=SecretStr(api_key or ""), + ) + + @staticmethod + def check_health( + model_name: str, + base_url: str, + api_key: str | None, + model_type: Literal["CHAT", "EMBEDDING"] | str, + ) -> None: + """对配置做一次最小化调用进行健康检查,失败则抛出。""" + if model_type == "CHAT": + model = LLMFactory.create_chat(model_name, base_url, api_key) + model.invoke("hello") + else: + model = LLMFactory.create_embedding(model_name, base_url, api_key) + model.embed_query("text") + + @staticmethod + def get_embedding_dimension( + model_name: str, + base_url: str, + api_key: str | None = None, + ) -> int: + """创建 Embedding 模型并返回向量维度。""" + emb = LLMFactory.create_embedding(model_name, base_url, api_key) + return len(emb.embed_query("text")) + + @staticmethod + def invoke_sync(chat_model: BaseChatModel, prompt: str) -> str: + """同步调用对话模型并返回 content,供 run_in_executor 等场景使用。""" + return chat_model.invoke(prompt).content diff --git a/runtime/datamate-python/app/module/system/interface/__init__.py b/runtime/datamate-python/app/module/system/interface/__init__.py index acd102360..15438e988 100644 --- a/runtime/datamate-python/app/module/system/interface/__init__.py +++ b/runtime/datamate-python/app/module/system/interface/__init__.py @@ -1,7 +1,9 @@ from fastapi import APIRouter from .about import router as about_router +from app.module.system.interface.models import router as models_router router = APIRouter() -router.include_router(about_router) \ No newline at end of file +router.include_router(about_router) +router.include_router(models_router) diff --git a/runtime/datamate-python/app/module/system/interface/models.py b/runtime/datamate-python/app/module/system/interface/models.py new file mode 100644 index 000000000..f268a9516 --- /dev/null +++ b/runtime/datamate-python/app/module/system/interface/models.py @@ -0,0 +1,75 @@ +from fastapi import APIRouter, Depends, Query + +from app.module.shared.schema import StandardResponse, PaginatedData +from app.module.system.schema.models import ( + CreateModelRequest, + QueryModelRequest, + ModelsResponse, + ProviderItem, + ModelType, +) +from app.module.system.service.models_service import ModelsService + +router = APIRouter(prefix="/models", tags=["models"]) + + +@router.get("/providers", response_model=StandardResponse[list[ProviderItem]]) +async def get_providers(svc: ModelsService = Depends()): + """获取厂商列表,与 Java GET /models/providers 一致。""" + data = await svc.get_providers() + return StandardResponse(code=200, message="success", data=data) + + +@router.get("/list", response_model=StandardResponse[PaginatedData[ModelsResponse]]) +async def get_models( + page: int = Query(0, ge=0, description="页码,从 0 开始"), + size: int = Query(20, gt=0, le=500, description="每页大小"), + provider: str | None = Query(None, description="模型提供商"), + type: ModelType | None = Query(None, description="模型类型"), + isEnabled: bool | None = Query(None, description="是否启用"), + isDefault: bool | None = Query(None, description="是否默认"), + svc: ModelsService = Depends(), +): + """分页查询模型列表,与 Java GET /models/list 一致。""" + q = QueryModelRequest( + page=page, + size=size, + provider=provider, + type=type, + isEnabled=isEnabled, + isDefault=isDefault, + ) + data = await svc.get_models(q) + return StandardResponse(code=200, message="success", data=data) + + +@router.post("/create", response_model=StandardResponse[ModelsResponse]) +async def create_model(req: CreateModelRequest, svc: ModelsService = Depends()): + """创建模型配置,与 Java POST /models/create 一致。""" + data = await svc.create_model(req) + return StandardResponse(code=200, message="success", data=data) + + +@router.get("/{model_id}", response_model=StandardResponse[ModelsResponse]) +async def get_model_detail(model_id: str, svc: ModelsService = Depends()): + """获取模型详情,与 Java GET /models/{modelId} 一致。""" + data = await svc.get_model_detail(model_id) + return StandardResponse(code=200, message="success", data=data) + + +@router.put("/{model_id}", response_model=StandardResponse[ModelsResponse]) +async def update_model( + model_id: str, + req: CreateModelRequest, + svc: ModelsService = Depends(), +): + """更新模型配置,与 Java PUT /models/{modelId} 一致。""" + data = await svc.update_model(model_id, req) + return StandardResponse(code=200, message="success", data=data) + + +@router.delete("/{model_id}", response_model=StandardResponse[None]) +async def delete_model(model_id: str, svc: ModelsService = Depends()): + """删除模型配置,与 Java DELETE /models/{modelId} 一致。""" + await svc.delete_model(model_id) + return StandardResponse(code=200, message="success", data=None) diff --git a/runtime/datamate-python/app/module/system/schema/__init__.py b/runtime/datamate-python/app/module/system/schema/__init__.py index 4c5f542fa..9ae5e6dbe 100644 --- a/runtime/datamate-python/app/module/system/schema/__init__.py +++ b/runtime/datamate-python/app/module/system/schema/__init__.py @@ -1,3 +1,17 @@ from .health import HealthResponse - -__all__ = ["HealthResponse"] \ No newline at end of file +from .models import ( + ModelType, + CreateModelRequest, + QueryModelRequest, + ModelsResponse, + ProviderItem, +) + +__all__ = [ + "HealthResponse", + "ModelType", + "CreateModelRequest", + "QueryModelRequest", + "ModelsResponse", + "ProviderItem", +] diff --git a/runtime/datamate-python/app/module/system/schema/models.py b/runtime/datamate-python/app/module/system/schema/models.py new file mode 100644 index 000000000..bfa5d4ae6 --- /dev/null +++ b/runtime/datamate-python/app/module/system/schema/models.py @@ -0,0 +1,71 @@ +""" +模型配置 DTO:与 Java 版 t_models 接口保持一致。 +""" +from datetime import datetime +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class ModelType(str, Enum): + """模型类型枚举,与 Java ModelType 一致。""" + CHAT = "CHAT" + EMBEDDING = "EMBEDDING" + + +# --- 请求 DTO --- + + +class CreateModelRequest(BaseModel): + """创建/更新模型配置请求,与 Java CreateModelRequest 一致。""" + model_config = ConfigDict(str_strip_whitespace=True, extra="forbid") + + modelName: str = Field(..., min_length=1, description="模型名称(如 qwen2)") + provider: str = Field(..., min_length=1, description="模型提供商(如 Ollama、OpenAI、DeepSeek)") + baseUrl: str = Field(..., min_length=1, description="API 基础地址") + apiKey: Optional[str] = Field(None, description="API 密钥(无密钥则为空)") + type: ModelType = Field(..., description="模型类型(如 chat、embedding)") + isEnabled: Optional[bool] = Field(None, description="是否启用:1-启用,0-禁用") + isDefault: Optional[bool] = Field(None, description="是否默认:1-默认,0-非默认") + + +class QueryModelRequest(BaseModel): + """模型查询请求,与 Java QueryModelRequest + PagingQuery 一致。page 从 0 开始,size 默认 20。""" + model_config = ConfigDict(str_strip_whitespace=True, extra="forbid") + + page: int = Field(0, ge=0, description="页码,从 0 开始") + size: int = Field(20, gt=0, le=500, description="每页大小") + provider: Optional[str] = Field(None, description="模型提供商") + type: Optional[ModelType] = Field(None, description="模型类型") + isEnabled: Optional[bool] = Field(None, description="是否启用") + isDefault: Optional[bool] = Field(None, description="是否默认") + + +# --- 响应 DTO --- + + +class ModelsResponse(BaseModel): + """模型配置响应,与 Java Models 实体字段一致。""" + model_config = ConfigDict(str_strip_whitespace=True, extra="forbid") + + id: str + modelName: str + provider: str + baseUrl: str + apiKey: str = "" + type: str # "CHAT" | "EMBEDDING" + isEnabled: bool = True + isDefault: bool = False + createdAt: Optional[datetime] = None + updatedAt: Optional[datetime] = None + createdBy: Optional[str] = None + updatedBy: Optional[str] = None + + +class ProviderItem(BaseModel): + """厂商项,仅含 provider、baseUrl,与 Java getProviders() 返回结构一致。""" + model_config = ConfigDict(str_strip_whitespace=True, extra="forbid") + + provider: str + baseUrl: str diff --git a/runtime/datamate-python/app/module/system/service/common_service.py b/runtime/datamate-python/app/module/system/service/common_service.py index 2cf5578bb..a74be5f42 100644 --- a/runtime/datamate-python/app/module/system/service/common_service.py +++ b/runtime/datamate-python/app/module/system/service/common_service.py @@ -1,43 +1,16 @@ +"""通用系统服务:仅保留与 DB 直接相关的查询。LLM 创建与调用统一使用 app.core.llm.LLMFactory。""" from typing import Optional -from langchain_core.language_models import BaseChatModel -from langchain_openai import ChatOpenAI, OpenAIEmbeddings -from pydantic import SecretStr from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.db.models.model_config import ModelConfig +from app.db.models.models import Models -async def get_model_by_id(db: AsyncSession, model_id: str) -> Optional[ModelConfig]: - """根据模型ID获取 ModelConfig 记录。""" - result = await db.execute(select(ModelConfig).where(ModelConfig.id == model_id)) - return result.scalar_one_or_none() - - -def get_chat_client(model: ModelConfig) -> BaseChatModel: - return ChatOpenAI( - model=model.model_name, - base_url=model.base_url, - api_key=SecretStr(model.api_key), - ) - - -def chat(model: BaseChatModel, prompt: str) -> str: - """使用指定模型进行聊天""" - response = model.invoke(prompt) - return response.content - - -# 实例化对象 -def get_openai_client(model: ModelConfig) -> OpenAIEmbeddings: - return OpenAIEmbeddings( - model=model.model_name, - base_url=model.base_url, - api_key=SecretStr(model.api_key), +async def get_model_by_id(db: AsyncSession, model_id: str) -> Optional[Models]: + """根据模型ID获取未删除的 Models 记录。""" + q = select(Models).where(Models.id == model_id).where( + (Models.is_deleted == False) | (Models.is_deleted.is_(None)) ) - -# 获取嵌入向量维度 -def get_embedding_dimension(model: OpenAIEmbeddings) -> int: - """获取 OpenAI 模型的嵌入向量维度""" - return len(model.embed_query(model.model)) + result = await db.execute(q) + return result.scalar_one_or_none() diff --git a/runtime/datamate-python/app/module/system/service/models_service.py b/runtime/datamate-python/app/module/system/service/models_service.py new file mode 100644 index 000000000..874416b72 --- /dev/null +++ b/runtime/datamate-python/app/module/system/service/models_service.py @@ -0,0 +1,206 @@ +import math +import uuid +from datetime import datetime, timezone + +from fastapi import Depends, HTTPException +from sqlalchemy import select, func, update +from sqlalchemy.ext.asyncio import AsyncSession + +from app.module.shared.llm import LLMFactory +from app.core.logging import get_logger +from app.db.models.models import Models +from app.db.session import get_db +from app.module.shared.schema import PaginatedData +from app.module.system.schema.models import ( + CreateModelRequest, + QueryModelRequest, + ModelsResponse, + ProviderItem, +) + +logger = get_logger(__name__) + +# 固定厂商列表 +PROVIDERS = [ + ProviderItem(provider="ModelEngine", baseUrl="http://localhost:9981"), + ProviderItem(provider="Ollama", baseUrl="http://localhost:11434"), + ProviderItem(provider="OpenAI", baseUrl="https://api.openai.com/v1"), + ProviderItem(provider="DeepSeek", baseUrl="https://api.deepseek.com/v1"), + ProviderItem(provider="火山方舟", baseUrl="https://ark.cn-beijing.volces.com/api/v3"), + ProviderItem(provider="阿里云百炼", baseUrl="https://dashscope.aliyuncs.com/compatible-mode/v1"), + ProviderItem(provider="硅基流动", baseUrl="https://api.siliconflow.cn/v1"), + ProviderItem(provider="智谱AI", baseUrl="https://open.bigmodel.cn/api/paas/v4"), + ProviderItem(provider="自定义模型", baseUrl=""), +] + + +def _orm_to_response(row: Models) -> ModelsResponse: + return ModelsResponse( + id=row.id, + modelName=row.model_name, + provider=row.provider, + baseUrl=row.base_url, + apiKey=row.api_key or "", + type=row.type or "CHAT", + isEnabled=bool(row.is_enabled) if row.is_enabled is not None else True, + isDefault=bool(row.is_default) if row.is_default is not None else False, + createdAt=row.created_at, + updatedAt=row.updated_at, + createdBy=row.created_by, + updatedBy=row.updated_by, + ) + + +class ModelsService: + """模型配置服务:db 通过 FastAPI Depends(get_db) 注入,不在路由中传递。""" + + def __init__(self, db: AsyncSession = Depends(get_db)): + self.db = db + + async def get_providers(self) -> list[ProviderItem]: + """返回固定厂商列表,与 Java getProviders() 一致。""" + return list(PROVIDERS) + + async def get_models(self, q: QueryModelRequest) -> PaginatedData[ModelsResponse]: + """分页查询,支持 provider/type/isEnabled/isDefault;排除已删除;page 从 0 开始。""" + query = select(Models).where( + (Models.is_deleted == False) | (Models.is_deleted.is_(None)) + ) + if q.provider is not None and q.provider != "": + query = query.where(Models.provider == q.provider) + if q.type is not None: + query = query.where(Models.type == q.type.value) + if q.isEnabled is not None: + query = query.where(Models.is_enabled == q.isEnabled) + if q.isDefault is not None: + query = query.where(Models.is_default == q.isDefault) + + total = (await self.db.execute(select(func.count()).select_from(query.subquery()))).scalar_one() + size = max(1, min(500, q.size)) + offset = max(0, q.page) * size + rows = ( + await self.db.execute( + query.order_by(Models.created_at.desc()).offset(offset).limit(size) + ) + ).scalars().all() + total_pages = math.ceil(total / size) if total else 0 + current = max(0, q.page) + 1 + return PaginatedData( + page=current, + size=size, + total_elements=total, + total_pages=total_pages, + content=[_orm_to_response(r) for r in rows], + ) + + async def get_model_detail(self, model_id: str) -> ModelsResponse: + """获取模型详情,已删除或不存在则 404。""" + query = select(Models).where(Models.id == model_id).where( + (Models.is_deleted == False) | (Models.is_deleted.is_(None)) + ) + r = (await self.db.execute(query)).scalar_one_or_none() + if not r: + raise HTTPException(status_code=404, detail="模型配置不存在") + return _orm_to_response(r) + + async def create_model(self, req: CreateModelRequest) -> ModelsResponse: + """创建模型:健康检查后 saveAndSetDefault;isEnabled 恒为 True。""" + try: + LLMFactory.check_health(req.modelName, req.baseUrl, req.apiKey, req.type.value) + except Exception as e: + logger.error("Model health check failed: model=%s type=%s err=%s", req.modelName, req.type, e) + raise HTTPException(status_code=400, detail="模型健康检查失败") from e + + existing = ( + await self.db.execute( + select(Models).where( + (Models.is_deleted == False) | (Models.is_deleted.is_(None)), + Models.type == req.type.value, + Models.is_default == True, + ) + ) + ).scalar_one_or_none() + + is_default: bool + if existing is None: + is_default = True + else: + await self.db.execute( + update(Models) + .where(Models.type == req.type.value, Models.is_default == True) + .values(is_default=False) + ) + is_default = req.isDefault if req.isDefault is not None else False + + now = datetime.now(timezone.utc).replace(tzinfo=None) + entity = Models( + id=str(uuid.uuid4()), + model_name=req.modelName, + provider=req.provider, + base_url=req.baseUrl, + api_key=req.apiKey or "", + type=req.type.value, + is_enabled=True, + is_default=is_default, + is_deleted=False, + created_at=now, + updated_at=now, + ) + self.db.add(entity) + await self.db.commit() + await self.db.refresh(entity) + return _orm_to_response(entity) + + async def update_model(self, model_id: str, req: CreateModelRequest) -> ModelsResponse: + """更新模型:存在性校验、健康检查后 updateAndSetDefault;isEnabled 恒为 True。""" + res = await self.db.execute( + select(Models).where(Models.id == model_id).where( + (Models.is_deleted == False) | (Models.is_deleted.is_(None)) + ) + ) + entity = res.scalar_one_or_none() + if not entity: + raise HTTPException(status_code=404, detail="模型配置不存在") + + try: + LLMFactory.check_health(req.modelName, req.baseUrl, req.apiKey, req.type.value) + except Exception as e: + logger.error("Model health check failed: model=%s type=%s err=%s", req.modelName, req.type, e) + raise HTTPException(status_code=400, detail="模型健康检查失败") from e + + entity.model_name = req.modelName + entity.provider = req.provider + entity.base_url = req.baseUrl + entity.api_key = req.apiKey or "" + entity.type = req.type.value + entity.is_enabled = True + + want_default = req.isDefault if req.isDefault is not None else False + if (entity.is_default is not True) and want_default: + await self.db.execute( + update(Models) + .where(Models.type == req.type.value, Models.is_default == True) + .values(is_default=False) + ) + entity.is_default = want_default + entity.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + + await self.db.commit() + await self.db.refresh(entity) + return _orm_to_response(entity) + + async def delete_model(self, model_id: str) -> None: + """软删除模型配置。""" + entity = ( + await self.db.execute( + select(Models).where(Models.id == model_id).where( + (Models.is_deleted == False) | (Models.is_deleted.is_(None)) + ) + ) + ).scalar_one_or_none() + if not entity: + raise HTTPException(status_code=404, detail="模型配置不存在") + entity.is_deleted = True + entity.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + await self.db.commit() + await self.db.refresh(entity) diff --git a/runtime/datamate-python/poetry.lock b/runtime/datamate-python/poetry.lock index e3e20bad4..1556f130c 100644 --- a/runtime/datamate-python/poetry.lock +++ b/runtime/datamate-python/poetry.lock @@ -1301,6 +1301,18 @@ files = [ [package.extras] all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] +[[package]] +name = "iniconfig" +version = "2.3.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12"}, + {file = "iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730"}, +] + [[package]] name = "jiter" version = "0.12.0" @@ -2822,6 +2834,22 @@ all = ["pipmaster[audit]", "pipmaster[dev]"] audit = ["pip-audit (>=2.5.0)"] dev = ["black", "cowsay", "mypy (>=1.0)", "myst-parser (>=0.17)", "pip-audit (>=2.5.0)", "pytest (>=7.0)", "pytest-asyncio (>=0.20)", "ruff (>=0.1.0)", "sphinx (>=5.0)", "sphinx-rtd-theme (>=1.0)"] +[[package]] +name = "pluggy" +version = "1.6.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"}, + {file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["coverage", "pytest", "pytest-benchmark"] + [[package]] name = "propcache" version = "0.4.1" @@ -3293,6 +3321,28 @@ files = [ {file = "pypinyin-0.55.0.tar.gz", hash = "sha256:b5711b3a0c6f76e67408ec6b2e3c4987a3a806b7c528076e7c7b86fcf0eaa66b"}, ] +[[package]] +name = "pytest" +version = "9.0.2" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b"}, + {file = "pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11"}, +] + +[package.dependencies] +colorama = {version = ">=0.4", markers = "sys_platform == \"win32\""} +iniconfig = ">=1.0.1" +packaging = ">=22" +pluggy = ">=1.5,<2" +pygments = ">=2.7.2" + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -5379,4 +5429,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.12,<4.0.0" -content-hash = "e48bb875f1482eaca13359f7673b202ae84f5e8b95273ca28372f5d427bc5d7e" +content-hash = "9ff6942c5f12288e198a6a9aff9e0fd6550403b9fe36b694192cf771334a8f70" diff --git a/runtime/datamate-python/pyproject.toml b/runtime/datamate-python/pyproject.toml index ab43e50a8..14407902a 100644 --- a/runtime/datamate-python/pyproject.toml +++ b/runtime/datamate-python/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "fastapi-mcp (>=0.4.0,<0.5.0)", "asyncpg (>=0.31.0,<0.32.0)", "lightrag-hku (==1.4.9.8)", + "pytest (>=9.0.2,<10.0.0)", ] diff --git a/scripts/db/setting-management-init.sql b/scripts/db/setting-management-init.sql index 138cd870b..cc0ddbd8d 100644 --- a/scripts/db/setting-management-init.sql +++ b/scripts/db/setting-management-init.sql @@ -2,7 +2,7 @@ \c datamate; -- 模型配置表 -CREATE TABLE IF NOT EXISTS t_model_config +CREATE TABLE IF NOT EXISTS t_models ( id VARCHAR(36) PRIMARY KEY, model_name VARCHAR(100) NOT NULL, @@ -12,6 +12,7 @@ CREATE TABLE IF NOT EXISTS t_model_config type VARCHAR(50) NOT NULL, is_enabled BOOLEAN DEFAULT TRUE, is_default BOOLEAN DEFAULT FALSE, + is_deleted BOOLEAN DEFAULT FALSE, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_by VARCHAR(255), @@ -19,26 +20,27 @@ CREATE TABLE IF NOT EXISTS t_model_config ); -- 添加注释 -COMMENT ON TABLE t_model_config IS '模型配置表'; -COMMENT ON COLUMN t_model_config.id IS '主键ID'; -COMMENT ON COLUMN t_model_config.model_name IS '模型名称(如 qwen2)'; -COMMENT ON COLUMN t_model_config.provider IS '模型提供商(如 Ollama、OpenAI、DeepSeek)'; -COMMENT ON COLUMN t_model_config.base_url IS 'API 基础地址'; -COMMENT ON COLUMN t_model_config.api_key IS 'API 密钥(无密钥则为空)'; -COMMENT ON COLUMN t_model_config.type IS '模型类型(如 chat、embedding)'; -COMMENT ON COLUMN t_model_config.is_enabled IS '是否启用:1-启用,0-禁用'; -COMMENT ON COLUMN t_model_config.is_default IS '是否默认:1-默认,0-非默认'; -COMMENT ON COLUMN t_model_config.created_at IS '创建时间'; -COMMENT ON COLUMN t_model_config.updated_at IS '更新时间'; -COMMENT ON COLUMN t_model_config.created_by IS '创建者'; -COMMENT ON COLUMN t_model_config.updated_by IS '更新者'; +COMMENT ON TABLE t_models IS '模型配置表'; +COMMENT ON COLUMN t_models.id IS '主键ID'; +COMMENT ON COLUMN t_models.model_name IS '模型名称(如 qwen2)'; +COMMENT ON COLUMN t_models.provider IS '模型提供商(如 Ollama、OpenAI、DeepSeek)'; +COMMENT ON COLUMN t_models.base_url IS 'API 基础地址'; +COMMENT ON COLUMN t_models.api_key IS 'API 密钥(无密钥则为空)'; +COMMENT ON COLUMN t_models.type IS '模型类型(如 chat、embedding)'; +COMMENT ON COLUMN t_models.is_enabled IS '是否启用:1-启用,0-禁用'; +COMMENT ON COLUMN t_models.is_default IS '是否默认:1-默认,0-非默认'; +COMMENT ON COLUMN t_models.is_deleted IS '是否删除:1-已删除,0-未删除'; +COMMENT ON COLUMN t_models.created_at IS '创建时间'; +COMMENT ON COLUMN t_models.updated_at IS '更新时间'; +COMMENT ON COLUMN t_models.created_by IS '创建者'; +COMMENT ON COLUMN t_models.updated_by IS '更新者'; -- 添加唯一约束 -ALTER TABLE t_model_config +ALTER TABLE t_models ADD CONSTRAINT uk_model_provider - UNIQUE (model_name, provider); + UNIQUE (model_name, provider, created_by); -COMMENT ON CONSTRAINT uk_model_provider ON t_model_config +COMMENT ON CONSTRAINT uk_model_provider ON t_models IS '避免同一提供商下模型名称重复'; -- 创建触发器用于自动更新 updated_at @@ -50,9 +52,9 @@ BEGIN END; $$ language 'plpgsql'; -DROP TRIGGER IF EXISTS update_t_model_config_updated_at ON t_model_config; +DROP TRIGGER IF EXISTS update_t_model_config_updated_at ON t_models; CREATE TRIGGER update_t_model_config_updated_at - BEFORE UPDATE ON t_model_config + BEFORE UPDATE ON t_models FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();