diff --git a/examples/filter/filter.sh b/examples/filter/filter.sh new file mode 100644 index 00000000..ca603d42 --- /dev/null +++ b/examples/filter/filter.sh @@ -0,0 +1,2 @@ +python3 -m graphgen.run \ +--config_file examples/filter/filter_config.yaml diff --git a/examples/filter/filter_config.yaml b/examples/filter/filter_config.yaml new file mode 100644 index 00000000..fbc17ece --- /dev/null +++ b/examples/filter/filter_config.yaml @@ -0,0 +1,116 @@ +global_params: + working_dir: cache + graph_backend: networkx # graph database backend, support: kuzu, networkx + kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv + +nodes: + - id: read_files # id is unique in the pipeline, and can be referenced by other steps + op_name: read + type: source + dependencies: [] + params: + input_path: + - examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples + + - id: chunk_documents + op_name: chunk + type: map_batch + dependencies: + - read_files + execution_params: + replicas: 4 + params: + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting + + - id: build_kg + op_name: build_kg + type: map_batch + dependencies: + - chunk_documents + execution_params: + replicas: 1 + batch_size: 128 + + - id: quiz + op_name: quiz + type: aggregate + dependencies: + - build_kg + execution_params: + replicas: 1 + batch_size: 128 + params: + quiz_samples: 2 # number of quiz samples to generate + concurrency_limit: 200 + + - id: judge + op_name: judge + type: map_batch + dependencies: + - quiz + execution_params: + replicas: 1 + batch_size: 128 + + - id: partition + op_name: partition + type: aggregate + dependencies: + - judge + params: + method: ece # ece is a custom partition method based on comprehension loss + method_params: + max_units_per_community: 20 # max nodes and edges per community + min_units_per_community: 5 # min nodes and edges per community + max_tokens_per_community: 10240 # max tokens per community + unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss + + - id: generate + op_name: generate + type: map_batch + dependencies: + - partition + execution_params: + replicas: 1 + batch_size: 128 + save_output: true + params: + method: aggregated # atomic, aggregated, multi_hop, cot, vqa + data_format: ChatML # Alpaca, Sharegpt, ChatML + + - id: evaluate + op_name: evaluate + type: map_batch + dependencies: + - generate + execution_params: + replicas: 1 + batch_size: 128 + save_output: true + params: + target: qa + metrics: + - length + - mtld + # - reward_score + # - uni_score + mtld_params: + threshold: 0.7 + + - id: filter + op_name: filter + type: filter + dependencies: + - evaluate + execution_params: + replicas: 1 + batch_size: 128 + save_output: true + params: + method: range + method_params: + metric: mtld + min_val: 300 + max_val: 400 + diff --git a/graphgen/bases/__init__.py b/graphgen/bases/__init__.py index ab143b44..65499cd4 100644 --- a/graphgen/bases/__init__.py +++ b/graphgen/bases/__init__.py @@ -1,5 +1,6 @@ from .base_evaluator import BaseKGEvaluator, BaseQAEvaluator, BaseTripleEvaluator from .base_extractor import BaseExtractor +from .base_filter import BaseValueFilter from .base_generator import BaseGenerator from .base_kg_builder import BaseKGBuilder from .base_llm_wrapper import BaseLLMWrapper diff --git a/graphgen/bases/base_filter.py b/graphgen/bases/base_filter.py new file mode 100644 index 00000000..e46983e9 --- /dev/null +++ b/graphgen/bases/base_filter.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from typing import Any, Union + +import numpy as np + + +class BaseFilter(ABC): + @abstractmethod + def filter(self, data: Any) -> bool: + """ + Filter the data and return True if it passes the filter, False otherwise. + """ + raise NotImplementedError + + +class BaseValueFilter(BaseFilter, ABC): + @abstractmethod + def filter(self, data: Union[int, float, np.number]) -> bool: + """ + Filter the numeric value and return True if it passes the filter, False otherwise. + """ + raise NotImplementedError + + @property + @abstractmethod + def filter_type(self) -> str: + """ + Return the type of filter (e.g., "greater_than", "less_than", etc.) + """ + raise NotImplementedError diff --git a/graphgen/engine.py b/graphgen/engine.py index d09eb106..54a18065 100644 --- a/graphgen/engine.py +++ b/graphgen/engine.py @@ -2,7 +2,6 @@ import logging import os from collections import defaultdict, deque -from functools import wraps from typing import Any, Callable, Dict, List, Set import ray @@ -103,7 +102,6 @@ def _scan_storage_requirements(self) -> tuple[set[str], set[str]]: kv_namespaces = set() graph_namespaces = set() - # TODO: Temporarily hard-coded; node storage will be centrally managed later. for node in self.config.nodes: op_name = node.op_name if self._function_needs_param(op_name, "kv_backend"): @@ -232,62 +230,38 @@ def _filter_kwargs( input_ds = self._get_input_dataset(node, initial_ds) - if inspect.isclass(op_handler): - execution_params = node.execution_params or {} - replicas = execution_params.get("replicas", 1) - batch_size = ( - int(execution_params.get("batch_size")) - if "batch_size" in execution_params - else "default" + # if inspect.isclass(op_handler): + execution_params = node.execution_params or {} + replicas = execution_params.get("replicas", 1) + batch_size = ( + int(execution_params.get("batch_size")) + if "batch_size" in execution_params + else "default" + ) + compute_resources = execution_params.get("compute_resources", {}) + + if node.type == "aggregate": + self.datasets[node.id] = input_ds.repartition(1).map_batches( + op_handler, + compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1), + batch_size=None, # aggregate processes the whole dataset at once + num_gpus=compute_resources.get("num_gpus", 0) + if compute_resources + else 0, + fn_constructor_kwargs=node_params, + batch_format="pandas", ) - compute_resources = execution_params.get("compute_resources", {}) - - if node.type == "aggregate": - self.datasets[node.id] = input_ds.repartition(1).map_batches( - op_handler, - compute=ray.data.ActorPoolStrategy(min_size=1, max_size=1), - batch_size=None, # aggregate processes the whole dataset at once - num_gpus=compute_resources.get("num_gpus", 0) - if compute_resources - else 0, - fn_constructor_kwargs=node_params, - batch_format="pandas", - ) - else: - # others like map, filter, flatmap, map_batch let actors process data inside batches - self.datasets[node.id] = input_ds.map_batches( - op_handler, - compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas), - batch_size=batch_size, - num_gpus=compute_resources.get("num_gpus", 0) - if compute_resources - else 0, - fn_constructor_kwargs=node_params, - batch_format="pandas", - ) - else: - - @wraps(op_handler) - def func_wrapper(row_or_batch: Dict[str, Any]) -> Dict[str, Any]: - return op_handler(row_or_batch, **node_params) - - if node.type == "map": - self.datasets[node.id] = input_ds.map(func_wrapper) - elif node.type == "filter": - self.datasets[node.id] = input_ds.filter(func_wrapper) - elif node.type == "flatmap": - self.datasets[node.id] = input_ds.flat_map(func_wrapper) - elif node.type == "aggregate": - self.datasets[node.id] = input_ds.repartition(1).map_batches( - func_wrapper, batch_format="default" - ) - elif node.type == "map_batch": - self.datasets[node.id] = input_ds.map_batches(func_wrapper) - else: - raise ValueError( - f"Unsupported node type {node.type} for node {node.id}" - ) + self.datasets[node.id] = input_ds.map_batches( + op_handler, + compute=ray.data.ActorPoolStrategy(min_size=1, max_size=replicas), + batch_size=batch_size, + num_gpus=compute_resources.get("num_gpus", 0) + if compute_resources + else 0, + fn_constructor_kwargs=node_params, + batch_format="pandas", + ) def execute( self, initial_ds: ray.data.Dataset, output_dir: str @@ -315,6 +289,14 @@ def execute( logger.info("Node %s output saved to %s", node.id, node_output_path) # ray will lazy read the dataset - self.datasets[node.id] = ray.data.read_json(node_output_path) + if os.path.exists(node_output_path) and os.listdir(node_output_path): + self.datasets[node.id] = ray.data.read_json(node_output_path) + else: + self.datasets[node.id] = ray.data.from_items([]) + logger.warning( + "Node %s output path %s is empty. Created an empty dataset.", + node.id, + node_output_path, + ) return self.datasets diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index bb708c15..b75c757c 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -6,6 +6,7 @@ StructureEvaluator, UniEvaluator, ) +from .filter import RangeFilter from .generator import ( AggregatedGenerator, AtomicGenerator, diff --git a/graphgen/models/filter/__init__.py b/graphgen/models/filter/__init__.py new file mode 100644 index 00000000..7addf4b5 --- /dev/null +++ b/graphgen/models/filter/__init__.py @@ -0,0 +1 @@ +from .range_filter import RangeFilter diff --git a/graphgen/models/filter/range_filter.py b/graphgen/models/filter/range_filter.py new file mode 100644 index 00000000..185c19cf --- /dev/null +++ b/graphgen/models/filter/range_filter.py @@ -0,0 +1,40 @@ +from typing import Union + +import numpy as np + +from graphgen.bases import BaseValueFilter + + +class RangeFilter(BaseValueFilter): + """ + keeps values within a specified range [min_val, max_val] (inclusive or exclusive) + """ + + def __init__( + self, + min_val: float, + max_val: float, + left_inclusive: bool = True, + right_inclusive: bool = True, + ): + self.min_val = min_val + self.max_val = max_val + self.left_inclusive = left_inclusive + self.right_inclusive = right_inclusive + + def filter(self, data: Union[int, float, np.number]) -> bool: + value = float(data) + if self.left_inclusive and self.right_inclusive: + return self.min_val <= value <= self.max_val + if self.left_inclusive and not self.right_inclusive: + return self.min_val <= value < self.max_val + if not self.left_inclusive and self.right_inclusive: + return self.min_val < value <= self.max_val + return self.min_val < value < self.max_val + + @property + def filter_type(self) -> str: + return "range" + + def __repr__(self) -> str: + return f"RangeFilter({self.min_val}, {self.max_val})" diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index ab840cc5..54ebc42a 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -2,6 +2,7 @@ from .chunk import ChunkService from .evaluate import EvaluateService from .extract import ExtractService +from .filter import FilterService from .generate import GenerateService from .judge import JudgeService from .partition import PartitionService @@ -9,7 +10,6 @@ from .read import read from .search import SearchService - operators = { "read": read, "chunk": ChunkService, @@ -21,4 +21,5 @@ "partition": PartitionService, "generate": GenerateService, "evaluate": EvaluateService, + "filter": FilterService, } diff --git a/graphgen/operators/filter/__init__.py b/graphgen/operators/filter/__init__.py new file mode 100644 index 00000000..67d3ce8a --- /dev/null +++ b/graphgen/operators/filter/__init__.py @@ -0,0 +1 @@ +from .filter_service import FilterService diff --git a/graphgen/operators/filter/filter_service.py b/graphgen/operators/filter/filter_service.py new file mode 100644 index 00000000..8e277103 --- /dev/null +++ b/graphgen/operators/filter/filter_service.py @@ -0,0 +1,49 @@ +from typing import Tuple + +from graphgen.bases import BaseOperator +from graphgen.utils import logger + + +class FilterService(BaseOperator): + def __init__( + self, working_dir: str = "cache", kv_backend: str = "rocksdb", **filter_kwargs + ): + super().__init__( + working_dir=working_dir, kv_backend=kv_backend, op_name="filter" + ) + method = filter_kwargs["method"] + method_params = filter_kwargs["method_params"] + self.metric = method_params["metric"] + if method == "range": + from graphgen.models import RangeFilter + + self.filter_instance = RangeFilter( + min_val=method_params["min_val"], + max_val=method_params["max_val"], + left_inclusive=method_params.get("left_inclusive", True), + right_inclusive=method_params.get("right_inclusive", True), + ) + else: + raise ValueError(f"Unsupported filter method: {method}") + + def process(self, batch: list) -> Tuple[list, dict]: + """ + Filter the items in the batch. + :return: A tuple of (results, meta_updates) + results: A list of filtered items. + meta_updates: empty as filtering does not create new items. + """ + results = [] + meta_updates = {} + + for item in batch: + value = item["metrics"].get(self.metric) + if value is None: + logger.warning( + f"Item {item} does not have metric {self.metric}. Skipping." + ) + continue + if self.filter_instance.filter(value): + results.append(item) + + return results, meta_updates