Skip to content

Commit e6970be

Browse files
authored
Merge pull request #48 from Serverless-Devs/add-adb-retrieve
feat(tests): add unit tests for knowledgebase module and API
2 parents 84a901c + 5aedc34 commit e6970be

File tree

16 files changed

+5117
-13
lines changed

16 files changed

+5117
-13
lines changed

agentrun/knowledgebase/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""KnowledgeBase 模块 / KnowledgeBase Module"""
22

33
from .api import (
4+
ADBDataAPI,
45
BailianDataAPI,
56
get_data_api,
67
KnowledgeBaseControlAPI,
@@ -10,6 +11,8 @@
1011
from .client import KnowledgeBaseClient
1112
from .knowledgebase import KnowledgeBase
1213
from .model import (
14+
ADBProviderSettings,
15+
ADBRetrieveSettings,
1316
BailianProviderSettings,
1417
BailianRetrieveSettings,
1518
KnowledgeBaseCreateInput,
@@ -33,17 +36,20 @@
3336
"KnowledgeBaseDataAPI",
3437
"RagFlowDataAPI",
3538
"BailianDataAPI",
39+
"ADBDataAPI",
3640
"get_data_api",
3741
# enums
3842
"KnowledgeBaseProvider",
3943
# provider settings
4044
"ProviderSettings",
4145
"RagFlowProviderSettings",
4246
"BailianProviderSettings",
47+
"ADBProviderSettings",
4348
# retrieve settings
4449
"RetrieveSettings",
4550
"RagFlowRetrieveSettings",
4651
"BailianRetrieveSettings",
52+
"ADBRetrieveSettings",
4753
# api model
4854
"KnowledgeBaseCreateInput",
4955
"KnowledgeBaseUpdateInput",

agentrun/knowledgebase/__knowledgebase_async_template.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from .api.data import get_data_api
1616
from .model import (
17+
ADBProviderSettings,
18+
ADBRetrieveSettings,
1719
BailianProviderSettings,
1820
BailianRetrieveSettings,
1921
KnowledgeBaseCreateInput,
@@ -294,6 +296,54 @@ def _get_data_api(self, config: Optional[Config] = None):
294296
**self.retrieve_settings
295297
)
296298

299+
elif provider == KnowledgeBaseProvider.ADB:
300+
# ADB 设置 / ADB settings
301+
if self.provider_settings:
302+
if isinstance(self.provider_settings, ADBProviderSettings):
303+
converted_provider_settings = self.provider_settings
304+
elif isinstance(self.provider_settings, dict):
305+
# ADB provider_settings 使用 PascalCase 键名,需要转换为 snake_case
306+
# ADB provider_settings uses PascalCase keys, need to convert to snake_case
307+
converted_provider_settings = ADBProviderSettings(
308+
db_instance_id=self.provider_settings.get(
309+
"DBInstanceId", ""
310+
),
311+
namespace=self.provider_settings.get("Namespace", ""),
312+
namespace_password=self.provider_settings.get(
313+
"NamespacePassword", ""
314+
),
315+
embedding_model=self.provider_settings.get(
316+
"EmbeddingModel"
317+
),
318+
metrics=self.provider_settings.get("Metrics"),
319+
metadata=self.provider_settings.get("Metadata"),
320+
)
321+
322+
if self.retrieve_settings:
323+
if isinstance(self.retrieve_settings, ADBRetrieveSettings):
324+
converted_retrieve_settings = self.retrieve_settings
325+
elif isinstance(self.retrieve_settings, dict):
326+
# ADB retrieve_settings 使用 PascalCase 键名,需要转换为 snake_case
327+
# ADB retrieve_settings uses PascalCase keys, need to convert to snake_case
328+
converted_retrieve_settings = ADBRetrieveSettings(
329+
top_k=self.retrieve_settings.get("TopK"),
330+
use_full_text_retrieval=self.retrieve_settings.get(
331+
"UseFullTextRetrieval"
332+
),
333+
rerank_factor=self.retrieve_settings.get(
334+
"RerankFactor"
335+
),
336+
recall_window=self.retrieve_settings.get(
337+
"RecallWindow"
338+
),
339+
hybrid_search=self.retrieve_settings.get(
340+
"HybridSearch"
341+
),
342+
hybrid_search_args=self.retrieve_settings.get(
343+
"HybridSearchArgs"
344+
),
345+
)
346+
297347
return get_data_api(
298348
provider=provider,
299349
knowledge_base_name=self.knowledge_base_name or "",

agentrun/knowledgebase/api/__data_async_template.py

Lines changed: 222 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
提供知识库检索功能的数据链路 API。
44
Provides data API for knowledge base retrieval operations.
55
6-
根据不同的 provider 类型(ragflow / bailian)分发到不同的实现。
7-
Dispatches to different implementations based on provider type (ragflow / bailian).
6+
根据不同的 provider 类型(ragflow / bailian / adb)分发到不同的实现。
7+
Dispatches to different implementations based on provider type (ragflow / bailian / adb).
88
"""
99

