Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Tuple, List, Any, Callable

from fitframework import fit_logger
from fitframework.api.decorators import fitable
from llama_index.core.node_parser import (
SentenceSplitter,
TokenTextSplitter,
Expand All @@ -17,11 +18,11 @@
from llama_index.core.schema import Document as LDocument
from llama_index.embeddings.openai import OpenAIEmbedding

from .callable_registers import register_callable_tool
from .node_utils import to_llama_index_document
from .types.semantic_splitter_options import SemanticSplitterOptions


def sentence_splitter(text: str, separator: str, chunk_size: int, chunk_overlap: int, **kwargs) -> List[str]:
@fitable("llama.tools.sentence_splitter", "default")
def sentence_splitter(text: str, separator: str, chunk_size: int, chunk_overlap: int) -> List[str]:
"""Parse text with a preference for complete sentences."""
if len(text) == 0:
return []
Expand All @@ -38,7 +39,8 @@ def sentence_splitter(text: str, separator: str, chunk_size: int, chunk_overlap:
return []


def token_text_splitter(text: str, separator: str, chunk_size: int, chunk_overlap: int, **kwargs) -> List[str]:
@fitable("llama.tools.token_text_splitter", "default")
def token_text_splitter(text: str, separator: str, chunk_size: int, chunk_overlap: int) -> List[str]:
"""Splitting text that looks at word tokens."""
if len(text) == 0:
return []
Expand All @@ -55,14 +57,15 @@ def token_text_splitter(text: str, separator: str, chunk_size: int, chunk_overla
return []


def semantic_splitter(buffer_size: int, breakpoint_percentile_threshold: int, docs: List[LDocument], **kwargs) \
# @fitable("llama.tools.semantic_splitter", "default")
def semantic_splitter(buffer_size: int, breakpoint_percentile_threshold: int, docs: List[LDocument], options: SemanticSplitterOptions) \
-> List[BaseNode]:
"""Splitting text that looks at word tokens."""
if len(docs) == 0:
return []
api_key = kwargs.get("api_key")
model_name = kwargs.get("model_name")
api_base = kwargs.get("api_base")
api_key = options.api_key
model_name = options.model_name
api_base = options.api_base

embed_model = OpenAIEmbedding(model_name=model_name, api_base=api_base, api_key=api_key, max_tokens=4096)

Expand All @@ -80,8 +83,9 @@ def semantic_splitter(buffer_size: int, breakpoint_percentile_threshold: int, do
return []


# @fitable("llama.tools.sentence_window_node_parser", "default")
def sentence_window_node_parser(window_size: int, window_metadata_key: str, original_text_metadata_key: str,
docs: List[LDocument], **kwargs) -> List[BaseNode]:
docs: List[LDocument]) -> List[BaseNode]:
"""Splitting text that looks at word tokens."""
if len(docs) == 0:
return []
Expand All @@ -96,26 +100,4 @@ def sentence_window_node_parser(window_size: int, window_metadata_key: str, orig
except BaseException:
fit_logger.error("Invoke semantic splitter failed.")
traceback.print_exc()
return []


# Tuple 结构: (tool_func, config_args, return_description)
splitter_basic_toolkit: List[Tuple[Callable[..., Any], List[str], str]] = [
(sentence_splitter, ["text", "separator", "chunk_size", "chunk_overlap"], "Split sentences by sentence."),
(token_text_splitter, ["text", "separator", "chunk_size", "chunk_overlap"], "Split sentences by token."),
(semantic_splitter,
["docs", "buffer_size", "breakpoint_percentile_threshold", "chunk_overlap", "model_name", "api_key", "api_base"],
"Split sentences by semantic."),
(sentence_window_node_parser, ["docs", "window_size", "window_metadata_key", "original_text_metadata_key"],
"Splits all documents into individual sentences")
]

for tool in splitter_basic_toolkit:
register_callable_tool(tool, sentence_splitter.__module__, "llama_index.rag.toolkit")

if __name__ == '__main__':
import time
from .llama_schema_helper import dump_llama_schema

current_timestamp = time.strftime('%Y%m%d%H%M%S')
dump_llama_schema(splitter_basic_toolkit, f"./llama_tool_schema-{str(current_timestamp)}.json")
return []
Loading