|
3 | 3 | 提供知识库检索功能的数据链路 API。 |
4 | 4 | Provides data API for knowledge base retrieval operations. |
5 | 5 |
|
6 | | -根据不同的 provider 类型(ragflow / bailian)分发到不同的实现。 |
7 | | -Dispatches to different implementations based on provider type (ragflow / bailian). |
| 6 | +根据不同的 provider 类型(ragflow / bailian / adb)分发到不同的实现。 |
| 7 | +Dispatches to different implementations based on provider type (ragflow / bailian / adb). |
8 | 8 | """ |
9 | 9 |
|
10 | 10 | from abc import ABC, abstractmethod |
11 | 11 | from typing import Any, Dict, List, Optional, Union |
12 | 12 |
|
13 | 13 | from alibabacloud_bailian20231229 import models as bailian_models |
| 14 | +from alibabacloud_gpdb20160503 import models as gpdb_models |
14 | 15 | import httpx |
15 | 16 |
|
16 | 17 | from agentrun.utils.config import Config |
|
19 | 20 | from agentrun.utils.log import logger |
20 | 21 |
|
21 | 22 | from ..model import ( |
| 23 | + ADBProviderSettings, |
| 24 | + ADBRetrieveSettings, |
22 | 25 | BailianProviderSettings, |
23 | 26 | BailianRetrieveSettings, |
24 | 27 | KnowledgeBaseProvider, |
@@ -347,15 +350,213 @@ async def retrieve_async( |
347 | 350 | } |
348 | 351 |
|
349 | 352 |
|
| 353 | +class ADBDataAPI(KnowledgeBaseDataAPI, ControlAPI): |
| 354 | + """ADB (AnalyticDB for PostgreSQL) 知识库数据链路 API / ADB KnowledgeBase Data API |
| 355 | +
|
| 356 | + 实现 ADB 知识库的检索逻辑,通过 GPDB SDK 调用 QueryContent 接口。 |
| 357 | + Implements retrieval logic for ADB knowledge base via GPDB SDK QueryContent API. |
| 358 | + """ |
| 359 | + |
| 360 | + def __init__( |
| 361 | + self, |
| 362 | + knowledge_base_name: str, |
| 363 | + config: Optional[Config] = None, |
| 364 | + provider_settings: Optional[ADBProviderSettings] = None, |
| 365 | + retrieve_settings: Optional[ADBRetrieveSettings] = None, |
| 366 | + ): |
| 367 | + """初始化 ADB 知识库数据链路 API / Initialize ADB KnowledgeBase Data API |
| 368 | +
|
| 369 | + Args: |
| 370 | + knowledge_base_name: 知识库名称 / Knowledge base name |
| 371 | + config: 配置 / Configuration |
| 372 | + provider_settings: ADB 提供商设置 / ADB provider settings |
| 373 | + retrieve_settings: ADB 检索设置 / ADB retrieve settings |
| 374 | + """ |
| 375 | + KnowledgeBaseDataAPI.__init__(self, knowledge_base_name, config) |
| 376 | + ControlAPI.__init__(self, config) |
| 377 | + self.provider_settings = provider_settings |
| 378 | + self.retrieve_settings = retrieve_settings |
| 379 | + |
| 380 | + def _build_query_content_request( |
| 381 | + self, query: str, config: Optional[Config] = None |
| 382 | + ) -> gpdb_models.QueryContentRequest: |
| 383 | + """构建 QueryContent 请求 / Build QueryContent request |
| 384 | +
|
| 385 | + Args: |
| 386 | + query: 查询文本 / Query text |
| 387 | + config: 配置 / Configuration |
| 388 | +
|
| 389 | + Returns: |
| 390 | + QueryContentRequest: GPDB QueryContent 请求对象 |
| 391 | + """ |
| 392 | + if self.provider_settings is None: |
| 393 | + raise ValueError("provider_settings is required for ADB retrieval") |
| 394 | + |
| 395 | + cfg = Config.with_configs(self.config, config) |
| 396 | + |
| 397 | + # 构建基础请求参数 / Build base request parameters |
| 398 | + request_params: Dict[str, Any] = { |
| 399 | + "content": query, |
| 400 | + "dbinstance_id": self.provider_settings.db_instance_id, |
| 401 | + "namespace": self.provider_settings.namespace, |
| 402 | + "namespace_password": self.provider_settings.namespace_password, |
| 403 | + "collection": self.knowledge_base_name, |
| 404 | + "region_id": cfg.get_region_id(), |
| 405 | + } |
| 406 | + |
| 407 | + # 添加可选的提供商设置 / Add optional provider settings |
| 408 | + if self.provider_settings.metrics is not None: |
| 409 | + request_params["metrics"] = self.provider_settings.metrics |
| 410 | + |
| 411 | + # 添加检索设置 / Add retrieve settings |
| 412 | + if self.retrieve_settings: |
| 413 | + if self.retrieve_settings.top_k is not None: |
| 414 | + request_params["top_k"] = self.retrieve_settings.top_k |
| 415 | + if self.retrieve_settings.use_full_text_retrieval is not None: |
| 416 | + request_params["use_full_text_retrieval"] = ( |
| 417 | + self.retrieve_settings.use_full_text_retrieval |
| 418 | + ) |
| 419 | + if self.retrieve_settings.rerank_factor is not None: |
| 420 | + request_params["rerank_factor"] = ( |
| 421 | + self.retrieve_settings.rerank_factor |
| 422 | + ) |
| 423 | + if self.retrieve_settings.recall_window is not None: |
| 424 | + request_params["recall_window"] = ( |
| 425 | + self.retrieve_settings.recall_window |
| 426 | + ) |
| 427 | + if self.retrieve_settings.hybrid_search is not None: |
| 428 | + request_params["hybrid_search"] = ( |
| 429 | + self.retrieve_settings.hybrid_search |
| 430 | + ) |
| 431 | + if self.retrieve_settings.hybrid_search_args is not None: |
| 432 | + request_params["hybrid_search_args"] = ( |
| 433 | + self.retrieve_settings.hybrid_search_args |
| 434 | + ) |
| 435 | + |
| 436 | + return gpdb_models.QueryContentRequest(**request_params) |
| 437 | + |
| 438 | + def _parse_query_content_response( |
| 439 | + self, response: gpdb_models.QueryContentResponse, query: str |
| 440 | + ) -> Dict[str, Any]: |
| 441 | + """解析 QueryContent 响应 / Parse QueryContent response |
| 442 | +
|
| 443 | + Args: |
| 444 | + response: GPDB QueryContent 响应对象 |
| 445 | + query: 原始查询文本 / Original query text |
| 446 | +
|
| 447 | + Returns: |
| 448 | + Dict[str, Any]: 格式化的检索结果 / Formatted retrieval results |
| 449 | + """ |
| 450 | + all_matches: List[Dict[str, Any]] = [] |
| 451 | + |
| 452 | + if response.body and response.body.matches: |
| 453 | + match_list = response.body.matches.match_list or [] |
| 454 | + for match in match_list: |
| 455 | + all_matches.append({ |
| 456 | + "content": ( |
| 457 | + match.content if hasattr(match, "content") else None |
| 458 | + ), |
| 459 | + "score": match.score if hasattr(match, "score") else None, |
| 460 | + "id": match.id if hasattr(match, "id") else None, |
| 461 | + "file_name": ( |
| 462 | + match.file_name if hasattr(match, "file_name") else None |
| 463 | + ), |
| 464 | + "file_url": ( |
| 465 | + match.file_url if hasattr(match, "file_url") else None |
| 466 | + ), |
| 467 | + "metadata": ( |
| 468 | + match.metadata if hasattr(match, "metadata") else None |
| 469 | + ), |
| 470 | + "rerank_score": ( |
| 471 | + match.rerank_score |
| 472 | + if hasattr(match, "rerank_score") |
| 473 | + else None |
| 474 | + ), |
| 475 | + "retrieval_source": ( |
| 476 | + match.retrieval_source |
| 477 | + if hasattr(match, "retrieval_source") |
| 478 | + else None |
| 479 | + ), |
| 480 | + }) |
| 481 | + |
| 482 | + return { |
| 483 | + "data": all_matches, |
| 484 | + "query": query, |
| 485 | + "knowledge_base_name": self.knowledge_base_name, |
| 486 | + "request_id": ( |
| 487 | + response.body.request_id |
| 488 | + if response.body and hasattr(response.body, "request_id") |
| 489 | + else None |
| 490 | + ), |
| 491 | + } |
| 492 | + |
| 493 | + async def retrieve_async( |
| 494 | + self, |
| 495 | + query: str, |
| 496 | + config: Optional[Config] = None, |
| 497 | + ) -> Dict[str, Any]: |
| 498 | + """ADB 检索(异步)/ ADB retrieval asynchronously |
| 499 | +
|
| 500 | + 通过 GPDB SDK 调用 QueryContent 接口进行知识库检索。 |
| 501 | + Retrieves from ADB knowledge base via GPDB SDK QueryContent API. |
| 502 | +
|
| 503 | + Args: |
| 504 | + query: 查询文本 / Query text |
| 505 | + config: 配置 / Configuration |
| 506 | +
|
| 507 | + Returns: |
| 508 | + Dict[str, Any]: 检索结果 / Retrieval results |
| 509 | + """ |
| 510 | + try: |
| 511 | + if self.provider_settings is None: |
| 512 | + raise ValueError( |
| 513 | + "provider_settings is required for ADB retrieval" |
| 514 | + ) |
| 515 | + |
| 516 | + # 获取 GPDB 客户端 / Get GPDB client |
| 517 | + client = self._get_gpdb_client(config) |
| 518 | + |
| 519 | + # 构建请求 / Build request |
| 520 | + request = self._build_query_content_request(query, config) |
| 521 | + logger.debug(f"ADB QueryContent request: {request}") |
| 522 | + |
| 523 | + # 调用 QueryContent API / Call QueryContent API |
| 524 | + response = await client.query_content_async(request) |
| 525 | + logger.debug(f"ADB QueryContent response: {response}") |
| 526 | + |
| 527 | + # 解析并返回结果 / Parse and return results |
| 528 | + return self._parse_query_content_response(response, query) |
| 529 | + |
| 530 | + except Exception as e: |
| 531 | + logger.warning( |
| 532 | + "Failed to retrieve from ADB knowledge base " |
| 533 | + f"'{self.knowledge_base_name}': {e}" |
| 534 | + ) |
| 535 | + return { |
| 536 | + "data": f"Failed to retrieve: {e}", |
| 537 | + "query": query, |
| 538 | + "knowledge_base_name": self.knowledge_base_name, |
| 539 | + "error": True, |
| 540 | + } |
| 541 | + |
| 542 | + |
350 | 543 | def get_data_api( |
351 | 544 | provider: KnowledgeBaseProvider, |
352 | 545 | knowledge_base_name: str, |
353 | 546 | config: Optional[Config] = None, |
354 | 547 | provider_settings: Optional[ |
355 | | - Union[RagFlowProviderSettings, BailianProviderSettings] |
| 548 | + Union[ |
| 549 | + RagFlowProviderSettings, |
| 550 | + BailianProviderSettings, |
| 551 | + ADBProviderSettings, |
| 552 | + ] |
356 | 553 | ] = None, |
357 | 554 | retrieve_settings: Optional[ |
358 | | - Union[RagFlowRetrieveSettings, BailianRetrieveSettings] |
| 555 | + Union[ |
| 556 | + RagFlowRetrieveSettings, |
| 557 | + BailianRetrieveSettings, |
| 558 | + ADBRetrieveSettings, |
| 559 | + ] |
359 | 560 | ] = None, |
360 | 561 | credential_name: Optional[str] = None, |
361 | 562 | ) -> KnowledgeBaseDataAPI: |
@@ -410,5 +611,22 @@ def get_data_api( |
410 | 611 | provider_settings=bailian_provider_settings, |
411 | 612 | retrieve_settings=bailian_retrieve_settings, |
412 | 613 | ) |
| 614 | + elif provider == KnowledgeBaseProvider.ADB or provider == "adb": |
| 615 | + adb_provider_settings = ( |
| 616 | + provider_settings |
| 617 | + if isinstance(provider_settings, ADBProviderSettings) |
| 618 | + else None |
| 619 | + ) |
| 620 | + adb_retrieve_settings = ( |
| 621 | + retrieve_settings |
| 622 | + if isinstance(retrieve_settings, ADBRetrieveSettings) |
| 623 | + else None |
| 624 | + ) |
| 625 | + return ADBDataAPI( |
| 626 | + knowledge_base_name, |
| 627 | + config, |
| 628 | + provider_settings=adb_provider_settings, |
| 629 | + retrieve_settings=adb_retrieve_settings, |
| 630 | + ) |
413 | 631 | else: |
414 | 632 | raise ValueError(f"Unsupported provider type: {provider}") |
0 commit comments