Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/filter/filter.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
python3 -m graphgen.run \
--config_file examples/filter/filter_config.yaml
116 changes: 116 additions & 0 deletions examples/filter/filter_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
global_params:
working_dir: cache
graph_backend: networkx # graph database backend, support: kuzu, networkx
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv

nodes:
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
op_name: read
type: source
dependencies: []
params:
input_path:
- examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples

- id: chunk_documents
op_name: chunk
type: map_batch
dependencies:
- read_files
execution_params:
replicas: 4
params:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting

- id: build_kg
op_name: build_kg
type: map_batch
dependencies:
- chunk_documents
execution_params:
replicas: 1
batch_size: 128

- id: quiz
op_name: quiz
type: aggregate
dependencies:
- build_kg
execution_params:
replicas: 1
batch_size: 128
params:
quiz_samples: 2 # number of quiz samples to generate
concurrency_limit: 200

- id: judge
op_name: judge
type: map_batch
dependencies:
- quiz
execution_params:
replicas: 1
batch_size: 128

- id: partition
op_name: partition
type: aggregate
dependencies:
- judge
params:
method: ece # ece is a custom partition method based on comprehension loss
method_params:
max_units_per_community: 20 # max nodes and edges per community
min_units_per_community: 5 # min nodes and edges per community
max_tokens_per_community: 10240 # max tokens per community
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss

- id: generate
op_name: generate
type: map_batch
dependencies:
- partition
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
method: aggregated # atomic, aggregated, multi_hop, cot, vqa
data_format: ChatML # Alpaca, Sharegpt, ChatML

- id: evaluate
op_name: evaluate
type: map_batch
dependencies:
- generate
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
target: qa
metrics:
- length
- mtld
# - reward_score
# - uni_score
mtld_params:
threshold: 0.7

- id: filter
op_name: filter
type: filter
dependencies:
- evaluate
execution_params:
replicas: 1
batch_size: 128
save_output: true
params:
method: range
method_params:
metric: mtld
min_val: 300
max_val: 400

1 change: 1 addition & 0 deletions graphgen/bases/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base_evaluator import BaseKGEvaluator, BaseQAEvaluator, BaseTripleEvaluator
from .base_extractor import BaseExtractor
from .base_filter import BaseValueFilter
from .base_generator import BaseGenerator
from .base_kg_builder import BaseKGBuilder
from .base_llm_wrapper import BaseLLMWrapper
Expand Down
30 changes: 30 additions & 0 deletions graphgen/bases/base_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from abc import ABC, abstractmethod
from typing import Any, Union

import numpy as np


class BaseFilter(ABC):
@abstractmethod
def filter(self, data: Any) -> bool:
"""
Filter the data and return True if it passes the filter, False otherwise.
"""
raise NotImplementedError


class BaseValueFilter(BaseFilter, ABC):
@abstractmethod
def filter(self, data: Union[int, float, np.number]) -> bool:
"""
Filter the numeric value and return True if it passes the filter, False otherwise.
"""
raise NotImplementedError

@property
@abstractmethod
def filter_type(self) -> str:
"""
Return the type of filter (e.g., "greater_than", "less_than", etc.)
"""
raise NotImplementedError
96 changes: 39 additions & 57 deletions graphgen/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import os
from collections import defaultdict, deque
from functools import wraps
from typing import Any, Callable, Dict, List, Set

import ray
Expand Down Expand Up @@ -103,7 +102,6 @@ def _scan_storage_requirements(self) -> tuple[set[str], set[str]]:
kv_namespaces = set()
graph_namespaces = set()

# TODO: Temporarily hard-coded; node storage will be centrally managed later.
for node in self.config.nodes:
op_name = node.op_name
if self._function_needs_param(op_name, "kv_backend"):
Expand Down Expand Up @@ -232,62 +230,38 @@ def _filter_kwargs(

input_ds = self._get_input_dataset(node, initial_ds)

if inspect.isclass(op_handler):
execution_params = node.execution_params or {}
replicas = execution_params.get("replicas", 1)
batch_size = (
int(execution_params.get("batch_size"))
if "batch_size" in execution_params
else "default"
# if inspect.isclass(op_handler):
execution_params = node.execution_params or {}
replicas = execution_params.get("replicas", 1)
batch_size = (
int(execution_params.get("batch_size"))
if "batch_size" in execution_params
else "default"
)
compute_resources = execution_params.get("compute_resources", {})

if node.type == "aggregate":
self.datasets[node.id] = input_ds.repartition(1).map_batches(
op_handler,
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1),
batch_size=None, # aggregate processes the whole dataset at once
num_gpus=compute_resources.get("num_gpus", 0)
if compute_resources
else 0,
fn_constructor_kwargs=node_params,
batch_format="pandas",
)
compute_resources = execution_params.get("compute_resources", {})

if node.type == "aggregate":
self.datasets[node.id] = input_ds.repartition(1).map_batches(
op_handler,
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1),
batch_size=None, # aggregate processes the whole dataset at once
num_gpus=compute_resources.get("num_gpus", 0)
if compute_resources
else 0,
fn_constructor_kwargs=node_params,
batch_format="pandas",
)
else:
# others like map, filter, flatmap, map_batch let actors process data inside batches
self.datasets[node.id] = input_ds.map_batches(
op_handler,
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas),
batch_size=batch_size,
num_gpus=compute_resources.get("num_gpus", 0)
if compute_resources
else 0,
fn_constructor_kwargs=node_params,
batch_format="pandas",
)

