Skip to content

Commit 57f3930

Browse files
committed
feat(table-selection): add LLM-based table selection for SQL generation
1 parent 9f89134 commit 57f3930

File tree

11 files changed

+456
-12
lines changed

11 files changed

+456
-12
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""add table_select_answer column to chat_record
2+
3+
Revision ID: 054_table_select
4+
Revises: 5755c0b95839
5+
Create Date: 2025-12-23
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
revision = '054_table_select'
12+
down_revision = '5755c0b95839'
13+
branch_labels = None
14+
depends_on = None
15+
16+
17+
def upgrade():
18+
op.add_column('chat_record', sa.Column('table_select_answer', sa.Text(), nullable=True))
19+
20+
21+
def downgrade():
22+
op.drop_column('chat_record', 'table_select_answer')

backend/apps/chat/curd/chat.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,26 @@ def save_select_datasource_answer(session: SessionDep, record_id: int, answer: s
694694
return result
695695

696696

697+
def save_table_select_answer(session: SessionDep, record_id: int, answer: str) -> ChatRecord:
698+
"""保存 LLM 表选择的结果到 ChatRecord"""
699+
if not record_id:
700+
raise Exception("Record id cannot be None")
701+
record = get_chat_record_by_id(session, record_id)
702+
703+
record.table_select_answer = answer
704+
705+
result = ChatRecord(**record.model_dump())
706+
707+
stmt = update(ChatRecord).where(and_(ChatRecord.id == record.id)).values(
708+
table_select_answer=record.table_select_answer,
709+
)
710+
711+
session.execute(stmt)
712+
session.commit()
713+
714+
return result
715+
716+
697717
def save_recommend_question_answer(session: SessionDep, record_id: int,
698718
answer: dict = None, articles_number: Optional[int] = 4) -> ChatRecord:
699719
if not record_id:

backend/apps/chat/models/chat_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class OperationEnum(Enum):
4040
GENERATE_SQL_WITH_PERMISSIONS = '5'
4141
CHOOSE_DATASOURCE = '6'
4242
GENERATE_DYNAMIC_SQL = '7'
43+
SELECT_TABLE = '8' # LLM 表选择
4344

4445

4546
class ChatFinishStep(Enum):
@@ -115,6 +116,7 @@ class ChatRecord(SQLModel, table=True):
115116
recommended_question_answer: str = Field(sa_column=Column(Text, nullable=True))
116117
recommended_question: str = Field(sa_column=Column(Text, nullable=True))
117118
datasource_select_answer: str = Field(sa_column=Column(Text, nullable=True))
119+
table_select_answer: str = Field(sa_column=Column(Text, nullable=True))
118120
finish: bool = Field(sa_column=Column(Boolean, nullable=True, default=False))
119121
error: str = Field(sa_column=Column(Text, nullable=True))
120122
analysis_record_id: int = Field(sa_column=Column(BigInteger, nullable=True))
@@ -140,6 +142,7 @@ class ChatRecordResult(BaseModel):
140142
predict_data: Optional[str] = None
141143
recommended_question: Optional[str] = None
142144
datasource_select_answer: Optional[str] = None
145+
table_select_answer: Optional[str] = None
143146
finish: Optional[bool] = None
144147
error: Optional[str] = None
145148
analysis_record_id: Optional[int] = None

backend/apps/chat/task/llm.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,16 @@ def __init__(self, session: Session, current_user: CurrentUser, chat_question: C
121121
if not ds:
122122
raise SingleMessageError("No available datasource configuration found")
123123
chat_question.engine = (ds.type_name if ds.type != 'excel' else 'PostgreSQL') + get_version(ds)
124-
chat_question.db_schema = get_table_schema(session=session, current_user=current_user, ds=ds,
125-
question=chat_question.question, embedding=embedding)
124+
# 延迟 get_table_schema 调用到 init_record 之后,以便记录 LLM 表选择日志
125+
self._pending_schema_params = {
126+
'session': session,
127+
'current_user': current_user,
128+
'ds': ds,
129+
'question': chat_question.question,
130+
'embedding': embedding,
131+
'history_questions': history_questions,
132+
'config': config
133+
}
126134

127135
self.generate_sql_logs = list_generate_sql_logs(session=session, chart_id=chat_id)
128136
self.generate_chart_logs = list_generate_chart_logs(session=session, chart_id=chat_id)
@@ -230,6 +238,22 @@ def init_messages(self):
230238

231239
def init_record(self, session: Session) -> ChatRecord:
232240
self.record = save_question(session=session, current_user=self.current_user, question=self.chat_question)
241+
242+
# 如果有延迟的 schema 获取,现在执行(此时 record 已存在,可以记录 LLM 表选择日志)
243+
if hasattr(self, '_pending_schema_params') and self._pending_schema_params:
244+
params = self._pending_schema_params
245+
self.chat_question.db_schema = get_table_schema(
246+
session=params['session'],
247+
current_user=params['current_user'],
248+
ds=params['ds'],
249+
question=params['question'],
250+
embedding=params['embedding'],
251+
history_questions=params['history_questions'],
252+
config=params['config'],
253+
record_id=self.record.id
254+
)
255+
self._pending_schema_params = None
256+
233257
return self.record
234258

235259
def get_record(self):
@@ -355,7 +379,9 @@ def generate_recommend_questions_task(self, _session: Session):
355379
session=_session,
356380
current_user=self.current_user, ds=self.ds,
357381
question=self.chat_question.question,
358-
embedding=False)
382+
embedding=False,
383+
config=self.config,
384+
record_id=self.record.id)
359385