1010
from abc import ABC, abstractmethod
1111
from typing import Any, Dict, List, Optional, Union
1212

1313
from alibabacloud_bailian20231229 import models as bailian_models
14+
from alibabacloud_gpdb20160503 import models as gpdb_models
1415
import httpx
1516

1617
from agentrun.utils.config import Config
@@ -19,6 +20,8 @@
1920
from agentrun.utils.log import logger
2021

2122
from ..model import (
23+
ADBProviderSettings,
24+
ADBRetrieveSettings,
2225
BailianProviderSettings,
2326
BailianRetrieveSettings,
2427
KnowledgeBaseProvider,
@@ -347,15 +350,213 @@ async def retrieve_async(
347350
}
348351

349352

353+
class ADBDataAPI(KnowledgeBaseDataAPI, ControlAPI):
354+
"""ADB (AnalyticDB for PostgreSQL) 知识库数据链路 API / ADB KnowledgeBase Data API
355+
356+
实现 ADB 知识库的检索逻辑,通过 GPDB SDK 调用 QueryContent 接口。
357+
Implements retrieval logic for ADB knowledge base via GPDB SDK QueryContent API.
358+
"""
359+
360+
def __init__(
361+
self,
362+
knowledge_base_name: str,
363+
config: Optional[Config] = None,
364+
provider_settings: Optional[ADBProviderSettings] = None,
365+
retrieve_settings: Optional[ADBRetrieveSettings] = None,
366+
):
367+
"""初始化 ADB 知识库数据链路 API / Initialize ADB KnowledgeBase Data API
368+
369+
Args:
370+
knowledge_base_name: 知识库名称 / Knowledge base name
371+
config: 配置 / Configuration
372+
provider_settings: ADB 提供商设置 / ADB provider settings
373+
retrieve_settings: ADB 检索设置 / ADB retrieve settings
374+
"""
375+
KnowledgeBaseDataAPI.__init__(self, knowledge_base_name, config)
376+
ControlAPI.__init__(self, config)
377+
self.provider_settings = provider_settings
378+
self.retrieve_settings = retrieve_settings
379+
380+
def _build_query_content_request(
381+
self, query: str, config: Optional[Config] = None
382+
) -> gpdb_models.QueryContentRequest:
383+
"""构建 QueryContent 请求 / Build QueryContent request
384+
385+
Args:
386+
query: 查询文本 / Query text
387+
config: 配置 / Configuration
388+
389+
Returns:
390+
QueryContentRequest: GPDB QueryContent 请求对象
391+
"""
392+
if self.provider_settings is None:
393+
raise ValueError("provider_settings is required for ADB retrieval")
394+
395+
cfg = Config.with_configs(self.config, config)
396+
397+
# 构建基础请求参数 / Build base request parameters
398+
request_params: Dict[str, Any] = {
399+
"content": query,
400+
"dbinstance_id": self.provider_settings.db_instance_id,
401+
"namespace": self.provider_settings.namespace,
402+
"namespace_password": self.provider_settings.namespace_password,
403+
"collection": self.knowledge_base_name,
404+
"region_id": cfg.get_region_id(),
405+
}
406+
407+
# 添加可选的提供商设置 / Add optional provider settings
408+
if self.provider_settings.metrics is not None:
409+
request_params["metrics"] = self.provider_settings.metrics
410+
411+
# 添加检索设置 / Add retrieve settings
412+
if self.retrieve_settings:
413+
if self.retrieve_settings.top_k is not None:
414+
request_params["top_k"] = self.retrieve_settings.top_k
415+
if self.retrieve_settings.use_full_text_retrieval is not None:
416+
request_params["use_full_text_retrieval"] = (
417+
self.retrieve_settings.use_full_text_retrieval
418+
)
419+
if self.retrieve_settings.rerank_factor is not None:
420+
request_params["rerank_factor"] = (
421+
self.retrieve_settings.rerank_factor
422+
)
423+
if self.retrieve_settings.recall_window is not None:
424+
request_params["recall_window"] = (
425+
self.retrieve_settings.recall_window
426+
)
427+
if self.retrieve_settings.hybrid_search is not None:
428+
request_params["hybrid_search"] = (
429+
self.retrieve_settings.hybrid_search
430+
)
431+
if self.retrieve_settings.hybrid_search_args is not None:
432+
request_params["hybrid_search_args"] = (
433+
self.retrieve_settings.hybrid_search_args
434+
)
435+
436+
return gpdb_models.QueryContentRequest(**request_params)
437+
438+
def _parse_query_content_response(
439+
self, response: gpdb_models.QueryContentResponse, query: str
440+
) -> Dict[str, Any]:
441+
"""解析 QueryContent 响应 / Parse QueryContent response
442+
443+
Args:
444+
response: GPDB QueryContent 响应对象
445+
query: 原始查询文本 / Original query text
446+
447+
Returns:
448+
Dict[str, Any]: 格式化的检索结果 / Formatted retrieval results
449+
"""
450+
all_matches: List[Dict[str, Any]] = []
451+
452+
if response.body and response.body.matches:
453+
match_list = response.body.matches.match_list or []
454+
for match in match_list:
455+
all_matches.append({
456+
"content": (
457+
match.content if hasattr(match, "content") else None
458+
),
459+
"score": match.score if hasattr(match, "score") else None,
460+
"id": match.id if hasattr(match, "id") else None,
461+
"file_name": (
462+
match.file_name if hasattr(match, "file_name") else None
463+
),
464+
"file_url": (
465+
match.file_url if hasattr(match, "file_url") else None
466+
),
467+
"metadata": (
468+
match.metadata if hasattr(match, "metadata") else None
469+
),
470+
"rerank_score": (
471+
match.rerank_score
472+
if hasattr(match, "rerank_score")
473+
else None
474+
),
475+
"retrieval_source": (
476+
match.retrieval_source
477+
if hasattr(match, "retrieval_source")
478+
else None
479+
),
480+
})
481+
482+
return {
483+
"data": all_matches,
484+
"query": query,
485+
"knowledge_base_name": self.knowledge_base_name,
486+
"request_id": (
487+
response.body.request_id
488+
if response.body and hasattr(response.body, "request_id")
489+
else None
490+
),
491+
}
492+
493+
async def retrieve_async(
494+
self,
495+
query: str,
496+
config: Optional[Config] = None,
497+
) -> Dict[str, Any]:
498+
"""ADB 检索(异步)/ ADB retrieval asynchronously
499+
500+
通过 GPDB SDK 调用 QueryContent 接口进行知识库检索。
501+
Retrieves from ADB knowledge base via GPDB SDK QueryContent API.
502+
503+
Args:
504+
query: 查询文本 / Query text
505+
config: 配置 / Configuration
506+
507+
Returns:
508+
Dict[str, Any]: 检索结果 / Retrieval results
509+
"""
510+
try:
511+
if self.provider_settings is None:
512+
raise ValueError(
513+
"provider_settings is required for ADB retrieval"
514+
)
515+
516+
# 获取 GPDB 客户端 / Get GPDB client
517+
client = self._get_gpdb_client(config)
518+
519+
# 构建请求 / Build request
520+
request = self._build_query_content_request(query, config)
521+
logger.debug(f"ADB QueryContent request: {request}")
522+
523+
# 调用 QueryContent API / Call QueryContent API
524+
response = await client.query_content_async(request)
525+
logger.debug(f"ADB QueryContent response: {response}")
526+
527+
# 解析并返回结果 / Parse and return results
528+
return self._parse_query_content_response(response, query)
529+
530+
except Exception as e:
531+
logger.warning(
532+
"Failed to retrieve from ADB knowledge base "
533+
f"'{self.knowledge_base_name}': {e}"
534+
)
535+
return {
536+
"data": f"Failed to retrieve: {e}",
537+
"query": query,
538+
"knowledge_base_name": self.knowledge_base_name,
539+
"error": True,
540+
}
541+
542+
350543
def get_data_api(
351544
provider: KnowledgeBaseProvider,
352545
knowledge_base_name: str,
353546
config: Optional[Config] = None,
354547
provider_settings: Optional[
355-
Union[RagFlowProviderSettings, BailianProviderSettings]
548+
Union[
549+
RagFlowProviderSettings,
550+
BailianProviderSettings,
551+
ADBProviderSettings,
552+
]
356553
] = None,
357554
retrieve_settings: Optional[
358-
Union[RagFlowRetrieveSettings, BailianRetrieveSettings]
555+
Union[
556+
RagFlowRetrieveSettings,
557+
BailianRetrieveSettings,
558+
ADBRetrieveSettings,
559+
]
359560
] = None,
360561
credential_name: Optional[str] = None,
361562
) -> KnowledgeBaseDataAPI:
@@ -410,5 +611,22 @@ def get_data_api(
410611
provider_settings=bailian_provider_settings,
411612
retrieve_settings=bailian_retrieve_settings,
412613
)
614+
elif provider == KnowledgeBaseProvider.ADB or provider == "adb":
615+
adb_provider_settings = (
616+
provider_settings
617+
if isinstance(provider_settings, ADBProviderSettings)
618+
else None
619+
)
620+
adb_retrieve_settings = (
621+
retrieve_settings
622+
if isinstance(retrieve_settings, ADBRetrieveSettings)
623+
else None
624+
)
625+
return ADBDataAPI(
626+
knowledge_base_name,
627+
config,
628+
provider_settings=adb_provider_settings,
629+
retrieve_settings=adb_retrieve_settings,
630+
)
413631
else:
414632
raise ValueError(f"Unsupported provider type: {provider}")

agentrun/knowledgebase/api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from .control import KnowledgeBaseControlAPI
44
from .data import (
5+
ADBDataAPI,
56
BailianDataAPI,
67
get_data_api,
78
KnowledgeBaseDataAPI,
@@ -15,5 +16,6 @@
1516
"KnowledgeBaseDataAPI",
1617
"RagFlowDataAPI",
1718
"BailianDataAPI",
19+
"ADBDataAPI",
1820
"get_data_api",
1921
]

0 commit comments

Comments
 (0)