Skip to content
Open

Dev #283

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
*/
@Getter
@Setter
@TableName("t_model_config")
@TableName("t_models")
@Builder
@ToString
@NoArgsConstructor
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
from sqlalchemy import Column, String, Integer, TIMESTAMP, select
from sqlalchemy import Boolean, Column, String, TIMESTAMP

from app.db.models.base_entity import BaseEntity


async def get_model_by_id(db_session, model_id: str):
"""根据 ID 获取单个模型配置。"""
result =await db_session.execute(select(ModelConfig).where(ModelConfig.id == model_id))
model_config = result.scalar_one_or_none()
return model_config
class Models(BaseEntity):
"""模型配置表,对应表 t_models

class ModelConfig(BaseEntity):
"""模型配置表,对应表 t_model_config

CREATE TABLE IF NOT EXISTS t_model_config (
CREATE TABLE IF NOT EXISTS t_models (
id VARCHAR(36) PRIMARY KEY COMMENT '主键ID',
model_name VARCHAR(100) NOT NULL COMMENT '模型名称(如 qwen2)',
provider VARCHAR(50) NOT NULL COMMENT '模型提供商(如 Ollama、OpenAI、DeepSeek)',
Expand All @@ -29,7 +23,7 @@ class ModelConfig(BaseEntity):
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4 COMMENT ='模型配置表';
"""

__tablename__ = "t_model_config"
__tablename__ = "t_models"

id = Column(String(36), primary_key=True, index=True, comment="主键ID")
model_name = Column(String(100), nullable=False, comment="模型名称(如 qwen2)")
Expand All @@ -38,9 +32,9 @@ class ModelConfig(BaseEntity):
api_key = Column(String(512), nullable=False, default="", comment="API 密钥(无密钥则为空)")
type = Column(String(50), nullable=False, comment="模型类型(如 chat、embedding)")

# 使用 Integer 存储 TINYINT,后续可在业务层将 0/1 转为 bool
is_enabled = Column(Integer, nullable=False, default=1, comment="是否启用:1-启用,0-禁用")
is_default = Column(Integer, nullable=False, default=0, comment="是否默认:1-默认,0-非默认")
is_enabled = Column(Boolean, nullable=False, default=True, comment="是否启用")
is_default = Column(Boolean, nullable=False, default=False, comment="是否默认")
is_deleted = Column(Boolean, nullable=False, default=False, comment="是否删除")

__table_args__ = (
# 与 DDL 中的 uk_model_provider 保持一致
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ async def create_evaluation_task(
if existing_task.scalar_one_or_none():
raise HTTPException(status_code=400, detail=f"Evaluation task with name '{request.name}' already exists")

model_config = await get_model_by_id(db, request.eval_config.model_id)
if not model_config:
models = await get_model_by_id(db, request.eval_config.model_id)
if not models:
raise HTTPException(status_code=400, detail=f"Model with id '{request.eval_config.model_id}' not found")

# 创建评估任务
Expand All @@ -96,7 +96,7 @@ async def create_evaluation_task(
eval_prompt=request.eval_prompt,
eval_config=json.dumps({
"modelId": request.eval_config.model_id,
"modelName": model_config.model_name,
"modelName": models.model_name,
"dimensions": request.eval_config.dimensions,
}),
status=TaskStatus.PENDING.value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_eval_prompt(self, item: EvaluationItem) -> str:

async def execute(self):
eval_config = json.loads(self.task.eval_config)
model_config = await get_model_by_id(self.db, eval_config.get("modelId"))
models = await get_model_by_id(self.db, eval_config.get("modelId"))
semaphore = asyncio.Semaphore(10)
files = (await self.db.execute(
select(EvaluationFile).where(EvaluationFile.task_id == self.task.id)
Expand All @@ -55,7 +55,7 @@ async def execute(self):
for file in files:
items = (await self.db.execute(query.where(EvaluationItem.file_id == file.file_id))).scalars().all()
tasks = [
self.evaluate_item(model_config, item, semaphore)
self.evaluate_item(models, item, semaphore)
for item in items
]
await asyncio.gather(*tasks, return_exceptions=True)
Expand All @@ -64,13 +64,13 @@ async def execute(self):
self.task.eval_process = evaluated_count / total
await self.db.commit()

async def evaluate_item(self, model_config, item: EvaluationItem, semaphore: asyncio.Semaphore):
async def evaluate_item(self, models, item: EvaluationItem, semaphore: asyncio.Semaphore):
async with semaphore:
max_try = 3
while max_try > 0:
prompt_text = self.get_eval_prompt(item)
resp_text = await asyncio.to_thread(
call_openai_style_model, model_config.base_url, model_config.api_key, model_config.model_name,
call_openai_style_model, models.base_url, models.api_key, models.model_name,
prompt_text,
)
resp_text = extract_json_substring(resp_text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from app.module.shared.common.document_loaders import load_documents
from app.module.shared.common.text_split import DocumentSplitter
from app.module.shared.util.model_chat import extract_json_substring
from app.module.system.service.common_service import chat, get_model_by_id, get_chat_client
from app.module.shared.llm import LLMFactory
from app.module.system.service.common_service import get_model_by_id


def _filter_docs(split_docs, chunk_size):
Expand Down Expand Up @@ -171,8 +172,12 @@ async def _process_single_file(
# 为本文件构建模型 client
question_model = await get_model_by_id(self.db, question_cfg.model_id)
answer_model = await get_model_by_id(self.db, answer_cfg.model_id)
question_chat = get_chat_client(question_model)
answer_chat = get_chat_client(answer_model)
question_chat = LLMFactory.create_chat(
question_model.model_name, question_model.base_url, question_model.api_key
)
answer_chat = LLMFactory.create_chat(
answer_model.model_name, answer_model.base_url, answer_model.api_key
)

# 分批次从 DB 读取并处理 chunk
batch_size = 100
Expand Down Expand Up @@ -356,7 +361,7 @@ async def _generate_questions_for_one_chunk(
loop = asyncio.get_running_loop()
raw_answer = await loop.run_in_executor(
None,
chat,
LLMFactory.invoke_sync,
question_chat,
prompt,
)
Expand Down Expand Up @@ -400,7 +405,7 @@ async def process_single_question(question: str):
loop = asyncio.get_running_loop()
answer = await loop.run_in_executor(
None,
chat,
LLMFactory.invoke_sync,
answer_chat,
prompt_local,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
router = APIRouter(prefix="/rag", tags=["rag"])

@router.post("/process/{knowledge_base_id}")
async def process_knowledge_base(knowledge_base_id: str, db: AsyncSession = Depends(get_db)):
async def process_knowledge_base(knowledge_base_id: str, rag_service: RAGService = Depends()):
"""
Process all unprocessed files in a knowledge base.
"""
try:
await RAGService(db).init_graph_rag(knowledge_base_id)
await rag_service.init_graph_rag(knowledge_base_id)
return StandardResponse(
code=200,
message="Processing started for knowledge base.",
Expand All @@ -24,12 +24,11 @@ async def process_knowledge_base(knowledge_base_id: str, db: AsyncSession = Depe
raise HTTPException(status_code=500, detail=str(e))

@router.post("/query")
async def query_knowledge_graph(payload: QueryRequest, db: AsyncSession = Depends(get_db)):
async def query_knowledge_graph(payload: QueryRequest, rag_service: RAGService = Depends()):
"""
Query the knowledge graph with the given query text and knowledge base ID.
"""
try:
rag_service = RAGService(db)
result = await rag_service.query_rag(payload.query, payload.knowledge_base_id)
return StandardResponse(code=200, message="success", data=result)
except HTTPException:
Expand Down
29 changes: 15 additions & 14 deletions runtime/datamate-python/app/module/rag/service/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import asyncio
from typing import Optional, Sequence

from fastapi import BackgroundTasks, Depends
from fastapi import Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from app.core.logging import get_logger
from app.db.models.dataset_management import DatasetFiles
from app.db.models.knowledge_gen import RagFile, RagKnowledgeBase
from app.db.models.model_config import ModelConfig
from app.db.session import get_db, AsyncSessionLocal
from app.module.shared.common.document_loaders import load_documents
from .graph_rag import (
Expand All @@ -18,7 +17,8 @@
build_llm_model_func,
initialize_rag,
)
from ...system.service.common_service import get_embedding_dimension, get_openai_client
from app.module.shared.llm import LLMFactory
from ...system.service.common_service import get_model_by_id

logger = get_logger(__name__)

Expand All @@ -27,10 +27,10 @@ class RAGService:
def __init__(
self,
db: AsyncSession = Depends(get_db),
background_tasks: BackgroundTasks | None = None,

):
self.db = db
self.background_tasks = background_tasks
self.background_tasks = None
self.rag = None

async def get_unprocessed_files(self, knowledge_base_id: str) -> Sequence[RagFile]:
Expand All @@ -44,8 +44,8 @@ async def get_unprocessed_files(self, knowledge_base_id: str) -> Sequence[RagFil

async def init_graph_rag(self, knowledge_base_id: str):
kb = await self._get_knowledge_base(knowledge_base_id)
embedding_model = await self._get_model_config(kb.embedding_model)
chat_model = await self._get_model_config(kb.chat_model)
embedding_model = await self._get_models(kb.embedding_model)
chat_model = await self._get_models(kb.chat_model)

llm_callable = await build_llm_model_func(
chat_model.model_name, chat_model.base_url, chat_model.api_key
Expand All @@ -54,7 +54,9 @@ async def init_graph_rag(self, knowledge_base_id: str):
embedding_model.model_name,
embedding_model.base_url,
embedding_model.api_key,
embedding_dim=get_embedding_dimension(get_openai_client(embedding_model)),
embedding_dim=LLMFactory.get_embedding_dimension(
embedding_model.model_name, embedding_model.base_url, embedding_model.api_key
),
)

kb_working_dir = os.path.join(DEFAULT_WORKING_DIR, kb.name)
Expand Down Expand Up @@ -124,14 +126,13 @@ async def _get_knowledge_base(self, knowledge_base_id: str):
raise ValueError(f"Knowledge base with ID {knowledge_base_id} not found.")
return knowledge_base

async def _get_model_config(self, model_id: Optional[str]):
async def _get_models(self, model_id: Optional[str]):
if not model_id:
raise ValueError("Model ID is required for initializing RAG.")
result = await self.db.execute(select(ModelConfig).where(ModelConfig.id == model_id))
model = result.scalars().first()
if not model:
raise ValueError(f"Model config with ID {model_id} not found.")
return model
models = await get_model_by_id(self.db, model_id)
if not models:
raise ValueError(f"Models with ID {model_id} not found.")
return models

async def query_rag(self, query: str, knowledge_base_id: str) -> str:
if not self.rag:
Expand Down
7 changes: 7 additions & 0 deletions runtime/datamate-python/app/module/shared/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# app/core/llm/__init__.py
"""
LangChain 模型工厂:统一创建 Chat、Embedding 及健康检查,便于各模块复用。
"""
from .factory import LLMFactory

__all__ = ["LLMFactory"]
71 changes: 71 additions & 0 deletions runtime/datamate-python/app/module/shared/llm/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# app/core/llm/factory.py
"""
LangChain 模型工厂:基于 OpenAI 兼容接口封装 Chat / Embedding 的创建、健康检查与同步调用。
便于模型配置、RAG、生成、评估等模块统一使用,避免分散的 get_chat_client / get_openai_client。
"""
from typing import Literal

from langchain_core.language_models import BaseChatModel
from langchain_core.embeddings import Embeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from pydantic import SecretStr


class LLMFactory:
"""基于 LangChain 的 Chat / Embedding 工厂,面向 OpenAI 兼容 API。"""

@staticmethod
def create_chat(
model_name: str,
base_url: str,
api_key: str | None = None,
) -> BaseChatModel:
"""创建对话模型,兼容 OpenAI 及任意 base_url 的 OpenAI 兼容服务。"""
return ChatOpenAI(
model=model_name,
base_url=base_url or None,
api_key=SecretStr(api_key or ""),
)

@staticmethod
def create_embedding(
model_name: str,
base_url: str,
api_key: str | None = None,
) -> Embeddings:
"""创建嵌入模型,兼容 OpenAI 及任意 base_url 的 OpenAI 兼容服务。"""
return OpenAIEmbeddings(
model=model_name,
base_url=base_url or None,
api_key=SecretStr(api_key or ""),
)

@staticmethod
def check_health(
model_name: str,
base_url: str,
api_key: str | None,
model_type: Literal["CHAT", "EMBEDDING"] | str,
) -> None:
"""对配置做一次最小化调用进行健康检查,失败则抛出。"""
if model_type == "CHAT":
model = LLMFactory.create_chat(model_name, base_url, api_key)
model.invoke("hello")
else:
model = LLMFactory.create_embedding(model_name, base_url, api_key)
model.embed_query("text")

@staticmethod
def get_embedding_dimension(
model_name: str,
base_url: str,
api_key: str | None = None,
) -> int:
"""创建 Embedding 模型并返回向量维度。"""
emb = LLMFactory.create_embedding(model_name, base_url, api_key)
return len(emb.embed_query("text"))

@staticmethod
def invoke_sync(chat_model: BaseChatModel, prompt: str) -> str:
"""同步调用对话模型并返回 content,供 run_in_executor 等场景使用。"""
return chat_model.invoke(prompt).content
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from fastapi import APIRouter

from .about import router as about_router
from app.module.system.interface.models import router as models_router

router = APIRouter()

router.include_router(about_router)
router.include_router(about_router)
router.include_router(models_router)
Loading
Loading