Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,25 @@
import os
import traceback
from enum import Enum, unique
from typing import List, Callable, Any, Tuple
from typing import List

from fitframework import fit_logger
from fitframework.core.repo.fitable_register import register_fitable
from fitframework import fit_logger, fitable
from llama_index.core.base.base_selector import SingleSelection
from llama_index.core.postprocessor import SimilarityPostprocessor, SentenceEmbeddingOptimizer, LLMRerank, \
LongContextReorder, FixedRecencyPostprocessor
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.prompts import PromptType, PromptTemplate
from llama_index.core.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT_TMPL
from llama_index.core.selectors import LLMSingleSelector, LLMMultiSelector
from llama_index.core.selectors.prompts import DEFAULT_SINGLE_SELECT_PROMPT_TMPL, DEFAULT_MULTI_SELECT_PROMPT_TMPL
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI

from .callable_registers import register_callable_tool
from .node_utils import document_to_query_node, query_node_to_document
from .types.document import Document
from .types.llm_rerank_options import LLMRerankOptions
from .types.embedding_options import EmbeddingOptions
from .types.retriever_options import RetrieverOptions
from .types.llm_choice_selector_options import LLMChoiceSelectorOptions
from .node_utils import document_to_query_node, query_node_to_document

os.environ["no_proxy"] = "*"

