From e015513f1fd7516446e8ea5c79ef06dde2216650 Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Wed, 21 Jan 2026 12:05:44 +0800 Subject: [PATCH 1/3] feat: add model configuration routes and update schema exports --- .../app/module/system/interface/__init__.py | 4 +- .../module/system/interface/model_config.py | 81 +++++++ .../app/module/system/schema/__init__.py | 18 +- .../app/module/system/schema/model_config.py | 63 ++++++ .../system/service/model_config_service.py | 209 ++++++++++++++++++ 5 files changed, 372 insertions(+), 3 deletions(-) create mode 100644 runtime/datamate-python/app/module/system/interface/model_config.py create mode 100644 runtime/datamate-python/app/module/system/schema/model_config.py create mode 100644 runtime/datamate-python/app/module/system/service/model_config_service.py diff --git a/runtime/datamate-python/app/module/system/interface/__init__.py b/runtime/datamate-python/app/module/system/interface/__init__.py index acd102360..3fdc4091e 100644 --- a/runtime/datamate-python/app/module/system/interface/__init__.py +++ b/runtime/datamate-python/app/module/system/interface/__init__.py @@ -1,7 +1,9 @@ from fastapi import APIRouter from .about import router as about_router +from .model_config import router as model_config_router router = APIRouter() -router.include_router(about_router) \ No newline at end of file +router.include_router(about_router) +router.include_router(model_config_router) \ No newline at end of file diff --git a/runtime/datamate-python/app/module/system/interface/model_config.py b/runtime/datamate-python/app/module/system/interface/model_config.py new file mode 100644 index 000000000..614f704ab --- /dev/null +++ b/runtime/datamate-python/app/module/system/interface/model_config.py @@ -0,0 +1,81 @@ +""" +模型配置 REST 接口:与 Java ModelConfigController 路径、语义一致,响应使用 StandardResponse。 +""" +from fastapi import APIRouter, Depends, Query + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.session import get_db +from app.module.shared.schema import StandardResponse, PaginatedData +from app.module.system.schema import ( + CreateModelRequest, + QueryModelRequest, + ModelConfigResponse, + ProviderItem, + ModelType, +) +from app.module.system.service import model_config_service + +router = APIRouter(prefix="/models", tags=["models"]) + + +@router.get("/providers", response_model=StandardResponse[list[ProviderItem]]) +async def get_providers(): + """获取厂商列表,与 Java GET /models/providers 一致。""" + data = await model_config_service.get_providers() + return StandardResponse(code=200, message="success", data=data) + + +@router.get("/list", response_model=StandardResponse[PaginatedData[ModelConfigResponse]]) +async def get_models( + page: int = Query(0, ge=0, description="页码,从 0 开始"), + size: int = Query(20, gt=0, le=500, description="每页大小"), + provider: str | None = Query(None, description="模型提供商"), + type: ModelType | None = Query(None, description="模型类型"), + isEnabled: bool | None = Query(None, description="是否启用"), + isDefault: bool | None = Query(None, description="是否默认"), + db: AsyncSession = Depends(get_db), +): + """分页查询模型列表,与 Java GET /models/list 一致。""" + q = QueryModelRequest( + page=page, + size=size, + provider=provider, + type=type, + isEnabled=isEnabled, + isDefault=isDefault, + ) + data = await model_config_service.get_models(db, q) + return StandardResponse(code=200, message="success", data=data) + + +@router.post("/create", response_model=StandardResponse[ModelConfigResponse]) +async def create_model(req: CreateModelRequest, db: AsyncSession = Depends(get_db)): + """创建模型配置,与 Java POST /models/create 一致。""" + data = await model_config_service.create_model(db, req) + return StandardResponse(code=200, message="success", data=data) + + +@router.get("/{model_id}", response_model=StandardResponse[ModelConfigResponse]) +async def get_model_detail(model_id: str, db: AsyncSession = Depends(get_db)): + """获取模型详情,与 Java GET /models/{modelId} 一致。""" + data = await model_config_service.get_model_detail(db, model_id) + return StandardResponse(code=200, message="success", data=data) + + +@router.put("/{model_id}", response_model=StandardResponse[ModelConfigResponse]) +async def update_model( + model_id: str, + req: CreateModelRequest, + db: AsyncSession = Depends(get_db), +): + """更新模型配置,与 Java PUT /models/{modelId} 一致。""" + data = await model_config_service.update_model(db, model_id, req) + return StandardResponse(code=200, message="success", data=data) + + +@router.delete("/{model_id}", response_model=StandardResponse[None]) +async def delete_model(model_id: str, db: AsyncSession = Depends(get_db)): + """删除模型配置,与 Java DELETE /models/{modelId} 一致。""" + await model_config_service.delete_model(db, model_id) + return StandardResponse(code=200, message="success", data=None) diff --git a/runtime/datamate-python/app/module/system/schema/__init__.py b/runtime/datamate-python/app/module/system/schema/__init__.py index 4c5f542fa..70372b855 100644 --- a/runtime/datamate-python/app/module/system/schema/__init__.py +++ b/runtime/datamate-python/app/module/system/schema/__init__.py @@ -1,3 +1,17 @@ from .health import HealthResponse - -__all__ = ["HealthResponse"] \ No newline at end of file +from .model_config import ( + ModelType, + CreateModelRequest, + QueryModelRequest, + ModelConfigResponse, + ProviderItem, +) + +__all__ = [ + "HealthResponse", + "ModelType", + "CreateModelRequest", + "QueryModelRequest", + "ModelConfigResponse", + "ProviderItem", +] \ No newline at end of file diff --git a/runtime/datamate-python/app/module/system/schema/model_config.py b/runtime/datamate-python/app/module/system/schema/model_config.py new file mode 100644 index 000000000..bef559cdb --- /dev/null +++ b/runtime/datamate-python/app/module/system/schema/model_config.py @@ -0,0 +1,63 @@ +""" +模型配置 DTO:与 Java 版 t_model_config 接口保持一致。 +""" +from datetime import datetime +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + + +class ModelType(str, Enum): + """模型类型枚举,与 Java ModelType 一致。""" + CHAT = "CHAT" + EMBEDDING = "EMBEDDING" + + +# --- 请求 DTO --- + + +class CreateModelRequest(BaseModel): + """创建/更新模型配置请求,与 Java CreateModelRequest 一致。""" + modelName: str = Field(..., min_length=1, description="模型名称(如 qwen2)") + provider: str = Field(..., min_length=1, description="模型提供商(如 Ollama、OpenAI、DeepSeek)") + baseUrl: str = Field(..., min_length=1, description="API 基础地址") + apiKey: Optional[str] = Field(None, description="API 密钥(无密钥则为空)") + type: ModelType = Field(..., description="模型类型(如 chat、embedding)") + isEnabled: Optional[bool] = Field(None, description="是否启用:1-启用,0-禁用") + isDefault: Optional[bool] = Field(None, description="是否默认:1-默认,0-非默认") + + +class QueryModelRequest(BaseModel): + """模型查询请求,与 Java QueryModelRequest + PagingQuery 一致。page 从 0 开始,size 默认 20。""" + page: int = Field(0, ge=0, description="页码,从 0 开始") + size: int = Field(20, gt=0, le=500, description="每页大小") + provider: Optional[str] = Field(None, description="模型提供商") + type: Optional[ModelType] = Field(None, description="模型类型") + isEnabled: Optional[bool] = Field(None, description="是否启用") + isDefault: Optional[bool] = Field(None, description="是否默认") + + +# --- 响应 DTO --- + + +class ModelConfigResponse(BaseModel): + """模型配置响应,与 Java ModelConfig 实体字段一致。""" + id: str + modelName: str + provider: str + baseUrl: str + apiKey: str = "" + type: str # "CHAT" | "EMBEDDING" + isEnabled: bool = True + isDefault: bool = False + createdAt: Optional[datetime] = None + updatedAt: Optional[datetime] = None + createdBy: Optional[str] = None + updatedBy: Optional[str] = None + + +class ProviderItem(BaseModel): + """厂商项,仅含 provider、baseUrl,与 Java getProviders() 返回结构一致。""" + provider: str + baseUrl: str diff --git a/runtime/datamate-python/app/module/system/service/model_config_service.py b/runtime/datamate-python/app/module/system/service/model_config_service.py new file mode 100644 index 000000000..3fd8926d4 --- /dev/null +++ b/runtime/datamate-python/app/module/system/service/model_config_service.py @@ -0,0 +1,209 @@ +""" +模型配置应用服务:与 Java ModelConfigApplicationService 行为一致。 +包含分页、详情、增改删、isDefault 互斥逻辑及健康检查。 +""" +import math +import uuid +from datetime import datetime, timezone +from typing import Optional + +from fastapi import HTTPException +from sqlalchemy import select, func, update, delete +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.logging import get_logger +from app.db.models.model_config import ModelConfig +from app.module.system.schema.model_config import ( + ModelType, + CreateModelRequest, + QueryModelRequest, + ModelConfigResponse, + ProviderItem, +) +from app.module.shared.schema import PaginatedData +from app.module.system.service.common_service import get_chat_client, get_openai_client + +logger = get_logger(__name__) + +# 固定厂商列表,与 Java getProviders() 一致 +PROVIDERS = [ + ProviderItem(provider="ModelEngine", baseUrl="http://localhost:9981"), + ProviderItem(provider="Ollama", baseUrl="http://localhost:11434"), + ProviderItem(provider="OpenAI", baseUrl="https://api.openai.com/v1"), + ProviderItem(provider="DeepSeek", baseUrl="https://api.deepseek.com/v1"), + ProviderItem(provider="火山方舟", baseUrl="https://ark.cn-beijing.volces.com/api/v3"), + ProviderItem(provider="阿里云百炼", baseUrl="https://dashscope.aliyuncs.com/compatible-mode/v1"), + ProviderItem(provider="硅基流动", baseUrl="https://api.siliconflow.cn/v1"), + ProviderItem(provider="智谱AI", baseUrl="https://open.bigmodel.cn/api/paas/v4"), +] + + +def _check_model_health(model_name: str, base_url: str, api_key: Optional[str], model_type: ModelType) -> None: + """对配置做一次最小化模型调用进行健康检查,失败则抛出 MODEL_HEALTH_CHECK_FAILED。""" + # 构造满足 common_service 接口的只读对象 + class _ModelLike: + def __init__(self, name: str, url: str, key: str, t: str): + self.model_name = name + self.base_url = url + self.api_key = key or "" + self.type = t + + m = _ModelLike(model_name, base_url, api_key or "", model_type.value) + try: + if model_type == ModelType.CHAT: + chat = get_chat_client(m) # type: ignore[arg-type] + chat.invoke("hello") + else: + emb = get_openai_client(m) # type: ignore[arg-type] + emb.embed_query("text") + except Exception as e: + logger.error("Model health check failed: model=%s type=%s err=%s", model_name, model_type, e) + raise HTTPException(status_code=400, detail="模型健康检查失败") from e + + +def _orm_to_response(row: ModelConfig) -> ModelConfigResponse: + return ModelConfigResponse( + id=row.id, + modelName=row.model_name, + provider=row.provider, + baseUrl=row.base_url, + apiKey=row.api_key or "", + type=row.type or "CHAT", + isEnabled=bool(row.is_enabled) if row.is_enabled is not None else True, + isDefault=bool(row.is_default) if row.is_default is not None else False, + createdAt=row.created_at, + updatedAt=row.updated_at, + createdBy=row.created_by, + updatedBy=row.updated_by, + ) + + +async def get_providers() -> list[ProviderItem]: + """返回固定厂商列表,与 Java getProviders() 一致。""" + return list(PROVIDERS) + + +async def get_models(db: AsyncSession, q: QueryModelRequest) -> PaginatedData[ModelConfigResponse]: + """分页查询,支持 provider/type/isEnabled/isDefault;page 从 0 开始,与 Java 一致。""" + query = select(ModelConfig) + if q.provider is not None and q.provider != "": + query = query.where(ModelConfig.provider == q.provider) + if q.type is not None: + query = query.where(ModelConfig.type == q.type.value) + if q.isEnabled is not None: + query = query.where(ModelConfig.is_enabled == (1 if q.isEnabled else 0)) + if q.isDefault is not None: + query = query.where(ModelConfig.is_default == (1 if q.isDefault else 0)) + + total = (await db.execute(select(func.count()).select_from(query.subquery()))).scalar_one() + size = max(1, min(500, q.size)) + offset = max(0, q.page) * size + rows = ( + await db.execute( + query.order_by(ModelConfig.created_at.desc()).offset(offset).limit(size) + ) + ).scalars().all() + total_pages = math.ceil(total / size) if total else 0 + # 响应中 page 使用 1-based,与 PaginatedData 注释一致 + current = max(0, q.page) + 1 + return PaginatedData( + page=current, + size=size, + total_elements=total, + total_pages=total_pages, + content=[_orm_to_response(r) for r in rows], + ) + + +async def get_model_detail(db: AsyncSession, model_id: str) -> ModelConfigResponse: + """获取模型详情,不存在则 404。""" + r = (await db.execute(select(ModelConfig).where(ModelConfig.id == model_id))).scalar_one_or_none() + if not r: + raise HTTPException(status_code=404, detail="模型配置不存在") + return _orm_to_response(r) + + +async def create_model(db: AsyncSession, req: CreateModelRequest) -> ModelConfigResponse: + """创建模型:健康检查后 saveAndSetDefault;isEnabled 恒为 True。""" + _check_model_health(req.modelName, req.baseUrl, req.apiKey, req.type) + + # 同类型下是否已有默认 + existing = ( + await db.execute( + select(ModelConfig).where( + ModelConfig.type == req.type.value, + ModelConfig.is_default == 1, + ) + ) + ).scalar_one_or_none() + + is_default: bool + if existing is None: + is_default = True + else: + # 清除同类型默认 + await db.execute( + update(ModelConfig) + .where(ModelConfig.type == req.type.value, ModelConfig.is_default == 1) + .values(is_default=0) + ) + is_default = req.isDefault if req.isDefault is not None else False + + now = datetime.now(timezone.utc) + entity = ModelConfig( + id=str(uuid.uuid4()), + model_name=req.modelName, + provider=req.provider, + base_url=req.baseUrl, + api_key=req.apiKey or "", + type=req.type.value, + is_enabled=1, + is_default=1 if is_default else 0, + created_at=now, + updated_at=now, + created_by=None, + updated_by=None, + ) + db.add(entity) + await db.commit() + await db.refresh(entity) + return _orm_to_response(entity) + + +async def update_model(db: AsyncSession, model_id: str, req: CreateModelRequest) -> ModelConfigResponse: + """更新模型:存在性校验、健康检查后 updateAndSetDefault;isEnabled 恒为 True。""" + entity = (await db.execute(select(ModelConfig).where(ModelConfig.id == model_id))).scalar_one_or_none() + if not entity: + raise HTTPException(status_code=404, detail="模型配置不存在") + + entity.model_name = req.modelName + entity.provider = req.provider + entity.base_url = req.baseUrl + entity.api_key = req.apiKey or "" + entity.type = req.type.value + entity.is_enabled = 1 + + _check_model_health(req.modelName, req.baseUrl, req.apiKey, req.type) + + want_default = req.isDefault if req.isDefault is not None else False + if (entity.is_default != 1) and want_default: + await db.execute( + update(ModelConfig) + .where(ModelConfig.type == req.type.value, ModelConfig.is_default == 1) + .values(is_default=0) + ) + entity.is_default = 1 if want_default else 0 + entity.updated_at = datetime.now(timezone.utc) + + await db.commit() + await db.refresh(entity) + return _orm_to_response(entity) + + +async def delete_model(db: AsyncSession, model_id: str) -> None: + """删除模型配置。""" + entity = (await db.execute(select(ModelConfig).where(ModelConfig.id == model_id))).scalar_one_or_none() + if not entity: + raise HTTPException(status_code=404, detail="模型配置不存在") + await db.delete(entity) + await db.commit() From e107489bc8841553ed40963b24288104a26f317d Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Wed, 21 Jan 2026 15:00:25 +0800 Subject: [PATCH 2/3] feat: implement LLMFactory for unified model creation and health checks; add is_deleted field to model config --- .../datamate-python/app/core/llm/__init__.py | 7 + .../datamate-python/app/core/llm/factory.py | 71 ++++ .../app/db/models/model_config.py | 1 + .../generation/service/generation_service.py | 15 +- .../app/module/rag/interface/rag_interface.py | 7 +- .../app/module/rag/service/rag_service.py | 15 +- .../module/system/interface/model_config.py | 30 +- .../module/system/service/common_service.py | 39 +-- .../system/service/model_config_service.py | 302 +++++++++--------- scripts/db/setting-management-init.sql | 4 +- 10 files changed, 277 insertions(+), 214 deletions(-) create mode 100644 runtime/datamate-python/app/core/llm/__init__.py create mode 100644 runtime/datamate-python/app/core/llm/factory.py diff --git a/runtime/datamate-python/app/core/llm/__init__.py b/runtime/datamate-python/app/core/llm/__init__.py new file mode 100644 index 000000000..8222271e5 --- /dev/null +++ b/runtime/datamate-python/app/core/llm/__init__.py @@ -0,0 +1,7 @@ +# app/core/llm/__init__.py +""" +LangChain 模型工厂:统一创建 Chat、Embedding 及健康检查,便于各模块复用。 +""" +from .factory import LLMFactory + +__all__ = ["LLMFactory"] diff --git a/runtime/datamate-python/app/core/llm/factory.py b/runtime/datamate-python/app/core/llm/factory.py new file mode 100644 index 000000000..f571a40ff --- /dev/null +++ b/runtime/datamate-python/app/core/llm/factory.py @@ -0,0 +1,71 @@ +# app/core/llm/factory.py +""" +LangChain 模型工厂:基于 OpenAI 兼容接口封装 Chat / Embedding 的创建、健康检查与同步调用。 +便于模型配置、RAG、生成、评估等模块统一使用,避免分散的 get_chat_client / get_openai_client。 +""" +from typing import Literal + +from langchain_core.language_models import BaseChatModel +from langchain_core.embeddings import Embeddings +from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from pydantic import SecretStr + + +class LLMFactory: + """基于 LangChain 的 Chat / Embedding 工厂,面向 OpenAI 兼容 API。""" + + @staticmethod + def create_chat( + model_name: str, + base_url: str, + api_key: str | None = None, + ) -> BaseChatModel: + """创建对话模型,兼容 OpenAI 及任意 base_url 的 OpenAI 兼容服务。""" + return ChatOpenAI( + model=model_name, + base_url=base_url or None, + api_key=SecretStr(api_key or ""), + ) + + @staticmethod + def create_embedding( + model_name: str, + base_url: str, + api_key: str | None = None, + ) -> Embeddings: + """创建嵌入模型,兼容 OpenAI 及任意 base_url 的 OpenAI 兼容服务。""" + return OpenAIEmbeddings( + model=model_name, + base_url=base_url or None, + api_key=SecretStr(api_key or ""), + ) + + @staticmethod + def check_health( + model_name: str, + base_url: str, + api_key: str | None, + model_type: Literal["CHAT", "EMBEDDING"], + ) -> None: + """对配置做一次最小化调用进行健康检查,失败则抛出。""" + if model_type == "CHAT": + model = LLMFactory.create_chat(model_name, base_url, api_key) + model.invoke("hello") + else: + model = LLMFactory.create_embedding(model_name, base_url, api_key) + model.embed_query("text") + + @staticmethod + def get_embedding_dimension( + model_name: str, + base_url: str, + api_key: str | None = None, + ) -> int: + """创建 Embedding 模型并返回向量维度。""" + emb = LLMFactory.create_embedding(model_name, base_url, api_key) + return len(emb.embed_query("text")) + + @staticmethod + def invoke_sync(chat_model: BaseChatModel, prompt: str) -> str: + """同步调用对话模型并返回 content,供 run_in_executor 等场景使用。""" + return chat_model.invoke(prompt).content diff --git a/runtime/datamate-python/app/db/models/model_config.py b/runtime/datamate-python/app/db/models/model_config.py index be75043f9..b0b1ee38f 100644 --- a/runtime/datamate-python/app/db/models/model_config.py +++ b/runtime/datamate-python/app/db/models/model_config.py @@ -41,6 +41,7 @@ class ModelConfig(Base): # 使用 Integer 存储 TINYINT,后续可在业务层将 0/1 转为 bool is_enabled = Column(Integer, nullable=False, default=1, comment="是否启用:1-启用,0-禁用") is_default = Column(Integer, nullable=False, default=0, comment="是否默认:1-默认,0-非默认") + is_deleted = Column(Integer, nullable=False, default=0, comment="是否删除:1-已删除,0-未删除") created_at = Column(TIMESTAMP, nullable=True, comment="创建时间") updated_at = Column(TIMESTAMP, nullable=True, comment="更新时间") diff --git a/runtime/datamate-python/app/module/generation/service/generation_service.py b/runtime/datamate-python/app/module/generation/service/generation_service.py index 2a085fd37..9bea09072 100644 --- a/runtime/datamate-python/app/module/generation/service/generation_service.py +++ b/runtime/datamate-python/app/module/generation/service/generation_service.py @@ -24,7 +24,8 @@ from app.module.shared.common.document_loaders import load_documents from app.module.shared.common.text_split import DocumentSplitter from app.module.shared.util.model_chat import extract_json_substring -from app.module.system.service.common_service import chat, get_model_by_id, get_chat_client +from app.core.llm import LLMFactory +from app.module.system.service.common_service import get_model_by_id def _filter_docs(split_docs, chunk_size): @@ -171,8 +172,12 @@ async def _process_single_file( # 为本文件构建模型 client question_model = await get_model_by_id(self.db, question_cfg.model_id) answer_model = await get_model_by_id(self.db, answer_cfg.model_id) - question_chat = get_chat_client(question_model) - answer_chat = get_chat_client(answer_model) + question_chat = LLMFactory.create_chat( + question_model.model_name, question_model.base_url, question_model.api_key + ) + answer_chat = LLMFactory.create_chat( + answer_model.model_name, answer_model.base_url, answer_model.api_key + ) # 分批次从 DB 读取并处理 chunk batch_size = 100 @@ -356,7 +361,7 @@ async def _generate_questions_for_one_chunk( loop = asyncio.get_running_loop() raw_answer = await loop.run_in_executor( None, - chat, + LLMFactory.invoke_sync, question_chat, prompt, ) @@ -400,7 +405,7 @@ async def process_single_question(question: str): loop = asyncio.get_running_loop() answer = await loop.run_in_executor( None, - chat, + LLMFactory.invoke_sync, answer_chat, prompt_local, ) diff --git a/runtime/datamate-python/app/module/rag/interface/rag_interface.py b/runtime/datamate-python/app/module/rag/interface/rag_interface.py index d7bd907b5..b66af6463 100644 --- a/runtime/datamate-python/app/module/rag/interface/rag_interface.py +++ b/runtime/datamate-python/app/module/rag/interface/rag_interface.py @@ -9,12 +9,12 @@ router = APIRouter(prefix="/rag", tags=["rag"]) @router.post("/process/{knowledge_base_id}") -async def process_knowledge_base(knowledge_base_id: str, db: AsyncSession = Depends(get_db)): +async def process_knowledge_base(knowledge_base_id: str, rag_service: RAGService = Depends()): """ Process all unprocessed files in a knowledge base. """ try: - await RAGService(db).init_graph_rag(knowledge_base_id) + await rag_service.init_graph_rag(knowledge_base_id) return StandardResponse( code=200, message="Processing started for knowledge base.", @@ -24,12 +24,11 @@ async def process_knowledge_base(knowledge_base_id: str, db: AsyncSession = Depe raise HTTPException(status_code=500, detail=str(e)) @router.post("/query") -async def query_knowledge_graph(payload: QueryRequest, db: AsyncSession = Depends(get_db)): +async def query_knowledge_graph(payload: QueryRequest, rag_service: RAGService = Depends()): """ Query the knowledge graph with the given query text and knowledge base ID. """ try: - rag_service = RAGService(db) result = await rag_service.query_rag(payload.query, payload.knowledge_base_id) return StandardResponse(code=200, message="success", data=result) except HTTPException: diff --git a/runtime/datamate-python/app/module/rag/service/rag_service.py b/runtime/datamate-python/app/module/rag/service/rag_service.py index 67cdfc15d..42c9f0ce8 100644 --- a/runtime/datamate-python/app/module/rag/service/rag_service.py +++ b/runtime/datamate-python/app/module/rag/service/rag_service.py @@ -9,7 +9,6 @@ from app.core.logging import get_logger from app.db.models.dataset_management import DatasetFiles from app.db.models.knowledge_gen import RagFile, RagKnowledgeBase -from app.db.models.model_config import ModelConfig from app.db.session import get_db, AsyncSessionLocal from app.module.shared.common.document_loaders import load_documents from .graph_rag import ( @@ -18,7 +17,8 @@ build_llm_model_func, initialize_rag, ) -from ...system.service.common_service import get_embedding_dimension, get_openai_client +from app.core.llm import LLMFactory +from ...system.service.common_service import get_model_by_id logger = get_logger(__name__) @@ -27,10 +27,10 @@ class RAGService: def __init__( self, db: AsyncSession = Depends(get_db), - background_tasks: BackgroundTasks | None = None, + ): self.db = db - self.background_tasks = background_tasks + self.background_tasks = None self.rag = None async def get_unprocessed_files(self, knowledge_base_id: str) -> Sequence[RagFile]: @@ -54,7 +54,9 @@ async def init_graph_rag(self, knowledge_base_id: str): embedding_model.model_name, embedding_model.base_url, embedding_model.api_key, - embedding_dim=get_embedding_dimension(get_openai_client(embedding_model)), + embedding_dim=LLMFactory.get_embedding_dimension( + embedding_model.model_name, embedding_model.base_url, embedding_model.api_key + ), ) kb_working_dir = os.path.join(DEFAULT_WORKING_DIR, kb.name) @@ -127,8 +129,7 @@ async def _get_knowledge_base(self, knowledge_base_id: str): async def _get_model_config(self, model_id: Optional[str]): if not model_id: raise ValueError("Model ID is required for initializing RAG.") - result = await self.db.execute(select(ModelConfig).where(ModelConfig.id == model_id)) - model = result.scalars().first() + model = await get_model_by_id(self.db, model_id) if not model: raise ValueError(f"Model config with ID {model_id} not found.") return model diff --git a/runtime/datamate-python/app/module/system/interface/model_config.py b/runtime/datamate-python/app/module/system/interface/model_config.py index 614f704ab..869c52c3d 100644 --- a/runtime/datamate-python/app/module/system/interface/model_config.py +++ b/runtime/datamate-python/app/module/system/interface/model_config.py @@ -1,11 +1,9 @@ """ 模型配置 REST 接口:与 Java ModelConfigController 路径、语义一致,响应使用 StandardResponse。 +db 通过 ModelConfigService 的 Depends(get_db) 注入,不在本层传递。 """ from fastapi import APIRouter, Depends, Query -from sqlalchemy.ext.asyncio import AsyncSession - -from app.db.session import get_db from app.module.shared.schema import StandardResponse, PaginatedData from app.module.system.schema import ( CreateModelRequest, @@ -14,15 +12,15 @@ ProviderItem, ModelType, ) -from app.module.system.service import model_config_service +from app.module.system.service.model_config_service import ModelConfigService router = APIRouter(prefix="/models", tags=["models"]) @router.get("/providers", response_model=StandardResponse[list[ProviderItem]]) -async def get_providers(): +async def get_providers(svc: ModelConfigService = Depends()): """获取厂商列表,与 Java GET /models/providers 一致。""" - data = await model_config_service.get_providers() + data = await svc.get_providers() return StandardResponse(code=200, message="success", data=data) @@ -34,7 +32,7 @@ async def get_models( type: ModelType | None = Query(None, description="模型类型"), isEnabled: bool | None = Query(None, description="是否启用"), isDefault: bool | None = Query(None, description="是否默认"), - db: AsyncSession = Depends(get_db), + svc: ModelConfigService = Depends(), ): """分页查询模型列表,与 Java GET /models/list 一致。""" q = QueryModelRequest( @@ -45,21 +43,21 @@ async def get_models( isEnabled=isEnabled, isDefault=isDefault, ) - data = await model_config_service.get_models(db, q) + data = await svc.get_models(q) return StandardResponse(code=200, message="success", data=data) @router.post("/create", response_model=StandardResponse[ModelConfigResponse]) -async def create_model(req: CreateModelRequest, db: AsyncSession = Depends(get_db)): +async def create_model(req: CreateModelRequest, svc: ModelConfigService = Depends()): """创建模型配置,与 Java POST /models/create 一致。""" - data = await model_config_service.create_model(db, req) + data = await svc.create_model(req) return StandardResponse(code=200, message="success", data=data) @router.get("/{model_id}", response_model=StandardResponse[ModelConfigResponse]) -async def get_model_detail(model_id: str, db: AsyncSession = Depends(get_db)): +async def get_model_detail(model_id: str, svc: ModelConfigService = Depends()): """获取模型详情,与 Java GET /models/{modelId} 一致。""" - data = await model_config_service.get_model_detail(db, model_id) + data = await svc.get_model_detail(model_id) return StandardResponse(code=200, message="success", data=data) @@ -67,15 +65,15 @@ async def get_model_detail(model_id: str, db: AsyncSession = Depends(get_db)): async def update_model( model_id: str, req: CreateModelRequest, - db: AsyncSession = Depends(get_db), + svc: ModelConfigService = Depends(), ): """更新模型配置,与 Java PUT /models/{modelId} 一致。""" - data = await model_config_service.update_model(db, model_id, req) + data = await svc.update_model(model_id, req) return StandardResponse(code=200, message="success", data=data) @router.delete("/{model_id}", response_model=StandardResponse[None]) -async def delete_model(model_id: str, db: AsyncSession = Depends(get_db)): +async def delete_model(model_id: str, svc: ModelConfigService = Depends()): """删除模型配置,与 Java DELETE /models/{modelId} 一致。""" - await model_config_service.delete_model(db, model_id) + await svc.delete_model(model_id) return StandardResponse(code=200, message="success", data=None) diff --git a/runtime/datamate-python/app/module/system/service/common_service.py b/runtime/datamate-python/app/module/system/service/common_service.py index 2cf5578bb..b8c4cd071 100644 --- a/runtime/datamate-python/app/module/system/service/common_service.py +++ b/runtime/datamate-python/app/module/system/service/common_service.py @@ -1,8 +1,6 @@ +"""通用系统服务:仅保留与 DB 直接相关的查询。LLM 创建与调用统一使用 app.core.llm.LLMFactory。""" from typing import Optional -from langchain_core.language_models import BaseChatModel -from langchain_openai import ChatOpenAI, OpenAIEmbeddings -from pydantic import SecretStr from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -10,34 +8,9 @@ async def get_model_by_id(db: AsyncSession, model_id: str) -> Optional[ModelConfig]: - """根据模型ID获取 ModelConfig 记录。""" - result = await db.execute(select(ModelConfig).where(ModelConfig.id == model_id)) - return result.scalar_one_or_none() - - -def get_chat_client(model: ModelConfig) -> BaseChatModel: - return ChatOpenAI( - model=model.model_name, - base_url=model.base_url, - api_key=SecretStr(model.api_key), - ) - - -def chat(model: BaseChatModel, prompt: str) -> str: - """使用指定模型进行聊天""" - response = model.invoke(prompt) - return response.content - - -# 实例化对象 -def get_openai_client(model: ModelConfig) -> OpenAIEmbeddings: - return OpenAIEmbeddings( - model=model.model_name, - base_url=model.base_url, - api_key=SecretStr(model.api_key), + """根据模型ID获取未删除的 ModelConfig 记录。""" + q = select(ModelConfig).where(ModelConfig.id == model_id).where( + (ModelConfig.is_deleted == 0) | (ModelConfig.is_deleted.is_(None)) ) - -# 获取嵌入向量维度 -def get_embedding_dimension(model: OpenAIEmbeddings) -> int: - """获取 OpenAI 模型的嵌入向量维度""" - return len(model.embed_query(model.model)) + result = await db.execute(q) + return result.scalar_one_or_none() diff --git a/runtime/datamate-python/app/module/system/service/model_config_service.py b/runtime/datamate-python/app/module/system/service/model_config_service.py index 3fd8926d4..c46e3f416 100644 --- a/runtime/datamate-python/app/module/system/service/model_config_service.py +++ b/runtime/datamate-python/app/module/system/service/model_config_service.py @@ -1,18 +1,22 @@ """ 模型配置应用服务:与 Java ModelConfigApplicationService 行为一致。 包含分页、详情、增改删、isDefault 互斥逻辑及健康检查。 +通过 Depends(get_db) 注入 db,不在接口层传递;健康检查使用 core.llm.LLMFactory。 """ import math import uuid from datetime import datetime, timezone from typing import Optional -from fastapi import HTTPException -from sqlalchemy import select, func, update, delete +from fastapi import Depends, HTTPException +from sqlalchemy import select, func, update from sqlalchemy.ext.asyncio import AsyncSession from app.core.logging import get_logger +from app.core.llm import LLMFactory from app.db.models.model_config import ModelConfig +from app.db.session import get_db +from app.module.shared.schema import PaginatedData from app.module.system.schema.model_config import ( ModelType, CreateModelRequest, @@ -20,8 +24,6 @@ ModelConfigResponse, ProviderItem, ) -from app.module.shared.schema import PaginatedData -from app.module.system.service.common_service import get_chat_client, get_openai_client logger = get_logger(__name__) @@ -35,32 +37,10 @@ ProviderItem(provider="阿里云百炼", baseUrl="https://dashscope.aliyuncs.com/compatible-mode/v1"), ProviderItem(provider="硅基流动", baseUrl="https://api.siliconflow.cn/v1"), ProviderItem(provider="智谱AI", baseUrl="https://open.bigmodel.cn/api/paas/v4"), + ProviderItem(provider="自定义模型", baseUrl=""), ] -def _check_model_health(model_name: str, base_url: str, api_key: Optional[str], model_type: ModelType) -> None: - """对配置做一次最小化模型调用进行健康检查,失败则抛出 MODEL_HEALTH_CHECK_FAILED。""" - # 构造满足 common_service 接口的只读对象 - class _ModelLike: - def __init__(self, name: str, url: str, key: str, t: str): - self.model_name = name - self.base_url = url - self.api_key = key or "" - self.type = t - - m = _ModelLike(model_name, base_url, api_key or "", model_type.value) - try: - if model_type == ModelType.CHAT: - chat = get_chat_client(m) # type: ignore[arg-type] - chat.invoke("hello") - else: - emb = get_openai_client(m) # type: ignore[arg-type] - emb.embed_query("text") - except Exception as e: - logger.error("Model health check failed: model=%s type=%s err=%s", model_name, model_type, e) - raise HTTPException(status_code=400, detail="模型健康检查失败") from e - - def _orm_to_response(row: ModelConfig) -> ModelConfigResponse: return ModelConfigResponse( id=row.id, @@ -78,132 +58,158 @@ def _orm_to_response(row: ModelConfig) -> ModelConfigResponse: ) -async def get_providers() -> list[ProviderItem]: - """返回固定厂商列表,与 Java getProviders() 一致。""" - return list(PROVIDERS) - - -async def get_models(db: AsyncSession, q: QueryModelRequest) -> PaginatedData[ModelConfigResponse]: - """分页查询,支持 provider/type/isEnabled/isDefault;page 从 0 开始,与 Java 一致。""" - query = select(ModelConfig) - if q.provider is not None and q.provider != "": - query = query.where(ModelConfig.provider == q.provider) - if q.type is not None: - query = query.where(ModelConfig.type == q.type.value) - if q.isEnabled is not None: - query = query.where(ModelConfig.is_enabled == (1 if q.isEnabled else 0)) - if q.isDefault is not None: - query = query.where(ModelConfig.is_default == (1 if q.isDefault else 0)) - - total = (await db.execute(select(func.count()).select_from(query.subquery()))).scalar_one() - size = max(1, min(500, q.size)) - offset = max(0, q.page) * size - rows = ( - await db.execute( - query.order_by(ModelConfig.created_at.desc()).offset(offset).limit(size) - ) - ).scalars().all() - total_pages = math.ceil(total / size) if total else 0 - # 响应中 page 使用 1-based,与 PaginatedData 注释一致 - current = max(0, q.page) + 1 - return PaginatedData( - page=current, - size=size, - total_elements=total, - total_pages=total_pages, - content=[_orm_to_response(r) for r in rows], - ) +class ModelConfigService: + """模型配置服务:db 通过 FastAPI Depends(get_db) 注入,不在路由中传递。""" + def __init__(self, db: AsyncSession = Depends(get_db)): + self.db = db -async def get_model_detail(db: AsyncSession, model_id: str) -> ModelConfigResponse: - """获取模型详情,不存在则 404。""" - r = (await db.execute(select(ModelConfig).where(ModelConfig.id == model_id))).scalar_one_or_none() - if not r: - raise HTTPException(status_code=404, detail="模型配置不存在") - return _orm_to_response(r) + async def get_providers(self) -> list[ProviderItem]: + """返回固定厂商列表,与 Java getProviders() 一致。""" + return list(PROVIDERS) + async def get_models(self, q: QueryModelRequest) -> PaginatedData[ModelConfigResponse]: + """分页查询,支持 provider/type/isEnabled/isDefault;排除已删除;page 从 0 开始。""" + query = select(ModelConfig).where( + (ModelConfig.is_deleted == 0) | (ModelConfig.is_deleted.is_(None)) + ) + if q.provider is not None and q.provider != "": + query = query.where(ModelConfig.provider == q.provider) + if q.type is not None: + query = query.where(ModelConfig.type == q.type.value) + if q.isEnabled is not None: + query = query.where(ModelConfig.is_enabled == (1 if q.isEnabled else 0)) + if q.isDefault is not None: + query = query.where(ModelConfig.is_default == (1 if q.isDefault else 0)) + + total = (await self.db.execute(select(func.count()).select_from(query.subquery()))).scalar_one() + size = max(1, min(500, q.size)) + offset = max(0, q.page) * size + rows = ( + await self.db.execute( + query.order_by(ModelConfig.created_at.desc()).offset(offset).limit(size) + ) + ).scalars().all() + total_pages = math.ceil(total / size) if total else 0 + current = max(0, q.page) + 1 + return PaginatedData( + page=current, + size=size, + total_elements=total, + total_pages=total_pages, + content=[_orm_to_response(r) for r in rows], + ) -async def create_model(db: AsyncSession, req: CreateModelRequest) -> ModelConfigResponse: - """创建模型:健康检查后 saveAndSetDefault;isEnabled 恒为 True。""" - _check_model_health(req.modelName, req.baseUrl, req.apiKey, req.type) + async def get_model_detail(self, model_id: str) -> ModelConfigResponse: + """获取模型详情,已删除或不存在则 404。""" + query = select(ModelConfig).where(ModelConfig.id == model_id).where( + (ModelConfig.is_deleted == 0) | (ModelConfig.is_deleted.is_(None)) + ) + r = (await self.db.execute(query)).scalar_one_or_none() + if not r: + raise HTTPException(status_code=404, detail="模型配置不存在") + return _orm_to_response(r) + + async def create_model(self, req: CreateModelRequest) -> ModelConfigResponse: + """创建模型:健康检查后 saveAndSetDefault;isEnabled 恒为 True。""" + try: + LLMFactory.check_health(req.modelName, req.baseUrl, req.apiKey, req.type.value) + except Exception as e: + logger.error("Model health check failed: model=%s type=%s err=%s", req.modelName, req.type, e) + raise HTTPException(status_code=400, detail="模型健康检查失败") from e + + existing = ( + await self.db.execute( + select(ModelConfig).where( + (ModelConfig.is_deleted == 0) | (ModelConfig.is_deleted.is_(None)), + ModelConfig.type == req.type.value, + ModelConfig.is_default == 1, + ) + ) + ).scalar_one_or_none() - # 同类型下是否已有默认 - existing = ( - await db.execute( - select(ModelConfig).where( - ModelConfig.type == req.type.value, - ModelConfig.is_default == 1, + is_default: bool + if existing is None: + is_default = True + else: + await self.db.execute( + update(ModelConfig) + .where(ModelConfig.type == req.type.value, ModelConfig.is_default == 1) + .values(is_default=0) ) + is_default = req.isDefault if req.isDefault is not None else False + + now = datetime.now(timezone.utc) + entity = ModelConfig( + id=str(uuid.uuid4()), + model_name=req.modelName, + provider=req.provider, + base_url=req.baseUrl, + api_key=req.apiKey or "", + type=req.type.value, + is_enabled=1, + is_default=1 if is_default else 0, + is_deleted=0, + created_at=now, + updated_at=now, + created_by=None, + updated_by=None, ) - ).scalar_one_or_none() - - is_default: bool - if existing is None: - is_default = True - else: - # 清除同类型默认 - await db.execute( - update(ModelConfig) - .where(ModelConfig.type == req.type.value, ModelConfig.is_default == 1) - .values(is_default=0) - ) - is_default = req.isDefault if req.isDefault is not None else False - - now = datetime.now(timezone.utc) - entity = ModelConfig( - id=str(uuid.uuid4()), - model_name=req.modelName, - provider=req.provider, - base_url=req.baseUrl, - api_key=req.apiKey or "", - type=req.type.value, - is_enabled=1, - is_default=1 if is_default else 0, - created_at=now, - updated_at=now, - created_by=None, - updated_by=None, - ) - db.add(entity) - await db.commit() - await db.refresh(entity) - return _orm_to_response(entity) - - -async def update_model(db: AsyncSession, model_id: str, req: CreateModelRequest) -> ModelConfigResponse: - """更新模型:存在性校验、健康检查后 updateAndSetDefault;isEnabled 恒为 True。""" - entity = (await db.execute(select(ModelConfig).where(ModelConfig.id == model_id))).scalar_one_or_none() - if not entity: - raise HTTPException(status_code=404, detail="模型配置不存在") - - entity.model_name = req.modelName - entity.provider = req.provider - entity.base_url = req.baseUrl - entity.api_key = req.apiKey or "" - entity.type = req.type.value - entity.is_enabled = 1 - - _check_model_health(req.modelName, req.baseUrl, req.apiKey, req.type) - - want_default = req.isDefault if req.isDefault is not None else False - if (entity.is_default != 1) and want_default: - await db.execute( - update(ModelConfig) - .where(ModelConfig.type == req.type.value, ModelConfig.is_default == 1) - .values(is_default=0) + self.db.add(entity) + await self.db.commit() + await self.db.refresh(entity) + return _orm_to_response(entity) + + async def update_model(self, model_id: str, req: CreateModelRequest) -> ModelConfigResponse: + """更新模型:存在性校验、健康检查后 updateAndSetDefault;isEnabled 恒为 True。""" + res = await self.db.execute( + select(ModelConfig).where(ModelConfig.id == model_id).where( + (ModelConfig.is_deleted == 0) | (ModelConfig.is_deleted.is_(None)) + ) ) - entity.is_default = 1 if want_default else 0 - entity.updated_at = datetime.now(timezone.utc) - - await db.commit() - await db.refresh(entity) - return _orm_to_response(entity) - - -async def delete_model(db: AsyncSession, model_id: str) -> None: - """删除模型配置。""" - entity = (await db.execute(select(ModelConfig).where(ModelConfig.id == model_id))).scalar_one_or_none() - if not entity: - raise HTTPException(status_code=404, detail="模型配置不存在") - await db.delete(entity) - await db.commit() + entity = res.scalar_one_or_none() + if not entity: + raise HTTPException(status_code=404, detail="模型配置不存在") + + try: + LLMFactory.check_health(req.modelName, req.baseUrl, req.apiKey, req.type.value) + except Exception as e: + logger.error("Model health check failed: model=%s type=%s err=%s", req.modelName, req.type, e) + raise HTTPException(status_code=400, detail="模型健康检查失败") from e + + entity.model_name = req.modelName + entity.provider = req.provider + entity.base_url = req.baseUrl + entity.api_key = req.apiKey or "" + entity.type = req.type.value + entity.is_enabled = 1 + + want_default = req.isDefault if req.isDefault is not None else False + if (entity.is_default != 1) and want_default: + await self.db.execute( + update(ModelConfig) + .where(ModelConfig.type == req.type.value, ModelConfig.is_default == 1) + .values(is_default=0) + ) + entity.is_default = 1 if want_default else 0 + entity.updated_at = datetime.now(timezone.utc) + + await self.db.commit() + await self.db.refresh(entity) + return _orm_to_response(entity) + + async def delete_model(self, model_id: str) -> None: + """软删除模型配置。""" + entity = ( + await self.db.execute( + select(ModelConfig).where(ModelConfig.id == model_id).where( + (ModelConfig.is_deleted == 0) | (ModelConfig.is_deleted.is_(None)) + ) + ) + ).scalar_one_or_none() + if not entity: + raise HTTPException(status_code=404, detail="模型配置不存在") + entity.is_deleted = 1 + entity.updated_at = datetime.now(timezone.utc) + await self.db.commit() + await self.db.refresh(entity) diff --git a/scripts/db/setting-management-init.sql b/scripts/db/setting-management-init.sql index 138cd870b..3f8aefb88 100644 --- a/scripts/db/setting-management-init.sql +++ b/scripts/db/setting-management-init.sql @@ -12,6 +12,7 @@ CREATE TABLE IF NOT EXISTS t_model_config type VARCHAR(50) NOT NULL, is_enabled BOOLEAN DEFAULT TRUE, is_default BOOLEAN DEFAULT FALSE, + is_deleted BOOLEAN DEFAULT FALSE, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_by VARCHAR(255), @@ -28,6 +29,7 @@ COMMENT ON COLUMN t_model_config.api_key IS 'API 密钥(无密钥则为空)' COMMENT ON COLUMN t_model_config.type IS '模型类型(如 chat、embedding)'; COMMENT ON COLUMN t_model_config.is_enabled IS '是否启用:1-启用,0-禁用'; COMMENT ON COLUMN t_model_config.is_default IS '是否默认:1-默认,0-非默认'; +COMMENT ON COLUMN t_model_config.is_deleted IS '是否删除:1-已删除,0-未删除'; COMMENT ON COLUMN t_model_config.created_at IS '创建时间'; COMMENT ON COLUMN t_model_config.updated_at IS '更新时间'; COMMENT ON COLUMN t_model_config.created_by IS '创建者'; @@ -36,7 +38,7 @@ COMMENT ON COLUMN t_model_config.updated_by IS '更新者'; -- 添加唯一约束 ALTER TABLE t_model_config ADD CONSTRAINT uk_model_provider - UNIQUE (model_name, provider); + UNIQUE (model_name, provider, created_by); COMMENT ON CONSTRAINT uk_model_provider ON t_model_config IS '避免同一提供商下模型名称重复'; From 477af37f465d959e3d0b1b32e03deb1ac68c6abf Mon Sep 17 00:00:00 2001 From: Dallas98 <990259227@qq.com> Date: Mon, 26 Jan 2026 10:44:04 +0800 Subject: [PATCH 3/3] feat: implement LLMFactory for unified model creation and health checks; add is_deleted field to model config --- .../setting/domain/entity/ModelConfig.java | 2 +- .../db/models/{model_config.py => models.py} | 23 ++-- .../module/evaluation/interface/evaluation.py | 6 +- .../module/evaluation/service/evaluation.py | 8 +- .../generation/service/generation_service.py | 2 +- .../app/module/rag/service/rag_service.py | 18 ++-- .../{core => module/shared}/llm/__init__.py | 0 .../{core => module/shared}/llm/factory.py | 2 +- .../app/module/system/interface/__init__.py | 4 +- .../interface/{model_config.py => models.py} | 30 +++--- .../app/module/system/schema/__init__.py | 8 +- .../schema/{model_config.py => models.py} | 16 ++- .../module/system/service/common_service.py | 10 +- ...el_config_service.py => models_service.py} | 101 ++++++++---------- runtime/datamate-python/poetry.lock | 52 ++++++++- runtime/datamate-python/pyproject.toml | 1 + scripts/db/setting-management-init.sql | 38 +++---- 17 files changed, 180 insertions(+), 141 deletions(-) rename runtime/datamate-python/app/db/models/{model_config.py => models.py} (68%) rename runtime/datamate-python/app/{core => module/shared}/llm/__init__.py (100%) rename runtime/datamate-python/app/{core => module/shared}/llm/factory.py (97%) rename runtime/datamate-python/app/module/system/interface/{model_config.py => models.py} (73%) rename runtime/datamate-python/app/module/system/schema/{model_config.py => models.py} (80%) rename runtime/datamate-python/app/module/system/service/{model_config_service.py => models_service.py} (67%) diff --git a/backend/shared/domain-common/src/main/java/com/datamate/common/setting/domain/entity/ModelConfig.java b/backend/shared/domain-common/src/main/java/com/datamate/common/setting/domain/entity/ModelConfig.java index 0b153efe9..f0e97e46b 100644 --- a/backend/shared/domain-common/src/main/java/com/datamate/common/setting/domain/entity/ModelConfig.java +++ b/backend/shared/domain-common/src/main/java/com/datamate/common/setting/domain/entity/ModelConfig.java @@ -12,7 +12,7 @@ */ @Getter @Setter -@TableName("t_model_config") +@TableName("t_models") @Builder @ToString @NoArgsConstructor diff --git a/runtime/datamate-python/app/db/models/model_config.py b/runtime/datamate-python/app/db/models/models.py similarity index 68% rename from runtime/datamate-python/app/db/models/model_config.py rename to runtime/datamate-python/app/db/models/models.py index 680a5c717..9e32cdaf3 100644 --- a/runtime/datamate-python/app/db/models/model_config.py +++ b/runtime/datamate-python/app/db/models/models.py @@ -1,18 +1,12 @@ -from sqlalchemy import Column, String, Integer, TIMESTAMP, select +from sqlalchemy import Boolean, Column, String, TIMESTAMP from app.db.models.base_entity import BaseEntity -async def get_model_by_id(db_session, model_id: str): - """根据 ID 获取单个模型配置。""" - result =await db_session.execute(select(ModelConfig).where(ModelConfig.id == model_id)) - model_config = result.scalar_one_or_none() - return model_config +class Models(BaseEntity): + """模型配置表,对应表 t_models -class ModelConfig(BaseEntity): - """模型配置表,对应表 t_model_config - - CREATE TABLE IF NOT EXISTS t_model_config ( + CREATE TABLE IF NOT EXISTS t_models ( id VARCHAR(36) PRIMARY KEY COMMENT '主键ID', model_name VARCHAR(100) NOT NULL COMMENT '模型名称(如 qwen2)', provider VARCHAR(50) NOT NULL COMMENT '模型提供商(如 Ollama、OpenAI、DeepSeek)', @@ -29,7 +23,7 @@ class ModelConfig(BaseEntity): ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COMMENT ='模型配置表'; """ - __tablename__ = "t_model_config" + __tablename__ = "t_models" id = Column(String(36), primary_key=True, index=True, comment="主键ID") model_name = Column(String(100), nullable=False, comment="模型名称(如 qwen2)") @@ -38,10 +32,9 @@ class ModelConfig(BaseEntity): api_key = Column(String(512), nullable=False, default="", comment="API 密钥(无密钥则为空)") type = Column(String(50), nullable=False, comment="模型类型(如 chat、embedding)") - # 使用 Integer 存储 TINYINT,后续可在业务层将 0/1 转为 bool - is_enabled = Column(Integer, nullable=False, default=1, comment="是否启用:1-启用,0-禁用") - is_default = Column(Integer, nullable=False, default=0, comment="是否默认:1-默认,0-非默认") - is_deleted = Column(Integer, nullable=False, default=0, comment="是否删除:1-已删除,0-未删除") + is_enabled = Column(Boolean, nullable=False, default=True, comment="是否启用") + is_default = Column(Boolean, nullable=False, default=False, comment="是否默认") + is_deleted = Column(Boolean, nullable=False, default=False, comment="是否删除") __table_args__ = ( # 与 DDL 中的 uk_model_provider 保持一致 diff --git a/runtime/datamate-python/app/module/evaluation/interface/evaluation.py b/runtime/datamate-python/app/module/evaluation/interface/evaluation.py index a32a90142..a718b4194 100644 --- a/runtime/datamate-python/app/module/evaluation/interface/evaluation.py +++ b/runtime/datamate-python/app/module/evaluation/interface/evaluation.py @@ -80,8 +80,8 @@ async def create_evaluation_task( if existing_task.scalar_one_or_none(): raise HTTPException(status_code=400, detail=f"Evaluation task with name '{request.name}' already exists") - model_config = await get_model_by_id(db, request.eval_config.model_id) - if not model_config: + models = await get_model_by_id(db, request.eval_config.model_id) + if not models: raise HTTPException(status_code=400, detail=f"Model with id '{request.eval_config.model_id}' not found") # 创建评估任务 @@ -96,7 +96,7 @@ async def create_evaluation_task( eval_prompt=request.eval_prompt, eval_config=json.dumps({ "modelId": request.eval_config.model_id, - "modelName": model_config.model_name, + "modelName": models.model_name, "dimensions": request.eval_config.dimensions, }), status=TaskStatus.PENDING.value, diff --git a/runtime/datamate-python/app/module/evaluation/service/evaluation.py b/runtime/datamate-python/app/module/evaluation/service/evaluation.py index 7994c0cf0..02d617a62 100644 --- a/runtime/datamate-python/app/module/evaluation/service/evaluation.py +++ b/runtime/datamate-python/app/module/evaluation/service/evaluation.py @@ -43,7 +43,7 @@ def get_eval_prompt(self, item: EvaluationItem) -> str: async def execute(self): eval_config = json.loads(self.task.eval_config) - model_config = await get_model_by_id(self.db, eval_config.get("modelId")) + models = await get_model_by_id(self.db, eval_config.get("modelId")) semaphore = asyncio.Semaphore(10) files = (await self.db.execute( select(EvaluationFile).where(EvaluationFile.task_id == self.task.id) @@ -55,7 +55,7 @@ async def execute(self): for file in files: items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all() tasks = [ - self.evaluate_item(model_config, item, semaphore) + self.evaluate_item(models, item, semaphore) for item in items ] await asyncio.gather(*tasks, return_exceptions=True) @@ -64,13 +64,13 @@ async def execute(self): self.task.eval_process = evaluated_count / total await self.db.commit() - async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore): + async def evaluate_item(self, models, item: EvaluationItem, semaphore: asyncio.Semaphore): async with semaphore: max_try = 3 while max_try > 0: prompt_text = self.get_eval_prompt(item) resp_text = await asyncio.to_thread( - call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name, + call_openai_style_model, models.base_url, models.api_key, models.model_name, prompt_text, ) resp_text = extract_json_substring(resp_text) diff --git a/runtime/datamate-python/app/module/generation/service/generation_service.py b/runtime/datamate-python/app/module/generation/service/generation_service.py index 9bea09072..f5d0ad985 100644 --- a/runtime/datamate-python/app/module/generation/service/generation_service.py +++ b/runtime/datamate-python/app/module/generation/service/generation_service.py @@ -24,7 +24,7 @@ from app.module.shared.common.document_loaders import load_documents from app.module.shared.common.text_split import DocumentSplitter from app.module.shared.util.model_chat import extract_json_substring -from app.core.llm import LLMFactory +from app.module.shared.llm import LLMFactory from app.module.system.service.common_service import get_model_by_id diff --git a/runtime/datamate-python/app/module/rag/service/rag_service.py b/runtime/datamate-python/app/module/rag/service/rag_service.py index 42c9f0ce8..1af9e49ba 100644 --- a/runtime/datamate-python/app/module/rag/service/rag_service.py +++ b/runtime/datamate-python/app/module/rag/service/rag_service.py @@ -2,7 +2,7 @@ import asyncio from typing import Optional, Sequence -from fastapi import BackgroundTasks, Depends +from fastapi import Depends from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -17,7 +17,7 @@ build_llm_model_func, initialize_rag, ) -from app.core.llm import LLMFactory +from app.module.shared.llm import LLMFactory from ...system.service.common_service import get_model_by_id logger = get_logger(__name__) @@ -44,8 +44,8 @@ async def get_unprocessed_files(self, knowledge_base_id: str) -> Sequence[RagFil async def init_graph_rag(self, knowledge_base_id: str): kb = await self._get_knowledge_base(knowledge_base_id) - embedding_model = await self._get_model_config(kb.embedding_model) - chat_model = await self._get_model_config(kb.chat_model) + embedding_model = await self._get_models(kb.embedding_model) + chat_model = await self._get_models(kb.chat_model) llm_callable = await build_llm_model_func( chat_model.model_name, chat_model.base_url, chat_model.api_key @@ -126,13 +126,13 @@ async def _get_knowledge_base(self, knowledge_base_id: str): raise ValueError(f"Knowledge base with ID {knowledge_base_id} not found.") return knowledge_base - async def _get_model_config(self, model_id: Optional[str]): + async def _get_models(self, model_id: Optional[str]): if not model_id: raise ValueError("Model ID is required for initializing RAG.") - model = await get_model_by_id(self.db, model_id) - if not model: - raise ValueError(f"Model config with ID {model_id} not found.") - return model + models = await get_model_by_id(self.db, model_id) + if not models: + raise ValueError(f"Models with ID {model_id} not found.") + return models async def query_rag(self, query: str, knowledge_base_id: str) -> str: if not self.rag: diff --git a/runtime/datamate-python/app/core/llm/__init__.py b/runtime/datamate-python/app/module/shared/llm/__init__.py similarity index 100% rename from runtime/datamate-python/app/core/llm/__init__.py rename to runtime/datamate-python/app/module/shared/llm/__init__.py diff --git a/runtime/datamate-python/app/core/llm/factory.py b/runtime/datamate-python/app/module/shared/llm/factory.py similarity index 97% rename from runtime/datamate-python/app/core/llm/factory.py rename to runtime/datamate-python/app/module/shared/llm/factory.py index f571a40ff..4713600df 100644 --- a/runtime/datamate-python/app/core/llm/factory.py +++ b/runtime/datamate-python/app/module/shared/llm/factory.py @@ -45,7 +45,7 @@ def check_health( model_name: str, base_url: str, api_key: str | None, - model_type: Literal["CHAT", "EMBEDDING"], + model_type: Literal["CHAT", "EMBEDDING"] | str, ) -> None: """对配置做一次最小化调用进行健康检查,失败则抛出。""" if model_type == "CHAT": diff --git a/runtime/datamate-python/app/module/system/interface/__init__.py b/runtime/datamate-python/app/module/system/interface/__init__.py index 3fdc4091e..15438e988 100644 --- a/runtime/datamate-python/app/module/system/interface/__init__.py +++ b/runtime/datamate-python/app/module/system/interface/__init__.py @@ -1,9 +1,9 @@ from fastapi import APIRouter from .about import router as about_router -from .model_config import router as model_config_router +from app.module.system.interface.models import router as models_router router = APIRouter() router.include_router(about_router) -router.include_router(model_config_router) \ No newline at end of file +router.include_router(models_router) diff --git a/runtime/datamate-python/app/module/system/interface/model_config.py b/runtime/datamate-python/app/module/system/interface/models.py similarity index 73% rename from runtime/datamate-python/app/module/system/interface/model_config.py rename to runtime/datamate-python/app/module/system/interface/models.py index 869c52c3d..f268a9516 100644 --- a/runtime/datamate-python/app/module/system/interface/model_config.py +++ b/runtime/datamate-python/app/module/system/interface/models.py @@ -1,30 +1,26 @@ -""" -模型配置 REST 接口:与 Java ModelConfigController 路径、语义一致,响应使用 StandardResponse。 -db 通过 ModelConfigService 的 Depends(get_db) 注入,不在本层传递。 -""" from fastapi import APIRouter, Depends, Query from app.module.shared.schema import StandardResponse, PaginatedData -from app.module.system.schema import ( +from app.module.system.schema.models import ( CreateModelRequest, QueryModelRequest, - ModelConfigResponse, + ModelsResponse, ProviderItem, ModelType, ) -from app.module.system.service.model_config_service import ModelConfigService +from app.module.system.service.models_service import ModelsService router = APIRouter(prefix="/models", tags=["models"]) @router.get("/providers", response_model=StandardResponse[list[ProviderItem]]) -async def get_providers(svc: ModelConfigService = Depends()): +async def get_providers(svc: ModelsService = Depends()): """获取厂商列表,与 Java GET /models/providers 一致。""" data = await svc.get_providers() return StandardResponse(code=200, message="success", data=data) -@router.get("/list", response_model=StandardResponse[PaginatedData[ModelConfigResponse]]) +@router.get("/list", response_model=StandardResponse[PaginatedData[ModelsResponse]]) async def get_models( page: int = Query(0, ge=0, description="页码,从 0 开始"), size: int = Query(20, gt=0, le=500, description="每页大小"), @@ -32,7 +28,7 @@ async def get_models( type: ModelType | None = Query(None, description="模型类型"), isEnabled: bool | None = Query(None, description="是否启用"), isDefault: bool | None = Query(None, description="是否默认"), - svc: ModelConfigService = Depends(), + svc: ModelsService = Depends(), ): """分页查询模型列表,与 Java GET /models/list 一致。""" q = QueryModelRequest( @@ -47,25 +43,25 @@ async def get_models( return StandardResponse(code=200, message="success", data=data) -@router.post("/create", response_model=StandardResponse[ModelConfigResponse]) -async def create_model(req: CreateModelRequest, svc: ModelConfigService = Depends()): +@router.post("/create", response_model=StandardResponse[ModelsResponse]) +async def create_model(req: CreateModelRequest, svc: ModelsService = Depends()): """创建模型配置,与 Java POST /models/create 一致。""" data = await svc.create_model(req) return StandardResponse(code=200, message="success", data=data) -@router.get("/{model_id}", response_model=StandardResponse[ModelConfigResponse]) -async def get_model_detail(model_id: str, svc: ModelConfigService = Depends()): +@router.get("/{model_id}", response_model=StandardResponse[ModelsResponse]) +async def get_model_detail(model_id: str, svc: ModelsService = Depends()): """获取模型详情,与 Java GET /models/{modelId} 一致。""" data = await svc.get_model_detail(model_id) return StandardResponse(code=200, message="success", data=data) -@router.put("/{model_id}", response_model=StandardResponse[ModelConfigResponse]) +@router.put("/{model_id}", response_model=StandardResponse[ModelsResponse]) async def update_model( model_id: str, req: CreateModelRequest, - svc: ModelConfigService = Depends(), + svc: ModelsService = Depends(), ): """更新模型配置,与 Java PUT /models/{modelId} 一致。""" data = await svc.update_model(model_id, req) @@ -73,7 +69,7 @@ async def update_model( @router.delete("/{model_id}", response_model=StandardResponse[None]) -async def delete_model(model_id: str, svc: ModelConfigService = Depends()): +async def delete_model(model_id: str, svc: ModelsService = Depends()): """删除模型配置,与 Java DELETE /models/{modelId} 一致。""" await svc.delete_model(model_id) return StandardResponse(code=200, message="success", data=None) diff --git a/runtime/datamate-python/app/module/system/schema/__init__.py b/runtime/datamate-python/app/module/system/schema/__init__.py index 70372b855..9ae5e6dbe 100644 --- a/runtime/datamate-python/app/module/system/schema/__init__.py +++ b/runtime/datamate-python/app/module/system/schema/__init__.py @@ -1,9 +1,9 @@ from .health import HealthResponse -from .model_config import ( +from .models import ( ModelType, CreateModelRequest, QueryModelRequest, - ModelConfigResponse, + ModelsResponse, ProviderItem, ) @@ -12,6 +12,6 @@ "ModelType", "CreateModelRequest", "QueryModelRequest", - "ModelConfigResponse", + "ModelsResponse", "ProviderItem", -] \ No newline at end of file +] diff --git a/runtime/datamate-python/app/module/system/schema/model_config.py b/runtime/datamate-python/app/module/system/schema/models.py similarity index 80% rename from runtime/datamate-python/app/module/system/schema/model_config.py rename to runtime/datamate-python/app/module/system/schema/models.py index bef559cdb..bfa5d4ae6 100644 --- a/runtime/datamate-python/app/module/system/schema/model_config.py +++ b/runtime/datamate-python/app/module/system/schema/models.py @@ -1,11 +1,11 @@ """ -模型配置 DTO:与 Java 版 t_model_config 接口保持一致。 +模型配置 DTO:与 Java 版 t_models 接口保持一致。 """ from datetime import datetime from enum import Enum from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field class ModelType(str, Enum): @@ -19,6 +19,8 @@ class ModelType(str, Enum): class CreateModelRequest(BaseModel): """创建/更新模型配置请求,与 Java CreateModelRequest 一致。""" + model_config = ConfigDict(str_strip_whitespace=True, extra="forbid") + modelName: str = Field(..., min_length=1, description="模型名称(如 qwen2)") provider: str = Field(..., min_length=1, description="模型提供商(如 Ollama、OpenAI、DeepSeek)") baseUrl: str = Field(..., min_length=1, description="API 基础地址") @@ -30,6 +32,8 @@ class CreateModelRequest(BaseModel): class QueryModelRequest(BaseModel): """模型查询请求,与 Java QueryModelRequest + PagingQuery 一致。page 从 0 开始,size 默认 20。""" + model_config = ConfigDict(str_strip_whitespace=True, extra="forbid") + page: int = Field(0, ge=0, description="页码,从 0 开始") size: int = Field(20, gt=0, le=500, description="每页大小") provider: Optional[str] = Field(None, description="模型提供商") @@ -41,8 +45,10 @@ class QueryModelRequest(BaseModel): # --- 响应 DTO --- -class ModelConfigResponse(BaseModel): - """模型配置响应,与 Java ModelConfig 实体字段一致。""" +class ModelsResponse(BaseModel): + """模型配置响应,与 Java Models 实体字段一致。""" + model_config = ConfigDict(str_strip_whitespace=True, extra="forbid") + id: str modelName: str provider: str @@ -59,5 +65,7 @@ class ModelConfigResponse(BaseModel): class ProviderItem(BaseModel): """厂商项,仅含 provider、baseUrl,与 Java getProviders() 返回结构一致。""" + model_config = ConfigDict(str_strip_whitespace=True, extra="forbid") + provider: str baseUrl: str diff --git a/runtime/datamate-python/app/module/system/service/common_service.py b/runtime/datamate-python/app/module/system/service/common_service.py index b8c4cd071..a74be5f42 100644 --- a/runtime/datamate-python/app/module/system/service/common_service.py +++ b/runtime/datamate-python/app/module/system/service/common_service.py @@ -4,13 +4,13 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.db.models.model_config import ModelConfig +from app.db.models.models import Models -async def get_model_by_id(db: AsyncSession, model_id: str) -> Optional[ModelConfig]: - """根据模型ID获取未删除的 ModelConfig 记录。""" - q = select(ModelConfig).where(ModelConfig.id == model_id).where( - (ModelConfig.is_deleted == 0) | (ModelConfig.is_deleted.is_(None)) +async def get_model_by_id(db: AsyncSession, model_id: str) -> Optional[Models]: + """根据模型ID获取未删除的 Models 记录。""" + q = select(Models).where(Models.id == model_id).where( + (Models.is_deleted == False) | (Models.is_deleted.is_(None)) ) result = await db.execute(q) return result.scalar_one_or_none() diff --git a/runtime/datamate-python/app/module/system/service/model_config_service.py b/runtime/datamate-python/app/module/system/service/models_service.py similarity index 67% rename from runtime/datamate-python/app/module/system/service/model_config_service.py rename to runtime/datamate-python/app/module/system/service/models_service.py index c46e3f416..874416b72 100644 --- a/runtime/datamate-python/app/module/system/service/model_config_service.py +++ b/runtime/datamate-python/app/module/system/service/models_service.py @@ -1,33 +1,26 @@ -""" -模型配置应用服务:与 Java ModelConfigApplicationService 行为一致。 -包含分页、详情、增改删、isDefault 互斥逻辑及健康检查。 -通过 Depends(get_db) 注入 db,不在接口层传递;健康检查使用 core.llm.LLMFactory。 -""" import math import uuid from datetime import datetime, timezone -from typing import Optional from fastapi import Depends, HTTPException from sqlalchemy import select, func, update from sqlalchemy.ext.asyncio import AsyncSession +from app.module.shared.llm import LLMFactory from app.core.logging import get_logger -from app.core.llm import LLMFactory -from app.db.models.model_config import ModelConfig +from app.db.models.models import Models from app.db.session import get_db from app.module.shared.schema import PaginatedData -from app.module.system.schema.model_config import ( - ModelType, +from app.module.system.schema.models import ( CreateModelRequest, QueryModelRequest, - ModelConfigResponse, + ModelsResponse, ProviderItem, ) logger = get_logger(__name__) -# 固定厂商列表,与 Java getProviders() 一致 +# 固定厂商列表 PROVIDERS = [ ProviderItem(provider="ModelEngine", baseUrl="http://localhost:9981"), ProviderItem(provider="Ollama", baseUrl="http://localhost:11434"), @@ -41,8 +34,8 @@ ] -def _orm_to_response(row: ModelConfig) -> ModelConfigResponse: - return ModelConfigResponse( +def _orm_to_response(row: Models) -> ModelsResponse: + return ModelsResponse( id=row.id, modelName=row.model_name, provider=row.provider, @@ -58,7 +51,7 @@ def _orm_to_response(row: ModelConfig) -> ModelConfigResponse: ) -class ModelConfigService: +class ModelsService: """模型配置服务:db 通过 FastAPI Depends(get_db) 注入,不在路由中传递。""" def __init__(self, db: AsyncSession = Depends(get_db)): @@ -68,26 +61,26 @@ async def get_providers(self) -> list[ProviderItem]: """返回固定厂商列表,与 Java getProviders() 一致。""" return list(PROVIDERS) - async def get_models(self, q: QueryModelRequest) -> PaginatedData[ModelConfigResponse]: + async def get_models(self, q: QueryModelRequest) -> PaginatedData[ModelsResponse]: """分页查询,支持 provider/type/isEnabled/isDefault;排除已删除;page 从 0 开始。""" - query = select(ModelConfig).where( - (ModelConfig.is_deleted == 0) | (ModelConfig.is_deleted.is_(None)) + query = select(Models).where( + (Models.is_deleted == False) | (Models.is_deleted.is_(None)) ) if q.provider is not None and q.provider != "": - query = query.where(ModelConfig.provider == q.provider) + query = query.where(Models.provider == q.provider) if q.type is not None: - query = query.where(ModelConfig.type == q.type.value) + query = query.where(Models.type == q.type.value) if q.isEnabled is not None: - query = query.where(ModelConfig.is_enabled == (1 if q.isEnabled else 0)) + query = query.where(Models.is_enabled == q.isEnabled) if q.isDefault is not None: - query = query.where(ModelConfig.is_default == (1 if q.isDefault else 0)) + query = query.where(Models.is_default == q.isDefault) total = (await self.db.execute(select(func.count()).select_from(query.subquery()))).scalar_one() size = max(1, min(500, q.size)) offset = max(0, q.page) * size rows = ( await self.db.execute( - query.order_by(ModelConfig.created_at.desc()).offset(offset).limit(size) + query.order_by(Models.created_at.desc()).offset(offset).limit(size) ) ).scalars().all() total_pages = math.ceil(total / size) if total else 0 @@ -100,17 +93,17 @@ async def get_models(self, q: QueryModelRequest) -> PaginatedData[ModelConfigRes content=[_orm_to_response(r) for r in rows], ) - async def get_model_detail(self, model_id: str) -> ModelConfigResponse: + async def get_model_detail(self, model_id: str) -> ModelsResponse: """获取模型详情,已删除或不存在则 404。""" - query = select(ModelConfig).where(ModelConfig.id == model_id).where( - (ModelConfig.is_deleted == 0) | (ModelConfig.is_deleted.is_(None)) + query = select(Models).where(Models.id == model_id).where( + (Models.is_deleted == False) | (Models.is_deleted.is_(None)) ) r = (await self.db.execute(query)).scalar_one_or_none() if not r: raise HTTPException(status_code=404, detail="模型配置不存在") return _orm_to_response(r) - async def create_model(self, req: CreateModelRequest) -> ModelConfigResponse: + async def create_model(self, req: CreateModelRequest) -> ModelsResponse: """创建模型:健康检查后 saveAndSetDefault;isEnabled 恒为 True。""" try: LLMFactory.check_health(req.modelName, req.baseUrl, req.apiKey, req.type.value) @@ -120,10 +113,10 @@ async def create_model(self, req: CreateModelRequest) -> ModelConfigResponse: existing = ( await self.db.execute( - select(ModelConfig).where( - (ModelConfig.is_deleted == 0) | (ModelConfig.is_deleted.is_(None)), - ModelConfig.type == req.type.value, - ModelConfig.is_default == 1, + select(Models).where( + (Models.is_deleted == False) | (Models.is_deleted.is_(None)), + Models.type == req.type.value, + Models.is_default == True, ) ) ).scalar_one_or_none() @@ -133,38 +126,36 @@ async def create_model(self, req: CreateModelRequest) -> ModelConfigResponse: is_default = True else: await self.db.execute( - update(ModelConfig) - .where(ModelConfig.type == req.type.value, ModelConfig.is_default == 1) - .values(is_default=0) + update(Models) + .where(Models.type == req.type.value, Models.is_default == True) + .values(is_default=False) ) is_default = req.isDefault if req.isDefault is not None else False - now = datetime.now(timezone.utc) - entity = ModelConfig( + now = datetime.now(timezone.utc).replace(tzinfo=None) + entity = Models( id=str(uuid.uuid4()), model_name=req.modelName, provider=req.provider, base_url=req.baseUrl, api_key=req.apiKey or "", type=req.type.value, - is_enabled=1, - is_default=1 if is_default else 0, - is_deleted=0, + is_enabled=True, + is_default=is_default, + is_deleted=False, created_at=now, updated_at=now, - created_by=None, - updated_by=None, ) self.db.add(entity) await self.db.commit() await self.db.refresh(entity) return _orm_to_response(entity) - async def update_model(self, model_id: str, req: CreateModelRequest) -> ModelConfigResponse: + async def update_model(self, model_id: str, req: CreateModelRequest) -> ModelsResponse: """更新模型:存在性校验、健康检查后 updateAndSetDefault;isEnabled 恒为 True。""" res = await self.db.execute( - select(ModelConfig).where(ModelConfig.id == model_id).where( - (ModelConfig.is_deleted == 0) | (ModelConfig.is_deleted.is_(None)) + select(Models).where(Models.id == model_id).where( + (Models.is_deleted == False) | (Models.is_deleted.is_(None)) ) ) entity = res.scalar_one_or_none() @@ -182,17 +173,17 @@ async def update_model(self, model_id: str, req: CreateModelRequest) -> ModelCon entity.base_url = req.baseUrl entity.api_key = req.apiKey or "" entity.type = req.type.value - entity.is_enabled = 1 + entity.is_enabled = True want_default = req.isDefault if req.isDefault is not None else False - if (entity.is_default != 1) and want_default: + if (entity.is_default is not True) and want_default: await self.db.execute( - update(ModelConfig) - .where(ModelConfig.type == req.type.value, ModelConfig.is_default == 1) - .values(is_default=0) + update(Models) + .where(Models.type == req.type.value, Models.is_default == True) + .values(is_default=False) ) - entity.is_default = 1 if want_default else 0 - entity.updated_at = datetime.now(timezone.utc) + entity.is_default = want_default + entity.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) await self.db.commit() await self.db.refresh(entity) @@ -202,14 +193,14 @@ async def delete_model(self, model_id: str) -> None: """软删除模型配置。""" entity = ( await self.db.execute( - select(ModelConfig).where(ModelConfig.id == model_id).where( - (ModelConfig.is_deleted == 0) | (ModelConfig.is_deleted.is_(None)) + select(Models).where(Models.id == model_id).where( + (Models.is_deleted == False) | (Models.is_deleted.is_(None)) ) ) ).scalar_one_or_none() if not entity: raise HTTPException(status_code=404, detail="模型配置不存在") - entity.is_deleted = 1 - entity.updated_at = datetime.now(timezone.utc) + entity.is_deleted = True + entity.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) await self.db.commit() await self.db.refresh(entity) diff --git a/runtime/datamate-python/poetry.lock b/runtime/datamate-python/poetry.lock index e3e20bad4..1556f130c 100644 --- a/runtime/datamate-python/poetry.lock +++ b/runtime/datamate-python/poetry.lock @@ -1301,6 +1301,18 @@ files = [ [package.extras] all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] +[[package]] +name = "iniconfig" +version = "2.3.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12"}, + {file = "iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730"}, +] + [[package]] name = "jiter" version = "0.12.0" @@ -2822,6 +2834,22 @@ all = ["pipmaster[audit]", "pipmaster[dev]"] audit = ["pip-audit (>=2.5.0)"] dev = ["black", "cowsay", "mypy (>=1.0)", "myst-parser (>=0.17)", "pip-audit (>=2.5.0)", "pytest (>=7.0)", "pytest-asyncio (>=0.20)", "ruff (>=0.1.0)", "sphinx (>=5.0)", "sphinx-rtd-theme (>=1.0)"] +[[package]] +name = "pluggy" +version = "1.6.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"}, + {file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["coverage", "pytest", "pytest-benchmark"] + [[package]] name = "propcache" version = "0.4.1" @@ -3293,6 +3321,28 @@ files = [ {file = "pypinyin-0.55.0.tar.gz", hash = "sha256:b5711b3a0c6f76e67408ec6b2e3c4987a3a806b7c528076e7c7b86fcf0eaa66b"}, ] +[[package]] +name = "pytest" +version = "9.0.2" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b"}, + {file = "pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11"}, +] + +[package.dependencies] +colorama = {version = ">=0.4", markers = "sys_platform == \"win32\""} +iniconfig = ">=1.0.1" +packaging = ">=22" +pluggy = ">=1.5,<2" +pygments = ">=2.7.2" + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -5379,4 +5429,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.12,<4.0.0" -content-hash = "e48bb875f1482eaca13359f7673b202ae84f5e8b95273ca28372f5d427bc5d7e" +content-hash = "9ff6942c5f12288e198a6a9aff9e0fd6550403b9fe36b694192cf771334a8f70" diff --git a/runtime/datamate-python/pyproject.toml b/runtime/datamate-python/pyproject.toml index ab43e50a8..14407902a 100644 --- a/runtime/datamate-python/pyproject.toml +++ b/runtime/datamate-python/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "fastapi-mcp (>=0.4.0,<0.5.0)", "asyncpg (>=0.31.0,<0.32.0)", "lightrag-hku (==1.4.9.8)", + "pytest (>=9.0.2,<10.0.0)", ] diff --git a/scripts/db/setting-management-init.sql b/scripts/db/setting-management-init.sql index 3f8aefb88..cc0ddbd8d 100644 --- a/scripts/db/setting-management-init.sql +++ b/scripts/db/setting-management-init.sql @@ -2,7 +2,7 @@ \c datamate; -- 模型配置表 -CREATE TABLE IF NOT EXISTS t_model_config +CREATE TABLE IF NOT EXISTS t_models ( id VARCHAR(36) PRIMARY KEY, model_name VARCHAR(100) NOT NULL, @@ -20,27 +20,27 @@ CREATE TABLE IF NOT EXISTS t_model_config ); -- 添加注释 -COMMENT ON TABLE t_model_config IS '模型配置表'; -COMMENT ON COLUMN t_model_config.id IS '主键ID'; -COMMENT ON COLUMN t_model_config.model_name IS '模型名称(如 qwen2)'; -COMMENT ON COLUMN t_model_config.provider IS '模型提供商(如 Ollama、OpenAI、DeepSeek)'; -COMMENT ON COLUMN t_model_config.base_url IS 'API 基础地址'; -COMMENT ON COLUMN t_model_config.api_key IS 'API 密钥(无密钥则为空)'; -COMMENT ON COLUMN t_model_config.type IS '模型类型(如 chat、embedding)'; -COMMENT ON COLUMN t_model_config.is_enabled IS '是否启用:1-启用,0-禁用'; -COMMENT ON COLUMN t_model_config.is_default IS '是否默认:1-默认,0-非默认'; -COMMENT ON COLUMN t_model_config.is_deleted IS '是否删除:1-已删除,0-未删除'; -COMMENT ON COLUMN t_model_config.created_at IS '创建时间'; -COMMENT ON COLUMN t_model_config.updated_at IS '更新时间'; -COMMENT ON COLUMN t_model_config.created_by IS '创建者'; -COMMENT ON COLUMN t_model_config.updated_by IS '更新者'; +COMMENT ON TABLE t_models IS '模型配置表'; +COMMENT ON COLUMN t_models.id IS '主键ID'; +COMMENT ON COLUMN t_models.model_name IS '模型名称(如 qwen2)'; +COMMENT ON COLUMN t_models.provider IS '模型提供商(如 Ollama、OpenAI、DeepSeek)'; +COMMENT ON COLUMN t_models.base_url IS 'API 基础地址'; +COMMENT ON COLUMN t_models.api_key IS 'API 密钥(无密钥则为空)'; +COMMENT ON COLUMN t_models.type IS '模型类型(如 chat、embedding)'; +COMMENT ON COLUMN t_models.is_enabled IS '是否启用:1-启用,0-禁用'; +COMMENT ON COLUMN t_models.is_default IS '是否默认:1-默认,0-非默认'; +COMMENT ON COLUMN t_models.is_deleted IS '是否删除:1-已删除,0-未删除'; +COMMENT ON COLUMN t_models.created_at IS '创建时间'; +COMMENT ON COLUMN t_models.updated_at IS '更新时间'; +COMMENT ON COLUMN t_models.created_by IS '创建者'; +COMMENT ON COLUMN t_models.updated_by IS '更新者'; -- 添加唯一约束 -ALTER TABLE t_model_config +ALTER TABLE t_models ADD CONSTRAINT uk_model_provider UNIQUE (model_name, provider, created_by); -COMMENT ON CONSTRAINT uk_model_provider ON t_model_config +COMMENT ON CONSTRAINT uk_model_provider ON t_models IS '避免同一提供商下模型名称重复'; -- 创建触发器用于自动更新 updated_at @@ -52,9 +52,9 @@ BEGIN END; $$ language 'plpgsql'; -DROP TRIGGER IF EXISTS update_t_model_config_updated_at ON t_model_config; +DROP TRIGGER IF EXISTS update_t_model_config_updated_at ON t_models; CREATE TRIGGER update_t_model_config_updated_at - BEFORE UPDATE ON t_model_config + BEFORE UPDATE ON t_models FOR EACH ROW EXECUTE FUNCTION update_updated_at_column();