feat: support synthesizing masked fill_in_blank QA pairs#173
feat: support synthesizing masked fill_in_blank QA pairs#173superfarther wants to merge 2 commits intoInternScience:mainfrom
Conversation
Summary of ChangesHello @superfarther, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the system's question generation capabilities by introducing a novel method for creating masked fill-in-blank QA pairs. It integrates a new generator that intelligently rephrases graph data and masks key entities, alongside a dedicated partitioner for extracting graph triples. The inclusion of a comprehensive example pipeline ensures immediate usability and demonstration of this new feature. Highlights
Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for synthesizing masked fill-in-the-blank QA pairs. It adds a new TriplePartitioner to break down the graph into triples and a MaskedFillInBlankGenerator for generating the QA pairs. The overall approach is sound and the changes are well-structured. I've identified one critical issue that could lead to a runtime error, along with several high-severity issues concerning violations of base class contracts and the use of global state, which should be addressed. I've also included some medium-severity suggestions to enhance code quality and maintainability.
| mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE) | ||
| masked_context = mask_pattern.sub("___", context) | ||
| # For accuracy, extract the actual replaced text from the context as the ground truth (keeping the original case) | ||
| gth = re.search(mask_pattern, context).group(0) |
There was a problem hiding this comment.
The call to re.search(mask_pattern, context) can return None if the mask_node_name is not found in the context. This would cause a crash with an AttributeError when .group(0) is called. You should add a check to handle this case gracefully, for example by logging a warning and returning an empty list.
| gth = re.search(mask_pattern, context).group(0) | |
| match = re.search(mask_pattern, context) | |
| if not match: | |
| logger.warning( | |
| "Could not find mask_node_name '%s' in the rephrased context. Context: %s", | |
| mask_node_name, | |
| context, | |
| ) | |
| return [] | |
| gth = match.group(0) |
| from graphgen.templates import AGGREGATED_GENERATION_PROMPT | ||
| from graphgen.utils import detect_main_language, logger | ||
|
|
||
| random.seed(42) |
There was a problem hiding this comment.
Setting a global random seed with random.seed(42) is generally discouraged as it affects the entire application's random number generation, which can lead to unexpected behavior in other parts of the code. For reproducibility, it's better to create a local random.Random instance within your class, for example in the __init__ method, and use that for random operations like random.choice on line 103.
| @staticmethod | ||
| def parse_response(response: str) -> dict: | ||
| pass |
There was a problem hiding this comment.
The parse_response method is defined as an abstract method in the BaseGenerator class but is implemented with pass here. Additionally, the return type hint dict is incompatible with the base class's list[dict]. Since this method is not used in the overridden generate method, it should either be implemented correctly or raise NotImplementedError to adhere to the abstract base class contract.
| @staticmethod | |
| def parse_response(response: str) -> dict: | |
| pass | |
| @staticmethod | |
| def parse_response(response: str) -> list[dict]: | |
| raise NotImplementedError("This method is not used in MaskedFillInBlankGenerator as it overrides the `generate` method.") |
| from graphgen.bases import BaseGraphStorage, BasePartitioner | ||
| from graphgen.bases.datatypes import Community | ||
|
|
||
| random.seed(42) |
There was a problem hiding this comment.
Setting a global random seed with random.seed(42) is generally discouraged as it affects the entire application's random number generation. This can lead to unexpected behavior in other parts of the code. For reproducibility, it's better to create a local random.Random instance within your class, for example in the __init__ method, and use that for random operations like random.shuffle.
| def partition( | ||
| self, | ||
| g: BaseGraphStorage, | ||
| **kwargs: Any, | ||
| ) -> Iterable[Community]: | ||
| nodes = [n[0] for n in g.get_all_nodes()] | ||
| random.shuffle(nodes) | ||
|
|
||
| visited_nodes: Set[str] = set() | ||
| used_edges: Set[frozenset[str]] = set() | ||
|
|
||
| for seed in nodes: | ||
| if seed in visited_nodes: | ||
| continue | ||
|
|
||
| # start BFS in a connected component | ||
| queue = deque([seed]) | ||
| visited_nodes.add(seed) | ||
|
|
||
| while queue: | ||
| u = queue.popleft() | ||
|
|
||
| for v in g.get_neighbors(u): | ||
| edge_key = frozenset((u, v)) | ||
|
|
||
| # if this edge has not been used, a new triple has been found | ||
| if edge_key not in used_edges: | ||
| used_edges.add(edge_key) | ||
|
|
||
| # use the edge name to ensure the uniqueness of the ID | ||
| u_sorted, v_sorted = sorted((u, v)) | ||
| yield Community( | ||
| id=f"{u_sorted}-{v_sorted}", | ||
| nodes=[u_sorted, v_sorted], | ||
| edges=[(u_sorted, v_sorted)] | ||
| ) | ||
|
|
||
| # continue to BFS | ||
| if v not in visited_nodes: | ||
| visited_nodes.add(v) | ||
| queue.append(v) No newline at end of file |
There was a problem hiding this comment.
The partition method in the base class BasePartitioner is type-hinted to return a List[Community]. This implementation returns an Iterable[Community] because it uses yield, which violates the Liskov Substitution Principle and the base class contract. To fix this, you should change this function to build and return a list of communities instead of using yield.
| dependencies: | ||
| - build_kg | ||
| params: | ||
| method: triple |
|
|
||
| if output_data_format == "QA_pairs": | ||
| return { | ||
| "question": question, |
| # TODO: configure add_context | ||
| # if add_context: | ||
| # original_ids = [ | ||
| # node["source_id"].split("<SEP>")[0] for node in _process_nodes | ||
| # ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges] | ||
| # original_ids = list(set(original_ids)) | ||
| # original_text = await text_chunks_storage.get_by_ids(original_ids) | ||
| # original_text = "\n".join( | ||
| # [ | ||
| # f"{index + 1}. {text['content']}" | ||
| # for index, text in enumerate(original_text) | ||
| # ] | ||
| # ) |
There was a problem hiding this comment.
| if not context: | ||
| return [] | ||
|
|
||
| nodes, edge = batch |
There was a problem hiding this comment.
The variable edge is used to unpack the second item from batch, but based on the assertion on line 100 (len(edge) == 1), it appears to be a list of edges. For clarity and to avoid confusion, consider renaming it to edges here and on line 100.
| nodes, edge = batch | |
| nodes, edges = batch |
| # continue to BFS | ||
| if v not in visited_nodes: | ||
| visited_nodes.add(v) | ||
| queue.append(v) No newline at end of file |
This PR support synthesizing masked fill_in_blank QA pairs