Expand All @@ -42,49 +43,50 @@ def __invoke_postprocessor(postprocessor: BaseNodePostprocessor, nodes: List[Doc
return nodes


def similarity_filter(nodes: List[Document], query_str: str, **kwargs) -> List[Document]:
@fitable("llama.tools.similarity_filter", "default")
def similarity_filter(nodes: List[Document], query_str: str, options: RetrieverOptions) -> List[Document]:
"""Remove documents that are below a similarity score threshold."""
similarity_cutoff = float(kwargs.get("similarity_cutoff") or 0.3)
postprocessor = SimilarityPostprocessor(similarity_cutoff=similarity_cutoff)
if options is None:
options = RetrieverOptions()
postprocessor = SimilarityPostprocessor(similarity_cutoff=options.similarity_cutoff)
return __invoke_postprocessor(postprocessor, nodes, query_str)


def sentence_embedding_optimizer(nodes: List[Document], query_str: str, **kwargs) -> List[Document]:
@fitable("llama.tools.sentence_embedding_optimizer", "default")
def sentence_embedding_optimizer(nodes: List[Document], query_str: str, options: EmbeddingOptions) -> List[Document]:
"""Optimization of a text chunk given the query by shortening the input text."""
api_key = kwargs.get("api_key") or "EMPTY"
model_name = kwargs.get("model_name") or "bce-embedding-base_v1"
api_base = kwargs.get("api_base") or ("http://51.36.139.24:8010/v1" if api_key == "EMPTY" else None)
percentile_cutoff = kwargs.get("percentile_cutoff")
threshold_cutoff = kwargs.get("threshold_cutoff")
percentile_cutoff = percentile_cutoff if percentile_cutoff is None else float(percentile_cutoff)
threshold_cutoff = threshold_cutoff if threshold_cutoff is None else float(threshold_cutoff)

embed_model = OpenAIEmbedding(model_name=model_name, api_base=api_base, api_key=api_key)
optimizer = SentenceEmbeddingOptimizer(embed_model=embed_model, percentile_cutoff=percentile_cutoff,
threshold_cutoff=threshold_cutoff)
if options is None:
options = EmbeddingOptions()
api_base = options.api_base
embed_model = OpenAIEmbedding(model_name=options.model_name, api_base=api_base, api_key=options.api_key)
optimizer = SentenceEmbeddingOptimizer(embed_model=embed_model, percentile_cutoff=options.percentile_cutoff,
threshold_cutoff=options.threshold_cutoff)
return __invoke_postprocessor(optimizer, nodes, query_str)


def llm_rerank(nodes: List[Document], query_str: str, **kwargs) -> List[Document]:
@fitable("llama.tools.llm_rerank", "default")
def llm_rerank(nodes: List[Document], query_str: str, options: LLMRerankOptions) -> List[Document]:
"""
Re-order nodes by asking the LLM to return the relevant documents and a score of how relevant they are.
Returns the top N ranked nodes.
"""
api_key = kwargs.get("api_key") or "EMPTY"
model_name = kwargs.get("model_name") or "Qwen1.5-14B-Chat"
api_base = kwargs.get("api_base") or ("http://80.11.128.62:8000/v1" if api_key == "EMPTY" else None)
prompt = kwargs.get("prompt") or DEFAULT_CHOICE_SELECT_PROMPT_TMPL
choice_batch_size = int(kwargs.get("choice_batch_size") or 10)
top_n = int(kwargs.get("top_n") or 10)

llm = OpenAI(model=model_name, api_base=api_base, api_key=api_key, max_tokens=4096)
if options is None:
options = LLMRerankOptions()

api_base = options.api_base

prompt = options.prompt

llm = OpenAI(model=options.model_name, api_base=api_base, api_key=options.api_key)
choice_select_prompt = PromptTemplate(prompt, prompt_type=PromptType.CHOICE_SELECT)
llm_rerank_obj = LLMRerank(llm=llm, choice_select_prompt=choice_select_prompt, choice_batch_size=choice_batch_size,
top_n=top_n)
llm_rerank_obj = LLMRerank(llm=llm, choice_select_prompt=choice_select_prompt,
choice_batch_size=options.choice_batch_size,
top_n=options.top_n)
return __invoke_postprocessor(llm_rerank_obj, nodes, query_str)


def long_context_rerank(nodes: List[Document], query_str: str, **kwargs) -> List[Document]:
@fitable("llama.tools.long_context_rerank", "default")
def long_context_rerank(nodes: List[Document], query_str: str) -> List[Document]:
"""Re-order the retrieved nodes, which can be helpful in cases where a large top-k is needed."""
return __invoke_postprocessor(LongContextReorder(), nodes, query_str)

Expand All @@ -95,24 +97,23 @@ class SelectorMode(Enum):
MULTI = "multi"


def llm_choice_selector(choice: List[str], query_str: str, **kwargs) -> List[SingleSelection]:
@fitable("llama.tools.llm_choice_selector", "default")
def llm_choice_selector(choice: List[str], query_str: str, options: LLMChoiceSelectorOptions) -> List[SingleSelection]:
"""LLM-based selector that chooses one or multiple out of many options."""
if len(choice) == 0:
return []
api_key = kwargs.get("api_key") or "EMPTY"
model_name = kwargs.get("model_name") or "Qwen1.5-14B-Chat"
api_base = kwargs.get("api_base") or ("http://80.11.128.62:8000/v1" if api_key == "EMPTY" else None)
prompt = kwargs.get("prompt")
mode = str(kwargs.get("mode") or SelectorMode.SINGLE.value)
if mode.lower() not in [m.value for m in SelectorMode]:
raise ValueError(f"Invalid mode {mode}.")

llm = OpenAI(model=model_name, api_base=api_base, api_key=api_key, max_tokens=4096)
if mode.lower() == SelectorMode.SINGLE.value:
selector_prompt = prompt or DEFAULT_SINGLE_SELECT_PROMPT_TMPL
if options is None:
options = LLMChoiceSelectorOptions()
api_base = options.api_base
if options.mode.lower() not in [m.value for m in SelectorMode]:
raise ValueError(f"Invalid mode {options.mode}.")

llm = OpenAI(model=options.model_name, api_base=api_base, api_key=options.api_key, max_tokens=4096)
if options.mode.lower() == SelectorMode.SINGLE.value:
selector_prompt = options.prompt or DEFAULT_SINGLE_SELECT_PROMPT_TMPL
selector = LLMSingleSelector.from_defaults(llm=llm, prompt_template_str=selector_prompt)
else:
multi_selector_prompt = prompt or DEFAULT_MULTI_SELECT_PROMPT_TMPL
multi_selector_prompt = options.prompt or DEFAULT_MULTI_SELECT_PROMPT_TMPL
selector = LLMMultiSelector.from_defaults(llm=llm, prompt_template_str=multi_selector_prompt)
try:
return selector.select(choice, query_str).selections
Expand All @@ -122,34 +123,10 @@ def llm_choice_selector(choice: List[str], query_str: str, **kwargs) -> List[Sin
return []


def fixed_recency(nodes: List[Document], tok_k: int, date_key: str, query_str: str, **kwargs) -> List[Document]:
@fitable("llama.tools.fixed_recency", "default")
def fixed_recency(nodes: List[Document], top_k: int, date_key: str, query_str: str) -> List[Document]:
"""This postprocessor returns the top K nodes sorted by date"""
postprocessor = FixedRecencyPostprocessor(
tok_k=tok_k, date_key=date_key if date_key else "date"
top_k=top_k, date_key=date_key if date_key else "date"
)
return __invoke_postprocessor(postprocessor, nodes, query_str)


# Tuple 结构: (tool_func, config_args, return_description)
rag_basic_toolkit: List[Tuple[Callable[..., Any], List[str], str]] = [
(similarity_filter, ["similarity_cutoff"], "The filtered documents."),
(sentence_embedding_optimizer, ["model_name", "api_key", "api_base", "percentile_cutoff", "threshold_cutoff"],
"The optimized documents."),
(llm_rerank, ["model_name", "api_key", "api_base", "prompt", "choice_batch_size", "top_n"],
"The re-ordered documents."),
(long_context_rerank, [], "The re-ordered documents."),
(llm_choice_selector, ["model_name", "api_key", "api_base", "prompt", "mode"], "The selected choice."),
(fixed_recency, ["nodes", "tok_k", "date_key", "query_str"], "The fixed recency postprocessor")
]


for tool in rag_basic_toolkit:
register_callable_tool(tool, llm_choice_selector.__module__, "llama_index.rag.toolkit")


if __name__ == '__main__':
import time
from .llama_schema_helper import dump_llama_schema

current_timestamp = time.strftime('%Y%m%d%H%M%S')
dump_llama_schema(rag_basic_toolkit, f"./llama_tool_schema-{str(current_timestamp)}.json")
Loading