Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion baselines/BDS/bds.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from graphgen.bases import BaseLLMWrapper
from graphgen.common import init_llm
from graphgen.models import NetworkXStorage
from graphgen.storage import NetworkXStorage
from graphgen.utils import create_event_loop

QA_GENERATION_PROMPT = """
Expand Down
7 changes: 3 additions & 4 deletions examples/evaluate/evaluate_kg/kg_evaluation_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ nodes:
dependencies: []
params:
input_path:
- examples/input_examples/extract_demo.txt
- examples/input_examples/jsonl_demo.jsonl

- id: chunk
op_name: chunk
Expand Down Expand Up @@ -39,7 +39,6 @@ nodes:
dependencies:
- build_kg
params:
target: kg
metrics:
- kg_structure
- kg_accuracy
- kg_consistency
- structure
13 changes: 7 additions & 6 deletions examples/evaluate/evaluate_qa/qa_evaluation_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
global_params:
working_dir: cache
graph_backend: kuzu # graph database backend, support: kuzu, networkx
kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv
graph_backend: networkx # graph database backend, support: kuzu, networkx
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv

nodes:
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
Expand Down Expand Up @@ -89,10 +89,11 @@ nodes:
batch_size: 128
save_output: true
params:
target: qa
metrics:
- qa_length
- qa_mtld
# - qa_reward_score
# - qa_uni_score
- length
- mtld
# - reward_score
# - uni_score
mtld_params:
threshold: 0.7
2 changes: 2 additions & 0 deletions examples/evaluate/evaluate_triple/evaluate_triple.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
python3 -m graphgen.run \
--config_file examples/evaluate/evaluate_triple/triple_evaluation_config.yaml
46 changes: 46 additions & 0 deletions examples/evaluate/evaluate_triple/triple_evaluation_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
global_params:
working_dir: cache
graph_backend: networkx # graph database backend, support: kuzu, networkx
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv

nodes:
- id: read
op_name: read
type: source
dependencies: []
params:
input_path:
- examples/input_examples/jsonl_demo.jsonl

- id: chunk
op_name: chunk
type: map_batch
dependencies:
- read
execution_params:
replicas: 4
params:
chunk_size: 20480 # larger chunk size for better context
chunk_overlap: 2000

- id: build_kg
op_name: build_kg
type: map_batch
dependencies:
- chunk
execution_params:
replicas: 1
batch_size: 128

- id: evaluate
op_name: evaluate
type: aggregate
save_output: true
dependencies:
- build_kg
params:
target: triple
src_namespace: chunk
tgt_namespace: build_kg
metrics:
- accuracy
2 changes: 1 addition & 1 deletion graphgen/bases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base_evaluator import BaseKGEvaluator, BaseQAEvaluator, BaseTripleEvaluator
from .base_extractor import BaseExtractor
from .base_generator import BaseGenerator
from .base_kg_builder import BaseKGBuilder
Expand All @@ -9,5 +10,4 @@
from .base_splitter import BaseSplitter
from .base_storage import BaseGraphStorage, BaseKVStorage, StorageNameSpace
from .base_tokenizer import BaseTokenizer
from .base_evaluator import BaseEvaluator
from .datatypes import Chunk, Config, Node, QAPair, Token
23 changes: 21 additions & 2 deletions graphgen/bases/base_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,29 @@
from abc import ABC, abstractmethod
from typing import Any

from .base_storage import BaseGraphStorage
from .datatypes import QAPair


class BaseEvaluator(ABC):
class BaseQAEvaluator(ABC):
@abstractmethod
def evaluate(self, pair: QAPair) -> float:
async def evaluate(self, pair: QAPair) -> dict[str, float]:
"""
Evaluate the text and return a score.
"""


class BaseKGEvaluator(ABC):
@abstractmethod
def evaluate(self, kg: BaseGraphStorage) -> dict[str, Any]:
"""
Evaluate the whole graph and return a dict of scores.
"""


class BaseTripleEvaluator(ABC):
@abstractmethod
async def evaluate(self, unit: dict) -> dict[str, float]:
"""
Evaluate a node/edge and return a score.
"""
83 changes: 34 additions & 49 deletions graphgen/bases/base_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,72 +21,57 @@ def build_prompt(

@staticmethod
@abstractmethod
def parse_response(response: str) -> Any:
def parse_response(response: str) -> list[dict]:
"""Parse the LLM response and return the generated QAs"""

async def generate(
self,
batch: tuple[
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
],
) -> dict[str, Any]:
) -> list[dict]:
"""
Generate QAs based on a given batch.
:param batch
:return: QA pairs
"""
result = {}
prompt = self.build_prompt(batch)
response = await self.llm_client.generate_answer(prompt)
qa_pairs = self.parse_response(response) # generate one or more QA pairs
result.update(qa_pairs)
return result
return qa_pairs

@staticmethod
def format_generation_results(
results: list[dict], output_data_format: str
) -> list[dict[str, Any]]:
result: dict, output_data_format: str
) -> dict[str, Any]:
question = result.get("question", "")
answer = result.get("answer", "")
if "options" in result and result["options"]:
options = result["options"]
options_str = "\n".join(
[f"{key}. {options[key]}" for key in sorted(options.keys())]
)
question += f"\nOptions:\n{options_str}"

flat_results = []
for item in results:
for _, qa_data in item.items():
question = qa_data.get("question", "")
answer = qa_data.get("answer", "")
if "options" in qa_data and qa_data["options"]:
options = qa_data["options"]
options_str = "\n".join(
[f"{key}. {options[key]}" for key in sorted(options.keys())]
)
question += f"\nOptions:\n{options_str}"
if output_data_format == "Alpaca":
return {
"instruction": question,
"input": "",
"output": answer,
}

if output_data_format == "Alpaca":
flat_results.append(
{
"instruction": question,
"input": "",
"output": answer,
}
)
elif output_data_format == "Sharegpt":
flat_results.append(
{
"conversations": [
{"from": "human", "value": question},
{"from": "gpt", "value": answer},
]
}
)
elif output_data_format == "ChatML":
flat_results.append(
{
"messages": [
{"role": "user", "content": question},
{"role": "assistant", "content": answer},
]
}
)
else:
raise ValueError(
f"Unknown output data format: {output_data_format}"
)
return flat_results
if output_data_format == "Sharegpt":
return {
"conversations": [
{"from": "human", "value": question},
{"from": "gpt", "value": answer},
]
}
if output_data_format == "ChatML":
return {
"messages": [
{"role": "user", "content": question},
{"role": "assistant", "content": answer},
]
}
raise ValueError(f"Unknown output data format: {output_data_format}")
Loading