360386
guess_msg: List[Union[BaseMessage, dict[str, Any]]] = []
361387
guess_msg.append(SystemMessage(content=self.chat_question.guess_sys_question(self.articles_number)))
@@ -500,7 +526,9 @@ def select_datasource(self, _session: Session):
500526
self.ds)
501527
self.chat_question.db_schema = get_table_schema(session=_session,
502528
current_user=self.current_user, ds=self.ds,
503-
question=self.chat_question.question)
529+
question=self.chat_question.question,
530+
config=self.config,
531+
record_id=self.record.id)
504532
_engine_type = self.chat_question.engine
505533
_chat.engine_type = _ds.type_name
506534
# save chat
@@ -1003,7 +1031,9 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
10031031
session=_session,
10041032
current_user=self.current_user,
10051033
ds=self.ds,
1006-
question=self.chat_question.question)
1034+
question=self.chat_question.question,
1035+
config=self.config,
1036+
record_id=self.record.id)
10071037
else:
10081038
self.validate_history_ds(_session)
10091039

backend/apps/datasource/crud/datasource.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from sqlbot_xpack.permissions.models.ds_rules import DsRules
88
from sqlmodel import select
99

10+
from apps.ai_model.model_factory import LLMConfig
1011
from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user
1112
from apps.datasource.embedding.table_embedding import calc_table_embedding
13+
from apps.datasource.llm_select.table_selection import calc_table_llm_selection
1214
from apps.datasource.utils.utils import aes_decrypt
1315
from apps.db.constant import DB
1416
from apps.db.db import get_tables, get_fields, exec_sql, check_connection
@@ -425,7 +427,8 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
425427

426428

427429
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
428-
embedding: bool = True) -> str:
430+
embedding: bool = True, history_questions: List[str] = None,
431+
config: LLMConfig = None, lang: str = "中文", record_id: int = None) -> str:
429432
schema_str = ""
430433
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
431434
if len(table_objs) == 0:
@@ -434,7 +437,12 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
434437
schema_str += f"【DB_ID】 {db_name}\n【Schema】\n"
435438
tables = []
436439
all_tables = [] # temp save all tables
440+
441+
# 构建 table_name -> table_obj 映射,用于 LLM 表选择
442+
table_name_to_obj = {}
437443
for obj in table_objs:
444+
table_name_to_obj[obj.table.table_name] = obj
445+
438446
schema_table = ''
439447
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
440448
table_comment = ''
@@ -462,16 +470,36 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
462470
tables.append(t_obj)
463471
all_tables.append(t_obj)
464472

465-
# do table embedding
466-
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
467-
tables = calc_table_embedding(tables, question)
473+
# do table selection
474+
used_llm_selection = False # 标记是否使用了 LLM 表选择
475+
if embedding and tables:
476+
if settings.TABLE_LLM_SELECTION_ENABLED and config:
477+
# 使用 LLM 表选择
478+
selected_table_names = calc_table_llm_selection(
479+
config=config,
480+
table_objs=table_objs,
481+
question=question,
482+
ds_table_relation=ds.table_relation,
483+
history_questions=history_questions,
484+
lang=lang,
485+
session=session,
486+
record_id=record_id
487+
)
488+
if selected_table_names:
489+
# 根据选中的表名筛选 tables
490+
selected_table_ids = [table_name_to_obj[name].table.id for name in selected_table_names if name in table_name_to_obj]
491+
tables = [t for t in tables if t.get('id') in selected_table_ids]
492+
used_llm_selection = True # LLM 成功选择了表
493+
elif settings.TABLE_EMBEDDING_ENABLED:
494+
# 使用 RAG 表选择
495+
tables = calc_table_embedding(tables, question, history_questions)
468496
# splice schema
469497
if tables:
470498
for s in tables:
471499
schema_str += s.get('schema_table')
472500

473-
# field relation
474-
if tables and ds.table_relation:
501+
# field relation - LLM 表选择模式下不补全关联表,完全信任 LLM 的选择结果
502+
if tables and ds.table_relation and not used_llm_selection:
475503
relations = list(filter(lambda x: x.get('shape') == 'edge', ds.table_relation))
476504
if relations:
477505
# Complete the missing table
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Author: SQLBot
2+
# Date: 2025/12/23

0 commit comments

Comments
 (0)