diff --git a/baselines/BDS/bds.py b/baselines/BDS/bds.py index ea8bf45f..5fe82f2c 100644 --- a/baselines/BDS/bds.py +++ b/baselines/BDS/bds.py @@ -9,7 +9,7 @@ from graphgen.bases import BaseLLMWrapper from graphgen.common import init_llm -from graphgen.models import NetworkXStorage +from graphgen.storage import NetworkXStorage from graphgen.utils import create_event_loop QA_GENERATION_PROMPT = """ diff --git a/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml index d86d01b1..cfe0138c 100644 --- a/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml +++ b/examples/evaluate/evaluate_kg/kg_evaluation_config.yaml @@ -10,7 +10,7 @@ nodes: dependencies: [] params: input_path: - - examples/input_examples/extract_demo.txt + - examples/input_examples/jsonl_demo.jsonl - id: chunk op_name: chunk @@ -39,7 +39,6 @@ nodes: dependencies: - build_kg params: + target: kg metrics: - - kg_structure - - kg_accuracy - - kg_consistency + - structure diff --git a/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml b/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml index 3e875143..e25b902f 100644 --- a/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml +++ b/examples/evaluate/evaluate_qa/qa_evaluation_config.yaml @@ -1,7 +1,7 @@ global_params: working_dir: cache - graph_backend: kuzu # graph database backend, support: kuzu, networkx - kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv + graph_backend: networkx # graph database backend, support: kuzu, networkx + kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv nodes: - id: read_files # id is unique in the pipeline, and can be referenced by other steps @@ -89,10 +89,11 @@ nodes: batch_size: 128 save_output: true params: + target: qa metrics: - - qa_length - - qa_mtld - # - qa_reward_score - # - qa_uni_score + - length + - mtld + # - reward_score + # - uni_score mtld_params: threshold: 0.7 diff --git a/examples/evaluate/evaluate_triple/evaluate_triple.sh b/examples/evaluate/evaluate_triple/evaluate_triple.sh new file mode 100644 index 00000000..14b6a5ce --- /dev/null +++ b/examples/evaluate/evaluate_triple/evaluate_triple.sh @@ -0,0 +1,2 @@ +python3 -m graphgen.run \ +--config_file examples/evaluate/evaluate_triple/triple_evaluation_config.yaml \ No newline at end of file diff --git a/examples/evaluate/evaluate_triple/triple_evaluation_config.yaml b/examples/evaluate/evaluate_triple/triple_evaluation_config.yaml new file mode 100644 index 00000000..f9cdbc4a --- /dev/null +++ b/examples/evaluate/evaluate_triple/triple_evaluation_config.yaml @@ -0,0 +1,46 @@ +global_params: + working_dir: cache + graph_backend: networkx # graph database backend, support: kuzu, networkx + kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv + +nodes: + - id: read + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/jsonl_demo.jsonl + + - id: chunk + op_name: chunk + type: map_batch + dependencies: + - read + execution_params: + replicas: 4 + params: + chunk_size: 20480 # larger chunk size for better context + chunk_overlap: 2000 + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk + execution_params: + replicas: 1 + batch_size: 128 + + - id: evaluate + op_name: evaluate + type: aggregate + save_output: true + dependencies: + - build_kg + params: + target: triple + src_namespace: chunk + tgt_namespace: build_kg + metrics: + - accuracy diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py index 0727b3fa..ab143b44 100644 --- a/graphgen/bases/__init__.py +++ b/graphgen/bases/__init__.py @@ -1,3 +1,4 @@ +from .base_evaluator import BaseKGEvaluator, BaseQAEvaluator, BaseTripleEvaluator from .base_extractor import BaseExtractor from .base_generator import BaseGenerator from .base_kg_builder import BaseKGBuilder @@ -9,5 +10,4 @@ from .base_splitter import BaseSplitter from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace from .base_tokenizer import BaseTokenizer -from .base_evaluator import BaseEvaluator from .datatypes import Chunk, Config, Node, QAPair, Token diff --git a/graphgen/bases/base_evaluator.py b/graphgen/bases/base_evaluator.py index 3cc5df18..5164d3f3 100644 --- a/graphgen/bases/base_evaluator.py +++ b/graphgen/bases/base_evaluator.py @@ -1,10 +1,29 @@ from abc import ABC, abstractmethod +from typing import Any + +from .base_storage import BaseGraphStorage from .datatypes import QAPair -class BaseEvaluator(ABC): +class BaseQAEvaluator(ABC): @abstractmethod - def evaluate(self, pair: QAPair) -> float: + async def evaluate(self, pair: QAPair) -> dict[str, float]: """ Evaluate the text and return a score. """ + + +class BaseKGEvaluator(ABC): + @abstractmethod + def evaluate(self, kg: BaseGraphStorage) -> dict[str, Any]: + """ + Evaluate the whole graph and return a dict of scores. + """ + + +class BaseTripleEvaluator(ABC): + @abstractmethod + async def evaluate(self, unit: dict) -> dict[str, float]: + """ + Evaluate a node/edge and return a score. + """ diff --git a/graphgen/bases/base_generator.py b/graphgen/bases/base_generator.py index b0186167..eb204535 100644 --- a/graphgen/bases/base_generator.py +++ b/graphgen/bases/base_generator.py @@ -21,7 +21,7 @@ def build_prompt( @staticmethod @abstractmethod - def parse_response(response: str) -> Any: + def parse_response(response: str) -> list[dict]: """Parse the LLM response and return the generated QAs""" async def generate( @@ -29,64 +29,49 @@ async def generate( batch: tuple[ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] ], - ) -> dict[str, Any]: + ) -> list[dict]: """ Generate QAs based on a given batch. :param batch :return: QA pairs """ - result = {} prompt = self.build_prompt(batch) response = await self.llm_client.generate_answer(prompt) qa_pairs = self.parse_response(response) # generate one or more QA pairs - result.update(qa_pairs) - return result + return qa_pairs @staticmethod def format_generation_results( - results: list[dict], output_data_format: str - ) -> list[dict[str, Any]]: + result: dict, output_data_format: str + ) -> dict[str, Any]: + question = result.get("question", "") + answer = result.get("answer", "") + if "options" in result and result["options"]: + options = result["options"] + options_str = "\n".join( + [f"{key}. {options[key]}" for key in sorted(options.keys())] + ) + question += f"\nOptions:\n{options_str}" - flat_results = [] - for item in results: - for _, qa_data in item.items(): - question = qa_data.get("question", "") - answer = qa_data.get("answer", "") - if "options" in qa_data and qa_data["options"]: - options = qa_data["options"] - options_str = "\n".join( - [f"{key}. {options[key]}" for key in sorted(options.keys())] - ) - question += f"\nOptions:\n{options_str}" + if output_data_format == "Alpaca": + return { + "instruction": question, + "input": "", + "output": answer, + } - if output_data_format == "Alpaca": - flat_results.append( - { - "instruction": question, - "input": "", - "output": answer, - } - ) - elif output_data_format == "Sharegpt": - flat_results.append( - { - "conversations": [ - {"from": "human", "value": question}, - {"from": "gpt", "value": answer}, - ] - } - ) - elif output_data_format == "ChatML": - flat_results.append( - { - "messages": [ - {"role": "user", "content": question}, - {"role": "assistant", "content": answer}, - ] - } - ) - else: - raise ValueError( - f"Unknown output data format: {output_data_format}" - ) - return flat_results + if output_data_format == "Sharegpt": + return { + "conversations": [ + {"from": "human", "value": question}, + {"from": "gpt", "value": answer}, + ] + } + if output_data_format == "ChatML": + return { + "messages": [ + {"role": "user", "content": question}, + {"role": "assistant", "content": answer}, + ] + } + raise ValueError(f"Unknown output data format: {output_data_format}") diff --git a/graphgen/bases/base_operator.py b/graphgen/bases/base_operator.py index be4c737e..70e19ce1 100644 --- a/graphgen/bases/base_operator.py +++ b/graphgen/bases/base_operator.py @@ -1,19 +1,43 @@ import inspect import os from abc import ABC, abstractmethod -from typing import Iterable, Union +from typing import Iterable, Tuple, Union +import numpy as np import pandas as pd import ray +def convert_to_serializable(obj): + if isinstance(obj, np.ndarray): + return obj.tolist() + if isinstance(obj, np.generic): + return obj.item() + if isinstance(obj, dict): + return {k: convert_to_serializable(v) for k, v in obj.items()} + if isinstance(obj, list): + return [convert_to_serializable(v) for v in obj] + return obj + + class BaseOperator(ABC): - def __init__(self, working_dir: str = "cache", op_name: str = None): + def __init__( + self, + working_dir: str = "cache", + kv_backend: str = "rocksdb", + op_name: str = None, + ): # lazy import to avoid circular import + from graphgen.common import init_storage from graphgen.utils import set_logger log_dir = os.path.join(working_dir, "logs") self.op_name = op_name or self.__class__.__name__ + self.working_dir = working_dir + self.kv_backend = kv_backend + self.kv_storage = init_storage( + backend=kv_backend, working_dir=working_dir, namespace=self.op_name + ) try: ctx = ray.get_runtime_context() @@ -45,17 +69,94 @@ def __call__( logger_token = CURRENT_LOGGER_VAR.set(self.logger) try: - result = self.process(batch) + self.kv_storage.reload() + to_process, recovered = self.split(batch) + # yield recovered chunks first + if not recovered.empty: + yield recovered + + if to_process.empty: + return + + data = to_process.to_dict(orient="records") + result, meta_update = self.process(data) if inspect.isgenerator(result): - yield from result + is_first = True + for res in result: + yield pd.DataFrame([res]) + self.store([res], meta_update if is_first else {}) + is_first = False else: - yield result + yield pd.DataFrame(result) + self.store(result, meta_update) finally: CURRENT_LOGGER_VAR.reset(logger_token) - @abstractmethod - def process(self, batch): - raise NotImplementedError("Subclasses must implement the process method.") - def get_logger(self): return self.logger + + def get_meta_forward(self): + return self.kv_storage.get_by_id("_meta_forward") or {} + + def get_meta_inverse(self): + return self.kv_storage.get_by_id("_meta_inverse") or {} + + def get_trace_id(self, content: dict) -> str: + from graphgen.utils import compute_dict_hash + + return compute_dict_hash(content, prefix=f"{self.op_name}-") + + def split(self, batch: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + Split the input batch into to_process & processed based on _meta data in KV_storage + :param batch + :return: + to_process: DataFrame of documents to be chunked + recovered: Result DataFrame of already chunked documents + """ + meta_forward = self.get_meta_forward() + meta_ids = set(meta_forward.keys()) + mask = batch["_trace_id"].isin(meta_ids) + to_process = batch[~mask] + processed = batch[mask] + + if processed.empty: + return to_process, pd.DataFrame() + + all_ids = [ + pid for tid in processed["_trace_id"] for pid in meta_forward.get(tid, []) + ] + + recovered_chunks = self.kv_storage.get_by_ids(all_ids) + recovered_chunks = [c for c in recovered_chunks if c is not None] + return to_process, pd.DataFrame(recovered_chunks) + + def store(self, results: list, meta_update: dict): + results = convert_to_serializable(results) + meta_update = convert_to_serializable(meta_update) + + batch = {res["_trace_id"]: res for res in results} + self.kv_storage.upsert(batch) + + # update forward meta + forward_meta = self.get_meta_forward() + forward_meta.update(meta_update) + self.kv_storage.update({"_meta_forward": forward_meta}) + + # update inverse meta + inverse_meta = self.get_meta_inverse() + for k, v_list in meta_update.items(): + for v in v_list: + inverse_meta[v] = k + self.kv_storage.update({"_meta_inverse": inverse_meta}) + self.kv_storage.index_done_callback() + + @abstractmethod + def process(self, batch: list) -> Tuple[Union[list, Iterable[dict]], dict]: + """ + Process the input batch and return the result. + :param batch + :return: + result: DataFrame of processed documents + meta_update: dict of meta data to be updated + """ diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index be50c2c7..6d91309b 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -39,6 +39,12 @@ def filter_keys(self, data: list[str]) -> set[str]: def upsert(self, data: dict[str, T]): raise NotImplementedError + def update(self, data: dict[str, T]): + raise NotImplementedError + + def delete(self, ids: list[str]): + raise NotImplementedError + def drop(self): raise NotImplementedError diff --git a/graphgen/bases/datatypes.py b/graphgen/bases/datatypes.py index 01d3f963..a7fc471f 100644 --- a/graphgen/bases/datatypes.py +++ b/graphgen/bases/datatypes.py @@ -31,6 +31,13 @@ class QAPair: question: str answer: str + @staticmethod + def from_dict(data: dict) -> "QAPair": + return QAPair( + question=data.get("question", ""), + answer=data.get("answer", ""), + ) + @dataclass class Token: diff --git a/graphgen/common/init_storage.py b/graphgen/common/init_storage.py index 226aeb5c..3e32371f 100644 --- a/graphgen/common/init_storage.py +++ b/graphgen/common/init_storage.py @@ -8,11 +8,11 @@ class KVStorageActor: def __init__(self, backend: str, working_dir: str, namespace: str): if backend == "json_kv": - from graphgen.models import JsonKVStorage + from graphgen.storage import JsonKVStorage self.kv = JsonKVStorage(working_dir, namespace) elif backend == "rocksdb": - from graphgen.models import RocksDBKVStorage + from graphgen.storage import RocksDBKVStorage self.kv = RocksDBKVStorage(working_dir, namespace) else: @@ -42,6 +42,12 @@ def filter_keys(self, data: list[str]) -> set[str]: def upsert(self, data: dict) -> dict: return self.kv.upsert(data) + def update(self, data: dict): + return self.kv.update(data) + + def delete(self, ids: list[str]): + return self.kv.delete(ids) + def drop(self): return self.kv.drop() @@ -55,11 +61,11 @@ def ready(self) -> bool: class GraphStorageActor: def __init__(self, backend: str, working_dir: str, namespace: str): if backend == "networkx": - from graphgen.models import NetworkXStorage + from graphgen.storage import NetworkXStorage self.graph = NetworkXStorage(working_dir, namespace) elif backend == "kuzu": - from graphgen.models import KuzuStorage + from graphgen.storage import KuzuStorage self.graph = KuzuStorage(working_dir, namespace) else: @@ -168,6 +174,12 @@ def filter_keys(self, data: list[str]) -> set[str]: def upsert(self, data: Dict[str, Any]): return ray.get(self.actor.upsert.remote(data)) + def update(self, data: Dict[str, Any]): + return ray.get(self.actor.update.remote(data)) + + def delete(self, ids: list[str]): + return ray.get(self.actor.delete.remote(ids)) + def drop(self): return ray.get(self.actor.drop.remote()) diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 94fa1b30..bb708c15 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,6 +1,5 @@ from .evaluator import ( AccuracyEvaluator, - ConsistencyEvaluator, LengthEvaluator, MTLDEvaluator, RewardEvaluator, @@ -44,11 +43,4 @@ from .searcher.web.bing_search import BingSearch from .searcher.web.google_search import GoogleSearch from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter -from .storage import ( - JsonKVStorage, - KuzuStorage, - NetworkXStorage, - RocksDBCache, - RocksDBKVStorage, -) from .tokenizer import Tokenizer diff --git a/graphgen/models/evaluator/__init__.py b/graphgen/models/evaluator/__init__.py index 6091aeb5..4b7f97d0 100644 --- a/graphgen/models/evaluator/__init__.py +++ b/graphgen/models/evaluator/__init__.py @@ -1,2 +1,3 @@ -from .kg import AccuracyEvaluator, ConsistencyEvaluator, StructureEvaluator +from .kg import StructureEvaluator from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator +from .triple import AccuracyEvaluator diff --git a/graphgen/models/evaluator/kg/README.md b/graphgen/models/evaluator/kg/README.md deleted file mode 100644 index 10e26f6b..00000000 --- a/graphgen/models/evaluator/kg/README.md +++ /dev/null @@ -1,237 +0,0 @@ -# KG Quality Evaluation Module - -This module provides comprehensive quality evaluation for knowledge graphs built by GraphGen. - -## Module Structure - -The evaluation functionality is organized into modular components: - -- **`accuracy_evaluator.py`**: Entity/relation extraction quality evaluation using LLM-as-a-Judge -- **`consistency_evaluator.py`**: Attribute value conflict detection -- **`structure_evaluator.py`**: Graph structural robustness metrics - -The evaluation components are integrated in `graphgen/operators/evaluate/evaluate_kg.py`, which provides functions to create and use these evaluators. - -## Features - -### 1. Accuracy Assessment -- **Entity Extraction Quality**: Uses LLM-as-a-Judge to evaluate the quality of entity extraction from chunks - - Evaluates accuracy (correctness of extracted entities) - - Evaluates completeness (whether important entities are missed) - - Evaluates precision (naming accuracy and specificity) -- **Relation Extraction Quality**: Uses LLM-as-a-Judge to evaluate the quality of relation extraction from chunks - - Evaluates accuracy (correctness of extracted relations) - - Evaluates completeness (whether important relations are missed) - - Evaluates precision (relation description accuracy) -- Provides multi-dimensional quality scores (0-1 scale) with detailed reasoning for each chunk - -### 2. Consistency Assessment -- **Semantic Conflict Detection**: Uses LLM-as-a-Judge to detect semantic conflicts in entity attributes - - **Entity Type Conflicts**: Detects when the same entity is extracted with different types across chunks - - **Entity Description Conflicts**: Detects when entity descriptions from different chunks are semantically inconsistent - - **Relation Conflicts**: Detects when the same entity pair has conflicting relation descriptions -- Only evaluates entities with multiple source chunks (entities appearing in multiple chunks) -- Uses LLM to extract entity attributes from each chunk and compare them semantically -- Calculates conflict rate: `conflict_entities_count / total_entities` -- Returns detailed conflict information including conflict severity and reasoning - -### 3. Structural Robustness Assessment -- **Noise Ratio**: Isolated nodes / total nodes (threshold: < 15%) -- **Largest Connected Component Ratio**: Largest CC nodes / total nodes (threshold: > 90%) -- **Average Node Degree**: Average degree across all nodes (threshold: 2-5) -- **Power Law Distribution R²**: Degree distribution fit (threshold: > 0.75) - -## Usage - -### Command Line Usage - -```bash -# Run all evaluations -python -m graphgen.operators.evaluate_kg.evaluate_kg --working_dir cache - -# Run specific evaluation -python -m graphgen.operators.evaluate_kg.evaluate_kg --working_dir cache --accuracy_only - -# Specify backends -python -m graphgen.operators.evaluate_kg.evaluate_kg \ - --working_dir cache \ - --graph_backend networkx \ - --kv_backend json_kv -``` - -### Shell Script Usage - -```bash -# Basic usage -bash examples/evaluate_kg/evaluate_kg.sh - -# With custom options -bash examples/evaluate_kg/evaluate_kg.sh \ - --working_dir cache \ - --accuracy_only -``` - -## Configuration - -All evaluation thresholds use default values defined in the evaluator classes: - -- **Structure thresholds**: Defined in `StructureEvaluator` with defaults: - - `noise_ratio_threshold`: 0.15 - - `largest_cc_ratio_threshold`: 0.90 - - `avg_degree_min`: 2.0 - - `avg_degree_max`: 5.0 - - `powerlaw_r2_threshold`: 0.75 - -**Note**: Accuracy evaluation automatically loads chunks from the chunk storage and evaluates the quality of entity/relation extraction using LLM-as-a-Judge. No configuration file is needed. - -## Requirements - -- **NetworkX**: Required for structural evaluation -- **scipy**: Required for power law distribution fitting -- **numpy**: Required for numerical calculations -- **LLM Client**: Required for accuracy evaluation (configure via `TRAINEE_*` env vars) - -## Output Format - -The evaluation returns a dictionary with the following structure: - -```python -{ - "accuracy": { - "entity_accuracy": { - "overall_score": { - "mean": float, - "median": float, - "min": float, - "max": float, - "std": float - }, - "accuracy": { - "mean": float, - "median": float, - "min": float, - "max": float, - "std": float - }, - "completeness": { - "mean": float, - "median": float, - "min": float, - "max": float, - "std": float - }, - "precision": { - "mean": float, - "median": float, - "min": float, - "max": float, - "std": float - }, - "total_chunks": int, - "detailed_results": [ - { - "chunk_id": str, - "chunk_content": str, - "extracted_entities_count": int, - "accuracy": float, - "completeness": float, - "precision": float, - "overall_score": float, - "accuracy_reasoning": str, - "completeness_reasoning": str, - "precision_reasoning": str, - "issues": [str] - }, - ... - ] - }, - "relation_accuracy": { - "overall_score": { - "mean": float, - "median": float, - "min": float, - "max": float, - "std": float - }, - "accuracy": { - "mean": float, - "median": float, - "min": float, - "max": float, - "std": float - }, - "completeness": { - "mean": float, - "median": float, - "min": float, - "max": float, - "std": float - }, - "precision": { - "mean": float, - "median": float, - "min": float, - "max": float, - "std": float - }, - "total_chunks": int, - "detailed_results": [ - { - "chunk_id": str, - "chunk_content": str, - "extracted_relations_count": int, - "accuracy": float, - "completeness": float, - "precision": float, - "overall_score": float, - "accuracy_reasoning": str, - "completeness_reasoning": str, - "precision_reasoning": str, - "issues": [str] - }, - ... - ] - } - }, - "consistency": { - "conflict_rate": float, - "conflict_entities_count": int, - "total_entities": int, - "entities_checked": int, - "conflicts": [ - { - "entity_id": str, - "conflict_type": str, # "entity_type" or "description" - "conflict_severity": float, # 0-1, severity of the conflict - "conflict_reasoning": str, - "conflicting_values": [str], - "recommended_value": str, # for entity_type conflicts - "conflict_details": str # for description conflicts - }, - ... - ] - }, - "structure": { - "total_nodes": int, - "total_edges": int, - "noise_ratio": float, - "largest_cc_ratio": float, - "avg_degree": float, - "powerlaw_r2": float | None, - "thresholds": { - "noise_ratio": { "value": float, "threshold": float, "pass": bool }, - ... - } - } -} -``` - -## Notes - -- Accuracy evaluation uses LLM-as-a-Judge to evaluate extraction quality from chunks -- Accuracy evaluation automatically loads chunks from chunk storage (no need for source_text_paths) -- The evaluator associates extracted entities/relations with their source chunks using the `source_id` field -- Structural evaluation automatically converts Kuzu storage to NetworkX for analysis -- All evaluations include error handling and will return error messages if something fails -- The evaluator automatically loads graph and chunk storage from the working directory -- LLM evaluation may take time for large numbers of chunks (controlled by `max_concurrent` parameter) diff --git a/graphgen/models/evaluator/kg/__init__.py b/graphgen/models/evaluator/kg/__init__.py index 375cbc50..bbb9aa73 100644 --- a/graphgen/models/evaluator/kg/__init__.py +++ b/graphgen/models/evaluator/kg/__init__.py @@ -1,18 +1 @@ -""" -Knowledge Graph Quality Evaluator - -This module provides comprehensive quality evaluation for knowledge graphs, -1. accuracy assessment (entity/relation/triple validation), -2. consistency assessment (attribute conflict detection), and structural -3. robustness assessment (noise ratio, connectivity, degree distribution). -""" - -from .accuracy_evaluator import AccuracyEvaluator -from .consistency_evaluator import ConsistencyEvaluator from .structure_evaluator import StructureEvaluator - -__all__ = [ - "AccuracyEvaluator", - "ConsistencyEvaluator", - "StructureEvaluator", -] diff --git a/graphgen/models/evaluator/kg/accuracy_evaluator.py b/graphgen/models/evaluator/kg/accuracy_evaluator.py deleted file mode 100644 index 9663b6f8..00000000 --- a/graphgen/models/evaluator/kg/accuracy_evaluator.py +++ /dev/null @@ -1,350 +0,0 @@ -import asyncio -import json -import re -from typing import Any, Dict, List - -from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper -from graphgen.bases.datatypes import Chunk -from graphgen.templates import ACCURACY_EVALUATION_PROMPT -from graphgen.utils import detect_main_language, logger - - -class AccuracyEvaluator: - """Evaluates accuracy of entity recognition and relation extraction using LLM-as-a-Judge. - - For each chunk, uses LLM to evaluate the quality of extracted entities and relations - by comparing them with the original chunk content. Provides multi-dimensional quality - scores (accuracy, completeness, precision). - """ - - def __init__( - self, - graph_storage: BaseGraphStorage, - chunk_storage: BaseKVStorage, - llm_client: BaseLLMWrapper, - ): - self.graph_storage = graph_storage - self.chunk_storage = chunk_storage - self.llm_client = llm_client - - def evaluate(self) -> Dict[str, Any]: - """Evaluate entity and relation extraction quality using LLM-as-a-Judge. - - Returns: - Dictionary containing entity_accuracy and relation_accuracy metrics. - """ - # 1. Load all chunks from storage - chunks = self._load_chunks_from_storage() - - if not chunks: - logger.warning("No chunks found in storage") - return {"error": "No chunks found in storage"} - - logger.info(f"Found {len(chunks)} chunks to evaluate") - - # 2. Evaluate each chunk - entity_evaluations, relation_evaluations = self._evaluate_all_chunks(chunks) - - # 3. Aggregate results - return self._aggregate_evaluation_results( - entity_evaluations, relation_evaluations - ) - - def _load_chunks_from_storage(self) -> List[Chunk]: - """Load all chunks from chunk storage.""" - chunks = [] - all_chunk_data = self.chunk_storage.get_all() - - for chunk_id, chunk_data in all_chunk_data.items(): - try: - chunk = Chunk.from_dict(chunk_id, chunk_data) - chunks.append(chunk) - except Exception as e: - logger.warning(f"Failed to load chunk {chunk_id}: {e}") - continue - - return chunks - - def _get_extracted_entities_for_chunk(self, chunk_id: str) -> List[Dict]: - """Get all entities extracted from the specified chunk.""" - entities = [] - all_nodes = self.graph_storage.get_all_nodes() or [] - - for node_id, node_data in all_nodes: - if not isinstance(node_data, dict): - continue - source_ids = node_data.get("source_id", "").split("") - # Check if this chunk_id is in the source_ids - if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]: - entities.append( - { - "entity_name": node_data.get("entity_name", node_id), - "entity_type": node_data.get("entity_type", ""), - "description": node_data.get("description", ""), - } - ) - - return entities - - def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]: - """Get all relations extracted from the specified chunk.""" - relations = [] - all_edges = self.graph_storage.get_all_edges() or [] - - for src_id, dst_id, edge_data in all_edges: - if not isinstance(edge_data, dict): - continue - source_ids = edge_data.get("source_id", "").split("") - # Check if this chunk_id is in the source_ids - if chunk_id in [sid.strip() for sid in source_ids if sid.strip()]: - src_node = self.graph_storage.get_node(src_id) or {} - dst_node = self.graph_storage.get_node(dst_id) or {} - relations.append( - { - "source_entity": src_node.get("entity_name", src_id), - "target_entity": dst_node.get("entity_name", dst_id), - "relationship_summary": edge_data.get("description", ""), - } - ) - - return relations - - def _evaluate_all_chunks( - self, chunks: List[Chunk] - ) -> tuple[List[Dict], List[Dict]]: - """Evaluate all chunks sequentially.""" - entity_evaluations = [] - relation_evaluations = [] - - for chunk in chunks: - try: - entities = self._get_extracted_entities_for_chunk(chunk.id) - relations = self._get_extracted_relations_for_chunk(chunk.id) - - entity_eval = self._evaluate_entity_extraction(chunk, entities) - relation_eval = self._evaluate_relation_extraction(chunk, relations) - - entity_evaluations.append(entity_eval) - relation_evaluations.append(relation_eval) - except Exception as e: - logger.error(f"Failed to evaluate chunk {chunk.id}: {e}") - continue - - return entity_evaluations, relation_evaluations - - def _evaluate_entity_extraction( - self, chunk: Chunk, extracted_entities: List[Dict] - ) -> Dict[str, Any]: - """Use LLM to evaluate entity extraction quality.""" - try: - lang = detect_main_language(chunk.content) - - prompt = ACCURACY_EVALUATION_PROMPT[lang]["ENTITY"].format( - chunk_content=chunk.content, - extracted_entities=json.dumps( - extracted_entities, ensure_ascii=False, indent=2 - ), - ) - - response = asyncio.run(self.llm_client.generate_answer(prompt)) - - # Try to parse JSON response - try: - evaluation_result = json.loads(response) - except json.JSONDecodeError: - # Try to extract JSON from markdown code blocks or other formats - json_match = re.search(r"\{.*\}", response, re.DOTALL) - if json_match: - evaluation_result = json.loads(json_match.group(0)) - else: - logger.warning( - f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}" - ) - # Return default evaluation - evaluation_result = { - "accuracy": 0.0, - "completeness": 0.0, - "precision": 0.0, - "overall_score": 0.0, - "accuracy_reasoning": "Failed to parse LLM response", - "completeness_reasoning": "", - "precision_reasoning": "", - "issues": ["LLM response parsing failed"], - } - - # Validate and calculate overall_score if not provided - if "overall_score" not in evaluation_result: - accuracy = float(evaluation_result.get("accuracy", 0.0)) - completeness = float(evaluation_result.get("completeness", 0.0)) - precision = float(evaluation_result.get("precision", 0.0)) - evaluation_result["overall_score"] = ( - 0.4 * accuracy + 0.4 * completeness + 0.2 * precision - ) - - return { - "chunk_id": chunk.id, - "chunk_content": chunk.content[:200] - if chunk.content - else "", # First 200 chars for debugging - "extracted_entities_count": len(extracted_entities), - **evaluation_result, - } - except Exception as e: - logger.error( - f"Error evaluating entity extraction for chunk {chunk.id}: {e}" - ) - return { - "chunk_id": chunk.id, - "chunk_content": chunk.content[:200] if chunk.content else "", - "extracted_entities_count": len(extracted_entities), - "accuracy": 0.0, - "completeness": 0.0, - "precision": 0.0, - "overall_score": 0.0, - "accuracy_reasoning": f"Evaluation failed: {str(e)}", - "completeness_reasoning": "", - "precision_reasoning": "", - "issues": [f"Evaluation error: {str(e)}"], - } - - def _evaluate_relation_extraction( - self, chunk: Chunk, extracted_relations: List[Dict] - ) -> Dict[str, Any]: - """Use LLM to evaluate relation extraction quality.""" - try: - lang = detect_main_language(chunk.content) - prompt = ACCURACY_EVALUATION_PROMPT[lang]["RELATION"].format( - chunk_content=chunk.content, - extracted_relations=json.dumps( - extracted_relations, ensure_ascii=False, indent=2 - ), - ) - - response = asyncio.run(self.llm_client.generate_answer(prompt)) - - # Try to parse JSON response - try: - evaluation_result = json.loads(response) - except json.JSONDecodeError: - # Try to extract JSON from markdown code blocks or other formats - json_match = re.search(r"\{.*\}", response, re.DOTALL) - if json_match: - evaluation_result = json.loads(json_match.group(0)) - else: - logger.warning( - f"Failed to parse LLM response for chunk {chunk.id}: {response[:200]}" - ) - # Return default evaluation - evaluation_result = { - "accuracy": 0.0, - "completeness": 0.0, - "precision": 0.0, - "overall_score": 0.0, - "accuracy_reasoning": "Failed to parse LLM response", - "completeness_reasoning": "", - "precision_reasoning": "", - "issues": ["LLM response parsing failed"], - } - - # Validate and calculate overall_score if not provided - if "overall_score" not in evaluation_result: - accuracy = float(evaluation_result.get("accuracy", 0.0)) - completeness = float(evaluation_result.get("completeness", 0.0)) - precision = float(evaluation_result.get("precision", 0.0)) - evaluation_result["overall_score"] = ( - 0.4 * accuracy + 0.4 * completeness + 0.2 * precision - ) - - return { - "chunk_id": chunk.id, - "chunk_content": chunk.content[:200] if chunk.content else "", - "extracted_relations_count": len(extracted_relations), - **evaluation_result, - } - except Exception as e: - logger.error( - f"Error evaluating relation extraction for chunk {chunk.id}: {e}" - ) - return { - "chunk_id": chunk.id, - "chunk_content": chunk.content[:200] if chunk.content else "", - "extracted_relations_count": len(extracted_relations), - "accuracy": 0.0, - "completeness": 0.0, - "precision": 0.0, - "overall_score": 0.0, - "accuracy_reasoning": f"Evaluation failed: {str(e)}", - "completeness_reasoning": "", - "precision_reasoning": "", - "issues": [f"Evaluation error: {str(e)}"], - } - - @staticmethod - def _aggregate_evaluation_results( - entity_evaluations: List[Dict], relation_evaluations: List[Dict] - ) -> Dict[str, Any]: - """Aggregate evaluation results from all chunks.""" - - def calculate_stats(scores: List[float]) -> Dict[str, float]: - if not scores: - return {"mean": 0.0, "median": 0.0, "min": 0.0, "max": 0.0, "std": 0.0} - sorted_scores = sorted(scores) - n = len(scores) - mean = sum(scores) / n - median = ( - sorted_scores[n // 2] - if n % 2 == 1 - else (sorted_scores[n // 2 - 1] + sorted_scores[n // 2]) / 2 - ) - variance = sum((x - mean) ** 2 for x in scores) / n - std = variance**0.5 - - return { - "mean": mean, - "median": median, - "min": min(scores), - "max": max(scores), - "std": std, - } - - # Extract scores - entity_overall_scores = [ - e.get("overall_score", 0.0) for e in entity_evaluations - ] - entity_accuracy_scores = [e.get("accuracy", 0.0) for e in entity_evaluations] - entity_completeness_scores = [ - e.get("completeness", 0.0) for e in entity_evaluations - ] - entity_precision_scores = [e.get("precision", 0.0) for e in entity_evaluations] - - relation_overall_scores = [ - r.get("overall_score", 0.0) for r in relation_evaluations - ] - relation_accuracy_scores = [ - r.get("accuracy", 0.0) for r in relation_evaluations - ] - relation_completeness_scores = [ - r.get("completeness", 0.0) for r in relation_evaluations - ] - relation_precision_scores = [ - r.get("precision", 0.0) for r in relation_evaluations - ] - - return { - "entity_accuracy": { - "overall_score": calculate_stats(entity_overall_scores), - "accuracy": calculate_stats(entity_accuracy_scores), - "completeness": calculate_stats(entity_completeness_scores), - "precision": calculate_stats(entity_precision_scores), - "total_chunks": len(entity_evaluations), - "detailed_results": entity_evaluations, - }, - "relation_accuracy": { - "overall_score": calculate_stats(relation_overall_scores), - "accuracy": calculate_stats(relation_accuracy_scores), - "completeness": calculate_stats(relation_completeness_scores), - "precision": calculate_stats(relation_precision_scores), - "total_chunks": len(relation_evaluations), - "detailed_results": relation_evaluations, - }, - } diff --git a/graphgen/models/evaluator/kg/consistency_evaluator.py b/graphgen/models/evaluator/kg/consistency_evaluator.py deleted file mode 100644 index 850dd48b..00000000 --- a/graphgen/models/evaluator/kg/consistency_evaluator.py +++ /dev/null @@ -1,388 +0,0 @@ -import asyncio -import json -import re -from typing import Any, Dict, List - -from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper -from graphgen.bases.datatypes import Chunk -from graphgen.templates.evaluation.kg.consistency_evaluation import ( - CONSISTENCY_EVALUATION_PROMPT, -) -from graphgen.utils import detect_main_language, logger - - -class ConsistencyEvaluator: - """Evaluates consistency by detecting semantic conflicts using LLM-as-a-Judge. - - For entities with multiple source chunks, compares entity_type and description - extracted from different chunks to detect semantic conflicts. - """ - - def __init__( - self, - graph_storage: BaseGraphStorage, - chunk_storage: BaseKVStorage, - llm_client: BaseLLMWrapper, - ): - self.graph_storage = graph_storage - self.chunk_storage = chunk_storage - self.llm_client = llm_client - - def evaluate(self) -> Dict[str, Any]: - """Evaluate consistency by detecting semantic conflicts.""" - all_nodes = self.graph_storage.get_all_nodes() or [] - if not all_nodes: - return {"error": "Empty graph"} - - return self._evaluate_consistency(all_nodes) - - def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]: - """Evaluate consistency by detecting semantic conflicts.""" - # Filter entities with multiple source chunks - entities_with_multiple_sources = [] - for node_id, node_data in all_nodes: - if not isinstance(node_data, dict): - continue - source_ids = node_data.get("source_id", "").split("") - source_ids = [sid.strip() for sid in source_ids if sid.strip()] - if len(source_ids) > 1: # Only check entities from multiple chunks - entities_with_multiple_sources.append((node_id, node_data, source_ids)) - - if not entities_with_multiple_sources: - logger.info( - "No entities with multiple sources found, skipping consistency check" - ) - return { - "conflict_rate": 0.0, - "conflict_entities_count": 0, - "total_entities": len(all_nodes), - "conflicts": [], - } - - logger.info( - f"Checking consistency for {len(entities_with_multiple_sources)} entities with multiple sources" - ) - - # Evaluate entities sequentially - conflicts = [] - conflict_entities = set() - - for entity_info in entities_with_multiple_sources: - try: - entity_id, entity_conflicts = self._evaluate_entity_consistency(entity_info) - if entity_conflicts: - conflicts.extend(entity_conflicts) - conflict_entities.add(entity_id) - except Exception as e: - logger.error( - f"Failed to evaluate entity {entity_info[0]}: {e}" - ) - continue - - total_entities = len(all_nodes) - conflict_rate = ( - len(conflict_entities) / total_entities if total_entities > 0 else 0 - ) - - return { - "conflict_rate": conflict_rate, - "conflict_entities_count": len(conflict_entities), - "total_entities": total_entities, - "entities_checked": len(entities_with_multiple_sources), - "conflicts": conflicts[:100], # Limit to first 100 conflicts - } - - def _clean_entity_id(self, entity_id: str) -> str: - """Clean entity ID by removing surrounding quotes.""" - clean_id = entity_id.strip() - if (clean_id.startswith('"') and clean_id.endswith('"')) or ( - clean_id.startswith("'") and clean_id.endswith("'") - ): - clean_id = clean_id[1:-1].strip() - return clean_id - - def _evaluate_entity_consistency( - self, entity_info: tuple - ) -> tuple[str, List[Dict]]: - """Evaluate consistency for a single entity.""" - entity_id, _node_data, source_ids = entity_info - # Clean entity_id for display - clean_entity_id = self._clean_entity_id(entity_id) - conflicts = [] - - # Get chunks for this entity - chunks = self._get_entity_chunks(source_ids) - if len(chunks) < 2: - return entity_id, [] - - # Extract entity attributes from each chunk - entity_extractions = {} - for chunk in chunks: - extraction = self._extract_entity_from_chunk(entity_id, chunk) - if extraction: - entity_extractions[chunk.id] = extraction - - if len(entity_extractions) < 2: - return entity_id, [] - - # Check entity type consistency - type_extractions = { - chunk_id: ext.get("entity_type", "") - for chunk_id, ext in entity_extractions.items() - } - type_conflict = self._check_entity_type_consistency( - entity_id, type_extractions - ) - if type_conflict and type_conflict.get("has_conflict", False): - conflicts.append( - { - "entity_id": clean_entity_id, - "conflict_type": "entity_type", - "conflict_severity": type_conflict.get("conflict_severity", 0.0), - "conflict_reasoning": type_conflict.get("conflict_reasoning", ""), - "conflicting_values": type_conflict.get("conflicting_types", []), - "recommended_value": type_conflict.get("recommended_type", ""), - } - ) - - # Check entity description consistency - descriptions = { - chunk_id: ext.get("description", "") - for chunk_id, ext in entity_extractions.items() - } - desc_conflict = self._check_entity_description_consistency( - entity_id, descriptions - ) - if desc_conflict and desc_conflict.get("has_conflict", False): - conflicts.append( - { - "entity_id": clean_entity_id, - "conflict_type": "description", - "conflict_severity": desc_conflict.get("conflict_severity", 0.0), - "conflict_reasoning": desc_conflict.get("conflict_reasoning", ""), - "conflicting_values": desc_conflict.get( - "conflicting_descriptions", [] - ), - "conflict_details": desc_conflict.get("conflict_details", ""), - } - ) - - return entity_id, conflicts - - def _get_entity_chunks(self, source_ids: List[str]) -> List[Chunk]: - """Get all chunks related to an entity.""" - chunks = [] - for chunk_id in source_ids: - chunk_data = self.chunk_storage.get_by_id(chunk_id) - if chunk_data: - try: - chunk = Chunk.from_dict(chunk_id, chunk_data) - chunks.append(chunk) - except Exception as e: - logger.warning(f"Failed to load chunk {chunk_id}: {e}") - continue - return chunks - - def _extract_entity_from_chunk( - self, entity_id: str, chunk: Chunk - ) -> Dict[str, str]: - """Extract entity attributes from a chunk using LLM.""" - try: - # Clean entity_id: remove surrounding quotes if present - clean_entity_id = self._clean_entity_id(entity_id) - - # Detect language and get appropriate prompt - lang = detect_main_language(chunk.content) - prompt = CONSISTENCY_EVALUATION_PROMPT[lang]["ENTITY_EXTRACTION"].format( - entity_name=clean_entity_id, - chunk_content=chunk.content[:2000] - if chunk.content - else "", # Limit content length - ) - - response = asyncio.run(self.llm_client.generate_answer(prompt)) - - # Try to parse JSON response - try: - extraction = json.loads(response) - except json.JSONDecodeError: - # Try to extract JSON from markdown code blocks - json_match = re.search(r"\{.*\}", response, re.DOTALL) - if json_match: - extraction = json.loads(json_match.group(0)) - else: - logger.warning( - f"Failed to parse extraction response for {entity_id} in chunk {chunk.id}" - ) - return {} - - # Normalize entity_type to lowercase and validate - entity_type = extraction.get("entity_type", "").lower().strip() - # Valid preset types - valid_types = { - "concept", - "date", - "location", - "keyword", - "organization", - "person", - "event", - "work", - "nature", - "artificial", - "science", - "technology", - "mission", - "gene", - } - # If entity_type is not in valid types, default to "concept" - if entity_type not in valid_types: - if entity_type: # If LLM provided a type but it's invalid - logger.warning( - f"Invalid entity_type '{entity_type}' for entity {clean_entity_id} in chunk {chunk.id}, " - f"defaulting to 'concept'" - ) - entity_type = "concept" - - return { - "entity_type": entity_type, - "description": extraction.get("description", ""), - } - except Exception as e: - logger.error( - f"Error extracting entity {entity_id} from chunk {chunk.id}: {e}" - ) - return {} - - def _check_entity_type_consistency( - self, entity_id: str, type_extractions: Dict[str, str] - ) -> Dict[str, Any]: - """Check entity type consistency using LLM.""" - if len(set(type_extractions.values())) <= 1: - # All types are the same, no conflict - return {"has_conflict": False} - - try: - type_list = [ - f"Chunk {chunk_id}: {entity_type}" - for chunk_id, entity_type in type_extractions.items() - if entity_type - ] - - # Detect language from type extraction text - type_text = "\n".join(type_list) - lang = detect_main_language(type_text) - prompt = CONSISTENCY_EVALUATION_PROMPT[lang]["ENTITY_TYPE_CONFLICT"].format( - entity_name=entity_id, type_extractions=type_text - ) - - response = asyncio.run(self.llm_client.generate_answer(prompt)) - - # Parse JSON response - try: - result = json.loads(response) - except json.JSONDecodeError: - json_match = re.search(r"\{.*\}", response, re.DOTALL) - if json_match: - result = json.loads(json_match.group(0)) - else: - logger.warning( - f"Failed to parse conflict detection response for {entity_id}" - ) - return {"has_conflict": False} - - return result - except Exception as e: - logger.error(f"Error checking type consistency for {entity_id}: {e}") - return {"has_conflict": False} - - def _check_entity_description_consistency( - self, entity_id: str, descriptions: Dict[str, str] - ) -> Dict[str, Any]: - """Check entity description consistency using LLM.""" - # Filter out empty descriptions - valid_descriptions = {k: v for k, v in descriptions.items() if v} - if len(valid_descriptions) < 2: - return {"has_conflict": False} - - if len(set(valid_descriptions.values())) <= 1: - # All descriptions are the same, no conflict - return {"has_conflict": False} - - try: - desc_list = [ - f"Chunk {chunk_id}: {description}" - for chunk_id, description in valid_descriptions.items() - ] - - # Detect language from description text - desc_text = "\n".join(desc_list) - lang = detect_main_language(desc_text) - prompt = CONSISTENCY_EVALUATION_PROMPT[lang]["ENTITY_DESCRIPTION_CONFLICT"].format( - entity_name=entity_id, descriptions=desc_text - ) - - response = asyncio.run(self.llm_client.generate_answer(prompt)) - - # Parse JSON response - try: - result = json.loads(response) - except json.JSONDecodeError: - json_match = re.search(r"\{.*\}", response, re.DOTALL) - if json_match: - result = json.loads(json_match.group(0)) - else: - logger.warning( - f"Failed to parse conflict detection response for {entity_id}" - ) - return {"has_conflict": False} - - return result - except Exception as e: - logger.error(f"Error checking description consistency for {entity_id}: {e}") - return {"has_conflict": False} - - def _check_relation_consistency( - self, src_id: str, dst_id: str, relation_extractions: Dict[str, str] - ) -> Dict[str, Any]: - """Check relation consistency using LLM.""" - if len(set(relation_extractions.values())) <= 1: - return {"has_conflict": False} - - try: - rel_list = [ - f"Chunk {chunk_id}: {relation}" - for chunk_id, relation in relation_extractions.items() - if relation - ] - - # Detect language from relation description text - rel_text = "\n".join(rel_list) - lang = detect_main_language(rel_text) - prompt = CONSISTENCY_EVALUATION_PROMPT[lang]["RELATION_CONFLICT"].format( - source_entity=src_id, - target_entity=dst_id, - relation_descriptions=rel_text, - ) - - response = asyncio.run(self.llm_client.generate_answer(prompt)) - - # Parse JSON response - try: - result = json.loads(response) - except json.JSONDecodeError: - json_match = re.search(r"\{.*\}", response, re.DOTALL) - if json_match: - result = json.loads(json_match.group(0)) - else: - logger.warning( - f"Failed to parse relation conflict response for {src_id}->{dst_id}" - ) - return {"has_conflict": False} - - return result - except Exception as e: - logger.error( - f"Error checking relation consistency for {src_id}->{dst_id}: {e}" - ) - return {"has_conflict": False} diff --git a/graphgen/models/evaluator/kg/structure_evaluator.py b/graphgen/models/evaluator/kg/structure_evaluator.py index 58e5b812..380459ad 100644 --- a/graphgen/models/evaluator/kg/structure_evaluator.py +++ b/graphgen/models/evaluator/kg/structure_evaluator.py @@ -4,49 +4,49 @@ import numpy as np from scipy import stats -from graphgen.bases import BaseGraphStorage +from graphgen.bases import BaseGraphStorage, BaseKGEvaluator from graphgen.utils import logger -class StructureEvaluator: +class StructureEvaluator(BaseKGEvaluator): """Evaluates structural robustness of the graph.""" def __init__( self, - graph_storage: BaseGraphStorage, noise_ratio_threshold: float = 0.15, largest_cc_ratio_threshold: float = 0.90, avg_degree_min: float = 2.0, avg_degree_max: float = 5.0, powerlaw_r2_threshold: float = 0.75, ): - self.graph_storage = graph_storage self.noise_ratio_threshold = noise_ratio_threshold self.largest_cc_ratio_threshold = largest_cc_ratio_threshold self.avg_degree_min = avg_degree_min self.avg_degree_max = avg_degree_max self.powerlaw_r2_threshold = powerlaw_r2_threshold - def evaluate(self) -> Dict[str, Any]: + def evaluate(self, kg: BaseGraphStorage) -> Dict[str, Any]: """ Evaluate the structural robustness of the graph. - :return: + :return: Dictionary of structural metrics and robustness verdict. The keys include: + - total_nodes: Total number of nodes in the graph + - total_edges: Total number of edges in the graph + - noise_ratio: Ratio of isolated nodes to total nodes + - largest_cc_ratio: Ratio of largest connected component size to total nodes + - avg_degree: Average node degree + - powerlaw_r2: R² value of power law fit to degree distribution + - is_robust: Boolean indicating if the graph is structurally robust """ - storage = self.graph_storage - - total_nodes = storage.get_node_count() - if total_nodes == 0: - return {"error": "Empty graph"} - - total_edges = storage.get_edge_count() - degree_map = storage.get_all_node_degrees() + total_nodes = kg.get_node_count() + total_edges = kg.get_edge_count() + degree_map = kg.get_all_node_degrees() # Noise ratio: isolated nodes / total nodes isolated_nodes = [nid for nid, deg in degree_map.items() if deg == 0] noise_ratio = len(isolated_nodes) / total_nodes # Largest connected component - components = storage.get_connected_components(undirected=True) + components = kg.get_connected_components(undirected=True) largest_cc_ratio = ( len(max(components, key=len)) / total_nodes if components else 0 ) diff --git a/graphgen/models/evaluator/qa/length_evaluator.py b/graphgen/models/evaluator/qa/length_evaluator.py index 266edfb6..723871b8 100644 --- a/graphgen/models/evaluator/qa/length_evaluator.py +++ b/graphgen/models/evaluator/qa/length_evaluator.py @@ -1,18 +1,19 @@ - import os -from graphgen.bases import BaseEvaluator, QAPair + +from graphgen.bases import BaseQAEvaluator, QAPair from graphgen.models.tokenizer import Tokenizer -class LengthEvaluator(BaseEvaluator): +class LengthEvaluator(BaseQAEvaluator): def __init__(self, tokenizer_name: str = None): - tokenizer_model = tokenizer_name or os.environ.get("TOKENIZER_MODEL", "cl100k_base") + tokenizer_model = tokenizer_name or os.environ.get( + "TOKENIZER_MODEL", "cl100k_base" + ) self.tokenizer: Tokenizer = Tokenizer(tokenizer_model) - def evaluate(self, pair: QAPair) -> float: + async def evaluate(self, pair: QAPair) -> dict[str, float]: """ Evaluate the length of the qa pair. """ content = pair.question + pair.answer - tokens = self.tokenizer.encode(content) - return len(tokens) + return {"length": self.tokenizer.count_tokens(content)} diff --git a/graphgen/models/evaluator/qa/mtld_evaluator.py b/graphgen/models/evaluator/qa/mtld_evaluator.py index e4e18d32..c8319036 100644 --- a/graphgen/models/evaluator/qa/mtld_evaluator.py +++ b/graphgen/models/evaluator/qa/mtld_evaluator.py @@ -1,10 +1,10 @@ from typing import Set -from graphgen.bases import BaseEvaluator, QAPair +from graphgen.bases import BaseQAEvaluator, QAPair from graphgen.utils import NLTKHelper, detect_main_language -class MTLDEvaluator(BaseEvaluator): +class MTLDEvaluator(BaseQAEvaluator): """ Metrics for measuring the lexical diversity of text. """ @@ -15,7 +15,7 @@ def __init__(self, threshold: float = 0.72): self.stopwords_zh: Set[str] = set(self.nltk_helper.get_stopwords("zh")) self.threshold = threshold - def evaluate(self, pair: QAPair) -> float: + async def evaluate(self, pair: QAPair) -> dict[str, float]: """ Calculate the MTLD (Mean Token Length Diversity) score for a given text. @@ -24,7 +24,7 @@ def evaluate(self, pair: QAPair) -> float: """ text = pair.answer if not text or not text.strip(): - return 0.0 + return {"mtld": 0} lang = detect_main_language(text) tokens = self.nltk_helper.word_tokenize(text, lang) @@ -34,7 +34,7 @@ def evaluate(self, pair: QAPair) -> float: filtered_tokens = [word for word in filtered_tokens if word.isalnum()] if not filtered_tokens: - return 0 + return {"mtld": 0} # Compute forward factors forward_factors = self._compute_factors(filtered_tokens, self.threshold) @@ -43,7 +43,8 @@ def evaluate(self, pair: QAPair) -> float: backward_factors = self._compute_factors(filtered_tokens[::-1], self.threshold) # Compute average factors - return (forward_factors + backward_factors) / 2 + mtld_score = (forward_factors + backward_factors) / 2 + return {"mtld": mtld_score} @staticmethod def _compute_factors(tokens: list, threshold: float) -> float: diff --git a/graphgen/models/evaluator/qa/reward_evaluator.py b/graphgen/models/evaluator/qa/reward_evaluator.py index a7fcbc22..f43b026a 100644 --- a/graphgen/models/evaluator/qa/reward_evaluator.py +++ b/graphgen/models/evaluator/qa/reward_evaluator.py @@ -1,8 +1,9 @@ from typing import Optional -from graphgen.bases import BaseEvaluator, QAPair +from graphgen.bases import BaseQAEvaluator, QAPair -class RewardEvaluator(BaseEvaluator): + +class RewardEvaluator(BaseQAEvaluator): """ Reward Model Evaluator for single QAPair evaluation. """ @@ -15,7 +16,7 @@ def __init__( ): """ Initialize the reward evaluator. - + Args: reward_name: Model name or path on HuggingFace Hub max_length: Maximum token length for the model @@ -26,6 +27,7 @@ def __init__( import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer + self.torch = torch # Set device (auto-detect if not specified) @@ -37,15 +39,17 @@ def __init__( self.model.to(self.device) self.model.eval() except Exception as e: - raise RuntimeError(f"Failed to load reward model '{reward_name}': {e}") from e + raise RuntimeError( + f"Failed to load reward model '{reward_name}': {e}" + ) from e - def evaluate(self, pair: QAPair) -> float: + async def evaluate(self, pair: QAPair) -> dict[str, float]: """ Evaluate a single question-answer pair using the reward model. - + Args: pair: QAPair containing question and answer strings - + Returns: Score as a float """ @@ -63,4 +67,4 @@ def evaluate(self, pair: QAPair) -> float: with self.torch.no_grad(): score = self.model(**inputs).logits[0].item() - return score + return {"reward_score": score} diff --git a/graphgen/models/evaluator/qa/uni_evaluator.py b/graphgen/models/evaluator/qa/uni_evaluator.py index 38406512..c07c3f01 100644 --- a/graphgen/models/evaluator/qa/uni_evaluator.py +++ b/graphgen/models/evaluator/qa/uni_evaluator.py @@ -1,14 +1,15 @@ # https://github.com/maszhongming/UniEval/tree/main -from typing import Optional, List -from graphgen.bases import BaseEvaluator, QAPair +from typing import List, Optional +from graphgen.bases import BaseQAEvaluator, QAPair -class UniEvaluator(BaseEvaluator): + +class UniEvaluator(BaseQAEvaluator): """ UniEvaluator for single QAPair evaluation across quality dimensions. - + Dimensions: naturalness, coherence, understandability - + Usage: evaluator = UniEvaluator() pair = QAPair(question="...", answer="...") @@ -34,6 +35,7 @@ def __init__( """ import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + self.torch = torch self.model_name = model_name or self.DEFAULT_MODEL @@ -58,10 +60,12 @@ def _build_input_text(dimension: str, question: str, answer: str) -> str: if dimension == "coherence": return f"question: Is this a coherent response? response: {answer} history: {question}" if dimension == "understandability": - return f"question: Is this an understandable response? response: {answer}" + return ( + f"question: Is this an understandable response? response: {answer}" + ) raise NotImplementedError(f"Unsupported dimension '{dimension}'") - def evaluate( + async def evaluate( self, pair: QAPair, dimensions: Optional[List[str]] = None, @@ -72,7 +76,9 @@ def evaluate( # Validate dimensions invalid = set(dimensions) - set(self.DEFAULT_DIMS) if invalid: - raise ValueError(f"Invalid dimensions: {invalid}. Available: {self.DEFAULT_DIMS}") + raise ValueError( + f"Invalid dimensions: {invalid}. Available: {self.DEFAULT_DIMS}" + ) results = {} no_token = self.torch.tensor([[self._no_id]], device=self.device) @@ -95,7 +101,9 @@ def evaluate( attention_mask=src_mask, labels=no_token, use_cache=False, - ).logits[:, 0, :] # [1, vocab_size] + ).logits[ + :, 0, : + ] # [1, vocab_size] probs = self.torch.softmax(logits, dim=-1)[0] score = probs[self._yes_id] / (probs[self._yes_id] + probs[self._no_id]) diff --git a/graphgen/models/evaluator/triple/__init__.py b/graphgen/models/evaluator/triple/__init__.py new file mode 100644 index 00000000..fedaf1be --- /dev/null +++ b/graphgen/models/evaluator/triple/__init__.py @@ -0,0 +1 @@ +from .accuracy_evaluator import AccuracyEvaluator diff --git a/graphgen/models/evaluator/triple/accuracy_evaluator.py b/graphgen/models/evaluator/triple/accuracy_evaluator.py new file mode 100644 index 00000000..6e1b7345 --- /dev/null +++ b/graphgen/models/evaluator/triple/accuracy_evaluator.py @@ -0,0 +1,94 @@ +import json +import re +from typing import Any, Dict + +from graphgen.bases import BaseLLMWrapper, BaseTripleEvaluator +from graphgen.templates import ACCURACY_EVALUATION_PROMPT +from graphgen.utils import detect_main_language, logger + + +class AccuracyEvaluator(BaseTripleEvaluator): + """Evaluates accuracy of entity recognition and relation extraction using LLM-as-a-Judge. + + For each chunk, uses LLM to evaluate the quality of extracted entities and relations + by comparing them with the original chunk content. Provides multi-dimensional quality + scores (accuracy, completeness, precision). + """ + + def __init__( + self, + llm_client: BaseLLMWrapper, + ): + self.llm_client = llm_client + + async def evaluate(self, unit: tuple) -> Dict[str, Any]: + """Evaluate entity and relation extraction quality using LLM-as-a-Judge. + + Returns: + Dictionary containing entity_accuracy and relation_accuracy metrics. + """ + chunk_content, nodes, edges = unit + lang = detect_main_language(chunk_content) + + # node + prompt = ACCURACY_EVALUATION_PROMPT[lang]["ENTITY"].format( + chunk_content=chunk_content, + extracted_entities=json.dumps(nodes, ensure_ascii=False, indent=2), + ) + + response = await self.llm_client.generate_answer(prompt) + + # Try to parse JSON response + try: + node_evaluation_result = json.loads(response) + except json.JSONDecodeError: + # Try to extract JSON from markdown code blocks or other formats + json_match = re.search(r"\{.*\}", response, re.DOTALL) + if json_match: + node_evaluation_result = json.loads(json_match.group(0)) + else: + logger.warning("Failed to parse LLM response.") + # default evaluation + node_evaluation_result = { + "accuracy": 0.0, + "completeness": 0.0, + "precision": 0.0, + "overall_score": 0.0, + "accuracy_reasoning": "Failed to parse LLM response", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": ["LLM response parsing failed"], + } + + # edge + prompt = ACCURACY_EVALUATION_PROMPT[lang]["RELATION"].format( + chunk_content=chunk_content, + extracted_relations=json.dumps(edges, ensure_ascii=False, indent=2), + ) + response = await self.llm_client.generate_answer(prompt) + # Try to parse JSON response + try: + edge_evaluation_result = json.loads(response) + except json.JSONDecodeError: + # Try to extract JSON from markdown code blocks or other formats + json_match = re.search(r"\{.*\}", response, re.DOTALL) + if json_match: + edge_evaluation_result = json.loads(json_match.group(0)) + else: + logger.warning("Failed to parse LLM response.") + # default evaluation + edge_evaluation_result = { + "accuracy": 0.0, + "completeness": 0.0, + "precision": 0.0, + "overall_score": 0.0, + "accuracy_reasoning": "Failed to parse LLM response", + "completeness_reasoning": "", + "precision_reasoning": "", + "issues": ["LLM response parsing failed"], + } + + return { + "entity_accuracy": node_evaluation_result, + "relation_accuracy": edge_evaluation_result, + } diff --git a/graphgen/models/extractor/schema_guided_extractor.py b/graphgen/models/extractor/schema_guided_extractor.py index 74801946..dee6e3e3 100644 --- a/graphgen/models/extractor/schema_guided_extractor.py +++ b/graphgen/models/extractor/schema_guided_extractor.py @@ -1,9 +1,8 @@ import json -from typing import Dict, List -from graphgen.bases import BaseExtractor, BaseLLMWrapper +from graphgen.bases import BaseExtractor, BaseLLMWrapper, Chunk from graphgen.templates import SCHEMA_GUIDED_EXTRACTION_PROMPT -from graphgen.utils import compute_dict_hash, detect_main_language, logger +from graphgen.utils import detect_main_language, logger class SchemaGuidedExtractor(BaseExtractor): @@ -59,9 +58,8 @@ def build_prompt(self, text: str) -> str: ) return prompt - async def extract(self, chunk: dict) -> dict: - _chunk_id = chunk.get("_chunk_id", "") - text = chunk.get("content", "") + async def extract(self, chunk: Chunk) -> dict: + text = chunk.content prompt = self.build_prompt(text) response = await self.llm_client.generate_answer(prompt) @@ -74,35 +72,9 @@ async def extract(self, chunk: dict) -> dict: if any(extracted_info[key] == "" for key in self.required_keys): logger.debug("Missing required keys in extraction: %s", extracted_info) return {} - main_keys_info = {key: extracted_info[key] for key in self.required_keys} logger.debug("Extracted info: %s", extracted_info) + return extracted_info - # add chunk metadata - extracted_info["_chunk_id"] = _chunk_id - - return { - compute_dict_hash(main_keys_info, prefix="extract-"): extracted_info - } except json.JSONDecodeError: logger.error("Failed to parse extraction response: %s", response) return {} - - @staticmethod - def merge_extractions(extraction_list: List[Dict[str, dict]]) -> Dict[str, dict]: - """ - Merge multiple extraction results based on their hashes. - :param extraction_list: List of extraction results, each is a dict with hash as key and record as value. - :return: Merged extraction results. - """ - merged: Dict[str, dict] = {} - for ext in extraction_list: - for h, rec in ext.items(): - if h not in merged: - merged[h] = rec.copy() - else: - for k, v in rec.items(): - if k not in merged[h] or merged[h][k] == v: - merged[h][k] = v - else: - merged[h][k] = f"{merged[h][k]}{v}" - return merged diff --git a/graphgen/models/generator/aggregated_generator.py b/graphgen/models/generator/aggregated_generator.py index 3ed2078b..3f223325 100644 --- a/graphgen/models/generator/aggregated_generator.py +++ b/graphgen/models/generator/aggregated_generator.py @@ -3,7 +3,7 @@ from graphgen.bases import BaseGenerator from graphgen.templates import AGGREGATED_GENERATION_PROMPT -from graphgen.utils import compute_content_hash, detect_main_language, logger +from graphgen.utils import detect_main_language, logger class AggregatedGenerator(BaseGenerator): @@ -101,30 +101,26 @@ async def generate( batch: tuple[ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] ], - ) -> dict[str, Any]: + ) -> list[dict]: """ Generate QAs based on a given batch. :param batch :return: QA pairs """ - result = {} rephrasing_prompt = self.build_prompt(batch) response = await self.llm_client.generate_answer(rephrasing_prompt) context = self.parse_rephrased_text(response) if not context: - return result + return [] question_generation_prompt = self._build_prompt_for_question_generation(context) response = await self.llm_client.generate_answer(question_generation_prompt) question = self.parse_response(response)["question"] if not question: - return result + return [] logger.debug("Question: %s", question) logger.debug("Answer: %s", context) qa_pairs = { - compute_content_hash(question): { - "question": question, - "answer": context, - } + "question": question, + "answer": context, } - result.update(qa_pairs) - return result + return [qa_pairs] diff --git a/graphgen/models/generator/atomic_generator.py b/graphgen/models/generator/atomic_generator.py index 152e6389..d045b0da 100644 --- a/graphgen/models/generator/atomic_generator.py +++ b/graphgen/models/generator/atomic_generator.py @@ -3,7 +3,7 @@ from graphgen.bases import BaseGenerator from graphgen.templates import ATOMIC_GENERATION_PROMPT -from graphgen.utils import compute_content_hash, detect_main_language, logger +from graphgen.utils import detect_main_language, logger class AtomicGenerator(BaseGenerator): @@ -23,7 +23,7 @@ def build_prompt( return prompt @staticmethod - def parse_response(response: str) -> dict: + def parse_response(response: str) -> list[dict]: """ AtomicGenerator normally generates one QA pair per response. So we just need to parse one QA pair from the response. @@ -38,15 +38,10 @@ def parse_response(response: str) -> dict: answer = answer_match.group(1).strip() else: logger.warning("Failed to parse response: %s", response) - return {} + return [] question = question.strip('"').strip("'") answer = answer.strip('"').strip("'") logger.debug("Question: %s", question) logger.debug("Answer: %s", answer) - return { - compute_content_hash(question): { - "question": question, - "answer": answer, - } - } + return [{"question": question, "answer": answer}] diff --git a/graphgen/models/generator/cot_generator.py b/graphgen/models/generator/cot_generator.py index 3a893784..88d04324 100644 --- a/graphgen/models/generator/cot_generator.py +++ b/graphgen/models/generator/cot_generator.py @@ -3,7 +3,7 @@ from graphgen.bases import BaseGenerator from graphgen.templates import COT_GENERATION_PROMPT -from graphgen.utils import compute_content_hash, detect_main_language, logger +from graphgen.utils import detect_main_language, logger class CoTGenerator(BaseGenerator): @@ -100,28 +100,25 @@ async def generate( batch: tuple[ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] ], - ) -> dict[str, Any]: + ) -> list[dict]: """ Generate QAs based on a given batch. :param batch :return: QA pairs """ - result = {} prompt = self.build_prompt(batch) response = await self.llm_client.generate_answer(prompt) response = self.parse_response(response) if not response: - return result + return [] question, reasoning_path = response["question"], response["reasoning_path"] prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path) cot_answer = await self.llm_client.generate_answer(prompt) logger.debug("CoT Answer: %s", cot_answer) - qa_pairs = { - compute_content_hash(question): { + return [ + { "question": question, "answer": cot_answer, "reasoning_path": reasoning_path, } - } - result.update(qa_pairs) - return result + ] diff --git a/graphgen/models/generator/fill_in_blank_generator.py b/graphgen/models/generator/fill_in_blank_generator.py index c2f43898..a26daf3e 100644 --- a/graphgen/models/generator/fill_in_blank_generator.py +++ b/graphgen/models/generator/fill_in_blank_generator.py @@ -3,7 +3,7 @@ from graphgen.bases import BaseGenerator from graphgen.templates import FILL_IN_BLANK_GENERATION_PROMPT -from graphgen.utils import compute_content_hash, detect_main_language, logger +from graphgen.utils import detect_main_language, logger class FillInBlankGenerator(BaseGenerator): @@ -12,7 +12,7 @@ def __init__(self, llm_client, num_of_questions) -> None: self.num_of_questions = num_of_questions @staticmethod - def parse_response(response: str) -> Any: + def parse_response(response: str) -> list[dict]: """ Parse fill-in-the-blank QA pairs from the LLM response. Each QA pair contains question text with placeholders and the correct answer(s). @@ -21,14 +21,14 @@ def parse_response(response: str) -> Any: :return: Dictionary mapping question hash to question data, where each value is a dict with "question", "answer", and "answers" keys """ - qa_pairs = {} + qa_pairs = [] # Extract all QA pair blocks qa_blocks = re.findall(r"(.*?)", response, re.DOTALL) if not qa_blocks: logger.warning("No QA pairs found in response: %s", response) - return {} + return qa_pairs for block in qa_blocks: # Extract and clean question text @@ -55,13 +55,13 @@ def parse_response(response: str) -> Any: logger.warning("No valid answers found in: %s", answer_text) continue - # Build result entry with question hash as key - question_hash = compute_content_hash(question) - qa_pairs[question_hash] = { - "question": question, - "answer": answer_text, # Original answer text with commas - "answers": answers, # List of individual answers: ["A8X"] or ["A8X", "八百万"] - } + qa_pairs.append( + { + "question": question, + "answer": answer_text, # Original answer text with commas + "answers": answers, # List of individual answers: ["A8X"] or ["A8X", "八百万"] + } + ) logger.debug( "Successfully parsed fill-in-the-blank question: %s", question[:50] diff --git a/graphgen/models/generator/multi_answer_generator.py b/graphgen/models/generator/multi_answer_generator.py index b5a0db5c..a341a4fd 100644 --- a/graphgen/models/generator/multi_answer_generator.py +++ b/graphgen/models/generator/multi_answer_generator.py @@ -3,7 +3,7 @@ from graphgen.bases import BaseGenerator from graphgen.templates import MAQ_GENERATION_PROMPT -from graphgen.utils import compute_content_hash, detect_main_language, logger +from graphgen.utils import detect_main_language, logger class MultiAnswerGenerator(BaseGenerator): @@ -12,7 +12,7 @@ def __init__(self, llm_client, num_of_questions) -> None: self.num_of_questions = num_of_questions @staticmethod - def parse_response(response: str) -> Any: + def parse_response(response: str) -> list[dict]: """ Parse multiple-answer QA pairs from the LLM response. Each QA pair contains question text, four options, and the correct answers (one or more). @@ -21,14 +21,14 @@ def parse_response(response: str) -> Any: :return: Dictionary mapping question hash to question data, where each value is a dict with "question", "options", and "answer" keys """ - qa_pairs = {} + qa_pairs = [] # Extract all QA pair blocks qa_blocks = re.findall(r"(.*?)", response, re.DOTALL) if not qa_blocks: logger.warning("No QA pairs found in response: %s", response) - return {} + return qa_pairs for block in qa_blocks: # Extract and clean question text @@ -61,7 +61,9 @@ def parse_response(response: str) -> Any: logger.warning("Failed to parse answer from block: %s", block) continue answer_text = ans_match.group(1).strip().strip('"').strip("'") - answers = [ans.strip().upper() for ans in answer_text.split(",") if ans.strip()] + answers = [ + ans.strip().upper() for ans in answer_text.split(",") if ans.strip() + ] invalid_answers = [ans for ans in answers if ans not in options] if invalid_answers: logger.warning( @@ -76,13 +78,13 @@ def parse_response(response: str) -> Any: logger.warning("No valid answers found in: %s", answer_text) continue - # Build result entry with question hash as key - question_hash = compute_content_hash(question) - qa_pairs[question_hash] = { - "question": question, - "options": options, # Dict like {"A": "text", "B": "text", ...} - "answer": ", ".join(answers), - } + qa_pairs.append( + { + "question": question, + "options": options, # Dict like {"A": "text", "B": "text", ...} + "answers": answers, # List of correct answers: ["A", "C"] + } + ) logger.debug("Successfully parsed MAQ: %s", question[:50]) diff --git a/graphgen/models/generator/multi_choice_generator.py b/graphgen/models/generator/multi_choice_generator.py index fcac2e1b..0c48b76d 100644 --- a/graphgen/models/generator/multi_choice_generator.py +++ b/graphgen/models/generator/multi_choice_generator.py @@ -3,7 +3,7 @@ from graphgen.bases import BaseGenerator from graphgen.templates import MCQ_GENERATION_PROMPT -from graphgen.utils import compute_content_hash, detect_main_language, logger +from graphgen.utils import detect_main_language, logger class MultiChoiceGenerator(BaseGenerator): @@ -12,7 +12,7 @@ def __init__(self, llm_client, num_of_questions) -> None: self.num_of_questions = num_of_questions @staticmethod - def parse_response(response: str) -> Any: + def parse_response(response: str) -> list[dict]: """ Parse multiple choice QA pairs from the LLM response. Each QA pair contains question text, four options, and the correct answer. @@ -21,14 +21,14 @@ def parse_response(response: str) -> Any: :return: Dictionary mapping question hash to question data, where each value is a dict with "question", "options", and "answer" keys """ - qa_pairs = {} + qa_pairs = [] # Extract all QA pair blocks qa_blocks = re.findall(r"(.*?)", response, re.DOTALL) if not qa_blocks: logger.warning("No QA pairs found in response: %s", response) - return {} + return qa_pairs for block in qa_blocks: # Extract and clean question text @@ -76,13 +76,13 @@ def parse_response(response: str) -> Any: ) continue - # Build result entry with question hash as key - question_hash = compute_content_hash(question) - qa_pairs[question_hash] = { - "question": question, - "options": options, # Dict like {"A": "text", "B": "text", ...} - "answer": answer, # Single letter: "A", "B", "C", or "D" - } + qa_pairs.append( + { + "question": question, + "options": options, # Dict like {"A": "text", "B": "text", ...} + "answer": answer, # Single letter: "A", "B", "C", or "D" + } + ) logger.debug("Successfully parsed MCQ: %s", question[:50]) diff --git a/graphgen/models/generator/multi_hop_generator.py b/graphgen/models/generator/multi_hop_generator.py index 896592e8..a19082b9 100644 --- a/graphgen/models/generator/multi_hop_generator.py +++ b/graphgen/models/generator/multi_hop_generator.py @@ -3,7 +3,7 @@ from graphgen.bases import BaseGenerator from graphgen.templates import MULTI_HOP_GENERATION_PROMPT -from graphgen.utils import compute_content_hash, detect_main_language, logger +from graphgen.utils import detect_main_language, logger class MultiHopGenerator(BaseGenerator): @@ -32,7 +32,7 @@ def build_prompt( return prompt @staticmethod - def parse_response(response: str) -> dict: + def parse_response(response: str) -> list[dict]: question_match = re.search(r"(.*?)", response, re.DOTALL) answer_match = re.search(r"(.*?)", response, re.DOTALL) @@ -41,15 +41,10 @@ def parse_response(response: str) -> dict: answer = answer_match.group(1).strip() else: logger.warning("Failed to parse response: %s", response) - return {} + return [] question = question.strip('"').strip("'") answer = answer.strip('"').strip("'") logger.debug("Question: %s", question) logger.debug("Answer: %s", answer) - return { - compute_content_hash(question): { - "question": question, - "answer": answer, - } - } + return [{"question": question, "answer": answer}] diff --git a/graphgen/models/generator/quiz_generator.py b/graphgen/models/generator/quiz_generator.py index d117092d..864e4a5d 100644 --- a/graphgen/models/generator/quiz_generator.py +++ b/graphgen/models/generator/quiz_generator.py @@ -31,12 +31,16 @@ def build_prompt( description = edges[0][2].get("description", "") template_type = edges[0][2].get("template_type", "TEMPLATE") else: - raise ValueError("Batch must contain at least one node or edge with description") + raise ValueError( + "Batch must contain at least one node or edge with description" + ) return QuizGenerator.build_prompt_for_description(description, template_type) @staticmethod - def build_prompt_for_description(description: str, template_type: str = "TEMPLATE") -> str: + def build_prompt_for_description( + description: str, template_type: str = "TEMPLATE" + ) -> str: """ Build prompt for rephrasing a single description. :param description: The description to rephrase @@ -49,17 +53,6 @@ def build_prompt_for_description(description: str, template_type: str = "TEMPLAT ) return prompt - @staticmethod - def parse_rephrased_text(response: str) -> str: - """ - Parse the rephrased text from the response. - :param response: - :return: - """ - rephrased_text = response.strip().strip('"') - logger.debug("Rephrased Text: %s", rephrased_text) - return rephrased_text - @staticmethod def parse_response(response: str) -> Any: """ @@ -67,4 +60,15 @@ def parse_response(response: str) -> Any: :param response: LLM response :return: Rephrased text """ - return QuizGenerator.parse_rephrased_text(response) + + def parse_rephrased_text(content: str) -> str: + """ + Parse the rephrased text from the response. + :param content: LLM response content + :return: + """ + rephrased_text = content.strip().strip('"') + logger.debug("Rephrased Text: %s", rephrased_text) + return rephrased_text + + return parse_rephrased_text(response) diff --git a/graphgen/models/generator/true_false_generator.py b/graphgen/models/generator/true_false_generator.py index 0ac67ced..1a1fa0d3 100644 --- a/graphgen/models/generator/true_false_generator.py +++ b/graphgen/models/generator/true_false_generator.py @@ -3,7 +3,7 @@ from graphgen.bases import BaseGenerator from graphgen.templates import TF_GENERATION_PROMPT -from graphgen.utils import compute_content_hash, detect_main_language, logger +from graphgen.utils import detect_main_language, logger class TrueFalseGenerator(BaseGenerator): @@ -12,7 +12,7 @@ def __init__(self, llm_client, num_of_questions) -> None: self.num_of_questions = num_of_questions @staticmethod - def parse_response(response: str) -> Any: + def parse_response(response: str) -> list[dict]: """ Parse true/false QA pairs from the LLM response. Each QA pair contains a statement question and True/False answer. @@ -21,14 +21,14 @@ def parse_response(response: str) -> Any: :return: Dictionary mapping question hash to question data, where each value is a dict with "question", "options", and "answer" keys """ - qa_pairs: dict[str, dict[str, Any]] = {} + qa_pairs: list[dict[str, str]] = [] # Extract all QA pair blocks qa_blocks = re.findall(r"(.*?)", response, re.DOTALL) if not qa_blocks: logger.warning("No QA pairs found in response: %s", response) - return {} + return qa_pairs for block in qa_blocks: # Extract and clean question text @@ -50,12 +50,12 @@ def parse_response(response: str) -> Any: logger.warning("Invalid answer '%s' in block: %s", answer, block) continue - # Build result entry with question hash as key - question_hash = compute_content_hash(question) - qa_pairs[question_hash] = { - "question": question, - "answer": answer, # "True" or "False" - } + qa_pairs.append( + { + "question": question, + "answer": answer, # "True" or "False" + } + ) logger.debug("Successfully parsed TF question: %s", question[:50]) diff --git a/graphgen/models/generator/vqa_generator.py b/graphgen/models/generator/vqa_generator.py index 790eef83..723bd2a6 100644 --- a/graphgen/models/generator/vqa_generator.py +++ b/graphgen/models/generator/vqa_generator.py @@ -1,9 +1,10 @@ +import json import re from typing import Any from graphgen.bases import BaseGenerator from graphgen.templates import VQA_GENERATION_PROMPT -from graphgen.utils import compute_content_hash, detect_main_language, logger +from graphgen.utils import detect_main_language, logger class VQAGenerator(BaseGenerator): @@ -32,13 +33,13 @@ def build_prompt( return prompt @staticmethod - def parse_response(response: str) -> Any: + def parse_response(response: str) -> list[dict]: """ Parse the LLM response and return the generated QAs :param response :return: QA pairs """ - qa_pairs = {} + qa_pairs = [] pattern = r"(.*?)\s*(.*?)" matches = re.findall(pattern, response, re.DOTALL) @@ -48,10 +49,12 @@ def parse_response(response: str) -> Any: answer = answer.strip().strip('"').strip("'") logger.debug("Question: %s", question) logger.debug("Answer: %s", answer) - qa_pairs[compute_content_hash(question)] = { - "question": question, - "answer": answer, - } + qa_pairs.append( + { + "question": question, + "answer": answer, + } + ) else: logger.warning("Error parsing the response %s", response) return qa_pairs @@ -61,76 +64,58 @@ async def generate( batch: tuple[ list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]] ], - ) -> dict[str, Any]: + ) -> list[dict]: """ Generate QAs based on a given batch. :param batch :return: QA pairs """ - result = {} prompt = self.build_prompt(batch) response = await self.llm_client.generate_answer(prompt) qa_pairs = self.parse_response(response) # generate one or more QA pairs nodes, _ = batch for node in nodes: node_data = node[1] - if "image_data" in node_data and node_data["image_data"]: - img_path = node_data["image_data"]["img_path"] - for qa in qa_pairs.values(): + if "metadata" in node_data and node_data["metadata"]: + metadata = json.loads(node_data["metadata"])["metadata"] + img_path = metadata.get("path", "") + for qa in qa_pairs: qa["img_path"] = img_path - result.update(qa_pairs) - return result + return qa_pairs @staticmethod - def format_generation_results( - results: list[dict], output_data_format: str - ) -> list[dict[str, Any]]: + def format_generation_results(result: dict, output_data_format: str) -> dict: + question = result.get("question", "") + answer = result.get("answer", "") + img_path = result.get("img_path", "") if output_data_format == "Alpaca": - results = [ - { - "instruction": v["question"], - "input": "", - "output": v["answer"], - "image": v.get("img_path", ""), - } - for item in results - for k, v in item.items() - ] - elif output_data_format == "Sharegpt": - results = [ - { - "conversations": [ - { - "from": "human", - "value": [ - {"text": v["question"], "image": v.get("img_path", "")} - ], - }, - {"from": "gpt", "value": [{"text": v["answer"]}]}, - ] - } - for item in results - for k, v in item.items() - ] - elif output_data_format == "ChatML": - results = [ - { - "messages": [ - { - "role": "user", - "content": [ - {"text": v["question"], "image": v.get("img_path", "")} - ], - }, - { - "role": "assistant", - "content": [{"type": "text", "text": v["answer"]}], - }, - ] - } - for item in results - for k, v in item.items() - ] - else: - raise ValueError(f"Unknown output data format: {output_data_format}") - return results + return { + "instruction": question, + "input": "", + "output": answer, + "image": img_path, + } + if output_data_format == "Sharegpt": + return { + "conversations": [ + { + "from": "human", + "value": [{"text": question, "image": img_path}], + }, + {"from": "gpt", "value": [{"text": answer}]}, + ] + } + if output_data_format == "ChatML": + return { + "messages": [ + { + "role": "user", + "content": [{"text": question, "image": img_path}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": answer}], + }, + ] + } + raise ValueError(f"Unknown output data format: {output_data_format}") diff --git a/graphgen/models/kg_builder/light_rag_kg_builder.py b/graphgen/models/kg_builder/light_rag_kg_builder.py index 460dcea0..b23178ce 100644 --- a/graphgen/models/kg_builder/light_rag_kg_builder.py +++ b/graphgen/models/kg_builder/light_rag_kg_builder.py @@ -1,3 +1,4 @@ +import json import re from collections import Counter, defaultdict from typing import Dict, List, Tuple @@ -130,15 +131,25 @@ async def merge_nodes( set([dp["source_id"] for dp in node_data] + source_ids) ) - node_data = { + node_data_dict = { "entity_type": entity_type, "entity_name": entity_name, "description": description, "source_id": source_id, "length": self.tokenizer.count_tokens(description), } - kg_instance.upsert_node(entity_name, node_data=node_data) - return node_data + + if entity_type in ("IMAGE", "TABLE", "FORMULA"): + metadata = next( + (dp["metadata"] for dp in node_data if dp.get("metadata")), None + ) + if metadata: + node_data_dict["metadata"] = json.dumps( + metadata, ensure_ascii=False, default=str + ) + + kg_instance.upsert_node(entity_name, node_data=node_data_dict) + return node_data_dict async def merge_edges( self, diff --git a/graphgen/models/kg_builder/mm_kg_builder.py b/graphgen/models/kg_builder/mm_kg_builder.py index f352cb2a..c406b7ce 100644 --- a/graphgen/models/kg_builder/mm_kg_builder.py +++ b/graphgen/models/kg_builder/mm_kg_builder.py @@ -70,6 +70,8 @@ async def extract( entity = await handle_single_entity_extraction(attributes, chunk_id) if entity is not None: + if entity["entity_type"] == "IMAGE": + entity["metadata"] = chunk.metadata nodes[entity["entity_name"]].append(entity) continue diff --git a/graphgen/models/reader/csv_reader.py b/graphgen/models/reader/csv_reader.py index a0343d97..2f6ba4c7 100644 --- a/graphgen/models/reader/csv_reader.py +++ b/graphgen/models/reader/csv_reader.py @@ -22,7 +22,7 @@ def read(self, input_path: Union[str, List[str]]) -> Dataset: :return: Ray Dataset containing validated and filtered data. """ - ds = ray.data.read_csv(input_path) + ds = ray.data.read_csv(input_path, include_paths=True) ds = ds.map_batches(self._validate_batch, batch_format="pandas") ds = ds.filter(self._should_keep_item) return ds diff --git a/graphgen/models/reader/json_reader.py b/graphgen/models/reader/json_reader.py index 6752e042..b8bb7f76 100644 --- a/graphgen/models/reader/json_reader.py +++ b/graphgen/models/reader/json_reader.py @@ -34,10 +34,13 @@ def read(self, input_path: Union[str, List[str]]) -> ray.data.Dataset: with open(file, "r", encoding="utf-8") as f: data = json.load(f) data = self._unify_schema(data) + # add path + for item in data: + item["path"] = file file_ds: ray.data.Dataset = ray.data.from_items(data) ds = ds.union(file_ds) # type: ignore else: - ds = ray.data.read_json(input_path) + ds = ray.data.read_json(input_path, include_paths=True) ds = ds.map_batches(self._validate_batch, batch_format="pandas") ds = ds.filter(self._should_keep_item) return ds diff --git a/graphgen/models/reader/parquet_reader.py b/graphgen/models/reader/parquet_reader.py index dd289e31..cc283927 100644 --- a/graphgen/models/reader/parquet_reader.py +++ b/graphgen/models/reader/parquet_reader.py @@ -24,7 +24,7 @@ def read(self, input_path: Union[str, List[str]]) -> Dataset: if not ray.is_initialized(): ray.init() - ds = ray.data.read_parquet(input_path) + ds = ray.data.read_parquet(input_path, include_paths=True) ds = ds.map_batches(self._validate_batch, batch_format="pandas") ds = ds.filter(self._should_keep_item) return ds diff --git a/graphgen/models/reader/rdf_reader.py b/graphgen/models/reader/rdf_reader.py index 9670107a..82e7d572 100644 --- a/graphgen/models/reader/rdf_reader.py +++ b/graphgen/models/reader/rdf_reader.py @@ -118,7 +118,7 @@ def _parse_rdf_file(self, file_path: Path) -> List[Dict[str, Any]]: "id": str(subj), self.text_column: text, "properties": props, - "source_file": str(file_path), + "path": str(file_path), } docs.append(doc) diff --git a/graphgen/models/reader/txt_reader.py b/graphgen/models/reader/txt_reader.py index 51a47de2..784dbe96 100644 --- a/graphgen/models/reader/txt_reader.py +++ b/graphgen/models/reader/txt_reader.py @@ -18,13 +18,14 @@ def read( """ docs_ds = ray.data.read_binary_files( input_path, - include_paths=False, + include_paths=True, ) docs_ds = docs_ds.map( lambda row: { "type": "text", self.text_column: row["bytes"].decode("utf-8"), + "path": row["path"], } ) diff --git a/graphgen/models/storage/__init__.py b/graphgen/models/storage/__init__.py deleted file mode 100644 index 889a074c..00000000 --- a/graphgen/models/storage/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from graphgen.models.storage.graph.kuzu_storage import KuzuStorage -from graphgen.models.storage.graph.networkx_storage import NetworkXStorage -from graphgen.models.storage.kv.json_storage import JsonKVStorage -from graphgen.models.storage.kv.rocksdb_storage import RocksDBKVStorage - -from .rocksdb_cache import RocksDBCache diff --git a/graphgen/models/storage/rocksdb_cache.py b/graphgen/models/storage/rocksdb_cache.py deleted file mode 100644 index 2345b5b5..00000000 --- a/graphgen/models/storage/rocksdb_cache.py +++ /dev/null @@ -1,43 +0,0 @@ -from pathlib import Path -from typing import Any, Iterator, Optional - -# rocksdict is a lightweight C wrapper around RocksDB for Python, pylint may not recognize it -# pylint: disable=no-name-in-module -from rocksdict import Rdict - - -class RocksDBCache: - def __init__(self, cache_dir: str): - self.db_path = Path(cache_dir) - self.db = Rdict(str(self.db_path)) - - def get(self, key: str) -> Optional[Any]: - return self.db.get(key) - - def set(self, key: str, value: Any): - self.db[key] = value - - def delete(self, key: str): - try: - del self.db[key] - except KeyError: - # If the key does not exist, do nothing (deletion is idempotent for caches) - pass - - def close(self): - if hasattr(self, "db") and self.db is not None: - self.db.close() - self.db = None - - def __del__(self): - # Ensure the database is closed when the object is destroyed - self.close() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def __iter__(self) -> Iterator[str]: - return iter(self.db.keys()) diff --git a/graphgen/models/vis/__init__.py b/graphgen/models/vis/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/graphgen/models/vis/community_visualizer.py b/graphgen/models/vis/community_visualizer.py deleted file mode 100644 index 05551014..00000000 --- a/graphgen/models/vis/community_visualizer.py +++ /dev/null @@ -1,48 +0,0 @@ -from dataclasses import dataclass -from typing import Dict - -import matplotlib.pyplot as plt -import networkx as nx - - -@dataclass -class Visualizer: - """ - Class for visualizing graphs using NetworkX and Matplotlib. - """ - - graph: nx.Graph = None - communities: Dict[str, int] = None - layout: str = "spring" - max_nodes: int = 1000 - node_size: int = 10 - alpha: float = 0.6 - - def visualize(self, save_path: str = None): - n = self.graph.number_of_nodes() - if self.layout == "spring": - k = max(0.1, 1.0 / (n**0.5)) - pos = nx.spring_layout(self.graph, k=k, seed=42) - else: - raise ValueError(f"Unknown layout: {self.layout}") - - plt.figure(figsize=(10, 10)) - - node_colors = [self.communities.get(node, 0) for node in self.graph.nodes()] - - nx.draw_networkx_nodes( - self.graph, - pos, - node_size=self.node_size, - node_color=node_colors, - cmap=plt.cm.tab20, - alpha=self.alpha, - ) - nx.draw_networkx_edges(self.graph, pos, alpha=0.3, width=0.2) - plt.axis("off") - - if save_path: - plt.savefig(save_path, dpi=300, bbox_inches="tight") - print("Saved to", save_path) - else: - plt.show() diff --git a/graphgen/operators/build_kg/build_kg_service.py b/graphgen/operators/build_kg/build_kg_service.py index b4155dde..b5b0696b 100644 --- a/graphgen/operators/build_kg/build_kg_service.py +++ b/graphgen/operators/build_kg/build_kg_service.py @@ -1,6 +1,4 @@ -from typing import List - -import pandas as pd +from typing import Tuple from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator from graphgen.bases.datatypes import Chunk @@ -13,9 +11,15 @@ class BuildKGService(BaseOperator): def __init__( - self, working_dir: str = "cache", graph_backend: str = "kuzu", **build_kwargs + self, + working_dir: str = "cache", + kv_backend: str = "rocksdb", + graph_backend: str = "kuzu", + **build_kwargs ): - super().__init__(working_dir=working_dir, op_name="build_kg_service") + super().__init__( + working_dir=working_dir, kv_backend=kv_backend, op_name="build_kg" + ) self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.graph_storage: BaseGraphStorage = init_storage( backend=graph_backend, working_dir=working_dir, namespace="graph" @@ -23,21 +27,15 @@ def __init__( self.build_kwargs = build_kwargs self.max_loop: int = int(self.build_kwargs.get("max_loop", 3)) - def process(self, batch: pd.DataFrame) -> pd.DataFrame: - docs = batch.to_dict(orient="records") - docs = [Chunk.from_dict(doc["_chunk_id"], doc) for doc in docs] - - # consume the chunks and build kg - nodes, edges = self.build_kg(docs) - return pd.DataFrame( - [{"node": node, "edge": []} for node in nodes] - + [{"node": [], "edge": edge} for edge in edges] - ) - - def build_kg(self, chunks: List[Chunk]) -> tuple: + def process(self, batch: list) -> Tuple[list, dict]: """ Build knowledge graph (KG) and merge into kg_instance + :return: A tuple of (results, meta_updates) + results: A list of dicts containing nodes and edges added to the KG. Each dict has the structure: + {"_trace_id": str, "node": dict, "edge": dict} + meta_updates: A dict mapping source IDs to lists of trace IDs for nodes and edges added. """ + chunks = [Chunk.from_dict(doc["_trace_id"], doc) for doc in batch] text_chunks = [chunk for chunk in chunks if chunk.type == "text"] mm_chunks = [ chunk @@ -75,4 +73,34 @@ def build_kg(self, chunks: List[Chunk]) -> tuple: self.graph_storage.index_done_callback() logger.info("Knowledge graph building completed.") - return nodes, edges + meta_updates = {} + results = [] + for node in nodes: + if not node: + continue + trace_id = node["entity_name"] + results.append( + { + "_trace_id": trace_id, + "node": node, + "edge": {}, + } + ) + source_ids = node.get("source_id", "").split("") + for source_id in source_ids: + meta_updates.setdefault(source_id, []).append(trace_id) + for edge in edges: + if not edge: + continue + trace_id = frozenset((edge["src_id"], edge["tgt_id"])) + results.append( + { + "_trace_id": str(trace_id), + "node": {}, + "edge": edge, + } + ) + source_ids = edge.get("source_id", "").split("") + for source_id in source_ids: + meta_updates.setdefault(source_id, []).append(str(trace_id)) + return results, meta_updates diff --git a/graphgen/operators/build_kg/build_text_kg.py b/graphgen/operators/build_kg/build_text_kg.py index 2a7e1b03..f0954dd4 100644 --- a/graphgen/operators/build_kg/build_text_kg.py +++ b/graphgen/operators/build_kg/build_text_kg.py @@ -30,6 +30,7 @@ def build_text_kg( desc="[2/4]Extracting entities and relationships from chunks", unit="chunk", ) + results = [res for res in results if res] nodes = defaultdict(list) edges = defaultdict(list) diff --git a/graphgen/operators/chunk/chunk_service.py b/graphgen/operators/chunk/chunk_service.py index 102c74fd..68d67914 100644 --- a/graphgen/operators/chunk/chunk_service.py +++ b/graphgen/operators/chunk/chunk_service.py @@ -1,17 +1,14 @@ import os from functools import lru_cache -from typing import Union - -import pandas as pd +from typing import Tuple, Union from graphgen.bases import BaseOperator -from graphgen.common import init_storage from graphgen.models import ( ChineseRecursiveTextSplitter, RecursiveCharacterSplitter, Tokenizer, ) -from graphgen.utils import compute_content_hash, detect_main_language +from graphgen.utils import detect_main_language _MAPPING = { "en": RecursiveCharacterSplitter, @@ -45,26 +42,25 @@ class ChunkService(BaseOperator): def __init__( self, working_dir: str = "cache", kv_backend: str = "rocksdb", **chunk_kwargs ): - super().__init__(working_dir=working_dir, op_name="chunk_service") + super().__init__( + working_dir=working_dir, kv_backend=kv_backend, op_name="chunk" + ) tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base") self.tokenizer_instance: Tokenizer = Tokenizer(model_name=tokenizer_model) - self.chunk_storage = init_storage( - backend=kv_backend, - working_dir=working_dir, - namespace="chunk", - ) self.chunk_kwargs = chunk_kwargs - def process(self, batch: pd.DataFrame) -> pd.DataFrame: - docs = batch.to_dict(orient="records") - return pd.DataFrame(self.chunk_documents(docs)) - - def chunk_documents(self, new_docs: list) -> list: - chunks = [] - for doc in new_docs: - doc_id = doc.get("_doc_id") + def process(self, batch: list) -> Tuple[list, dict]: + """ + Chunk the documents in the batch. + :return: A tuple of (results, meta_updates) + results: A list of chunked documents. Each chunked document is a dict with the structure: + {"_trace_id": str, "content": str, "type": str, "metadata": {"length": int, "language": str, ...} + meta_updates: A dict mapping source document IDs to lists of trace IDs for the chunked documents. + """ + results = [] + meta_updates = {} + for doc in batch: doc_type = doc.get("type") - if doc_type == "text": doc_language = detect_main_language(doc["content"]) text_chunks = split_chunks( @@ -72,32 +68,30 @@ def chunk_documents(self, new_docs: list) -> list: language=doc_language, **self.chunk_kwargs, ) - - chunks.extend( - [ - { - "_chunk_id": compute_content_hash( - chunk_text, prefix="chunk-" - ), - "content": chunk_text, - "type": "text", - "_doc_id": doc_id, - "length": len(self.tokenizer_instance.encode(chunk_text)) + for text_chunk in text_chunks: + chunk = { + "content": text_chunk, + "type": "text", + "metadata": { + "length": len(self.tokenizer_instance.encode(text_chunk)) if self.tokenizer_instance - else len(chunk_text), + else len(text_chunk), "language": doc_language, - } - for chunk_text in text_chunks - ] - ) + }, + } + chunk["_trace_id"] = self.get_trace_id(chunk) + results.append(chunk) + meta_updates.setdefault(doc["_trace_id"], []).append( + chunk["_trace_id"] + ) else: # other types of documents(images, sequences) are not chunked - chunks.append( - { - "_chunk_id": doc_id.replace("doc-", f"{doc_type}-"), - **doc, - } - ) - self.chunk_storage.upsert({chunk["_chunk_id"]: chunk for chunk in chunks}) - self.chunk_storage.index_done_callback() - return chunks + data = doc.copy() + input_trace_id = data.pop("_trace_id") + content = data.pop("content") if "content" in data else "" + doc_type = data.pop("type") + chunk = {"content": content, "type": doc_type, "metadata": data} + chunk["_trace_id"] = self.get_trace_id(chunk) + results.append(chunk) + meta_updates.setdefault(input_trace_id, []).append(chunk["_trace_id"]) + return results, meta_updates diff --git a/graphgen/operators/evaluate/evaluate_kg.py b/graphgen/operators/evaluate/evaluate_kg.py new file mode 100644 index 00000000..5d774431 --- /dev/null +++ b/graphgen/operators/evaluate/evaluate_kg.py @@ -0,0 +1,15 @@ +from typing import Any, Dict + +from graphgen.bases import BaseGraphStorage +from graphgen.utils import logger + + +def evaluate_kg( + kg_evaluators: Dict[str, Any], + kg_instance: BaseGraphStorage, +) -> Dict[str, Any]: + results = {} + for key, kg_evaluator in kg_evaluators.items(): + results[key] = kg_evaluator.evaluate(kg_instance) + logger.info(f"KG Evaluation result for {key}: {results[key]}") + return results diff --git a/graphgen/operators/evaluate/evaluate_qa.py b/graphgen/operators/evaluate/evaluate_qa.py new file mode 100644 index 00000000..9b736443 --- /dev/null +++ b/graphgen/operators/evaluate/evaluate_qa.py @@ -0,0 +1,107 @@ +from typing import Any + +from graphgen.bases import QAPair +from graphgen.utils import run_concurrent + + +def transform_to_qa_format( + items: list[dict], format_hint: str = "auto" +) -> list[dict[str, str]]: + extractors = { + "ChatML": lambda x: ( + next( + ( + m["content"] + for m in x.get("messages", []) + if m.get("role") == "user" + ), + "", + ), + next( + ( + m["content"] + for m in x.get("messages", []) + if m.get("role") == "assistant" + ), + "", + ), + ), + "Alpaca": lambda x: ( + f"{x.get('instruction', '')}\n\n{x['input']}".strip() + if x.get("input") + else x.get("instruction", ""), + x.get("output", ""), + ), + "Sharegpt": lambda x: ( + next( + ( + c["value"] + for c in x.get("conversations", []) + if c.get("from") == "human" + ), + "", + ), + next( + ( + c["value"] + for c in x.get("conversations", []) + if c.get("from") in ("gpt", "assistant") + ), + "", + ), + ), + } + + auto_detect = { + "messages": "ChatML", + "conversations": "Sharegpt", + "instruction": "Alpaca", + } + + transformed = [] + for item in items: + fmt = format_hint + if fmt == "auto": + fmt = next( + (fmt_name for key, fmt_name in auto_detect.items() if key in item), None + ) + if not fmt: + raise ValueError( + "Could not auto-detect format. Please specify format_hint." + ) + + question, answer = extractors[fmt](item) + options = None + if "\nOptions:\n" in question: + q_part, opt_part = question.split("\nOptions:\n", 1) + question = q_part + options = { + k.strip(): v.strip() + for line in opt_part.strip().split("\n") + if "." in line + for k, v in [line.split(".", 1)] + } + + result = {"question": question.strip(), "answer": answer.strip()} + if options: + result["options"] = options + transformed.append(result) + + return transformed + + +def evaluate_qa( + qa_evaluators: dict[str, Any], items: list[dict[str, Any]] +) -> dict[str, Any]: + items = transform_to_qa_format(items) + items = [QAPair.from_dict(item) for item in items] + + results = {} + for key, qa_evaluator in qa_evaluators.items(): + result = run_concurrent( + qa_evaluator.evaluate, + items, + desc=f"Evaluating QA with {key}", + ) + results[key] = result + return results diff --git a/graphgen/operators/evaluate/evaluate_service.py b/graphgen/operators/evaluate/evaluate_service.py index 80586c1d..7f69ef75 100644 --- a/graphgen/operators/evaluate/evaluate_service.py +++ b/graphgen/operators/evaluate/evaluate_service.py @@ -1,10 +1,12 @@ -from typing import Any, Dict +from typing import Tuple -import pandas as pd - -from graphgen.bases import BaseLLMWrapper, BaseOperator, QAPair +from graphgen.bases import BaseLLMWrapper, BaseOperator from graphgen.common import init_llm, init_storage -from graphgen.utils import logger, run_concurrent +from graphgen.utils import logger + +from .evaluate_kg import evaluate_kg +from .evaluate_qa import evaluate_qa +from .evaluate_triple import evaluate_triple class EvaluateService(BaseOperator): @@ -15,167 +17,135 @@ class EvaluateService(BaseOperator): def __init__( self, + target: str, + metrics: list[str], working_dir: str = "cache", - metrics: list[str] = None, graph_backend: str = "kuzu", kv_backend: str = "rocksdb", **kwargs, ): - super().__init__(working_dir=working_dir, op_name="evaluate_service") + super().__init__( + working_dir=working_dir, kv_backend=kv_backend, op_name="evaluate" + ) self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.metrics = metrics or [] self.kwargs = kwargs self.graph_storage = init_storage( backend=graph_backend, working_dir=working_dir, namespace="graph" ) - self.chunk_storage = init_storage( - backend=kv_backend, working_dir=working_dir, namespace="chunk" - ) # Initialize evaluators - self.qa_evaluators = {} - self.kg_evaluators = {} - self._init_evaluators() - - def _init_evaluators(self): - """Initialize QA and KG evaluators based on metrics.""" - for metric in self.metrics: - if metric == "qa_length": - from graphgen.models import LengthEvaluator - - self.qa_evaluators[metric] = LengthEvaluator() - elif metric == "qa_mtld": - from graphgen.models import MTLDEvaluator - - self.qa_evaluators[metric] = MTLDEvaluator( - **self.kwargs.get("mtld_params", {}) - ) - elif metric == "qa_reward_score": - from graphgen.models import RewardEvaluator - - self.qa_evaluators[metric] = RewardEvaluator( - **self.kwargs.get("reward_params", {}) - ) - elif metric == "qa_uni_score": - from graphgen.models import UniEvaluator - - self.qa_evaluators[metric] = UniEvaluator( - **self.kwargs.get("uni_params", {}) - ) - elif metric == "kg_accuracy": - from graphgen.models import AccuracyEvaluator - - self.kg_evaluators[metric] = AccuracyEvaluator( - graph_storage=self.graph_storage, - chunk_storage=self.chunk_storage, - llm_client=self.llm_client, - ) - elif metric == "kg_consistency": - from graphgen.models import ConsistencyEvaluator - - self.kg_evaluators[metric] = ConsistencyEvaluator( - graph_storage=self.graph_storage, - chunk_storage=self.chunk_storage, - llm_client=self.llm_client, - ) - elif metric == "kg_structure": - from graphgen.models import StructureEvaluator - - self.kg_evaluators[metric] = StructureEvaluator( - graph_storage=self.graph_storage, - **self.kwargs.get("structure_params", {}), - ) - else: - raise ValueError(f"Unknown QA metric: {metric}") - - async def _process_single_qa(self, item: dict[str, Any]) -> dict[str, Any]: - try: - qa_pair = QAPair( - question=str(item.get("question", "")), - answer=str(item.get("answer", "")), + self.target = target + self.src_storage = None + self.tgt_storage = None + self.evaluators = {} + self._init_evaluators(self.target, metrics) + + def _init_evaluators(self, target: str, metrics: list[str]): + """Initialize evaluators based on target and metrics.""" + if target not in {"qa", "kg", "triple"}: + raise ValueError(f"Unknown evaluation target: {target}") + + # Delegate to target-specific initializer + getattr(self, f"_init_{target}_evaluators")(metrics) + + def _init_qa_evaluators(self, metrics: list[str]): + """Initialize QA evaluators.""" + for metric in metrics: + self.evaluators[metric] = self._create_qa_evaluator(metric) + + def _create_qa_evaluator(self, metric: str): + """Factory method for QA evaluator instances.""" + if metric == "length": + from graphgen.models import LengthEvaluator + + return LengthEvaluator() + if metric == "mtld": + from graphgen.models import MTLDEvaluator + + return MTLDEvaluator(**self.kwargs.get("mtld_params", {})) + if metric == "reward_score": + from graphgen.models import RewardEvaluator + + return RewardEvaluator(**self.kwargs.get("reward_params", {})) + if metric == "uni_score": + from graphgen.models import UniEvaluator + + return UniEvaluator(**self.kwargs.get("uni_params", {})) + raise ValueError(f"Unknown QA metric: {metric}") + + def _init_kg_evaluators(self, metrics: list[str]): + """Initialize KG evaluators.""" + for metric in metrics: + if metric != "structure": + raise ValueError(f"Unknown KG metric: {metric}") + from graphgen.models import StructureEvaluator + + self.evaluators[metric] = StructureEvaluator( + **self.kwargs.get("structure_params", {}) ) - if not qa_pair.question or not qa_pair.answer: - logger.error("Empty question or answer, skipping.") - return {} - except Exception as e: - logger.error("Error in QAPair creation: %s", str(e)) - return {} - - for metric, evaluator in self.qa_evaluators.items(): - try: - score = evaluator.evaluate(qa_pair) - if isinstance(score, dict): - for sub_metric, sub_score in score.items(): - item[f"{metric}_{sub_metric}"] = float(sub_score) - else: - item[metric] = float(score) - except Exception as e: - logger.error("Error in %s evaluation: %s", metric, str(e)) - item[metric] = None - return item - - def _evaluate_qa(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]: - def transform_messages_format(items: list[dict]) -> list[dict]: - """ - Transform from [{'messages': [...]}, ...] to [{'question': '...', 'answer': '...'}, ...] - """ - transformed = [] - for item in items: - messages = item.get("messages", []) - question = next( - (m["content"] for m in messages if m.get("role") == "user"), "" - ) - answer = next( - (m["content"] for m in messages if m.get("role") == "assistant"), "" - ) - - transformed.append({"question": question, "answer": answer}) - return transformed - - if not items: - return [] - - if not self.qa_evaluators: - logger.warning("No QA evaluators initialized, skipping QA evaluation") - return [] - - items = transform_messages_format(items) - results = run_concurrent( - self._process_single_qa, - items, - desc="Evaluating QA items", - unit="item", + + def _init_triple_evaluators(self, metrics: list[str]): + """Initialize Triple evaluators.""" + self.src_storage = init_storage( + backend=self.kv_backend, + working_dir=self.working_dir, + namespace=self.kwargs["src_namespace"], + ) + self.tgt_storage = init_storage( + backend=self.kv_backend, + working_dir=self.working_dir, + namespace=self.kwargs["tgt_namespace"], ) - results = [item for item in results if item] - return results - - def _evaluate_kg(self) -> Dict[str, Any]: - results = {} - - for metric, evaluator in self.kg_evaluators.items(): - try: - logger.info("Running %s evaluation...", metric) - score = evaluator.evaluate() - results[metric] = score - except Exception as e: - logger.error("Error in %s evaluation: %s", metric, str(e)) - results[metric] = {"error": str(e)} - return results - - def process(self, batch: pd.DataFrame) -> pd.DataFrame: - # QA evaluation - if len(self.qa_evaluators) > 0: - items = batch.to_dict(orient="records") - results = self._evaluate_qa(items) - return pd.DataFrame(results) - - # KG evaluation - if len(self.kg_evaluators) > 0: - results = self._evaluate_kg() - # Convert dict to DataFrame (single row) - return pd.DataFrame([results]) + for metric in metrics: + if metric != "accuracy": + raise ValueError(f"Unknown Triple metric: {metric}") + from graphgen.models import AccuracyEvaluator + + self.evaluators[metric] = AccuracyEvaluator(llm_client=self.llm_client) + + def process(self, batch: list) -> Tuple[list, dict]: + final_results = [] + meta_updates = {} + + # 1. QA Evaluation (per item) + if self.target == "qa" and self.evaluators: + results: dict = evaluate_qa(self.evaluators, batch) + for i, item in enumerate(batch): + metrics = {} + for _, scores in results.items(): + metrics.update(scores[i]) + item.update({"metrics": metrics}) + input_trace_id = item.pop("_trace_id") + item["_trace_id"] = self.get_trace_id(item) + final_results.append(item) + meta_updates.setdefault(input_trace_id, []).append(item["_trace_id"]) + + return final_results, meta_updates + + # 2. KG evaluation + if self.target == "kg" and self.evaluators: + results = evaluate_kg( + self.evaluators, + self.graph_storage, + ) + if not results: + logger.warning("No KG evaluation results, returning empty DataFrame") + return [], {} + results["_trace_id"] = self.get_trace_id(results) + final_results.append(results) + return final_results, {} + + # 3. Triple evaluation + if self.target == "triple" and self.evaluators: + results = evaluate_triple( + self.evaluators, self.src_storage, self.tgt_storage + ) + results["_trace_id"] = "evaluate-triple-result" + final_results.append(results) + return final_results, {} # No metrics specified logger.warning("No metrics specified, returning empty DataFrame") - return pd.DataFrame() + return [], {} diff --git a/graphgen/operators/evaluate/evaluate_triple.py b/graphgen/operators/evaluate/evaluate_triple.py new file mode 100644 index 00000000..18f25327 --- /dev/null +++ b/graphgen/operators/evaluate/evaluate_triple.py @@ -0,0 +1,39 @@ +from typing import Any + +from graphgen.bases import BaseKVStorage +from graphgen.utils import logger, run_concurrent + + +def evaluate_triple( + triple_evaluators: dict[str, Any], + src_storage: BaseKVStorage, + tgt_storage: BaseKVStorage, +) -> dict[str, Any]: + forward_meta = tgt_storage.get_by_id("_meta_forward") + + tasks = [] + for chunk_id, unit_ids in forward_meta.items(): + chunk_content = str(src_storage.get_by_id(chunk_id)) + + nodes = [] + edges = [] + + for unit_id in unit_ids: + unit_data = tgt_storage.get_by_id(unit_id) + if "node" in unit_data and unit_data["node"]: + nodes.append(unit_data["node"]) + if "edge" in unit_data and unit_data["edge"]: + edges.append(unit_data["edge"]) + + tasks.append((chunk_content, nodes, edges)) + + results = {} + for key, triple_evaluator in triple_evaluators.items(): + logger.info(f"Evaluating Triples with metric: {key}...") + result = run_concurrent( + triple_evaluator.evaluate, + tasks, + desc=f"Evaluating Triples with {key}", + ) + results[key] = result + return results diff --git a/graphgen/operators/extract/extract_service.py b/graphgen/operators/extract/extract_service.py index 33987fcb..a0da1235 100644 --- a/graphgen/operators/extract/extract_service.py +++ b/graphgen/operators/extract/extract_service.py @@ -1,16 +1,19 @@ import json +from typing import Tuple -import pandas as pd - -from graphgen.bases import BaseLLMWrapper, BaseOperator +from graphgen.bases import BaseLLMWrapper, BaseOperator, Chunk from graphgen.common import init_llm from graphgen.models.extractor import SchemaGuidedExtractor from graphgen.utils import logger, run_concurrent class ExtractService(BaseOperator): - def __init__(self, working_dir: str = "cache", **extract_kwargs): - super().__init__(working_dir=working_dir, op_name="extract_service") + def __init__( + self, working_dir: str = "cache", kv_backend: str = "rocksdb", **extract_kwargs + ): + super().__init__( + working_dir=working_dir, kv_backend=kv_backend, op_name="extract" + ) self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.extract_kwargs = extract_kwargs self.method = self.extract_kwargs.get("method") @@ -22,24 +25,32 @@ def __init__(self, working_dir: str = "cache", **extract_kwargs): else: raise ValueError(f"Unsupported extraction method: {self.method}") - def process(self, batch: pd.DataFrame) -> pd.DataFrame: - items = batch.to_dict(orient="records") - return pd.DataFrame(self.extract(items)) - - def extract(self, items: list[dict]) -> list[dict]: - - logger.info("Start extracting information from %d items", len(items)) - + def process(self, batch: list) -> Tuple[list, dict]: + """ + Extract information from the batch of chunks. + :return: A tuple of (results, meta_updates) + results: A list of dicts containing extracted information. Each dict has the structure: + {"_trace_id": str, "content": dict} + meta_updates: A dict mapping source IDs to lists of trace IDs for the extracted information. + """ + logger.info("Start extracting information from %d items", len(batch)) + chunks = [Chunk.from_dict(item["_trace_id"], item) for item in batch] results = run_concurrent( self.extractor.extract, - items, + chunks, desc="Extracting information", unit="item", ) - results = self.extractor.merge_extractions(results) - results = [ - {"_extract_id": key, "extracted_data": value} - for key, value in results.items() - ] - return results + meta_updates = {} + final_results = [] + # chunk -> extracted info + for input_trace_id, result in zip( + [item["_trace_id"] for item in batch], results + ): + if not result: + continue + result = {"_trace_id": self.get_trace_id(result), "content": result} + meta_updates.setdefault(input_trace_id, []).append(result["_trace_id"]) + final_results.append(result) + return final_results, meta_updates diff --git a/graphgen/operators/generate/generate_service.py b/graphgen/operators/generate/generate_service.py index 2107876b..0fed097b 100644 --- a/graphgen/operators/generate/generate_service.py +++ b/graphgen/operators/generate/generate_service.py @@ -1,9 +1,6 @@ -import json - -import pandas as pd - -from graphgen.bases import BaseLLMWrapper, BaseOperator -from graphgen.common import init_llm +from typing import Tuple +from graphgen.bases import BaseKVStorage, BaseLLMWrapper, BaseOperator +from graphgen.common import init_llm, init_storage from graphgen.utils import logger, run_concurrent @@ -15,12 +12,18 @@ class GenerateService(BaseOperator): def __init__( self, working_dir: str = "cache", + kv_backend: str = "rocksdb", method: str = "aggregated", data_format: str = "ChatML", **generate_kwargs, ): - super().__init__(working_dir=working_dir, op_name="generate_service") + super().__init__( + working_dir=working_dir, kv_backend=kv_backend, op_name="generate" + ) self.llm_client: BaseLLMWrapper = init_llm("synthesizer") + self.generate_storage: BaseKVStorage = init_storage( + backend=kv_backend, working_dir=working_dir, namespace="generate" + ) self.method = method self.data_format = data_format @@ -76,32 +79,31 @@ def __init__( else: raise ValueError(f"Unsupported generation mode: {method}") - def process(self, batch: pd.DataFrame) -> pd.DataFrame: - items = batch.to_dict(orient="records") - return pd.DataFrame(self.generate(items)) - - def generate(self, items: list[dict]) -> list[dict]: + def process(self, batch: list) -> Tuple[list, dict]: """ Generate question-answer pairs based on nodes and edges. - :param items - :return: QA pairs """ - logger.info("[Generation] mode: %s, batches: %d", self.method, len(items)) - items = [ - (json.loads(item["nodes"]), json.loads(item["edges"])) for item in items - ] + logger.info("[Generation] mode: %s, batches: %d", self.method, len(batch)) + triples = [(item["nodes"], item["edges"]) for item in batch] results = run_concurrent( self.generator.generate, - items, - desc="[4/4]Generating QAs", + triples, + desc="Generating QAs", unit="batch", ) - # Filter out empty results - results = [res for res in results if res] - - results = self.generator.format_generation_results( - results, output_data_format=self.data_format - ) - - return results + meta_updates = {} + final_results = [] + for input_trace_id, qa_pairs in zip( + [item["_trace_id"] for item in batch], results + ): + if not qa_pairs: + continue + for qa_pair in qa_pairs: + res = self.generator.format_generation_results( + qa_pair, output_data_format=self.data_format + ) + res["_trace_id"] = self.get_trace_id(res) + final_results.append(res) + meta_updates.setdefault(input_trace_id, []).append(res["_trace_id"]) + return final_results, meta_updates diff --git a/graphgen/operators/judge/judge_service.py b/graphgen/operators/judge/judge_service.py index c7693aec..c4ccd994 100644 --- a/graphgen/operators/judge/judge_service.py +++ b/graphgen/operators/judge/judge_service.py @@ -1,7 +1,6 @@ +from typing import Tuple import math -import pandas as pd - from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator from graphgen.common import init_llm, init_storage from graphgen.templates import STATEMENT_JUDGEMENT_PROMPT @@ -11,8 +10,15 @@ class JudgeService(BaseOperator): """Service for judging graph edges and nodes using a trainee LLM.""" - def __init__(self, working_dir: str = "cache", graph_backend: str = "kuzu"): - super().__init__(working_dir=working_dir, op_name="judge_service") + def __init__( + self, + working_dir: str = "cache", + kv_backend: str = "rocksdb", + graph_backend: str = "kuzu", + ): + super().__init__( + working_dir=working_dir, kv_backend=kv_backend, op_name="judge" + ) self.llm_client: BaseLLMWrapper = init_llm("trainee") self.graph_storage: BaseGraphStorage = init_storage( backend=graph_backend, @@ -20,12 +26,6 @@ def __init__(self, working_dir: str = "cache", graph_backend: str = "kuzu"): namespace="graph", ) - def process(self, batch: pd.DataFrame) -> pd.DataFrame: - items = batch.to_dict(orient="records") - self.graph_storage.reload() - self.judge(items) - return pd.DataFrame([{"status": "judging_completed"}]) - async def _process_single_judge(self, item: dict) -> dict: description = item["description"] try: @@ -43,20 +43,29 @@ async def _process_single_judge(self, item: dict) -> dict: item["loss"] = -math.log(0.1) return item - def judge(self, items: list[dict]) -> None: + def process(self, batch: list) -> Tuple[list, dict]: """ Judge the description in the item and compute the loss. """ + self.graph_storage.reload() + results = run_concurrent( self._process_single_judge, - items, + batch, desc="Judging descriptions", unit="description", ) - # Update the graph storage with the computed losses - for item in results: - index = item["index"] - loss = item["loss"] + + to_store = [] + meta_update = {} + + for input_trace_id, result in zip( + [item["_trace_id"] for item in batch], results + ): + if not result: + continue + index = result["index"] + loss = result["loss"] if isinstance(index, str): node_id = index node_data = self.graph_storage.get_node(node_id) @@ -67,4 +76,10 @@ def judge(self, items: list[dict]) -> None: edge_data = self.graph_storage.get_edge(edge_source, edge_target) edge_data["loss"] = loss self.graph_storage.update_edge(edge_source, edge_target, edge_data) + + result["_trace_id"] = self.get_trace_id(result) + to_store.append(result) + meta_update.setdefault(input_trace_id, []).append(result["_trace_id"]) self.graph_storage.index_done_callback() + + return results, meta_update diff --git a/graphgen/operators/partition/partition_service.py b/graphgen/operators/partition/partition_service.py index 6622e411..5b6fae3d 100644 --- a/graphgen/operators/partition/partition_service.py +++ b/graphgen/operators/partition/partition_service.py @@ -1,10 +1,7 @@ -import json import os -from typing import Iterable +from typing import Iterable, Tuple -import pandas as pd - -from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseOperator, BaseTokenizer +from graphgen.bases import BaseGraphStorage, BaseOperator, BaseTokenizer from graphgen.common import init_storage from graphgen.models import ( AnchorBFSPartitioner, @@ -21,112 +18,64 @@ class PartitionService(BaseOperator): def __init__( self, working_dir: str = "cache", - graph_backend: str = "kuzu", kv_backend: str = "rocksdb", + graph_backend: str = "kuzu", **partition_kwargs, ): - super().__init__(working_dir=working_dir, op_name="partition_service") + super().__init__( + working_dir=working_dir, kv_backend=kv_backend, op_name="partition" + ) self.kg_instance: BaseGraphStorage = init_storage( backend=graph_backend, working_dir=working_dir, namespace="graph", ) - self.chunk_storage: BaseKVStorage = init_storage( - backend=kv_backend, - working_dir=working_dir, - namespace="chunk", - ) tokenizer_model = os.getenv("TOKENIZER_MODEL", "cl100k_base") self.tokenizer_instance: BaseTokenizer = Tokenizer(model_name=tokenizer_model) - self.partition_kwargs = partition_kwargs - - def process(self, batch: pd.DataFrame) -> Iterable[pd.DataFrame]: - # this operator does not consume any batch data - # but for compatibility we keep the interface - _ = batch.to_dict(orient="records") - self.kg_instance.reload() - self.chunk_storage.reload() + method = partition_kwargs["method"] + self.method_params = partition_kwargs["method_params"] - yield from self.partition() - - def partition(self) -> Iterable[pd.DataFrame]: - method = self.partition_kwargs["method"] - method_params = self.partition_kwargs["method_params"] if method == "bfs": - logger.info("Partitioning knowledge graph using BFS method.") - partitioner = BFSPartitioner() + self.partitioner = BFSPartitioner() elif method == "dfs": - logger.info("Partitioning knowledge graph using DFS method.") - partitioner = DFSPartitioner() + self.partitioner = DFSPartitioner() elif method == "ece": - logger.info("Partitioning knowledge graph using ECE method.") # before ECE partitioning, we need to: # 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random - partitioner = ECEPartitioner() + self.partitioner = ECEPartitioner() elif method == "leiden": - logger.info("Partitioning knowledge graph using Leiden method.") - partitioner = LeidenPartitioner() + self.partitioner = LeidenPartitioner() elif method == "anchor_bfs": - logger.info("Partitioning knowledge graph using Anchor BFS method.") - partitioner = AnchorBFSPartitioner( - anchor_type=method_params.get("anchor_type"), - anchor_ids=set(method_params.get("anchor_ids", [])) - if method_params.get("anchor_ids") + self.partitioner = AnchorBFSPartitioner( + anchor_type=self.method_params.get("anchor_type"), + anchor_ids=set(self.method_params.get("anchor_ids", [])) + if self.method_params.get("anchor_ids") else None, ) else: raise ValueError(f"Unsupported partition method: {method}") - communities: Iterable = partitioner.partition( - g=self.kg_instance, **method_params - ) - - count = 0 - for community in communities: - count += 1 - batch = partitioner.community2batch(community, g=self.kg_instance) - batch = self._attach_additional_data_to_node(batch) - - yield pd.DataFrame( - { - "nodes": json.dumps(batch[0]), - "edges": json.dumps(batch[1]), - }, - index=[0], - ) - logger.info("Total communities partitioned: %d", count) - - def _attach_additional_data_to_node(self, batch: tuple) -> tuple: - """ - Attach additional data from chunk_storage to nodes in the batch. - :param batch: tuple of (nodes_data, edges_data) - :return: updated batch with additional data attached to nodes - """ - nodes_data, edges_data = batch + def process(self, batch: list) -> Tuple[Iterable[dict], dict]: + # this operator does not consume any batch data + # but for compatibility we keep the interface + self.kg_instance.reload() - for node_id, node_data in nodes_data: - entity_type = (node_data.get("entity_type") or "").lower() - if not entity_type: - continue + communities: Iterable = self.partitioner.partition( + g=self.kg_instance, **self.method_params + ) - source_ids = [ - sid.strip() - for sid in node_data.get("source_id", "").split("") - if sid.strip() - ] + def generator(): + count = 0 + for community in communities: + count += 1 + b = self.partitioner.community2batch(community, g=self.kg_instance) - # Handle images - if "image" in entity_type: - image_chunks = [ - data - for sid in source_ids - if "image" in sid.lower() - and (data := self.chunk_storage.get_by_id(sid)) - ] - if image_chunks: - # The generator expects a dictionary with an 'img_path' key, not a list of captions. - # We'll use the first image chunk found for this node. - node_data["image_data"] = json.loads(image_chunks[0]["content"]) - logger.debug("Attached image data to node %s", node_id) + result = { + "nodes": b[0], + "edges": b[1], + } + result["_trace_id"] = self.get_trace_id(result) + yield result + logger.info("Total communities partitioned: %d", count) - return nodes_data, edges_data + return generator(), {} diff --git a/graphgen/operators/quiz/quiz_service.py b/graphgen/operators/quiz/quiz_service.py index cabd4f77..7b19d1ae 100644 --- a/graphgen/operators/quiz/quiz_service.py +++ b/graphgen/operators/quiz/quiz_service.py @@ -1,9 +1,9 @@ -import pandas as pd +from typing import Tuple -from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper, BaseOperator +from graphgen.bases import BaseGraphStorage, BaseLLMWrapper, BaseOperator from graphgen.common import init_llm, init_storage from graphgen.models import QuizGenerator -from graphgen.utils import compute_dict_hash, logger, run_concurrent +from graphgen.utils import logger, run_concurrent class QuizService(BaseOperator): @@ -14,27 +14,18 @@ def __init__( kv_backend: str = "rocksdb", quiz_samples: int = 1, ): - super().__init__(working_dir=working_dir, op_name="quiz_service") + super().__init__(working_dir=working_dir, kv_backend=kv_backend, op_name="quiz") self.quiz_samples = quiz_samples self.llm_client: BaseLLMWrapper = init_llm("synthesizer") self.graph_storage: BaseGraphStorage = init_storage( backend=graph_backend, working_dir=working_dir, namespace="graph" ) - # { _quiz_id: { "description": str, "quizzes": List[Tuple[str, str]] } } - self.quiz_storage: BaseKVStorage = init_storage( - backend=kv_backend, working_dir=working_dir, namespace="quiz" - ) + # { _trace_id: { "description": str, "quizzes": List[Tuple[str, str]] } } self.generator = QuizGenerator(self.llm_client) - def process(self, batch: pd.DataFrame) -> pd.DataFrame: - data = batch.to_dict(orient="records") - self.graph_storage.reload() - return self.quiz(data) - async def _process_single_quiz(self, item: tuple) -> dict | None: # if quiz in quiz_storage exists already, directly get it - index, desc = item - _quiz_id = compute_dict_hash({"index": index, "description": desc}) + desc, index = item tasks = [] for i in range(self.quiz_samples): @@ -51,52 +42,49 @@ async def _process_single_quiz(self, item: tuple) -> dict | None: rephrased_text = self.generator.parse_rephrased_text(new_description) quizzes.append((rephrased_text, gt)) return { - "_quiz_id": _quiz_id, - "description": desc, "index": index, + "description": desc, "quizzes": quizzes, } except Exception as e: logger.error("Error when quizzing description %s: %s", item, e) return None - def quiz(self, batch) -> pd.DataFrame: + def process(self, batch: list) -> Tuple[list, dict]: """ Get all nodes and edges and quiz their descriptions using QuizGenerator. """ items = [] for item in batch: - node_data = item.get("node", []) - edge_data = item.get("edge", []) + input_id = item["_trace_id"] + node = item.get("node") + edge = item.get("edge") - if node_data: - node_id = node_data["entity_name"] - desc = node_data["description"] - items.append((node_id, desc)) - if edge_data: - edge_key = (edge_data["src_id"], edge_data["tgt_id"]) - desc = edge_data["description"] - items.append((edge_key, desc)) + if node and node.get("description"): + items.append((input_id, node["description"], node["entity_name"])) + elif edge and edge.get("description"): + edge_key = (edge["src_id"], edge["tgt_id"]) + items.append((input_id, edge["description"], edge_key)) + if not items: + return [], {} logger.info("Total descriptions to quiz: %d", len(items)) - results = run_concurrent( self._process_single_quiz, - items, + [(desc, orig_id) for (_, desc, orig_id) in items], desc=f"Quizzing batch of {len(items)} descriptions", unit="description", ) - valid_results = [res for res in results if res] - for res in valid_results: - self.quiz_storage.upsert( - { - res["_quiz_id"]: { - "description": res["description"], - "quizzes": res["quizzes"], - } - } - ) - self.quiz_storage.index_done_callback() - return pd.DataFrame(valid_results) + final_results = [] + meta_update = {} + + for (input_id, _, _), quiz_data in zip(items, results): + if quiz_data is None: + continue + quiz_data["_trace_id"] = self.get_trace_id(quiz_data) + final_results.append(quiz_data) + meta_update[input_id] = [quiz_data["_trace_id"]] + + return final_results, meta_update diff --git a/graphgen/operators/read/parallel_file_scanner.py b/graphgen/operators/read/parallel_file_scanner.py index 84219139..d8e9eb0a 100644 --- a/graphgen/operators/read/parallel_file_scanner.py +++ b/graphgen/operators/read/parallel_file_scanner.py @@ -2,17 +2,22 @@ import time from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path -from typing import Any, Dict, List, Set, Union +from typing import Any, Dict, List, Optional, Set, Union -from graphgen.models import RocksDBCache +from graphgen.bases import BaseKVStorage +from graphgen.utils import compute_content_hash, logger class ParallelFileScanner: def __init__( - self, cache_dir: str, allowed_suffix, rescan: bool = False, max_workers: int = 4 + self, + input_path_cache: BaseKVStorage, + allowed_suffix: Optional[List[str]] = None, + rescan: bool = False, + max_workers: int = 4, ): - self.cache = RocksDBCache(os.path.join(cache_dir, "input_paths.db")) - self.allowed_suffix = set(allowed_suffix) if allowed_suffix else None + self.cache = input_path_cache + self.allowed_suffix = set(allowed_suffix) if allowed_suffix else set() self.rescan = rescan self.max_workers = max_workers @@ -55,8 +60,10 @@ def _scan_files( return self._empty_result(path_str) # cache check - cache_key = f"scan::{path_str}::recursive::{recursive}" - cached = self.cache.get(cache_key) + cache_key = compute_content_hash( + f"scan::{path_str}::recursive::{recursive}", prefix="path-" + ) + cached = self.cache.get_by_id(cache_key) if cached and not self.rescan: return cached["data"] @@ -66,7 +73,9 @@ def _scan_files( try: path_stat = path.stat() if path.is_file(): - return self._scan_single_file(path, path_str, path_stat) + result = self._scan_single_file(path, path_str, path_stat) + self._cache_result(cache_key, result, path) + return result if path.is_dir(): with os.scandir(path_str) as entries: for entry in entries: @@ -113,6 +122,12 @@ def _scan_files( stats["file_count"] += sub_data["stats"].get("file_count", 0) result = {"path": path_str, "files": files, "dirs": dirs, "stats": stats} + logger.debug( + "Scanned %s: %d files, %d dirs", + path_str, + stats["file_count"], + stats["dir_count"], + ) self._cache_result(cache_key, result, path) return result @@ -174,31 +189,26 @@ def _scan_subdirs(self, dir_list: List[Dict], visited: Set[str]) -> Dict[str, An def _cache_result(self, key: str, result: Dict, path: Path): """Cache the scan result""" - self.cache.set( - key, + self.cache.upsert( { - "data": result, - "dir_mtime": path.stat().st_mtime, - "cached_at": time.time(), - }, + key: { + "data": result, + "dir_mtime": path.stat().st_mtime, + "cached_at": time.time(), + }, + } ) def _is_allowed_file(self, path: Path) -> bool: """Check if the file has an allowed suffix""" - if self.allowed_suffix is None: + if not self.allowed_suffix or len(self.allowed_suffix) == 0: return True suffix = path.suffix.lower().lstrip(".") return suffix in self.allowed_suffix - def invalidate(self, path: str): - """Invalidate cache for a specific path""" - path = Path(path).resolve() - keys = [k for k in self.cache if k.startswith(f"scan::{path}")] - for k in keys: - self.cache.delete(k) - def close(self): - self.cache.close() + self.cache.index_done_callback() + del self.cache def __enter__(self): return self diff --git a/graphgen/operators/read/read.py b/graphgen/operators/read/read.py index 3ff60c15..c13a605e 100644 --- a/graphgen/operators/read/read.py +++ b/graphgen/operators/read/read.py @@ -3,6 +3,7 @@ import ray +from graphgen.common import init_storage from graphgen.models import ( CSVReader, JSONReader, @@ -12,7 +13,7 @@ RDFReader, TXTReader, ) -from graphgen.utils import compute_mm_hash, logger +from graphgen.utils import compute_dict_hash, logger from .parallel_file_scanner import ParallelFileScanner @@ -51,6 +52,7 @@ def read( input_path: Union[str, List[str]], allowed_suffix: Optional[List[str]] = None, working_dir: Optional[str] = "cache", + kv_backend: str = "rocksdb", parallelism: int = 4, recursive: bool = True, read_nums: Optional[int] = None, @@ -62,71 +64,86 @@ def read( :param input_path: File or directory path(s) to read from :param allowed_suffix: List of allowed file suffixes (e.g., ['pdf', 'txt']) :param working_dir: Directory to cache intermediate files (PDF processing) + :param kv_backend: Backend for key-value storage :param parallelism: Number of parallel workers :param recursive: Whether to scan directories recursively :param read_nums: Limit the number of documents to read :param reader_kwargs: Additional kwargs passed to readers :return: Ray Dataset containing all documents """ + input_path_cache = init_storage( + backend=kv_backend, working_dir=working_dir, namespace="input_path" + ) + read_storage = init_storage( + backend=kv_backend, working_dir=working_dir, namespace="read" + ) try: # 1. Scan all paths to discover files logger.info("[READ] Scanning paths: %s", input_path) - scanner = ParallelFileScanner( - cache_dir=working_dir, + with ParallelFileScanner( + input_path_cache=input_path_cache, allowed_suffix=allowed_suffix, rescan=False, max_workers=parallelism if parallelism > 0 else 1, - ) - - all_files = [] - scan_results = scanner.scan(input_path, recursive=recursive) - - for result in scan_results.values(): - all_files.extend(result.get("files", [])) - - logger.info("[READ] Found %d files to process", len(all_files)) - - if not all_files: - raise ValueError("No files found to read.") - - # 2. Group files by suffix to use appropriate reader - files_by_suffix = {} - for file_info in all_files: - suffix = Path(file_info["path"]).suffix.lower().lstrip(".") - if allowed_suffix and suffix not in [ - s.lower().lstrip(".") for s in allowed_suffix - ]: - continue - files_by_suffix.setdefault(suffix, []).append(file_info["path"]) - - # 3. Create read tasks - read_tasks = [] - for suffix, file_paths in files_by_suffix.items(): - reader = _build_reader(suffix, working_dir, **reader_kwargs) - ds = reader.read(file_paths) - read_tasks.append(ds) - - # 4. Combine all datasets - if not read_tasks: - raise ValueError("No datasets created from the provided files.") - - if len(read_tasks) == 1: - combined_ds = read_tasks[0] - else: - combined_ds = read_tasks[0].union(*read_tasks[1:]) - - combined_ds = combined_ds.map( - lambda record: { - **record, - "_doc_id": compute_mm_hash(record, prefix="doc-"), - } - ) - - if read_nums is not None: - combined_ds = combined_ds.limit(read_nums) - - logger.info("[READ] Successfully read files from %s", input_path) - return combined_ds + ) as scanner: + all_files = [] + scan_results = scanner.scan(input_path, recursive=recursive) + + for result in scan_results.values(): + all_files.extend(result.get("files", [])) + + logger.info("[READ] Found %d files to process", len(all_files)) + + if not all_files: + raise ValueError("No files found to read.") + + # 2. Group files by suffix to use appropriate reader + files_by_suffix = {} + for file_info in all_files: + suffix = Path(file_info["path"]).suffix.lower().lstrip(".") + if allowed_suffix and suffix not in [ + s.lower().lstrip(".") for s in allowed_suffix + ]: + continue + files_by_suffix.setdefault(suffix, []).append(file_info["path"]) + + # 3. Create read tasks + read_tasks = [] + for suffix, file_paths in files_by_suffix.items(): + reader = _build_reader(suffix, working_dir, **reader_kwargs) + ds = reader.read(file_paths) + read_tasks.append(ds) + + # 4. Combine all datasets + if not read_tasks: + raise ValueError("No datasets created from the provided files.") + + if len(read_tasks) == 1: + combined_ds = read_tasks[0] + else: + combined_ds = read_tasks[0].union(*read_tasks[1:]) + + if read_nums is not None: + combined_ds = combined_ds.limit(read_nums) + + def add_trace_id(batch): + batch["_trace_id"] = batch.apply( + lambda row: compute_dict_hash(row, prefix="read-"), axis=1 + ) + records = batch.to_dict(orient="records") + data_to_upsert = {record["_trace_id"]: record for record in records} + read_storage.upsert(data_to_upsert) + read_storage.index_done_callback() + return batch + + combined_ds = combined_ds.map_batches(add_trace_id, batch_format="pandas") + + # sample record + for i, item in enumerate(combined_ds.take(1)): + logger.debug("[READ] Sample record %d: %s", i, item) + + logger.info("[READ] Successfully read files from %s", input_path) + return combined_ds except Exception as e: logger.error("[READ] Failed to read files from %s: %s", input_path, e) diff --git a/graphgen/storage/__init__.py b/graphgen/storage/__init__.py new file mode 100644 index 00000000..14250051 --- /dev/null +++ b/graphgen/storage/__init__.py @@ -0,0 +1,4 @@ +from .graph.kuzu_storage import KuzuStorage +from .graph.networkx_storage import NetworkXStorage +from .kv.json_storage import JsonKVStorage +from .kv.rocksdb_storage import RocksDBKVStorage diff --git a/graphgen/models/storage/graph/__init__.py b/graphgen/storage/graph/__init__.py similarity index 100% rename from graphgen/models/storage/graph/__init__.py rename to graphgen/storage/graph/__init__.py diff --git a/graphgen/models/storage/graph/kuzu_storage.py b/graphgen/storage/graph/kuzu_storage.py similarity index 100% rename from graphgen/models/storage/graph/kuzu_storage.py rename to graphgen/storage/graph/kuzu_storage.py diff --git a/graphgen/models/storage/graph/networkx_storage.py b/graphgen/storage/graph/networkx_storage.py similarity index 100% rename from graphgen/models/storage/graph/networkx_storage.py rename to graphgen/storage/graph/networkx_storage.py diff --git a/graphgen/models/storage/kv/__init__.py b/graphgen/storage/kv/__init__.py similarity index 100% rename from graphgen/models/storage/kv/__init__.py rename to graphgen/storage/kv/__init__.py diff --git a/graphgen/models/storage/kv/json_storage.py b/graphgen/storage/kv/json_storage.py similarity index 84% rename from graphgen/models/storage/kv/json_storage.py rename to graphgen/storage/kv/json_storage.py index aa7c6f42..cf7dbd7f 100644 --- a/graphgen/models/storage/kv/json_storage.py +++ b/graphgen/storage/kv/json_storage.py @@ -1,7 +1,7 @@ import os from dataclasses import dataclass -from graphgen.bases.base_storage import BaseKVStorage +from graphgen.bases.base_storage import BaseKVStorage, T from graphgen.utils import load_json, write_json @@ -51,6 +51,15 @@ def upsert(self, data: dict): self._data.update(left_data) return left_data + def update(self, data: dict[str, T]): + for k, v in data.items(): + self._data[k] = v + + def delete(self, ids: list[str]): + for _id in ids: + if _id in self._data: + del self._data[_id] + def drop(self): if self._data: self._data.clear() diff --git a/graphgen/models/storage/kv/rocksdb_storage.py b/graphgen/storage/kv/rocksdb_storage.py similarity index 90% rename from graphgen/models/storage/kv/rocksdb_storage.py rename to graphgen/storage/kv/rocksdb_storage.py index 45055b93..d1361169 100644 --- a/graphgen/models/storage/kv/rocksdb_storage.py +++ b/graphgen/storage/kv/rocksdb_storage.py @@ -68,6 +68,15 @@ def upsert(self, data: Dict[str, Any]): return left_data + def update(self, data: Dict[str, Any]): + for k, v in data.items(): + self._db[k] = v + + def delete(self, ids: List[str]): + for _id in ids: + if _id in self._db: + del self._db[_id] + def drop(self): self._db.close() Rdict.destroy(self._db_path) diff --git a/graphgen/templates/__init__.py b/graphgen/templates/__init__.py index 72ab9446..a83d81b8 100644 --- a/graphgen/templates/__init__.py +++ b/graphgen/templates/__init__.py @@ -1,6 +1,6 @@ from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT -from .evaluation import ACCURACY_EVALUATION_PROMPT, CONSISTENCY_EVALUATION_PROMPT +from .evaluation import ACCURACY_EVALUATION_PROMPT from .extraction import SCHEMA_GUIDED_EXTRACTION_PROMPT from .generation import ( AGGREGATED_GENERATION_PROMPT, diff --git a/graphgen/templates/evaluation/__init__.py b/graphgen/templates/evaluation/__init__.py index 7c2676a5..93761e85 100644 --- a/graphgen/templates/evaluation/__init__.py +++ b/graphgen/templates/evaluation/__init__.py @@ -1 +1 @@ -from .kg import ACCURACY_EVALUATION_PROMPT, CONSISTENCY_EVALUATION_PROMPT +from .kg import ACCURACY_EVALUATION_PROMPT diff --git a/graphgen/templates/evaluation/kg/__init__.py b/graphgen/templates/evaluation/kg/__init__.py index db8edce6..9c500d1f 100644 --- a/graphgen/templates/evaluation/kg/__init__.py +++ b/graphgen/templates/evaluation/kg/__init__.py @@ -1,2 +1 @@ from .accuracy_evaluation import ACCURACY_EVALUATION_PROMPT -from .consistency_evaluation import CONSISTENCY_EVALUATION_PROMPT diff --git a/graphgen/templates/evaluation/kg/consistency_evaluation.py b/graphgen/templates/evaluation/kg/consistency_evaluation.py deleted file mode 100644 index 3c3dbe9e..00000000 --- a/graphgen/templates/evaluation/kg/consistency_evaluation.py +++ /dev/null @@ -1,228 +0,0 @@ -ENTITY_TYPE_CONFLICT_PROMPT_ZH = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中被提取为不同的类型,是否存在语义冲突。 - -实体名称:{entity_name} - -在不同文本块中的类型提取结果: -{type_extractions} - -预设的实体类型列表(供参考): -concept, date, location, keyword, organization, person, event, work, nature, artificial, science, technology, mission, gene - -请判断这些类型是否存在语义冲突(即它们是否描述的是同一类事物,还是存在矛盾)。 -注意:如果类型只是同一概念的不同表述(如 concept 和 keyword),可能不算严重冲突。 - -请以 JSON 格式返回: -{{ - "has_conflict": , - "conflict_severity": <0-1之间的浮点数,0表示无冲突,1表示严重冲突>, - "conflict_reasoning": "<冲突判断的理由>", - "conflicting_types": ["<存在冲突的类型对>"], - "recommended_type": "<如果存在冲突,推荐的正确类型(必须是预设类型之一)>" -}} -""" - -ENTITY_TYPE_CONFLICT_PROMPT_EN = ( - """You are a Knowledge Graph Consistency Assessment Expert. """ - """Your task is to determine whether there are semantic conflicts """ - """when the same entity is extracted as different types in different text blocks. - -Entity Name: {entity_name} - -Type extraction results from different text blocks: -{type_extractions} - -Preset entity type list (for reference): -concept, date, location, keyword, organization, person, event, work, nature, """ - """artificial, science, technology, mission, gene - -Please determine whether these types have semantic conflicts """ - """(i.e., whether they describe the same category of things, """ - """or if there are contradictions). -Note: If types are just different expressions of the same concept """ - """(such as concept and keyword), it may not be considered a serious conflict. - -Please return in JSON format: -{{ - "has_conflict": , - "conflict_severity": , - "conflict_reasoning": "", - "conflicting_types": [""], - "recommended_type": "" -}} -""" -) - -ENTITY_DESCRIPTION_CONFLICT_PROMPT_ZH = """你是一个知识图谱一致性评估专家。你的任务是判断同一个实体在不同文本块中的描述是否存在语义冲突。 - -实体名称:{entity_name} - -在不同文本块中的描述: -{descriptions} - -请判断这些描述是否存在语义冲突(即它们是否描述的是同一个实体,还是存在矛盾的信息)。 - -请以 JSON 格式返回: -{{ - "has_conflict": , - "conflict_severity": <0-1之间的浮点数>, - "conflict_reasoning": "<冲突判断的理由>", - "conflicting_descriptions": ["<存在冲突的描述对>"], - "conflict_details": "<具体的冲突内容>" -}} -""" - -ENTITY_DESCRIPTION_CONFLICT_PROMPT_EN = ( - """You are a Knowledge Graph Consistency Assessment Expert. """ - """Your task is to determine whether there are semantic conflicts """ - """in the descriptions of the same entity across different text blocks. - -Entity Name: {entity_name} - -Descriptions from different text blocks: -{descriptions} - -Please determine whether these descriptions have semantic conflicts """ - """(i.e., whether they describe the same entity, """ - """or if there is contradictory information). - -Please return in JSON format: -{{ - "has_conflict": , - "conflict_severity": , - "conflict_reasoning": "", - "conflicting_descriptions": [""], - "conflict_details": "" -}} -""" -) - -RELATION_CONFLICT_PROMPT_ZH = """你是一个知识图谱一致性评估专家。你的任务是判断同一对实体在不同文本块中的关系描述是否存在语义冲突。 - -实体对:{source_entity} -> {target_entity} - -在不同文本块中的关系描述: -{relation_descriptions} - -请判断这些关系描述是否存在语义冲突。 - -请以 JSON 格式返回: -{{ - "has_conflict": , - "conflict_severity": <0-1之间的浮点数>, - "conflict_reasoning": "<冲突判断的理由>", - "conflicting_relations": ["<存在冲突的关系描述对>"] -}} -""" - -RELATION_CONFLICT_PROMPT_EN = ( - """You are a Knowledge Graph Consistency Assessment Expert. """ - """Your task is to determine whether there are semantic conflicts """ - """in the relation descriptions of the same entity pair across different text blocks. - -Entity Pair: {source_entity} -> {target_entity} - -Relation descriptions from different text blocks: -{relation_descriptions} - -Please determine whether these relation descriptions have semantic conflicts. - -Please return in JSON format: -{{ - "has_conflict": , - "conflict_severity": , - "conflict_reasoning": "", - "conflicting_relations": [""] -}} -""" -) - -ENTITY_EXTRACTION_PROMPT_ZH = """从以下文本块中提取指定实体的类型和描述。 - -**重要**:你只需要提取指定的实体,不要提取其他实体。 - -实体名称:{entity_name} - -文本块: -{chunk_content} - -请从文本块中找到并提取**仅此实体**(实体名称:{entity_name})的以下信息: - -1. entity_type: 实体类型,必须是以下预设类型之一(小写): - - concept: 概念 - - date: 日期 - - location: 地点 - - keyword: 关键词 - - organization: 组织 - - person: 人物 - - event: 事件 - - work: 作品/工作 - - nature: 自然 - - artificial: 人工 - - science: 科学 - - technology: 技术 - - mission: 任务 - - gene: 基因 - - 如果无法确定类型,请使用 "concept" 作为默认值。 - -2. description: 实体描述(简要描述该实体在文本中的作用和特征) - -请以 JSON 格式返回: -{{ - "entity_type": "<实体类型(必须是上述预设类型之一)>", - "description": "<实体描述>" -}} -""" - -ENTITY_EXTRACTION_PROMPT_EN = """Extract the type and description of the specified entity from the following text block. - -**Important**: You should only extract the specified entity, do not extract other entities. - -Entity Name: {entity_name} - -Text Block: -{chunk_content} - -Please find and extract the following information for **this entity only** (entity name: {entity_name}) from the text block: - -1. entity_type: Entity type, must be one of the following preset types (lowercase): - - concept: concept - - date: date - - location: location - - keyword: keyword - - organization: organization - - person: person - - event: event - - work: work - - nature: nature - - artificial: artificial - - science: science - - technology: technology - - mission: mission - - gene: gene - - If the type cannot be determined, please use "concept" as the default value. - -2. description: Entity description (briefly describe the role and characteristics of this entity in the text) - -Please return in JSON format: -{{ - "entity_type": "", - "description": "" -}} -""" - -CONSISTENCY_EVALUATION_PROMPT = { - "zh": { - "ENTITY_TYPE_CONFLICT": ENTITY_TYPE_CONFLICT_PROMPT_ZH, - "ENTITY_DESCRIPTION_CONFLICT": ENTITY_DESCRIPTION_CONFLICT_PROMPT_ZH, - "RELATION_CONFLICT": RELATION_CONFLICT_PROMPT_ZH, - "ENTITY_EXTRACTION": ENTITY_EXTRACTION_PROMPT_ZH, - }, - "en": { - "ENTITY_TYPE_CONFLICT": ENTITY_TYPE_CONFLICT_PROMPT_EN, - "ENTITY_DESCRIPTION_CONFLICT": ENTITY_DESCRIPTION_CONFLICT_PROMPT_EN, - "RELATION_CONFLICT": RELATION_CONFLICT_PROMPT_EN, - "ENTITY_EXTRACTION": ENTITY_EXTRACTION_PROMPT_EN, - }, -} diff --git a/graphgen/utils/__init__.py b/graphgen/utils/__init__.py index 840b2cec..48c7ceb5 100644 --- a/graphgen/utils/__init__.py +++ b/graphgen/utils/__init__.py @@ -9,12 +9,7 @@ split_string_by_multi_markers, write_json, ) -from .hash import ( - compute_args_hash, - compute_content_hash, - compute_dict_hash, - compute_mm_hash, -) +from .hash import compute_args_hash, compute_content_hash, compute_dict_hash from .help_nltk import NLTKHelper from .log import CURRENT_LOGGER_VAR, logger, set_logger from .loop import create_event_loop diff --git a/graphgen/utils/format.py b/graphgen/utils/format.py index 1f0675f1..9a687d90 100644 --- a/graphgen/utils/format.py +++ b/graphgen/utils/format.py @@ -30,7 +30,9 @@ def clean_str(input: Any) -> str: result = html.unescape(input.strip()) # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python - return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) + result = re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) + result = result.strip('"').strip("'") + return result async def handle_single_entity_extraction( diff --git a/graphgen/utils/hash.py b/graphgen/utils/hash.py index 04ba96e7..ce4b3d53 100644 --- a/graphgen/utils/hash.py +++ b/graphgen/utils/hash.py @@ -9,20 +9,6 @@ def compute_content_hash(content, prefix: str = ""): return prefix + md5(content.encode()).hexdigest() -def compute_mm_hash(item, prefix: str = ""): - if item.get("type") == "text" and item.get("text"): - content = item["text"].strip() - elif item.get("type") == "image" and item.get("img_path"): - content = f"image:{item['img_path']}" - elif item.get("type") == "table" and item.get("table_body"): - content = f"table:{item['table_body']}" - elif item.get("type") == "equation" and item.get("text"): - content = f"equation:{item['text']}" - else: - content = str(item) - return prefix + md5(content.encode()).hexdigest() - - def compute_dict_hash(d: dict, prefix: str = ""): items = tuple(sorted(d.items())) return prefix + md5(str(items).encode()).hexdigest() diff --git a/graphgen/utils/run_concurrent.py b/graphgen/utils/run_concurrent.py index d1a9b0e2..45a08e0f 100644 --- a/graphgen/utils/run_concurrent.py +++ b/graphgen/utils/run_concurrent.py @@ -19,23 +19,37 @@ def run_concurrent( unit: str = "item", ) -> List[R]: async def _run_all(): - tasks = [asyncio.create_task(coro_fn(item)) for item in items] + # Wrapper to return the index alongside the result + # This eliminates the need to map task IDs + async def _worker(index: int, item: T): + try: + res = await coro_fn(item) + return index, res, None + except Exception as e: + return index, None, e + + # Create tasks using the wrapper + tasks_list = [ + asyncio.create_task(_worker(i, item)) for i, item in enumerate(items) + ] - results = [] + results: List[Exception | R] = [None] * len(items) pbar = tqdm_async(total=len(items), desc=desc, unit=unit) - for future in asyncio.as_completed(tasks): - try: - result = await future - results.append(result) - except Exception as e: - logger.exception("Task failed: %s", e) - results.append(e) + # Iterate over completed tasks + for future in asyncio.as_completed(tasks_list): + # We await the wrapper, which guarantees we get the index back + idx, result, error = await future + + if error: + logger.exception(f"Task failed at index {idx}: {error}") + else: + results[idx] = result pbar.update(1) pbar.close() - return [res for res in results if not isinstance(res, Exception)] + return results loop = create_event_loop() try: diff --git a/tests/e2e_tests/evaluate/test_evaluate_triple.py b/tests/e2e_tests/evaluate/test_evaluate_triple.py new file mode 100644 index 00000000..5a52a47e --- /dev/null +++ b/tests/e2e_tests/evaluate/test_evaluate_triple.py @@ -0,0 +1,9 @@ +from pathlib import Path + +from tests.e2e_tests.conftest import run_generate_test + + +def test_evaluate_kg(tmp_path: Path): + run_generate_test( + tmp_path, "examples/evaluate/evaluate_triple/triple_evaluation_config.yaml" + )