else:

@wraps(op_handler)
def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]:
return op_handler(row_or_batch, **node_params)

if node.type == "map":
self.datasets[node.id] = input_ds.map(func_wrapper)
elif node.type == "filter":
self.datasets[node.id] = input_ds.filter(func_wrapper)
elif node.type == "flatmap":
self.datasets[node.id] = input_ds.flat_map(func_wrapper)
elif node.type == "aggregate":
self.datasets[node.id] = input_ds.repartition(1).map_batches(
func_wrapper, batch_format="default"
)
elif node.type == "map_batch":
self.datasets[node.id] = input_ds.map_batches(func_wrapper)
else:
raise ValueError(
f"Unsupported node type {node.type} for node {node.id}"
)
self.datasets[node.id] = input_ds.map_batches(
op_handler,
compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas),
batch_size=batch_size,
num_gpus=compute_resources.get("num_gpus", 0)
if compute_resources
else 0,
fn_constructor_kwargs=node_params,
batch_format="pandas",
)

def execute(
self, initial_ds: ray.data.Dataset, output_dir: str
Expand Down Expand Up @@ -315,6 +289,14 @@ def execute(
logger.info("Node %s output saved to %s", node.id, node_output_path)

# ray will lazy read the dataset
self.datasets[node.id] = ray.data.read_json(node_output_path)
if os.path.exists(node_output_path) and os.listdir(node_output_path):
self.datasets[node.id] = ray.data.read_json(node_output_path)
else:
self.datasets[node.id] = ray.data.from_items([])
logger.warning(
"Node %s output path %s is empty. Created an empty dataset.",
node.id,
node_output_path,
)

return self.datasets
1 change: 1 addition & 0 deletions graphgen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
StructureEvaluator,
UniEvaluator,
)
from .filter import RangeFilter
from .generator import (
AggregatedGenerator,
AtomicGenerator,
Expand Down
1 change: 1 addition & 0 deletions graphgen/models/filter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .range_filter import RangeFilter
40 changes: 40 additions & 0 deletions graphgen/models/filter/range_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Union

import numpy as np

from graphgen.bases import BaseValueFilter


class RangeFilter(BaseValueFilter):
"""
keeps values within a specified range [min_val, max_val] (inclusive or exclusive)
"""

def __init__(
self,
min_val: float,
max_val: float,
left_inclusive: bool = True,
right_inclusive: bool = True,
):
self.min_val = min_val
self.max_val = max_val
self.left_inclusive = left_inclusive
self.right_inclusive = right_inclusive

def filter(self, data: Union[int, float, np.number]) -> bool:
value = float(data)
if self.left_inclusive and self.right_inclusive:
return self.min_val <= value <= self.max_val
if self.left_inclusive and not self.right_inclusive:
return self.min_val <= value < self.max_val
if not self.left_inclusive and self.right_inclusive:
return self.min_val < value <= self.max_val
return self.min_val < value < self.max_val

@property
def filter_type(self) -> str:
return "range"

def __repr__(self) -> str:
return f"RangeFilter({self.min_val}, {self.max_val})"
3 changes: 2 additions & 1 deletion graphgen/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from .chunk import ChunkService
from .evaluate import EvaluateService
from .extract import ExtractService
from .filter import FilterService
from .generate import GenerateService
from .judge import JudgeService
from .partition import PartitionService
from .quiz import QuizService
from .read import read
from .search import SearchService


operators = {
"read": read,
"chunk": ChunkService,
Expand All @@ -21,4 +21,5 @@
"partition": PartitionService,
"generate": GenerateService,
"evaluate": EvaluateService,
"filter": FilterService,
}
1 change: 1 addition & 0 deletions graphgen/operators/filter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .filter_service import FilterService
Loading