From 3f1ec6ace59e138f6f0357650bf6455dfc66716e Mon Sep 17 00:00:00 2001 From: Vprov Date: Fri, 28 Feb 2025 11:42:23 -0800 Subject: [PATCH 01/13] Initial DPO update for the finetuning python client --- src/together/constants.py | 8 ++ src/together/resources/finetune.py | 28 ++++++ src/together/types/__init__.py | 2 + src/together/types/finetune.py | 12 +++ src/together/utils/files.py | 145 ++++++++++++++++++++++++++++- 5 files changed, 192 insertions(+), 3 deletions(-) diff --git a/src/together/constants.py b/src/together/constants.py index c64af326..7e35aad6 100644 --- a/src/together/constants.py +++ b/src/together/constants.py @@ -39,12 +39,20 @@ class DatasetFormat(enum.Enum): GENERAL = "general" CONVERSATION = "conversation" INSTRUCTION = "instruction" + PREFERENCE = "preference" + PREFERENCE_OPENAI = "preference_openai" JSONL_REQUIRED_COLUMNS_MAP = { DatasetFormat.GENERAL: ["text"], DatasetFormat.CONVERSATION: ["messages"], DatasetFormat.INSTRUCTION: ["prompt", "completion"], + DatasetFormat.PREFERENCE: ["chosen", "rejected"], + DatasetFormat.PREFERENCE_OPENAI: [ + "input", + "preferred_output", + "non_preferred_output", + ], } REQUIRED_COLUMNS_MESSAGE = ["role", "content"] POSSIBLE_ROLES_CONVERSATION = ["system", "user", "assistant"] diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index b58cdae2..242ac6eb 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -22,6 +22,7 @@ TrainingType, FinetuneLRScheduler, FinetuneLinearLRSchedulerArgs, + DPOTrainingMethodType, ) from together.types.finetune import DownloadCheckpointType from together.utils import log_warn_once, normalize_key @@ -52,6 +53,8 @@ def createFinetuneRequest( wandb_project_name: str | None = None, wandb_name: str | None = None, train_on_inputs: bool | Literal["auto"] = "auto", + training_method: str = "sft", + dpo_beta: float | None = None, ) -> FinetuneRequest: if batch_size == "max": log_warn_once( @@ -105,6 +108,11 @@ def createFinetuneRequest( lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), ) + if training_method == "dpo": + training_method_args = DPOTrainingMethodType(dpo_beta=dpo_beta) + else: + training_method_args = None + finetune_request = FinetuneRequest( model=model, training_file=training_file, @@ -125,6 +133,8 @@ def createFinetuneRequest( wandb_project_name=wandb_project_name, wandb_name=wandb_name, train_on_inputs=train_on_inputs, + training_method=training_method, + training_method_args=training_method_args, ) return finetune_request @@ -162,6 +172,8 @@ def create( verbose: bool = False, model_limits: FinetuneTrainingLimits | None = None, train_on_inputs: bool | Literal["auto"] = "auto", + training_method: str = "sft", + dpo_beta: float = 0.1, ) -> FinetuneResponse: """ Method to initiate a fine-tuning job @@ -207,6 +219,9 @@ def create( For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields (Instruction format), inputs will be masked. Defaults to "auto". + training_method (str, optional): Training method. Defaults to "sft". + Supported methods: "sft", "dpo". + dpo_beta (float, optional): DPO beta parameter. Defaults to 0.1. Returns: FinetuneResponse: Object containing information about fine-tuning job. @@ -244,6 +259,8 @@ def create( wandb_project_name=wandb_project_name, wandb_name=wandb_name, train_on_inputs=train_on_inputs, + training_method=training_method, + dpo_beta=dpo_beta, ) if verbose: @@ -253,6 +270,8 @@ def create( ) parameter_payload = finetune_request.model_dump(exclude_none=True) + # Print the request payload before sending + print(f"Request payload: {parameter_payload}") response, _, _ = requestor.request( options=TogetherRequest( method="POST", @@ -261,6 +280,8 @@ def create( ), stream=False, ) + # Print the response before processing + print(f"Response: {response}") assert isinstance(response, TogetherResponse) @@ -503,6 +524,8 @@ async def create( verbose: bool = False, model_limits: FinetuneTrainingLimits | None = None, train_on_inputs: bool | Literal["auto"] = "auto", + training_method: str = "sft", + dpo_beta: float = 0.1, ) -> FinetuneResponse: """ Async method to initiate a fine-tuning job @@ -548,6 +571,9 @@ async def create( For datasets with the "messages" field (conversational format) or "prompt" and "completion" fields (Instruction format), inputs will be masked. Defaults to "auto". + training_method (str, optional): Training method. Defaults to "sft". + Supported methods: "sft", "dpo". + dpo_beta (float, optional): DPO beta parameter. Defaults to 0.1. Returns: FinetuneResponse: Object containing information about fine-tuning job. @@ -585,6 +611,8 @@ async def create( wandb_project_name=wandb_project_name, wandb_name=wandb_name, train_on_inputs=train_on_inputs, + training_method=training_method, + dpo_beta=dpo_beta, ) if verbose: diff --git a/src/together/types/__init__.py b/src/together/types/__init__.py index c3100cd1..b47d3484 100644 --- a/src/together/types/__init__.py +++ b/src/together/types/__init__.py @@ -31,6 +31,7 @@ FileType, ) from together.types.finetune import ( + DPOTrainingMethodType, FinetuneDownloadResult, FinetuneLinearLRSchedulerArgs, FinetuneList, @@ -79,6 +80,7 @@ "TrainingType", "FullTrainingType", "LoRATrainingType", + "DPOTrainingMethodType", "RerankRequest", "RerankResponse", "FinetuneTrainingLimits", diff --git a/src/together/types/finetune.py b/src/together/types/finetune.py index 05bc8c42..5d0e43f8 100644 --- a/src/together/types/finetune.py +++ b/src/together/types/finetune.py @@ -135,6 +135,14 @@ class LoRATrainingType(TrainingType): type: str = "Lora" +class DPOTrainingMethodType(BaseModel): + """ + Training method type for DPO training + """ + + dpo_beta: float + + class FinetuneRequest(BaseModel): """ Fine-tune request type @@ -178,6 +186,10 @@ class FinetuneRequest(BaseModel): training_type: FullTrainingType | LoRATrainingType | None = None # train on inputs train_on_inputs: StrictBool | Literal["auto"] = "auto" + # training method + training_method: str = "sft" + # DPO params + training_method_args: DPOTrainingMethodType | None = None class FinetuneResponse(BaseModel): diff --git a/src/together/utils/files.py b/src/together/utils/files.py index cc39fca0..eeb1753a 100644 --- a/src/together/utils/files.py +++ b/src/together/utils/files.py @@ -4,7 +4,7 @@ import os from pathlib import Path from traceback import format_exc -from typing import Any, Dict +from typing import Any, Dict, List from pyarrow import ArrowInvalid, parquet @@ -96,6 +96,123 @@ def check_file( return report_dict +def _has_weights(messages: List[Dict[str, str | bool]]) -> bool: + """Check if any message in the conversation has a weight parameter. + + Args: + messages (List[Dict[str, str]]): List of messages to check. + + Returns: + bool: True if any message has a weight parameter, False otherwise. + """ + return any("weight" in message for message in messages) + + +def validate_and_filter_messages( + messages: List[Dict[str, str | bool]] +) -> tuple[List[Dict[str, str | bool]], bool]: + """Validate and filter the messages column.""" + if not isinstance(messages, list): + raise ValueError( + "The dataset is malformed, the `messages` column must be a list." + ) + if len(messages) == 0: + raise ValueError( + "The dataset is malformed, the `messages` column must not be empty." + ) + + has_weights = False + # Check for weights in messages + if _has_weights(messages): + has_weights = True + + filtered_messages = [] + for message in messages: + if any(column not in message for column in REQUIRED_COLUMNS_MESSAGE): + raise ValueError( + "The dataset is malformed. " + "Each message in the messages column must have " + f"{REQUIRED_COLUMNS_MESSAGE} columns." + ) + for column in REQUIRED_COLUMNS_MESSAGE: + if not isinstance(message[column], str): + raise ValueError( + f"The dataset is malformed, the column `{column}` must be of the string type." + ) + + if has_weights and "weight" in message: + weight = message["weight"] + if not isinstance(weight, int): + raise ValueError("Weight must be an integer") + if weight not in {0, 1}: + raise ValueError("Weight must be either 0 or 1") + if message["role"] not in POSSIBLE_ROLES_CONVERSATION: + raise ValueError( + f"Invalid role {message['role']} in conversation, possible roles: " + f"{', '.join(POSSIBLE_ROLES_CONVERSATION)}" + ) + filtered_messages.append( + {column: message[column] for column in REQUIRED_COLUMNS_MESSAGE} + ) + + return filtered_messages, has_weights + + +def validate_preference_openai(example: Dict[str, Any]) -> Dict[str, Any]: + """Validate the OpenAI preference dataset format. + + Args: + example (dict): Input entry to be checked. + + Raises: + ValueError: If the dataset format is invalid. + + Returns: + Dict[str, Any]: The validated example. + """ + if not isinstance(example["input"], dict): + raise ValueError( + "The dataset is malformed, the `input` field must be a dictionary." + ) + + if "messages" not in example["input"]: + raise ValueError( + "The dataset is malformed, the `input` dictionary must contain a `messages` field." + ) + + example["input"]["messages"], _ = validate_and_filter_messages( + example["input"]["messages"] + ) + + if not isinstance(example["preferred_output"], list): + raise ValueError( + "The dataset is malformed, the `preferred_output` field must be a list." + ) + + if not isinstance(example["non_preferred_output"], list): + raise ValueError( + "The dataset is malformed, the `non_preferred_output` field must be a list." + ) + + if len(example["preferred_output"]) != 1: + raise ValueError( + "The dataset is malformed, the `preferred_output` list must contain exactly one message." + ) + + if len(example["non_preferred_output"]) != 1: + raise ValueError( + "The dataset is malformed, the `non_preferred_output` list must contain exactly one message." + ) + + example["preferred_output"], _ = validate_and_filter_messages( + example["preferred_output"] + ) + example["non_preferred_output"], _ = validate_and_filter_messages( + example["non_preferred_output"] + ) + return example + + def _check_jsonl(file: Path) -> Dict[str, Any]: report_dict: Dict[str, Any] = {} # Check that the file is UTF-8 encoded. If not report where the error occurs. @@ -164,8 +281,30 @@ def _check_jsonl(file: Path) -> Dict[str, Any]: line_number=idx + 1, error_source="format", ) - - if current_format == DatasetFormat.CONVERSATION: + if current_format == DatasetFormat.PREFERENCE_OPENAI: + validate_preference_openai(json_line) + elif current_format == DatasetFormat.PREFERENCE: + for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]: + if not isinstance(json_line[column], list): + raise InvalidFileFormatError( + message=f"The dataset is malformed, the column `{column}` must be a list.", + line_number=idx + 1, + error_source="key_value", + ) + if len(json_line[column]) == 0: + raise InvalidFileFormatError( + message=f"The dataset is malformed, the column `{column}` must not be empty.", + line_number=idx + 1, + error_source="key_value", + ) + validate_and_filter_messages(json_line[column]) + if not json_line[column][-1].get("role") == "assistant": + raise InvalidFileFormatError( + message=f"The last message in {column} must be from an assistant", + line_number=idx + 1, + error_source="key_value", + ) + elif current_format == DatasetFormat.CONVERSATION: message_column = JSONL_REQUIRED_COLUMNS_MAP[ DatasetFormat.CONVERSATION ][0] From ee7e02d9bfc6ad9ac6ecf0db804b05128947fe1b Mon Sep 17 00:00:00 2001 From: Vprov Date: Mon, 3 Mar 2025 04:12:34 -0800 Subject: [PATCH 02/13] Add dpo to cli, fix typing mismatch with API --- src/together/cli/api/finetune.py | 16 ++++++++++++++++ src/together/resources/finetune.py | 28 ++++++++++++++++------------ src/together/types/__init__.py | 6 ++++-- src/together/types/finetune.py | 29 +++++++++++++++++++++++------ 4 files changed, 59 insertions(+), 20 deletions(-) diff --git a/src/together/cli/api/finetune.py b/src/together/cli/api/finetune.py index 7bc02744..679c8fe1 100644 --- a/src/together/cli/api/finetune.py +++ b/src/together/cli/api/finetune.py @@ -104,6 +104,18 @@ def fine_tuning(ctx: click.Context) -> None: default="all-linear", help="Trainable modules for LoRA adapters. For example, 'all-linear', 'q_proj,v_proj'", ) +@click.option( + "--training-method", + type=click.Choice(["sft", "dpo"]), + default="sft", + help="Training method to use. Options: sft (supervised fine-tuning), dpo (direct preference optimization)", +) +@click.option( + "--dpo-beta", + type=float, + default=0.1, + help="Beta parameter for DPO training (only used when training-method is 'dpo')", +) @click.option( "--suffix", type=str, default=None, help="Suffix for the fine-tuned model name" ) @@ -152,6 +164,8 @@ def create( wandb_name: str, confirm: bool, train_on_inputs: bool | Literal["auto"], + training_method: str, + dpo_beta: float, ) -> None: """Start fine-tuning""" client: Together = ctx.obj @@ -180,6 +194,8 @@ def create( wandb_project_name=wandb_project_name, wandb_name=wandb_name, train_on_inputs=train_on_inputs, + training_method=training_method, + dpo_beta=dpo_beta, ) model_limits: FinetuneTrainingLimits = client.fine_tuning.get_model_limits( diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 242ac6eb..41b9d29d 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Literal +from typing import Literal, Union from rich import print as rprint @@ -22,7 +22,8 @@ TrainingType, FinetuneLRScheduler, FinetuneLinearLRSchedulerArgs, - DPOTrainingMethodType, + TrainingMethodDPO, + TrainingMethodSFT, ) from together.types.finetune import DownloadCheckpointType from together.utils import log_warn_once, normalize_key @@ -108,10 +109,13 @@ def createFinetuneRequest( lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), ) + training_method_cls: Union[TrainingMethodSFT, TrainingMethodDPO] = ( + TrainingMethodSFT() + ) if training_method == "dpo": - training_method_args = DPOTrainingMethodType(dpo_beta=dpo_beta) - else: - training_method_args = None + training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta) + + print("\n TRAINING METHOD at CREATE FINE TUNE REQUEST", training_method) finetune_request = FinetuneRequest( model=model, @@ -133,8 +137,7 @@ def createFinetuneRequest( wandb_project_name=wandb_project_name, wandb_name=wandb_name, train_on_inputs=train_on_inputs, - training_method=training_method, - training_method_args=training_method_args, + training_method=training_method_cls, ) return finetune_request @@ -173,7 +176,7 @@ def create( model_limits: FinetuneTrainingLimits | None = None, train_on_inputs: bool | Literal["auto"] = "auto", training_method: str = "sft", - dpo_beta: float = 0.1, + dpo_beta: float | None = None, ) -> FinetuneResponse: """ Method to initiate a fine-tuning job @@ -221,7 +224,7 @@ def create( Defaults to "auto". training_method (str, optional): Training method. Defaults to "sft". Supported methods: "sft", "dpo". - dpo_beta (float, optional): DPO beta parameter. Defaults to 0.1. + dpo_beta (float, optional): DPO beta parameter. Defaults to None. Returns: FinetuneResponse: Object containing information about fine-tuning job. @@ -233,7 +236,7 @@ def create( if model_limits is None: model_limits = self.get_model_limits(model=model) - + print("\n DPO BETA at CREATE FINE TUNE REQUEST", dpo_beta) finetune_request = createFinetuneRequest( model_limits=model_limits, training_file=training_file, @@ -268,6 +271,7 @@ def create( "Submitting a fine-tuning job with the following parameters:", finetune_request, ) + print("\n FINETUNE REQUEST before dump", finetune_request) parameter_payload = finetune_request.model_dump(exclude_none=True) # Print the request payload before sending @@ -525,7 +529,7 @@ async def create( model_limits: FinetuneTrainingLimits | None = None, train_on_inputs: bool | Literal["auto"] = "auto", training_method: str = "sft", - dpo_beta: float = 0.1, + dpo_beta: float | None = None, ) -> FinetuneResponse: """ Async method to initiate a fine-tuning job @@ -573,7 +577,7 @@ async def create( Defaults to "auto". training_method (str, optional): Training method. Defaults to "sft". Supported methods: "sft", "dpo". - dpo_beta (float, optional): DPO beta parameter. Defaults to 0.1. + dpo_beta (float, optional): DPO beta parameter. Defaults to None. Returns: FinetuneResponse: Object containing information about fine-tuning job. diff --git a/src/together/types/__init__.py b/src/together/types/__init__.py index b47d3484..4fb3135e 100644 --- a/src/together/types/__init__.py +++ b/src/together/types/__init__.py @@ -31,7 +31,8 @@ FileType, ) from together.types.finetune import ( - DPOTrainingMethodType, + TrainingMethodDPO, + TrainingMethodSFT, FinetuneDownloadResult, FinetuneLinearLRSchedulerArgs, FinetuneList, @@ -80,7 +81,8 @@ "TrainingType", "FullTrainingType", "LoRATrainingType", - "DPOTrainingMethodType", + "TrainingMethodDPO", + "TrainingMethodSFT", "RerankRequest", "RerankResponse", "FinetuneTrainingLimits", diff --git a/src/together/types/finetune.py b/src/together/types/finetune.py index 5d0e43f8..69b22d7a 100644 --- a/src/together/types/finetune.py +++ b/src/together/types/finetune.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import List, Literal +from typing import List, Literal, Union from pydantic import StrictBool, Field, validator, field_validator @@ -135,12 +135,29 @@ class LoRATrainingType(TrainingType): type: str = "Lora" -class DPOTrainingMethodType(BaseModel): +class TrainingMethod(BaseModel): + """ + Training method type + """ + + method: str + + +class TrainingMethodSFT(TrainingMethod): + """ + Training method type for SFT training + """ + + method: str = "sft" + + +class TrainingMethodDPO(TrainingMethod): """ Training method type for DPO training """ - dpo_beta: float + method: str = "dpo" + dpo_beta: float | None = None class FinetuneRequest(BaseModel): @@ -187,9 +204,9 @@ class FinetuneRequest(BaseModel): # train on inputs train_on_inputs: StrictBool | Literal["auto"] = "auto" # training method - training_method: str = "sft" - # DPO params - training_method_args: DPOTrainingMethodType | None = None + training_method: Union[TrainingMethodSFT, TrainingMethodDPO] = Field( + default_factory=TrainingMethodSFT + ) class FinetuneResponse(BaseModel): From fdbdc8ef6d5f16385cf830bfc8f4841fdb1b7acc Mon Sep 17 00:00:00 2001 From: Vprov Date: Mon, 3 Mar 2025 04:51:30 -0800 Subject: [PATCH 03/13] Remove prints --- src/together/resources/finetune.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 41b9d29d..624446de 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -115,8 +115,6 @@ def createFinetuneRequest( if training_method == "dpo": training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta) - print("\n TRAINING METHOD at CREATE FINE TUNE REQUEST", training_method) - finetune_request = FinetuneRequest( model=model, training_file=training_file, @@ -236,7 +234,6 @@ def create( if model_limits is None: model_limits = self.get_model_limits(model=model) - print("\n DPO BETA at CREATE FINE TUNE REQUEST", dpo_beta) finetune_request = createFinetuneRequest( model_limits=model_limits, training_file=training_file, @@ -271,11 +268,8 @@ def create( "Submitting a fine-tuning job with the following parameters:", finetune_request, ) - print("\n FINETUNE REQUEST before dump", finetune_request) parameter_payload = finetune_request.model_dump(exclude_none=True) - # Print the request payload before sending - print(f"Request payload: {parameter_payload}") response, _, _ = requestor.request( options=TogetherRequest( method="POST", @@ -284,9 +278,6 @@ def create( ), stream=False, ) - # Print the response before processing - print(f"Response: {response}") - assert isinstance(response, TogetherResponse) return FinetuneResponse(**response.data) From ee470fc03ec50d77c997046e48f5ebfa3df2e263 Mon Sep 17 00:00:00 2001 From: Vprov Date: Tue, 4 Mar 2025 06:26:37 -0800 Subject: [PATCH 04/13] Add check that the prompt is the same for the PREFERENCE dataset format --- src/together/utils/files.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/together/utils/files.py b/src/together/utils/files.py index eeb1753a..52808948 100644 --- a/src/together/utils/files.py +++ b/src/together/utils/files.py @@ -304,6 +304,31 @@ def _check_jsonl(file: Path) -> Dict[str, Any]: line_number=idx + 1, error_source="key_value", ) + # Check that all messages except the last one are the same for "chosen" and "rejected" + chosen_messages = json_line["chosen"] + rejected_messages = json_line["rejected"] + + if len(chosen_messages) != len(rejected_messages): + raise InvalidFileFormatError( + message="The 'chosen' and 'rejected' lists must have the same number of messages.", + line_number=idx + 1, + error_source="key_value", + ) + + # Count discrepancies between messages using a generator + discrepancies = sum( + 1 + for i in range(len(chosen_messages) - 1) + if chosen_messages[i] != rejected_messages[i] + ) + + if discrepancies > 1: + raise InvalidFileFormatError( + message=f"Found {discrepancies} different messages between 'chosen' and 'rejected'. " + "Only the last message should differ.", + line_number=idx + 1, + error_source="key_value", + ) elif current_format == DatasetFormat.CONVERSATION: message_column = JSONL_REQUIRED_COLUMNS_MAP[ DatasetFormat.CONVERSATION From 8f50eb8cbaebed7b03c8946e2bf3ddd47c11e594 Mon Sep 17 00:00:00 2001 From: Ivan Provilkov Date: Wed, 5 Mar 2025 17:53:25 +0000 Subject: [PATCH 05/13] Update src/together/cli/api/finetune.py Co-authored-by: Max Ryabinin --- src/together/cli/api/finetune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/together/cli/api/finetune.py b/src/together/cli/api/finetune.py index 679c8fe1..59060b9f 100644 --- a/src/together/cli/api/finetune.py +++ b/src/together/cli/api/finetune.py @@ -108,7 +108,7 @@ def fine_tuning(ctx: click.Context) -> None: "--training-method", type=click.Choice(["sft", "dpo"]), default="sft", - help="Training method to use. Options: sft (supervised fine-tuning), dpo (direct preference optimization)", + help="Training method to use. Options: sft (supervised fine-tuning), dpo (Direct Preference Optimization)", ) @click.option( "--dpo-beta", From b322858ad8f8add2fb7848bd319d021cea93761b Mon Sep 17 00:00:00 2001 From: Ivan Provilkov Date: Wed, 5 Mar 2025 17:53:34 +0000 Subject: [PATCH 06/13] Update src/together/cli/api/finetune.py Co-authored-by: Max Ryabinin --- src/together/cli/api/finetune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/together/cli/api/finetune.py b/src/together/cli/api/finetune.py index 59060b9f..9c81d9dd 100644 --- a/src/together/cli/api/finetune.py +++ b/src/together/cli/api/finetune.py @@ -114,7 +114,7 @@ def fine_tuning(ctx: click.Context) -> None: "--dpo-beta", type=float, default=0.1, - help="Beta parameter for DPO training (only used when training-method is 'dpo')", + help="Beta parameter for DPO training (only used when '--training-method' is 'dpo')", ) @click.option( "--suffix", type=str, default=None, help="Suffix for the fine-tuned model name" From 24d209e14f02ddf3b986621a8422f462007c9c5a Mon Sep 17 00:00:00 2001 From: Vprov Date: Wed, 5 Mar 2025 10:18:10 -0800 Subject: [PATCH 07/13] Update validation function; Use correct errors --- src/together/utils/files.py | 193 ++++++++++++++++-------------------- 1 file changed, 84 insertions(+), 109 deletions(-) diff --git a/src/together/utils/files.py b/src/together/utils/files.py index 52808948..f9c889d0 100644 --- a/src/together/utils/files.py +++ b/src/together/utils/files.py @@ -108,17 +108,21 @@ def _has_weights(messages: List[Dict[str, str | bool]]) -> bool: return any("weight" in message for message in messages) -def validate_and_filter_messages( - messages: List[Dict[str, str | bool]] +def validate_messages( + messages: List[Dict[str, str | bool]], idx: int = 0 ) -> tuple[List[Dict[str, str | bool]], bool]: - """Validate and filter the messages column.""" + """Validate the messages column.""" if not isinstance(messages, list): - raise ValueError( - "The dataset is malformed, the `messages` column must be a list." + raise InvalidFileFormatError( + message="The dataset is malformed, the `messages` column must be a list.", + line_number=idx + 1, + error_source="key_value", ) - if len(messages) == 0: - raise ValueError( - "The dataset is malformed, the `messages` column must not be empty." + if not messages: + raise InvalidFileFormatError( + message="The dataset is malformed, the `messages` column must not be empty.", + line_number=idx + 1, + error_source="key_value", ) has_weights = False @@ -126,89 +130,120 @@ def validate_and_filter_messages( if _has_weights(messages): has_weights = True - filtered_messages = [] + previous_role = None for message in messages: if any(column not in message for column in REQUIRED_COLUMNS_MESSAGE): - raise ValueError( - "The dataset is malformed. " + raise InvalidFileFormatError( + message="The dataset is malformed. " "Each message in the messages column must have " - f"{REQUIRED_COLUMNS_MESSAGE} columns." + f"{REQUIRED_COLUMNS_MESSAGE} columns.", + line_number=idx + 1, + error_source="key_value", ) for column in REQUIRED_COLUMNS_MESSAGE: if not isinstance(message[column], str): - raise ValueError( - f"The dataset is malformed, the column `{column}` must be of the string type." + raise InvalidFileFormatError( + message=f"The dataset is malformed, the column `{column}` must be of the string type.", + line_number=idx + 1, + error_source="key_value", ) if has_weights and "weight" in message: weight = message["weight"] if not isinstance(weight, int): - raise ValueError("Weight must be an integer") + raise InvalidFileFormatError( + message="Weight must be an integer", + line_number=idx + 1, + error_source="key_value", + ) if weight not in {0, 1}: - raise ValueError("Weight must be either 0 or 1") + raise InvalidFileFormatError( + message="Weight must be either 0 or 1", + line_number=idx + 1, + error_source="key_value", + ) if message["role"] not in POSSIBLE_ROLES_CONVERSATION: - raise ValueError( - f"Invalid role {message['role']} in conversation, possible roles: " - f"{', '.join(POSSIBLE_ROLES_CONVERSATION)}" + raise InvalidFileFormatError( + message=f"Invalid role {message['role']} in conversation, possible roles: " + f"{', '.join(POSSIBLE_ROLES_CONVERSATION)}", + line_number=idx + 1, + error_source="key_value", ) - filtered_messages.append( - {column: message[column] for column in REQUIRED_COLUMNS_MESSAGE} - ) - return filtered_messages, has_weights + if previous_role == message["role"]: + raise InvalidFileFormatError( + message=f"Invalid role turns on line {idx + 1} of the input file. " + "`user` and `assistant` roles must alternate user/assistant/user/assistant/...", + line_number=idx + 1, + error_source="key_value", + ) + previous_role = message["role"] + + return messages, has_weights -def validate_preference_openai(example: Dict[str, Any]) -> Dict[str, Any]: +def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> Dict[str, Any]: """Validate the OpenAI preference dataset format. Args: example (dict): Input entry to be checked. + idx (int): Line number in the file. Raises: - ValueError: If the dataset format is invalid. + InvalidFileFormatError: If the dataset format is invalid. Returns: Dict[str, Any]: The validated example. """ if not isinstance(example["input"], dict): - raise ValueError( - "The dataset is malformed, the `input` field must be a dictionary." + raise InvalidFileFormatError( + message="The dataset is malformed, the `input` field must be a dictionary.", + line_number=idx + 1, + error_source="key_value", ) if "messages" not in example["input"]: - raise ValueError( - "The dataset is malformed, the `input` dictionary must contain a `messages` field." + raise InvalidFileFormatError( + message="The dataset is malformed, the `input` dictionary must contain a `messages` field.", + line_number=idx + 1, + error_source="key_value", ) - example["input"]["messages"], _ = validate_and_filter_messages( - example["input"]["messages"] + example["input"]["messages"], _ = validate_messages( + example["input"]["messages"], idx ) if not isinstance(example["preferred_output"], list): - raise ValueError( - "The dataset is malformed, the `preferred_output` field must be a list." + raise InvalidFileFormatError( + message="The dataset is malformed, the `preferred_output` field must be a list.", + line_number=idx + 1, + error_source="key_value", ) if not isinstance(example["non_preferred_output"], list): - raise ValueError( - "The dataset is malformed, the `non_preferred_output` field must be a list." + raise InvalidFileFormatError( + message="The dataset is malformed, the `non_preferred_output` field must be a list.", + line_number=idx + 1, + error_source="key_value", ) if len(example["preferred_output"]) != 1: - raise ValueError( - "The dataset is malformed, the `preferred_output` list must contain exactly one message." + raise InvalidFileFormatError( + message="The dataset is malformed, the `preferred_output` list must contain exactly one message.", + line_number=idx + 1, + error_source="key_value", ) if len(example["non_preferred_output"]) != 1: - raise ValueError( - "The dataset is malformed, the `non_preferred_output` list must contain exactly one message." + raise InvalidFileFormatError( + message="The dataset is malformed, the `non_preferred_output` list must contain exactly one message.", + line_number=idx + 1, + error_source="key_value", ) - example["preferred_output"], _ = validate_and_filter_messages( - example["preferred_output"] - ) - example["non_preferred_output"], _ = validate_and_filter_messages( - example["non_preferred_output"] + example["preferred_output"], _ = validate_messages(example["preferred_output"], idx) + example["non_preferred_output"], _ = validate_messages( + example["non_preferred_output"], idx ) return example @@ -282,7 +317,7 @@ def _check_jsonl(file: Path) -> Dict[str, Any]: error_source="format", ) if current_format == DatasetFormat.PREFERENCE_OPENAI: - validate_preference_openai(json_line) + validate_preference_openai(json_line, idx) elif current_format == DatasetFormat.PREFERENCE: for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]: if not isinstance(json_line[column], list): @@ -297,7 +332,7 @@ def _check_jsonl(file: Path) -> Dict[str, Any]: line_number=idx + 1, error_source="key_value", ) - validate_and_filter_messages(json_line[column]) + validate_messages(json_line[column], idx) if not json_line[column][-1].get("role") == "assistant": raise InvalidFileFormatError( message=f"The last message in {column} must be from an assistant", @@ -333,69 +368,9 @@ def _check_jsonl(file: Path) -> Dict[str, Any]: message_column = JSONL_REQUIRED_COLUMNS_MAP[ DatasetFormat.CONVERSATION ][0] - if not isinstance(json_line[message_column], list): - raise InvalidFileFormatError( - message=f"Invalid format on line {idx + 1} of the input file. " - f"Expected a list of messages. Found {type(json_line[message_column])}", - line_number=idx + 1, - error_source="key_value", - ) - - if len(json_line[message_column]) == 0: - raise InvalidFileFormatError( - message=f"Invalid format on line {idx + 1} of the input file. " - f"Expected a non-empty list of messages. Found empty list", - line_number=idx + 1, - error_source="key_value", - ) - - for turn_id, turn in enumerate(json_line[message_column]): - if not isinstance(turn, dict): - raise InvalidFileFormatError( - message=f"Invalid format on line {idx + 1} of the input file. " - f"Expected a dictionary in the {turn_id + 1} turn. Found {type(turn)}", - line_number=idx + 1, - error_source="key_value", - ) - - previous_role = None - for turn in json_line[message_column]: - for column in REQUIRED_COLUMNS_MESSAGE: - if column not in turn: - raise InvalidFileFormatError( - message=f"Field `{column}` is missing for a turn `{turn}` on line {idx + 1} " - "of the the input file.", - line_number=idx + 1, - error_source="key_value", - ) - else: - if not isinstance(turn[column], str): - raise InvalidFileFormatError( - message=f"Invalid format on line {idx + 1} in the column {column} for turn `{turn}` " - f"of the input file. Expected string. Found {type(turn[column])}", - line_number=idx + 1, - error_source="text_field", - ) - role = turn["role"] - - if role not in POSSIBLE_ROLES_CONVERSATION: - raise InvalidFileFormatError( - message=f"Found invalid role `{role}` in the messages on the line {idx + 1}. " - f"Possible roles in the conversation are: {POSSIBLE_ROLES_CONVERSATION}", - line_number=idx + 1, - error_source="key_value", - ) - - if previous_role == role: - raise InvalidFileFormatError( - message=f"Invalid role turns on line {idx + 1} of the input file. " - "`user` and `assistant` roles must alternate user/assistant/user/assistant/...", - line_number=idx + 1, - error_source="key_value", - ) - - previous_role = role - + messages, has_weights = validate_messages( + json_line[message_column], idx + ) else: for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]: if not isinstance(json_line[column], str): From f1f8d9ac08f727949f457d9ffb325676edfb9384 Mon Sep 17 00:00:00 2001 From: Vprov Date: Wed, 5 Mar 2025 11:04:57 -0800 Subject: [PATCH 08/13] Update error messages --- src/together/utils/files.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/together/utils/files.py b/src/together/utils/files.py index f9c889d0..e7cb6380 100644 --- a/src/together/utils/files.py +++ b/src/together/utils/files.py @@ -114,13 +114,15 @@ def validate_messages( """Validate the messages column.""" if not isinstance(messages, list): raise InvalidFileFormatError( - message="The dataset is malformed, the `messages` column must be a list.", + message=f"Invalid format on line {idx + 1} of the input file. " + f"Expected a list of messages. Found {type(messages)}", line_number=idx + 1, error_source="key_value", ) if not messages: raise InvalidFileFormatError( - message="The dataset is malformed, the `messages` column must not be empty.", + message=f"Invalid format on line {idx + 1} of the input file. " + f"Expected a non-empty list of messages. Found empty list", line_number=idx + 1, error_source="key_value", ) @@ -132,21 +134,29 @@ def validate_messages( previous_role = None for message in messages: - if any(column not in message for column in REQUIRED_COLUMNS_MESSAGE): + if not isinstance(message, dict): raise InvalidFileFormatError( - message="The dataset is malformed. " - "Each message in the messages column must have " - f"{REQUIRED_COLUMNS_MESSAGE} columns.", + message=f"Invalid format on line {idx + 1} of the input file. " + f"Expected a dictionary in the messages list. Found {type(message)}", line_number=idx + 1, error_source="key_value", ) for column in REQUIRED_COLUMNS_MESSAGE: - if not isinstance(message[column], str): + if column not in message: raise InvalidFileFormatError( - message=f"The dataset is malformed, the column `{column}` must be of the string type.", + message=f"Field `{column}` is missing for a turn `{message}` on line {idx + 1} " + "of the the input file.", line_number=idx + 1, error_source="key_value", ) + else: + if not isinstance(message[column], str): + raise InvalidFileFormatError( + message=f"Invalid format on line {idx + 1} in the column {column} for turn `{message}` " + f"of the input file. Expected string. Found {type(message[column])}", + line_number=idx + 1, + error_source="text_field", + ) if has_weights and "weight" in message: weight = message["weight"] @@ -164,8 +174,8 @@ def validate_messages( ) if message["role"] not in POSSIBLE_ROLES_CONVERSATION: raise InvalidFileFormatError( - message=f"Invalid role {message['role']} in conversation, possible roles: " - f"{', '.join(POSSIBLE_ROLES_CONVERSATION)}", + message=f"Found invalid role `{message['role']}` in the messages on the line {idx + 1}. " + f"Possible roles in the conversation are: {POSSIBLE_ROLES_CONVERSATION}", line_number=idx + 1, error_source="key_value", ) From 5320437e35a216ad697576d75c7cf04516601fe4 Mon Sep 17 00:00:00 2001 From: Vprov Date: Wed, 5 Mar 2025 11:25:44 -0800 Subject: [PATCH 09/13] Remove PREFERENCE dataset support --- src/together/utils/files.py | 46 ------------------------------------- 1 file changed, 46 deletions(-) diff --git a/src/together/utils/files.py b/src/together/utils/files.py index e7cb6380..eeef6498 100644 --- a/src/together/utils/files.py +++ b/src/together/utils/files.py @@ -328,52 +328,6 @@ def _check_jsonl(file: Path) -> Dict[str, Any]: ) if current_format == DatasetFormat.PREFERENCE_OPENAI: validate_preference_openai(json_line, idx) - elif current_format == DatasetFormat.PREFERENCE: - for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]: - if not isinstance(json_line[column], list): - raise InvalidFileFormatError( - message=f"The dataset is malformed, the column `{column}` must be a list.", - line_number=idx + 1, - error_source="key_value", - ) - if len(json_line[column]) == 0: - raise InvalidFileFormatError( - message=f"The dataset is malformed, the column `{column}` must not be empty.", - line_number=idx + 1, - error_source="key_value", - ) - validate_messages(json_line[column], idx) - if not json_line[column][-1].get("role") == "assistant": - raise InvalidFileFormatError( - message=f"The last message in {column} must be from an assistant", - line_number=idx + 1, - error_source="key_value", - ) - # Check that all messages except the last one are the same for "chosen" and "rejected" - chosen_messages = json_line["chosen"] - rejected_messages = json_line["rejected"] - - if len(chosen_messages) != len(rejected_messages): - raise InvalidFileFormatError( - message="The 'chosen' and 'rejected' lists must have the same number of messages.", - line_number=idx + 1, - error_source="key_value", - ) - - # Count discrepancies between messages using a generator - discrepancies = sum( - 1 - for i in range(len(chosen_messages) - 1) - if chosen_messages[i] != rejected_messages[i] - ) - - if discrepancies > 1: - raise InvalidFileFormatError( - message=f"Found {discrepancies} different messages between 'chosen' and 'rejected'. " - "Only the last message should differ.", - line_number=idx + 1, - error_source="key_value", - ) elif current_format == DatasetFormat.CONVERSATION: message_column = JSONL_REQUIRED_COLUMNS_MAP[ DatasetFormat.CONVERSATION From fbd17a6be1dba0a1f4de50d838e04c44a571dbbc Mon Sep 17 00:00:00 2001 From: Vprov Date: Wed, 5 Mar 2025 11:51:56 -0800 Subject: [PATCH 10/13] Remove Preference support from constants; Add unit tests for PREFERENCE_OPENAI --- src/together/constants.py | 2 - tests/unit/test_files_checks.py | 210 +++++++++++++++++++++++++++----- 2 files changed, 181 insertions(+), 31 deletions(-) diff --git a/src/together/constants.py b/src/together/constants.py index 7e35aad6..99e27a4a 100644 --- a/src/together/constants.py +++ b/src/together/constants.py @@ -39,7 +39,6 @@ class DatasetFormat(enum.Enum): GENERAL = "general" CONVERSATION = "conversation" INSTRUCTION = "instruction" - PREFERENCE = "preference" PREFERENCE_OPENAI = "preference_openai" @@ -47,7 +46,6 @@ class DatasetFormat(enum.Enum): DatasetFormat.GENERAL: ["text"], DatasetFormat.CONVERSATION: ["messages"], DatasetFormat.INSTRUCTION: ["prompt", "completion"], - DatasetFormat.PREFERENCE: ["chosen", "rejected"], DatasetFormat.PREFERENCE_OPENAI: [ "input", "preferred_output", diff --git a/tests/unit/test_files_checks.py b/tests/unit/test_files_checks.py index 37c698d2..0705651d 100644 --- a/tests/unit/test_files_checks.py +++ b/tests/unit/test_files_checks.py @@ -5,6 +5,54 @@ from together.constants import MIN_SAMPLES from together.utils.files import check_file +_TEST_PREFERENCE_OPENAI_CONTENT = [ + { + "input": { + "messages": [ + {"role": "user", "content": "Hi there, I have a question."}, + {"role": "assistant", "content": "Hello, how is your day going?"}, + { + "role": "user", + "content": "Hello, can you tell me how cold San Francisco is today?", + }, + ], + }, + "preferred_output": [ + { + "role": "assistant", + "content": "Today in San Francisco, it is not quite cold as expected. Morning clouds will give away " + "to sunshine, with a high near 68°F (20°C) and a low around 57°F (14°C).", + } + ], + "non_preferred_output": [ + { + "role": "assistant", + "content": "It is not particularly cold in San Francisco today.", + } + ], + }, + { + "input": { + "messages": [ + { + "role": "user", + "content": "What's the best way to learn programming?", + }, + ], + }, + "preferred_output": [ + { + "role": "assistant", + "content": "The best way to learn programming is through consistent practice, working on real projects, " + "and breaking down complex problems into smaller parts. Start with a beginner-friendly language like Python.", + } + ], + "non_preferred_output": [ + {"role": "assistant", "content": "Just read some books and you'll be fine."} + ], + }, +] + def test_check_jsonl_valid_general(tmp_path: Path): # Create a valid JSONL file @@ -80,45 +128,149 @@ def test_check_jsonl_valid_conversational_single_turn(tmp_path: Path): def test_check_jsonl_valid_conversational_multiple_turns(tmp_path: Path): # Create a valid JSONL file with conversational format and multiple user-assistant turn pairs file = tmp_path / "valid_conversational_multiple_turns.jsonl" - content = [ + content = _TEST_PREFERENCE_OPENAI_CONTENT + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + + assert report["is_check_passed"] + assert report["utf8"] + assert report["num_samples"] == len(content) + assert report["has_min_samples"] + + +def test_check_jsonl_valid_preference_openai(tmp_path: Path): + file = tmp_path / "valid_preference_openai.jsonl" + content = _TEST_PREFERENCE_OPENAI_CONTENT + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + + assert report["is_check_passed"] + assert report["utf8"] + assert report["num_samples"] == len(content) + assert report["has_min_samples"] + + +def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path): + # Test all required fields in OpenAI preference format + required_fields = [ + ("input", "Missing input field"), + ("preferred_output", "Missing preferred_output field"), + ("non_preferred_output", "Missing non_preferred_output field"), + ] + + for field_to_remove, description in required_fields: + file = tmp_path / f"invalid_preference_openai_missing_{field_to_remove}.jsonl" + content = [item.copy() for item in _TEST_PREFERENCE_OPENAI_CONTENT] + + # Remove the specified field from the first item + del content[0][field_to_remove] + + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + + assert not report["is_check_passed"], f"Test should fail when {description}" + + +def test_check_jsonl_invalid_preference_openai_structural_issues(tmp_path: Path): + # Test various structural issues in OpenAI preference format + test_cases = [ { - "messages": [ - {"role": "user", "content": "Is it going to rain today?"}, + "name": "empty_messages", + "modifier": lambda item: item.update({"input": {"messages": []}}), + "description": "Empty messages array", + }, + { + "name": "missing_role_preferred", + "modifier": lambda item: item.update( + {"preferred_output": [{"content": "Missing role field"}]} + ), + "description": "Missing role in preferred_output", + }, + { + "name": "missing_role_non_preferred", + "modifier": lambda item: item.update( + {"non_preferred_output": [{"content": "Missing role field"}]} + ), + "description": "Missing role in non_preferred_output", + }, + { + "name": "wrong_output_format_preferred", + "modifier": lambda item: item.update( + {"preferred_output": "Not an array but a string"} + ), + "description": "Wrong format for preferred_output", + }, + { + "name": "wrong_output_format_non_preferred", + "modifier": lambda item: item.update( + {"non_preferred_output": "Not an array but a string"} + ), + "description": "Wrong format for non_preferred_output", + }, + { + "name": "missing_content", + "modifier": lambda item: item.update( + {"input": {"messages": [{"role": "user"}]}} + ), + "description": "Missing content in messages", + }, + { + "name": "multiple_preferred_outputs", + "modifier": lambda item: item.update( { - "role": "assistant", - "content": "Yes, expect showers in the afternoon.", - }, - {"role": "user", "content": "What is the weather like in Tokyo?"}, - {"role": "assistant", "content": "It is sunny with a chance of rain."}, - ] + "preferred_output": [ + {"role": "assistant", "content": "First response"}, + {"role": "assistant", "content": "Second response"}, + ] + } + ), + "description": "Multiple messages in preferred_output", }, { - "messages": [ - {"role": "user", "content": "Who won the game last night?"}, - {"role": "assistant", "content": "The home team won by two points."}, - {"role": "user", "content": "What is the weather like in Amsterdam?"}, - {"role": "assistant", "content": "It is cloudy with a chance of snow."}, - ] + "name": "multiple_non_preferred_outputs", + "modifier": lambda item: item.update( + { + "non_preferred_output": [ + {"role": "assistant", "content": "First response"}, + {"role": "assistant", "content": "Second response"}, + ] + } + ), + "description": "Multiple messages in non_preferred_output", }, { - "messages": [ - {"role": "system", "content": "You are a kind AI"}, - {"role": "user", "content": "Who won the game last night?"}, - {"role": "assistant", "content": "The home team won by two points."}, - {"role": "user", "content": "What is the weather like in Amsterdam?"}, - {"role": "assistant", "content": "It is cloudy with a chance of snow."}, - ] + "name": "empty_preferred_output", + "modifier": lambda item: item.update({"preferred_output": []}), + "description": "Empty preferred_output array", + }, + { + "name": "empty_non_preferred_output", + "modifier": lambda item: item.update({"non_preferred_output": []}), + "description": "Empty non_preferred_output array", }, ] - with file.open("w") as f: - f.write("\n".join(json.dumps(item) for item in content)) - report = check_file(file) + for test_case in test_cases: + file = tmp_path / f"invalid_preference_openai_{test_case['name']}.jsonl" + content = [item.copy() for item in _TEST_PREFERENCE_OPENAI_CONTENT] - assert report["is_check_passed"] - assert report["utf8"] - assert report["num_samples"] == len(content) - assert report["has_min_samples"] + # Apply the modification to the first item + test_case["modifier"](content[0]) + + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + + assert not report[ + "is_check_passed" + ], f"Test should fail with {test_case['description']}" def test_check_jsonl_empty_file(tmp_path: Path): From bf0b180a5ab4d12d41a5da913aa6b6c90ea9d3ae Mon Sep 17 00:00:00 2001 From: Vprov Date: Tue, 11 Mar 2025 05:50:43 -0700 Subject: [PATCH 11/13] Add type checks and style improvements --- src/together/resources/finetune.py | 16 ++++++++++++---- src/together/types/finetune.py | 8 ++++---- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 624446de..0db3e98b 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Literal, Union +from typing import Literal from rich import print as rprint @@ -57,6 +57,7 @@ def createFinetuneRequest( training_method: str = "sft", dpo_beta: float | None = None, ) -> FinetuneRequest: + if batch_size == "max": log_warn_once( "Starting from together>=1.3.0, " @@ -104,14 +105,21 @@ def createFinetuneRequest( if weight_decay is not None and (weight_decay < 0): raise ValueError("Weight decay should be non-negative") + AVAILABLE_TRAINING_METHODS = { + TrainingMethodSFT().method, + TrainingMethodDPO().method, + } + if training_method not in AVAILABLE_TRAINING_METHODS: + raise ValueError( + f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}" + ) + lrScheduler = FinetuneLRScheduler( lr_scheduler_type="linear", lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), ) - training_method_cls: Union[TrainingMethodSFT, TrainingMethodDPO] = ( - TrainingMethodSFT() - ) + training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT() if training_method == "dpo": training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta) diff --git a/src/together/types/finetune.py b/src/together/types/finetune.py index 69b22d7a..3d68cca7 100644 --- a/src/together/types/finetune.py +++ b/src/together/types/finetune.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import List, Literal, Union +from typing import List, Literal from pydantic import StrictBool, Field, validator, field_validator @@ -148,7 +148,7 @@ class TrainingMethodSFT(TrainingMethod): Training method type for SFT training """ - method: str = "sft" + method: Literal["sft"] = "sft" class TrainingMethodDPO(TrainingMethod): @@ -156,7 +156,7 @@ class TrainingMethodDPO(TrainingMethod): Training method type for DPO training """ - method: str = "dpo" + method: Literal["dpo"] = "dpo" dpo_beta: float | None = None @@ -204,7 +204,7 @@ class FinetuneRequest(BaseModel): # train on inputs train_on_inputs: StrictBool | Literal["auto"] = "auto" # training method - training_method: Union[TrainingMethodSFT, TrainingMethodDPO] = Field( + training_method: TrainingMethodSFT | TrainingMethodDPO = Field( default_factory=TrainingMethodSFT ) From 7357926dc3cedd4e2d33cdb2b197979283d27785 Mon Sep 17 00:00:00 2001 From: Vprov Date: Tue, 11 Mar 2025 07:12:07 -0700 Subject: [PATCH 12/13] Move tests to another file; Add more test cases for openai format --- src/together/utils/files.py | 95 ++++----- tests/unit/test_files_checks.py | 210 +++---------------- tests/unit/test_preference_openai.py | 288 +++++++++++++++++++++++++++ 3 files changed, 351 insertions(+), 242 deletions(-) create mode 100644 tests/unit/test_preference_openai.py diff --git a/src/together/utils/files.py b/src/together/utils/files.py index eeef6498..fcdeb2f3 100644 --- a/src/together/utils/files.py +++ b/src/together/utils/files.py @@ -96,21 +96,9 @@ def check_file( return report_dict -def _has_weights(messages: List[Dict[str, str | bool]]) -> bool: - """Check if any message in the conversation has a weight parameter. - - Args: - messages (List[Dict[str, str]]): List of messages to check. - - Returns: - bool: True if any message has a weight parameter, False otherwise. - """ - return any("weight" in message for message in messages) - - def validate_messages( - messages: List[Dict[str, str | bool]], idx: int = 0 -) -> tuple[List[Dict[str, str | bool]], bool]: + messages: List[Dict[str, str | bool]], idx: int +) -> None: """Validate the messages column.""" if not isinstance(messages, list): raise InvalidFileFormatError( @@ -127,10 +115,7 @@ def validate_messages( error_source="key_value", ) - has_weights = False - # Check for weights in messages - if _has_weights(messages): - has_weights = True + has_weights = any("weight" in message for message in messages) previous_role = None for message in messages: @@ -189,10 +174,8 @@ def validate_messages( ) previous_role = message["role"] - return messages, has_weights - -def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> Dict[str, Any]: +def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None: """Validate the OpenAI preference dataset format. Args: @@ -201,9 +184,6 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> Dict[st Raises: InvalidFileFormatError: If the dataset format is invalid. - - Returns: - Dict[str, Any]: The validated example. """ if not isinstance(example["input"], dict): raise InvalidFileFormatError( @@ -219,43 +199,38 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> Dict[st error_source="key_value", ) - example["input"]["messages"], _ = validate_messages( - example["input"]["messages"], idx - ) - - if not isinstance(example["preferred_output"], list): - raise InvalidFileFormatError( - message="The dataset is malformed, the `preferred_output` field must be a list.", - line_number=idx + 1, - error_source="key_value", - ) - - if not isinstance(example["non_preferred_output"], list): - raise InvalidFileFormatError( - message="The dataset is malformed, the `non_preferred_output` field must be a list.", - line_number=idx + 1, - error_source="key_value", - ) + validate_messages(example["input"]["messages"], idx) - if len(example["preferred_output"]) != 1: - raise InvalidFileFormatError( - message="The dataset is malformed, the `preferred_output` list must contain exactly one message.", - line_number=idx + 1, - error_source="key_value", - ) + for output_field in ["preferred_output", "non_preferred_output"]: + if not isinstance(example[output_field], list): + raise InvalidFileFormatError( + message=f"The dataset is malformed, the `{output_field}` field must be a list.", + line_number=idx + 1, + error_source="key_value", + ) - if len(example["non_preferred_output"]) != 1: - raise InvalidFileFormatError( - message="The dataset is malformed, the `non_preferred_output` list must contain exactly one message.", - line_number=idx + 1, - error_source="key_value", - ) + if len(example[output_field]) != 1: + raise InvalidFileFormatError( + message=f"The dataset is malformed, the `{output_field}` list must contain exactly one message.", + line_number=idx + 1, + error_source="key_value", + ) + if "role" not in example[output_field][0]: + raise InvalidFileFormatError( + message=f"The dataset is malformed, the `{output_field}` message is missing the `role` field.", + line_number=idx + 1, + error_source="key_value", + ) + elif example[output_field][0]["role"] != "assistant": + raise InvalidFileFormatError( + message=f"The dataset is malformed, the `{output_field}` must contain an assistant message.", + line_number=idx + 1, + error_source="key_value", + ) + - example["preferred_output"], _ = validate_messages(example["preferred_output"], idx) - example["non_preferred_output"], _ = validate_messages( - example["non_preferred_output"], idx - ) - return example + validate_messages(example["preferred_output"], idx) + validate_messages(example["non_preferred_output"], idx) def _check_jsonl(file: Path) -> Dict[str, Any]: @@ -332,9 +307,7 @@ def _check_jsonl(file: Path) -> Dict[str, Any]: message_column = JSONL_REQUIRED_COLUMNS_MAP[ DatasetFormat.CONVERSATION ][0] - messages, has_weights = validate_messages( - json_line[message_column], idx - ) + validate_messages(json_line[message_column], idx) else: for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]: if not isinstance(json_line[column], str): diff --git a/tests/unit/test_files_checks.py b/tests/unit/test_files_checks.py index 0705651d..37c698d2 100644 --- a/tests/unit/test_files_checks.py +++ b/tests/unit/test_files_checks.py @@ -5,54 +5,6 @@ from together.constants import MIN_SAMPLES from together.utils.files import check_file -_TEST_PREFERENCE_OPENAI_CONTENT = [ - { - "input": { - "messages": [ - {"role": "user", "content": "Hi there, I have a question."}, - {"role": "assistant", "content": "Hello, how is your day going?"}, - { - "role": "user", - "content": "Hello, can you tell me how cold San Francisco is today?", - }, - ], - }, - "preferred_output": [ - { - "role": "assistant", - "content": "Today in San Francisco, it is not quite cold as expected. Morning clouds will give away " - "to sunshine, with a high near 68°F (20°C) and a low around 57°F (14°C).", - } - ], - "non_preferred_output": [ - { - "role": "assistant", - "content": "It is not particularly cold in San Francisco today.", - } - ], - }, - { - "input": { - "messages": [ - { - "role": "user", - "content": "What's the best way to learn programming?", - }, - ], - }, - "preferred_output": [ - { - "role": "assistant", - "content": "The best way to learn programming is through consistent practice, working on real projects, " - "and breaking down complex problems into smaller parts. Start with a beginner-friendly language like Python.", - } - ], - "non_preferred_output": [ - {"role": "assistant", "content": "Just read some books and you'll be fine."} - ], - }, -] - def test_check_jsonl_valid_general(tmp_path: Path): # Create a valid JSONL file @@ -128,149 +80,45 @@ def test_check_jsonl_valid_conversational_single_turn(tmp_path: Path): def test_check_jsonl_valid_conversational_multiple_turns(tmp_path: Path): # Create a valid JSONL file with conversational format and multiple user-assistant turn pairs file = tmp_path / "valid_conversational_multiple_turns.jsonl" - content = _TEST_PREFERENCE_OPENAI_CONTENT - with file.open("w") as f: - f.write("\n".join(json.dumps(item) for item in content)) - - report = check_file(file) - - assert report["is_check_passed"] - assert report["utf8"] - assert report["num_samples"] == len(content) - assert report["has_min_samples"] - - -def test_check_jsonl_valid_preference_openai(tmp_path: Path): - file = tmp_path / "valid_preference_openai.jsonl" - content = _TEST_PREFERENCE_OPENAI_CONTENT - with file.open("w") as f: - f.write("\n".join(json.dumps(item) for item in content)) - - report = check_file(file) - - assert report["is_check_passed"] - assert report["utf8"] - assert report["num_samples"] == len(content) - assert report["has_min_samples"] - - -def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path): - # Test all required fields in OpenAI preference format - required_fields = [ - ("input", "Missing input field"), - ("preferred_output", "Missing preferred_output field"), - ("non_preferred_output", "Missing non_preferred_output field"), - ] - - for field_to_remove, description in required_fields: - file = tmp_path / f"invalid_preference_openai_missing_{field_to_remove}.jsonl" - content = [item.copy() for item in _TEST_PREFERENCE_OPENAI_CONTENT] - - # Remove the specified field from the first item - del content[0][field_to_remove] - - with file.open("w") as f: - f.write("\n".join(json.dumps(item) for item in content)) - - report = check_file(file) - - assert not report["is_check_passed"], f"Test should fail when {description}" - - -def test_check_jsonl_invalid_preference_openai_structural_issues(tmp_path: Path): - # Test various structural issues in OpenAI preference format - test_cases = [ - { - "name": "empty_messages", - "modifier": lambda item: item.update({"input": {"messages": []}}), - "description": "Empty messages array", - }, - { - "name": "missing_role_preferred", - "modifier": lambda item: item.update( - {"preferred_output": [{"content": "Missing role field"}]} - ), - "description": "Missing role in preferred_output", - }, - { - "name": "missing_role_non_preferred", - "modifier": lambda item: item.update( - {"non_preferred_output": [{"content": "Missing role field"}]} - ), - "description": "Missing role in non_preferred_output", - }, - { - "name": "wrong_output_format_preferred", - "modifier": lambda item: item.update( - {"preferred_output": "Not an array but a string"} - ), - "description": "Wrong format for preferred_output", - }, - { - "name": "wrong_output_format_non_preferred", - "modifier": lambda item: item.update( - {"non_preferred_output": "Not an array but a string"} - ), - "description": "Wrong format for non_preferred_output", - }, - { - "name": "missing_content", - "modifier": lambda item: item.update( - {"input": {"messages": [{"role": "user"}]}} - ), - "description": "Missing content in messages", - }, - { - "name": "multiple_preferred_outputs", - "modifier": lambda item: item.update( - { - "preferred_output": [ - {"role": "assistant", "content": "First response"}, - {"role": "assistant", "content": "Second response"}, - ] - } - ), - "description": "Multiple messages in preferred_output", - }, + content = [ { - "name": "multiple_non_preferred_outputs", - "modifier": lambda item: item.update( + "messages": [ + {"role": "user", "content": "Is it going to rain today?"}, { - "non_preferred_output": [ - {"role": "assistant", "content": "First response"}, - {"role": "assistant", "content": "Second response"}, - ] - } - ), - "description": "Multiple messages in non_preferred_output", + "role": "assistant", + "content": "Yes, expect showers in the afternoon.", + }, + {"role": "user", "content": "What is the weather like in Tokyo?"}, + {"role": "assistant", "content": "It is sunny with a chance of rain."}, + ] }, { - "name": "empty_preferred_output", - "modifier": lambda item: item.update({"preferred_output": []}), - "description": "Empty preferred_output array", + "messages": [ + {"role": "user", "content": "Who won the game last night?"}, + {"role": "assistant", "content": "The home team won by two points."}, + {"role": "user", "content": "What is the weather like in Amsterdam?"}, + {"role": "assistant", "content": "It is cloudy with a chance of snow."}, + ] }, { - "name": "empty_non_preferred_output", - "modifier": lambda item: item.update({"non_preferred_output": []}), - "description": "Empty non_preferred_output array", + "messages": [ + {"role": "system", "content": "You are a kind AI"}, + {"role": "user", "content": "Who won the game last night?"}, + {"role": "assistant", "content": "The home team won by two points."}, + {"role": "user", "content": "What is the weather like in Amsterdam?"}, + {"role": "assistant", "content": "It is cloudy with a chance of snow."}, + ] }, ] + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) - for test_case in test_cases: - file = tmp_path / f"invalid_preference_openai_{test_case['name']}.jsonl" - content = [item.copy() for item in _TEST_PREFERENCE_OPENAI_CONTENT] - - # Apply the modification to the first item - test_case["modifier"](content[0]) - - with file.open("w") as f: - f.write("\n".join(json.dumps(item) for item in content)) - - report = check_file(file) + report = check_file(file) - assert not report[ - "is_check_passed" - ], f"Test should fail with {test_case['description']}" + assert report["is_check_passed"] + assert report["utf8"] + assert report["num_samples"] == len(content) + assert report["has_min_samples"] def test_check_jsonl_empty_file(tmp_path: Path): diff --git a/tests/unit/test_preference_openai.py b/tests/unit/test_preference_openai.py new file mode 100644 index 00000000..8f4e2d5d --- /dev/null +++ b/tests/unit/test_preference_openai.py @@ -0,0 +1,288 @@ +import json +import pytest +from pathlib import Path + +from together.constants import MIN_SAMPLES +from together.utils.files import check_file + +# Test data for preference OpenAI format +_TEST_PREFERENCE_OPENAI_CONTENT = [ + { + "input": { + "messages": [ + {"role": "user", "content": "Hi there, I have a question."}, + {"role": "assistant", "content": "Hello, how is your day going?"}, + { + "role": "user", + "content": "Hello, can you tell me how cold San Francisco is today?", + }, + ], + }, + "preferred_output": [ + { + "role": "assistant", + "content": "Today in San Francisco, it is not quite cold as expected. Morning clouds will give away " + "to sunshine, with a high near 68°F (20°C) and a low around 57°F (14°C).", + } + ], + "non_preferred_output": [ + { + "role": "assistant", + "content": "It is not particularly cold in San Francisco today.", + } + ], + }, + { + "input": { + "messages": [ + { + "role": "user", + "content": "What's the best way to learn programming?", + }, + ], + }, + "preferred_output": [ + { + "role": "assistant", + "content": "The best way to learn programming is through consistent practice, working on real projects, " + "and breaking down complex problems into smaller parts. Start with a beginner-friendly language like Python.", + } + ], + "non_preferred_output": [ + {"role": "assistant", "content": "Just read some books and you'll be fine."} + ], + }, +] + + +def test_check_jsonl_valid_preference_openai(tmp_path: Path): + """Test valid preference OpenAI format.""" + file = tmp_path / "valid_preference_openai.jsonl" + content = _TEST_PREFERENCE_OPENAI_CONTENT + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + + assert report["is_check_passed"] + assert report["utf8"] + assert report["num_samples"] == len(content) + assert report["has_min_samples"] + + +# Define test cases for missing fields +MISSING_FIELDS_TEST_CASES = [ + pytest.param("input", "Missing input field", id="missing_input"), + pytest.param("preferred_output", "Missing preferred_output field", id="missing_preferred_output"), + pytest.param("non_preferred_output", "Missing non_preferred_output field", id="missing_non_preferred_output"), +] + + +@pytest.mark.parametrize("field_to_remove, description", MISSING_FIELDS_TEST_CASES) +def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path, field_to_remove, description): + """Test missing required fields in OpenAI preference format.""" + file = tmp_path / f"invalid_preference_openai_missing_{field_to_remove}.jsonl" + content = [item.copy() for item in _TEST_PREFERENCE_OPENAI_CONTENT] + + # Remove the specified field from the first item + del content[0][field_to_remove] + + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + + assert not report["is_check_passed"], f"Test should fail when {description}" + + +# Define test cases for structural issues +STRUCTURAL_ISSUE_TEST_CASES = [ + pytest.param( + "empty_messages", + lambda item: item.update({"input": {"messages": []}}), + "Empty messages array", + id="empty_messages" + ), + pytest.param( + "missing_role_preferred", + lambda item: item.update( + {"preferred_output": [{"content": "Missing role field"}]} + ), + "Missing role in preferred_output", + id="missing_role_preferred" + ), + pytest.param( + "missing_role_non_preferred", + lambda item: item.update( + {"non_preferred_output": [{"content": "Missing role field"}]} + ), + "Missing role in non_preferred_output", + id="missing_role_non_preferred" + ), + pytest.param( + "missing_content_preferred", + lambda item: item.update( + {"preferred_output": [{"role": "assistant"}]} + ), + "Missing content in preferred_output", + id="missing_content_preferred" + ), + pytest.param( + "missing_content_non_preferred", + lambda item: item.update( + {"non_preferred_output": [{"role": "assistant"}]} + ), + "Missing content in non_preferred_output", + id="missing_content_non_preferred" + ), + pytest.param( + "wrong_output_format_preferred", + lambda item: item.update( + {"preferred_output": "Not an array but a string"} + ), + "Wrong format for preferred_output", + id="wrong_output_format_preferred" + ), + pytest.param( + "wrong_output_format_non_preferred", + lambda item: item.update( + {"non_preferred_output": "Not an array but a string"} + ), + "Wrong format for non_preferred_output", + id="wrong_output_format_non_preferred" + ), + pytest.param( + "missing_content", + lambda item: item.update( + {"input": {"messages": [{"role": "user"}]}} + ), + "Missing content in messages", + id="missing_content" + ), + pytest.param( + "multiple_preferred_outputs", + lambda item: item.update( + { + "preferred_output": [ + {"role": "assistant", "content": "First response"}, + {"role": "assistant", "content": "Second response"}, + ] + } + ), + "Multiple messages in preferred_output", + id="multiple_preferred_outputs" + ), + pytest.param( + "multiple_non_preferred_outputs", + lambda item: item.update( + { + "non_preferred_output": [ + {"role": "assistant", "content": "First response"}, + {"role": "assistant", "content": "Second response"}, + ] + } + ), + "Multiple messages in non_preferred_output", + id="multiple_non_preferred_outputs" + ), + pytest.param( + "empty_preferred_output", + lambda item: item.update({"preferred_output": []}), + "Empty preferred_output array", + id="empty_preferred_output" + ), + pytest.param( + "empty_non_preferred_output", + lambda item: item.update({"non_preferred_output": []}), + "Empty non_preferred_output array", + id="empty_non_preferred_output" + ), + pytest.param( + "non_string_content_in_messages", + lambda item: item.update({"input": {"messages": [{"role": "user", "content": 123}]}}), + "Non-string content in messages", + id="non_string_content_in_messages" + ), + pytest.param( + "invalid_role_in_messages", + lambda item: item.update({"input": {"messages": [{"role": "invalid_role", "content": "Hello"}]}}), + "Invalid role in messages", + id="invalid_role_in_messages" + ), + pytest.param( + "non_alternating_roles", + lambda item: item.update({"input": {"messages": [ + {"role": "user", "content": "Hello"}, + {"role": "user", "content": "How are you?"} + ]}}), + "Non-alternating roles in messages", + id="non_alternating_roles" + ), + pytest.param( + "invalid_weight_type", + lambda item: item.update({"input": {"messages": [ + {"role": "user", "content": "Hello", "weight": "not_an_integer"} + ]}}), + "Invalid weight type", + id="invalid_weight_type" + ), + pytest.param( + "invalid_weight_value", + lambda item: item.update({"input": {"messages": [ + {"role": "user", "content": "Hello", "weight": 2} + ]}}), + "Invalid weight value", + id="invalid_weight_value" + ), + pytest.param( + "non_dict_message", + lambda item: item.update({"input": {"messages": [ + "Not a dictionary" + ]}}), + "Non-dictionary message", + id="non_dict_message" + ), + pytest.param( + "non_dict_input", + lambda item: item.update({"input": "Not a dictionary"}), + "Non-dictionary input", + id="non_dict_input" + ), + pytest.param( + "missing_messages_in_input", + lambda item: item.update({"input": {}}), + "Missing messages in input", + id="missing_messages_in_input" + ), + pytest.param( + "non_assistant_role_in_preferred", + lambda item: item.update({"preferred_output": [{"role": "user", "content": "This should be assistant"}]}), + "Non-assistant role in preferred output", + id="non_assistant_role_in_preferred" + ), + pytest.param( + "non_assistant_role_in_non_preferred", + lambda item: item.update({"non_preferred_output": [{"role": "user", "content": "This should be assistant"}]}), + "Non-assistant role in non-preferred output", + id="non_assistant_role_in_non_preferred" + ), +] + + +@pytest.mark.parametrize("name, modifier, description", STRUCTURAL_ISSUE_TEST_CASES) +def test_check_jsonl_invalid_preference_openai_structural_issues( + tmp_path: Path, name, modifier, description +): + """Test various structural issues in OpenAI preference format.""" + file = tmp_path / f"invalid_preference_openai_{name}.jsonl" + content = [item.copy() for item in _TEST_PREFERENCE_OPENAI_CONTENT] + + # Apply the modification to the first item + modifier(content[0]) + + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + + assert not report["is_check_passed"], f"Test should fail with {description}" From 5f00d958ac3d4251dcbc9a68e0c7914b53f32b75 Mon Sep 17 00:00:00 2001 From: Vprov Date: Tue, 11 Mar 2025 08:46:11 -0700 Subject: [PATCH 13/13] Small style changes --- src/together/resources/finetune.py | 10 +- src/together/utils/files.py | 5 +- tests/unit/test_preference_openai.py | 144 ++++++++++++++++----------- 3 files changed, 91 insertions(+), 68 deletions(-) diff --git a/src/together/resources/finetune.py b/src/together/resources/finetune.py index 0db3e98b..52c2960d 100644 --- a/src/together/resources/finetune.py +++ b/src/together/resources/finetune.py @@ -29,6 +29,12 @@ from together.utils import log_warn_once, normalize_key +AVAILABLE_TRAINING_METHODS = { + TrainingMethodSFT().method, + TrainingMethodDPO().method, +} + + def createFinetuneRequest( model_limits: FinetuneTrainingLimits, training_file: str, @@ -105,10 +111,6 @@ def createFinetuneRequest( if weight_decay is not None and (weight_decay < 0): raise ValueError("Weight decay should be non-negative") - AVAILABLE_TRAINING_METHODS = { - TrainingMethodSFT().method, - TrainingMethodDPO().method, - } if training_method not in AVAILABLE_TRAINING_METHODS: raise ValueError( f"training_method must be one of {', '.join(AVAILABLE_TRAINING_METHODS)}" diff --git a/src/together/utils/files.py b/src/together/utils/files.py index fcdeb2f3..e1e1d4ed 100644 --- a/src/together/utils/files.py +++ b/src/together/utils/files.py @@ -96,9 +96,7 @@ def check_file( return report_dict -def validate_messages( - messages: List[Dict[str, str | bool]], idx: int -) -> None: +def validate_messages(messages: List[Dict[str, str | bool]], idx: int) -> None: """Validate the messages column.""" if not isinstance(messages, list): raise InvalidFileFormatError( @@ -227,7 +225,6 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None: line_number=idx + 1, error_source="key_value", ) - validate_messages(example["preferred_output"], idx) validate_messages(example["non_preferred_output"], idx) diff --git a/tests/unit/test_preference_openai.py b/tests/unit/test_preference_openai.py index 8f4e2d5d..3781c830 100644 --- a/tests/unit/test_preference_openai.py +++ b/tests/unit/test_preference_openai.py @@ -5,7 +5,7 @@ from together.constants import MIN_SAMPLES from together.utils.files import check_file -# Test data for preference OpenAI format + _TEST_PREFERENCE_OPENAI_CONTENT = [ { "input": { @@ -70,16 +70,25 @@ def test_check_jsonl_valid_preference_openai(tmp_path: Path): assert report["has_min_samples"] -# Define test cases for missing fields MISSING_FIELDS_TEST_CASES = [ pytest.param("input", "Missing input field", id="missing_input"), - pytest.param("preferred_output", "Missing preferred_output field", id="missing_preferred_output"), - pytest.param("non_preferred_output", "Missing non_preferred_output field", id="missing_non_preferred_output"), + pytest.param( + "preferred_output", + "Missing preferred_output field", + id="missing_preferred_output", + ), + pytest.param( + "non_preferred_output", + "Missing non_preferred_output field", + id="missing_non_preferred_output", + ), ] @pytest.mark.parametrize("field_to_remove, description", MISSING_FIELDS_TEST_CASES) -def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path, field_to_remove, description): +def test_check_jsonl_invalid_preference_openai_missing_fields( + tmp_path: Path, field_to_remove, description +): """Test missing required fields in OpenAI preference format.""" file = tmp_path / f"invalid_preference_openai_missing_{field_to_remove}.jsonl" content = [item.copy() for item in _TEST_PREFERENCE_OPENAI_CONTENT] @@ -95,13 +104,12 @@ def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path, fi assert not report["is_check_passed"], f"Test should fail when {description}" -# Define test cases for structural issues STRUCTURAL_ISSUE_TEST_CASES = [ pytest.param( "empty_messages", lambda item: item.update({"input": {"messages": []}}), "Empty messages array", - id="empty_messages" + id="empty_messages", ), pytest.param( "missing_role_preferred", @@ -109,7 +117,7 @@ def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path, fi {"preferred_output": [{"content": "Missing role field"}]} ), "Missing role in preferred_output", - id="missing_role_preferred" + id="missing_role_preferred", ), pytest.param( "missing_role_non_preferred", @@ -117,47 +125,37 @@ def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path, fi {"non_preferred_output": [{"content": "Missing role field"}]} ), "Missing role in non_preferred_output", - id="missing_role_non_preferred" + id="missing_role_non_preferred", ), pytest.param( "missing_content_preferred", - lambda item: item.update( - {"preferred_output": [{"role": "assistant"}]} - ), + lambda item: item.update({"preferred_output": [{"role": "assistant"}]}), "Missing content in preferred_output", - id="missing_content_preferred" + id="missing_content_preferred", ), pytest.param( "missing_content_non_preferred", - lambda item: item.update( - {"non_preferred_output": [{"role": "assistant"}]} - ), + lambda item: item.update({"non_preferred_output": [{"role": "assistant"}]}), "Missing content in non_preferred_output", - id="missing_content_non_preferred" + id="missing_content_non_preferred", ), pytest.param( "wrong_output_format_preferred", - lambda item: item.update( - {"preferred_output": "Not an array but a string"} - ), + lambda item: item.update({"preferred_output": "Not an array but a string"}), "Wrong format for preferred_output", - id="wrong_output_format_preferred" + id="wrong_output_format_preferred", ), pytest.param( "wrong_output_format_non_preferred", - lambda item: item.update( - {"non_preferred_output": "Not an array but a string"} - ), + lambda item: item.update({"non_preferred_output": "Not an array but a string"}), "Wrong format for non_preferred_output", - id="wrong_output_format_non_preferred" + id="wrong_output_format_non_preferred", ), pytest.param( "missing_content", - lambda item: item.update( - {"input": {"messages": [{"role": "user"}]}} - ), + lambda item: item.update({"input": {"messages": [{"role": "user"}]}}), "Missing content in messages", - id="missing_content" + id="missing_content", ), pytest.param( "multiple_preferred_outputs", @@ -170,7 +168,7 @@ def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path, fi } ), "Multiple messages in preferred_output", - id="multiple_preferred_outputs" + id="multiple_preferred_outputs", ), pytest.param( "multiple_non_preferred_outputs", @@ -183,88 +181,114 @@ def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path, fi } ), "Multiple messages in non_preferred_output", - id="multiple_non_preferred_outputs" + id="multiple_non_preferred_outputs", ), pytest.param( "empty_preferred_output", lambda item: item.update({"preferred_output": []}), "Empty preferred_output array", - id="empty_preferred_output" + id="empty_preferred_output", ), pytest.param( "empty_non_preferred_output", lambda item: item.update({"non_preferred_output": []}), "Empty non_preferred_output array", - id="empty_non_preferred_output" + id="empty_non_preferred_output", ), pytest.param( "non_string_content_in_messages", - lambda item: item.update({"input": {"messages": [{"role": "user", "content": 123}]}}), + lambda item: item.update( + {"input": {"messages": [{"role": "user", "content": 123}]}} + ), "Non-string content in messages", - id="non_string_content_in_messages" + id="non_string_content_in_messages", ), pytest.param( "invalid_role_in_messages", - lambda item: item.update({"input": {"messages": [{"role": "invalid_role", "content": "Hello"}]}}), + lambda item: item.update( + {"input": {"messages": [{"role": "invalid_role", "content": "Hello"}]}} + ), "Invalid role in messages", - id="invalid_role_in_messages" + id="invalid_role_in_messages", ), pytest.param( "non_alternating_roles", - lambda item: item.update({"input": {"messages": [ - {"role": "user", "content": "Hello"}, - {"role": "user", "content": "How are you?"} - ]}}), + lambda item: item.update( + { + "input": { + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "user", "content": "How are you?"}, + ] + } + } + ), "Non-alternating roles in messages", - id="non_alternating_roles" + id="non_alternating_roles", ), pytest.param( "invalid_weight_type", - lambda item: item.update({"input": {"messages": [ - {"role": "user", "content": "Hello", "weight": "not_an_integer"} - ]}}), + lambda item: item.update( + { + "input": { + "messages": [ + {"role": "user", "content": "Hello", "weight": "not_an_integer"} + ] + } + } + ), "Invalid weight type", - id="invalid_weight_type" + id="invalid_weight_type", ), pytest.param( "invalid_weight_value", - lambda item: item.update({"input": {"messages": [ - {"role": "user", "content": "Hello", "weight": 2} - ]}}), + lambda item: item.update( + {"input": {"messages": [{"role": "user", "content": "Hello", "weight": 2}]}} + ), "Invalid weight value", - id="invalid_weight_value" + id="invalid_weight_value", ), pytest.param( "non_dict_message", - lambda item: item.update({"input": {"messages": [ - "Not a dictionary" - ]}}), + lambda item: item.update({"input": {"messages": ["Not a dictionary"]}}), "Non-dictionary message", - id="non_dict_message" + id="non_dict_message", ), pytest.param( "non_dict_input", lambda item: item.update({"input": "Not a dictionary"}), "Non-dictionary input", - id="non_dict_input" + id="non_dict_input", ), pytest.param( "missing_messages_in_input", lambda item: item.update({"input": {}}), "Missing messages in input", - id="missing_messages_in_input" + id="missing_messages_in_input", ), pytest.param( "non_assistant_role_in_preferred", - lambda item: item.update({"preferred_output": [{"role": "user", "content": "This should be assistant"}]}), + lambda item: item.update( + { + "preferred_output": [ + {"role": "user", "content": "This should be assistant"} + ] + } + ), "Non-assistant role in preferred output", - id="non_assistant_role_in_preferred" + id="non_assistant_role_in_preferred", ), pytest.param( "non_assistant_role_in_non_preferred", - lambda item: item.update({"non_preferred_output": [{"role": "user", "content": "This should be assistant"}]}), + lambda item: item.update( + { + "non_preferred_output": [ + {"role": "user", "content": "This should be assistant"} + ] + } + ), "Non-assistant role in non-preferred output", - id="non_assistant_role_in_non_preferred" + id="non_assistant_role_in_non_preferred", ), ]