diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index e5a133f18..d033d01f4 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -17,6 +17,7 @@ # Standard from enum import Enum from typing import Any, Dict, List, Union +import json # import copy import logging @@ -183,6 +184,7 @@ def __wrap_jinja_rendering_with_exception_handling(render_template: callable, ** def apply_custom_jinja_template( element: Dict[str, str], formatted_text_column_name: str, + to_json: bool, template: str, **kwargs, ): @@ -193,6 +195,7 @@ def apply_custom_jinja_template( formatted_text_column_name: Name of the dataset column where formatted text is to be saved. If doesn't exist a new column will be created. + to_json: whether to cast the output as a json template: Template to format data with. Features of Dataset should be referred to by {{key}}. Returns: @@ -212,7 +215,16 @@ def render(): env = SandboxedEnvironment(undefined=StrictUndefined) jinja_template = env.from_string(template) template_kwargs = {**tokenizer.special_tokens_map, **element} - return jinja_template.render(element=element, **template_kwargs) + res = jinja_template.render(element=element, **template_kwargs) + if to_json: + try: + # this can easily fail if the individual values in the template are not already json encoded + res = json.loads(res) + except json.decoder.JSONDecodeError as e: + raise RuntimeError( + "Column data not in expected json format: %s" % (res) + ) from e + return res return { f"{formatted_text_column_name}": __wrap_jinja_rendering_with_exception_handling(