From fd5ad686a7b5a04730d692281f8a246b5ff5a51d Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Mon, 2 Feb 2026 17:30:32 +0000 Subject: [PATCH] Remove DPO (Direct Preference Optimization) feature --- src/MaxText/__init__.py | 1 - src/MaxText/gradient_accumulation.py | 6 +- .../input_pipeline/_grain_data_processing.py | 155 ++++-------------- .../input_pipeline/_hf_data_processing.py | 16 +- .../input_pipeline/_tfds_data_processing.py | 24 +-- src/MaxText/train.py | 56 +------ src/maxtext/common/metric_logger.py | 8 - src/maxtext/configs/base.yml | 9 +- src/maxtext/configs/post_train/dpo.yml | 31 ---- src/maxtext/configs/types.py | 7 +- src/maxtext/experimental/rl/grpo_trainer.py | 2 - .../trainers/post_train/dpo/dpo_utils.py | 150 ----------------- src/maxtext/utils/maxtext_utils.py | 10 +- src/maxtext/utils/train_utils.py | 34 ---- tests/end_to_end/tpu/test_dpo.sh | 36 ---- tests/unit/configs_test.py | 1 - tests/unit/sft_data_processing_test.py | 1 - 17 files changed, 56 insertions(+), 491 deletions(-) delete mode 100644 src/maxtext/configs/post_train/dpo.yml delete mode 100644 src/maxtext/trainers/post_train/dpo/dpo_utils.py delete mode 100644 tests/end_to_end/tpu/test_dpo.sh diff --git a/src/MaxText/__init__.py b/src/MaxText/__init__.py index 1eca5831b8..154531ed8a 100644 --- a/src/MaxText/__init__.py +++ b/src/MaxText/__init__.py @@ -31,7 +31,6 @@ from MaxText import pyconfig from MaxText.layers import models -from maxtext.trainers.post_train.dpo import dpo_utils from maxtext.utils import maxtext_utils from maxtext.utils import model_creation_utils from maxtext.utils.model_creation_utils import from_config diff --git a/src/MaxText/gradient_accumulation.py b/src/MaxText/gradient_accumulation.py index f2bf2ffdb0..44d1ed4338 100644 --- a/src/MaxText/gradient_accumulation.py +++ b/src/MaxText/gradient_accumulation.py @@ -30,7 +30,6 @@ def gradient_accumulation_loss_and_grad( params_shardings, data, dropout_rng, - extra_dpo_args, ): """ Calculates gradients using gradient accumulation. @@ -45,7 +44,7 @@ def gradient_accumulation_loss_and_grad( Args: _loss_fn: The loss function to differentiate. Its signature is expected - to be: `(model, config, data, dropout_rng, params, *extra_args, is_train=True)`. + to be: `(model, config, data, dropout_rng, params, is_train=True)`. config: Model and training configuration object. Must contain `gradient_accumulation_steps` and `shard_optimizer_over_data`. model: The model module. @@ -54,7 +53,6 @@ def gradient_accumulation_loss_and_grad( data: A PyTree of batched data. The leading dimension is assumed to be the total batch size (microbatch_size * num_accumulations). dropout_rng: JAX PRNGKey for dropout. - extra_dpo_args: A tuple of extra arguments to pass to the loss function. Returns: A tuple containing: @@ -91,7 +89,7 @@ def convert_to_bf16(param): def accumulate_gradient(acc_grad_and_loss, data): ga_params = acc_grad_and_loss["ga_params"] - (_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, *extra_dpo_args, is_train=True) + (_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, is_train=True) acc_grad_and_loss["loss"] += aux["total_loss"] acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"] acc_grad_and_loss["mtp_loss"] += aux["mtp_loss"] diff --git a/src/MaxText/input_pipeline/_grain_data_processing.py b/src/MaxText/input_pipeline/_grain_data_processing.py index 61258ca493..14016a8cfc 100644 --- a/src/MaxText/input_pipeline/_grain_data_processing.py +++ b/src/MaxText/input_pipeline/_grain_data_processing.py @@ -303,57 +303,6 @@ def pretrain_preprocessing_pipeline( return dataset -def dpo_preprocessing_pipeline( - dataset, - config, - data_columns, - tokenize, - grain_worker_count, - grain_per_worker_buffer_size, -): - """Use grain to pre-process the dataset and return iterators for dpo fine-tuning""" - if config.grain_file_type == "arrayrecord": - dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize)) - dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) - tokenizer_model = tokenizer.build_tokenizer( - config.tokenizer_path, - config.tokenizer_type, - config.add_bos, - config.add_eos, - config.hf_access_token, - config.dataset_type, - ) - if tokenizer_model.pad_id is not None: - pad_id = tokenizer_model.pad_id - elif tokenizer_model.unk_id is not None: - pad_id = tokenizer_model.unk_id - else: - pad_id = -1 - - if tokenize: - dataset = dataset.map(_grain_tokenizer.TokenizeAndTrim(data_columns, config.max_target_length, tokenizer_model)) - - dataset = dataset.map(_input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id)) - batch_size = config.global_batch_size_to_load // jax.process_count() - batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id) - dataset = dataset.batch(batch_size, batch_fn=batch_fn) - multiprocessing_options = ( - pick_performance_config( - ds=dataset, - ram_budget_mb=config.grain_ram_budget_mb, - max_workers=None, - max_buffer_size=None, - ).multiprocessing_options - if grain_worker_count == -1 - else grain.MultiprocessingOptions( - num_workers=grain_worker_count, - per_worker_buffer_size=grain_per_worker_buffer_size, - ) - ) - dataset = dataset.mp_prefetch(multiprocessing_options) - return dataset - - def make_grain_train_iterator( config: ml_collections.ConfigDict, global_mesh, @@ -378,24 +327,14 @@ def make_grain_train_iterator( grain_data_source_max_workers=config.grain_data_source_max_workers, mixture_config_path=config.grain_train_mixture_config_path, ) - if config.use_dpo: - train_dataloader = dpo_preprocessing_pipeline( - train_ds, - config, - data_columns=config.train_data_columns, - tokenize=config.tokenize_train_data, - grain_worker_count=config.grain_worker_count, - grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, - ) - else: - train_dataloader = pretrain_preprocessing_pipeline( - train_ds, - config, - data_columns=config.train_data_columns, - tokenize=config.tokenize_train_data, - grain_worker_count=config.grain_worker_count, - grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, - ) + train_dataloader = pretrain_preprocessing_pipeline( + train_ds, + config, + data_columns=config.train_data_columns, + tokenize=config.tokenize_train_data, + grain_worker_count=config.grain_worker_count, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, + ) return multihost_dataloading.MultiHostDataLoadIterator( train_dataloader, global_mesh, @@ -415,24 +354,14 @@ def make_grain_train_iterator( grain_prefetch_buffer_size=config.grain_prefetch_buffer_size, grain_data_source_max_workers=config.grain_data_source_max_workers, ) - if config.use_dpo: - preprocessing_fn = functools.partial( - pretrain_preprocessing_pipeline, - config=config, - data_columns=config.train_data_columns, - tokenize=config.tokenize_train_data, - grain_worker_count=config.grain_worker_count, - grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, - ) - else: - preprocessing_fn = functools.partial( - pretrain_preprocessing_pipeline, - config=config, - data_columns=config.train_data_columns, - tokenize=config.tokenize_train_data, - grain_worker_count=config.grain_worker_count, - grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, - ) + preprocessing_fn = functools.partial( + pretrain_preprocessing_pipeline, + config=config, + data_columns=config.train_data_columns, + tokenize=config.tokenize_train_data, + grain_worker_count=config.grain_worker_count, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size, + ) if config.colocated_python_data_input: global_shape = (config.global_batch_size_to_load, config.max_target_length) return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, global_mesh, global_shape) @@ -475,24 +404,14 @@ def make_grain_eval_iterator( grain_prefetch_buffer_size=config.grain_prefetch_buffer_size_eval, grain_data_source_max_workers=config.grain_data_source_max_workers, ) - if config.use_dpo: - eval_dataloader = dpo_preprocessing_pipeline( - eval_ds, - config, - data_columns=config.eval_data_columns, - tokenize=config.tokenize_eval_data, - grain_worker_count=config.grain_worker_count_eval, - grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, - ) - else: - eval_dataloader = pretrain_preprocessing_pipeline( - eval_ds, - config, - data_columns=config.eval_data_columns, - tokenize=config.tokenize_eval_data, - grain_worker_count=config.grain_worker_count_eval, - grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, - ) + eval_dataloader = pretrain_preprocessing_pipeline( + eval_ds, + config, + data_columns=config.eval_data_columns, + tokenize=config.tokenize_eval_data, + grain_worker_count=config.grain_worker_count_eval, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, + ) return multihost_dataloading.MultiHostDataLoadIterator( eval_dataloader, global_mesh, config.generate_padding_batch_eval ) @@ -509,23 +428,13 @@ def make_grain_eval_iterator( grain_prefetch_buffer_size=config.grain_prefetch_buffer_size_eval, grain_data_source_max_workers=config.grain_data_source_max_workers, ) - if config.use_dpo: - preprocessing_fn = functools.partial( - dpo_preprocessing_pipeline, - config=config, - data_columns=config.eval_data_columns, - tokenize=config.tokenize_eval_data, - grain_worker_count=config.grain_worker_count_eval, - grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, - ) - else: - preprocessing_fn = functools.partial( - pretrain_preprocessing_pipeline, - config=config, - data_columns=config.eval_data_columns, - tokenize=config.tokenize_eval_data, - grain_worker_count=config.grain_worker_count_eval, - grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, - ) + preprocessing_fn = functools.partial( + pretrain_preprocessing_pipeline, + config=config, + data_columns=config.eval_data_columns, + tokenize=config.tokenize_eval_data, + grain_worker_count=config.grain_worker_count_eval, + grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval, + ) global_shape = (config.global_batch_size_to_load, config.max_target_length) return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, global_mesh, global_shape) diff --git a/src/MaxText/input_pipeline/_hf_data_processing.py b/src/MaxText/input_pipeline/_hf_data_processing.py index 5c2fdac568..34980e8582 100644 --- a/src/MaxText/input_pipeline/_hf_data_processing.py +++ b/src/MaxText/input_pipeline/_hf_data_processing.py @@ -22,8 +22,6 @@ import grain.python as grain -import numpy as np - from MaxText.input_pipeline import _input_pipeline_utils from MaxText.input_pipeline import instruction_data_processing from MaxText import multihost_dataloading @@ -205,7 +203,6 @@ def preprocessing_pipeline( num_threads=1, drop_remainder=True, generate_padding_batch=False, - use_dpo=None, use_sft=None, use_tunix_gradient_accumulation=False, num_microbatches=1, @@ -312,19 +309,12 @@ def preprocessing_pipeline( ) ) data_column_names = ("inputs", "targets") - elif use_dpo: - - def lists2array(x): - """Convert lists/tuples to array""" - return jax.tree.map(np.asarray, x, is_leaf=lambda y: isinstance(y, (list, tuple))) - - operations.append(grain.MapOperation(lists2array)) else: assert len(data_column_names) == 1 operations.append(_input_pipeline_utils.HFNormalizeFeatures(data_column_names[0])) data_column_names = ("inputs", "targets") - if packing and not use_dpo: + if packing: length_struct = {col: max_target_length for col in data_column_names} max_segments = max_segments_per_seq if max_segments is not None and max_segments <= 0: @@ -341,7 +331,7 @@ def lists2array(x): operations.append(_input_pipeline_utils.PadOrTrimToMaxLength(max_target_length, pad_id)) operations.append(grain.Batch(batch_size=batch_size, drop_remainder=drop_remainder)) - if shift and not use_dpo: + if shift: operations.append(_input_pipeline_utils.ShiftData(ignored_ids=[pad_id, tokenizer.bos_token_id], axis=1)) # Since HuggingFace IterableDataset does not support access through index @@ -418,7 +408,6 @@ def make_hf_train_iterator( add_eos=config.add_eos, packing=config.packing, generate_padding_batch=config.generate_padding_batch_train, - use_dpo=config.use_dpo, use_sft=config.use_sft, use_tunix_gradient_accumulation=config.use_tunix_gradient_accumulation, num_microbatches=config.gradient_accumulation_steps, @@ -476,7 +465,6 @@ def make_hf_eval_iterator( add_eos=config.add_eos, packing=config.packing, generate_padding_batch=config.generate_padding_batch_eval, - use_dpo=config.use_dpo, use_sft=config.use_sft, num_microbatches=config.gradient_accumulation_steps, sft_train_on_completion_only=config.sft_train_on_completion_only, diff --git a/src/MaxText/input_pipeline/_tfds_data_processing.py b/src/MaxText/input_pipeline/_tfds_data_processing.py index f103acc628..c64327625c 100644 --- a/src/MaxText/input_pipeline/_tfds_data_processing.py +++ b/src/MaxText/input_pipeline/_tfds_data_processing.py @@ -92,19 +92,15 @@ def preprocessing_pipeline( shift: bool = True, drop_remainder: bool = True, prefetch_size=tf.data.experimental.AUTOTUNE, - use_dpo: bool = False, hf_access_token: str = "", ): """pipeline for preprocessing TFDS dataset.""" - if not use_dpo: - assert len(data_column_names) == 1 - dataset = dataset.map( - lambda x: _input_pipeline_utils.normalize_features(x, data_column_names[0]), num_parallel_calls=AUTOTUNE - ) - else: - dataset = dataset.map(lambda x: {col: x[col] for col in data_column_names}, num_parallel_calls=AUTOTUNE) + assert len(data_column_names) == 1 + dataset = dataset.map( + lambda x: _input_pipeline_utils.normalize_features(x, data_column_names[0]), num_parallel_calls=AUTOTUNE + ) - data_column_names = data_column_names if use_dpo else ("inputs", "targets") + data_column_names = ("inputs", "targets") tokenizer_model = _input_pipeline_utils.get_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token) if tokenizer_model.pad_id is not None: @@ -123,7 +119,7 @@ def preprocessing_pipeline( if max_target_length > 0: # in pre-training we can take upto max_length+1 because there would be truncation by # 1 token for both inputs and targets - extra_tokens = 1 if not use_dpo else 0 + extra_tokens = 1 dataset = dataset.map( lambda x: _input_pipeline_utils.truncate_to_max_allowable_length(x, max_target_length + extra_tokens), num_parallel_calls=AUTOTUNE, @@ -136,13 +132,13 @@ def preprocessing_pipeline( dataset = dataset.repeat(num_epochs) # Shift inputs for teacher-forced training - if shift and not use_dpo: + if shift: dataset = dataset.map( _input_pipeline_utils.shift_data_by_truncation, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True ) # Perform greedy sequence packing and batching - if pack_examples and not use_dpo: + if pack_examples: dataset = sequence_packing.pack_dataset(dataset, max_target_length, pad_id) dataset = dataset.batch(global_batch_size // jax.process_count(), drop_remainder=drop_remainder) else: @@ -202,7 +198,6 @@ def make_tfds_train_iterator( add_eos=config.add_eos, num_epochs=config.num_epoch, pack_examples=config.packing, - use_dpo=config.use_dpo, hf_access_token=config.hf_access_token, ) return multihost_dataloading.MultiHostDataLoadIterator( @@ -227,7 +222,6 @@ def make_tfds_train_iterator( add_eos=config.add_eos, num_epochs=config.num_epoch, pack_examples=config.packing, - use_dpo=config.use_dpo, hf_access_token=config.hf_access_token, ) global_shape = (config.global_batch_size_to_load, config.max_target_length) @@ -265,7 +259,6 @@ def make_tfds_eval_iterator( add_bos=config.add_bos, add_eos=config.add_eos, pack_examples=config.packing, - use_dpo=config.use_dpo, hf_access_token=config.hf_access_token, ) return multihost_dataloading.MultiHostDataLoadIterator( @@ -292,7 +285,6 @@ def make_tfds_eval_iterator( add_bos=config.add_bos, add_eos=config.add_eos, pack_examples=config.packing, - use_dpo=config.use_dpo, hf_access_token=config.hf_access_token, ) return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, config, global_mesh) diff --git a/src/MaxText/train.py b/src/MaxText/train.py index c66472e685..8352609dac 100644 --- a/src/MaxText/train.py +++ b/src/MaxText/train.py @@ -59,7 +59,6 @@ from maxtext.common.gcloud_stub import cloud_diagnostics as _cloud_diag, is_decoupled from maxtext.common.gcloud_stub import vertex_tensorboard_modules from maxtext.common.metric_logger import MetricLogger, record_activation_metrics -from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state, _split_dpo_state, dpo_loss_fn from maxtext.utils import exceptions from maxtext.utils import gcs_utils from maxtext.utils import max_logging @@ -238,47 +237,27 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat rng2: A new rng key that can be used in future calls. """ - reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = ( - [], - [], - [], - loss_fn, - ) - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn - params = state.params if config.gradient_accumulation_steps > 1: loss, aux, raw_grads = gradient_accumulation_loss_and_grad( - _loss_fn, + loss_fn, config, model, params, params_shardings, data, dropout_rng, - extra_dpo_args, ) else: - if config.optimizer_memory_host_offload: - if config.use_dpo: - reference_params = jax.device_put( - reference_params, - max_utils.with_memory_kind(reference_params_sharding, "device"), - ) - extra_dpo_args = [reference_params] if config.shard_optimizer_over_data: params = jax.tree.map( functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), params, params_shardings, ) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) - (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, *extra_dpo_args, is_train=True) + grad_func = jax.value_and_grad(loss_fn, argnums=4, has_aux=True) + (loss, aux), raw_grads = grad_func(model, config, data, dropout_rng, params, is_train=True) raw_grads = jax.tree_util.tree_map( lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, @@ -338,8 +317,6 @@ def move(path, value): scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads) scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads) scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) - if config.use_dpo: - scalar_metrics["learning/dpo_reward_accuracy"] = aux["reward_accuracy"] metrics = { "scalar": scalar_metrics, "scalars": {}, @@ -348,23 +325,14 @@ def move(path, value): if config.record_internal_nn_metrics: record_activation_metrics(metrics, intermediate_outputs, config) - if config.use_dpo: - new_state = _merge_dpo_state(new_state, reference_params) - return new_state, metrics def eval_step(model, config, state, data, dropout_rng): """eval_step no backprop and new state compared with train_step.""" - reference_params, extra_dpo_args, _loss_fn = [], [], loss_fn - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn - - eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) - loss, aux = eval_loss_fn(state.params, *extra_dpo_args) + eval_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=False) + loss, aux = eval_loss_fn(state.params) mtp_acceptance_rate = 0.0 if config.mtp_eval_target_module > 0: @@ -384,8 +352,6 @@ def eval_step(model, config, state, data, dropout_rng): "evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate, }, } - if config.use_dpo: - metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"] return metrics @@ -406,12 +372,6 @@ def train_loop(config, recorder, state=None): state, ) = train_utils.setup_train_loop(config, recorder) - if config.use_dpo: - if "reference_params" not in state.params: - reference_params = jax.tree.map(jnp.copy, state.params["params"]) - state = _merge_dpo_state(state, reference_params) - state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) - params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( @@ -461,8 +421,7 @@ def train_loop(config, recorder, state=None): step_time_delta = datetime.datetime.now() - last_step_completion last_step_completion = datetime.datetime.now() - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) + checkpointing.maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step) if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): jax.block_until_ready(state) # Ensure compilation has finished. @@ -503,8 +462,7 @@ def train_loop(config, recorder, state=None): metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) + checkpointing.maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator) if checkpoint_manager is not None: # in case the last checkpoint_period checkpoint is still in progress checkpoint_manager.wait_until_finished() diff --git a/src/maxtext/common/metric_logger.py b/src/maxtext/common/metric_logger.py index da9777ef72..acbb6a86de 100644 --- a/src/maxtext/common/metric_logger.py +++ b/src/maxtext/common/metric_logger.py @@ -342,10 +342,6 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None): self.cumulative_eval_metrics["scalar"]["eval/mtp_acceptance_rate_percent"] += float( metrics["scalar"].get("evaluation/mtp_acceptance_rate_percent", 0.0) ) - if self.config.use_dpo: - self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] += float( - metrics["scalar"].get("evaluation/dpo_reward_accuracy", 0.0) - ) if eval_step_count: eval_loss = self.cumulative_eval_metrics["scalar"]["eval/total_loss"] / ( @@ -361,10 +357,6 @@ def record_eval_metrics(self, step, metrics=None, eval_step_count=None): self.cumulative_eval_metrics["scalar"]["eval/avg_mtp_acceptance_rate_percent"] = ( self.cumulative_eval_metrics["scalar"]["eval/mtp_acceptance_rate_percent"] / eval_step_count ) - if self.config.use_dpo: - self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] = ( - self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] / eval_step_count - ) self.write_metrics(self.cumulative_eval_metrics, step, is_training=False) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index b152f6d081..48bac40d45 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -568,9 +568,9 @@ per_device_batch_size: 12.0 expansion_factor_real_data: -1.0 eval_per_device_batch_size: 0.0 max_corpus_chars: 10_000_000 -train_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected" +train_data_columns: ['text'] train_image_column: 'image' -eval_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected" +eval_data_columns: ['text'] eval_image_column: 'image' packing: True num_epoch: 1 @@ -595,11 +595,6 @@ per_device_batch_size_increment: 2.0 # There is no strict rule for this value, it only needs to be positive. global_rampup_samples: 500 -# direct preference optimization (DPO) -use_dpo: False -dpo_label_smoothing: 0.0 -dpo_beta: 0.1 - # Supervised Fine-Tuning (SFT) use_sft: False # sft_train_on_completion_only=False trains on both prompt and completion tokens; trains only on completion tokens otherwise diff --git a/src/maxtext/configs/post_train/dpo.yml b/src/maxtext/configs/post_train/dpo.yml deleted file mode 100644 index dbcdadb1ba..0000000000 --- a/src/maxtext/configs/post_train/dpo.yml +++ /dev/null @@ -1,31 +0,0 @@ -base_config: "base.yml" - -use_dpo: true -train_data_columns: ['chosen', 'rejected'] -eval_data_columns: ['chosen', 'rejected'] -base_output_directory: 'gs://maxtext-external/logs' - -per_device_batch_size: 2.0 -steps: 10 -max_target_length: 512 -eval_interval: 5 # test eval once, in the middle of 10 training steps -eval_steps: 2 - -# TFDS Pipeline ---------------------- -dataset_type: tfds -dataset_path: 'gs://maxtext-dataset/dpo/anthropic_rlhf' -dataset_name: 'tfds:1.0.0' -eval_dataset_name: 'tfds:1.0.0' -eval_split: 'test' - -# HF Pipeline ------------------------- -hf_eval_split: 'test' - -gradient_clipping_threshold: 10.0 -learning_rate: 5.0e-7 -dpo_label_smoothing: 0.0 -dpo_beta: 0.1 - -enable_goodput_recording: false -monitor_goodput: false -enable_checkpointing: true diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 6457b49041..a85b354931 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -248,7 +248,7 @@ class ProfilerType(str, Enum): "llama4-17b-16e", "llama4-17b-128e", "olmo3-7b", - 'olmo3-7b-pt', + "olmo3-7b-pt", "olmo3-32b", ] @@ -1021,11 +1021,8 @@ class GrainDataset(BaseModel): class FineTuning(BaseModel): - """Configuration for fine-tuning methods like DPO, SFT, and GRPO.""" + """Configuration for fine-tuning methods like SFT and GRPO.""" - use_dpo: bool = Field(False, description="If True, enables Direct Preference Optimization training.") - dpo_label_smoothing: float = Field(0.0, ge=0.0, le=1.0, description="Label smoothing for DPO.") - dpo_beta: float = Field(0.1, description="Beta parameter for DPO.") use_sft: bool = Field(False, description="If True, enables Supervised Fine-Tuning.") sft_train_on_completion_only: bool = Field( False, description="If True, trains only on the completion part of the text." diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 5bcf29215c..bd03aeda2b 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -489,8 +489,6 @@ def eval_step(model, config, state, data, dropout_rng): "evaluation/moe_lb_loss": moe_lb_loss, }, } - if config.use_dpo: - metrics["scalar"]["evaluation/grpo_reward_accuracy"] = aux["reward_accuracy"] return metrics diff --git a/src/maxtext/trainers/post_train/dpo/dpo_utils.py b/src/maxtext/trainers/post_train/dpo/dpo_utils.py deleted file mode 100644 index 18ee7ec22a..0000000000 --- a/src/maxtext/trainers/post_train/dpo/dpo_utils.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Copyright 2025 Google LLC - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -"""DPO (Direct Preference Optimization) utilities.""" - -import jax -import jax.numpy as jnp - -from maxtext.utils import maxtext_utils - - -def _split_dpo_state(state): - """Split DPO state to separate reference parameters.""" - reference_params = state.params["reference_params"] - new_state = state.replace(params={k: v for k, v in state.params.items() if k != "reference_params"}) - return new_state, reference_params - - -def dpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_train=True): - """loss_fn for both train and eval. - - Args: - model: A model module - config: Config of parameters - data: Batch of data to apply to the model - dropout_rng: A key to use to generate rng for dropout - params: Model params - reference_params: Reference model params for DPO - is_train: True for train_step and False for eval_step - - Returns: - loss: average loss - aux: a dictionary including intermediate_outputs, total_loss, and total_weights - """ - # inputs, targets, segments, positions = apply_args - rng1, aqt_rng = jax.random.split(dropout_rng) - - # decimate proportion of data when per_device_batch_size<1 - if is_train: - for k, v in data.items(): - data[k] = v[: config.micro_batch_size_to_train_on, :] - - # for DPO we don't support packed sequence (they shouldn't be present in the first place) - data["chosen_segmentation"] = (data["chosen_segmentation"] == 1).astype(jnp.int32) - data["rejected_segmentation"] = (data["rejected_segmentation"] == 1).astype(jnp.int32) - data["chosen_position"] = data["chosen_position"] * (data["chosen_segmentation"] == 1) - data["rejected_position"] = data["rejected_position"] * (data["rejected_segmentation"] == 1) - - # concatenated model and reference model forward pass - inputs = jnp.concatenate([data["chosen"], data["rejected"]], 0) - inputs_position = jnp.concatenate([data["chosen_position"], data["rejected_position"]], 0) - inputs_segmentation = jnp.concatenate([data["chosen_segmentation"], data["rejected_segmentation"]], 0) - - logits, intermediate_outputs = model.apply( - params, - inputs, - inputs_position, - decoder_segment_ids=inputs_segmentation, - enable_dropout=config.enable_dropout if is_train else False, - rngs={"dropout": rng1, "params": aqt_rng}, - mutable="intermediates", - ) - ref_logits = model.apply( - {"params": reference_params}, - inputs, - inputs_position, - decoder_segment_ids=inputs_segmentation, - enable_dropout=False, - rngs={"dropout": rng1, "params": aqt_rng}, - ) - ref_logits = jax.lax.stop_gradient(ref_logits) - - # extract token ids, segmentation and logits for chosen and rejected sequences - chosen_ids = data["chosen"][..., 1:] - rejected_ids = data["rejected"][..., 1:] - chosen_segmentation = data["chosen_segmentation"][..., 1:] - rejected_segmentation = data["rejected_segmentation"][..., 1:] - n_logits = logits.shape[-3] // 2 # [B, S, E] - [batch, sequence, embedding/vocab] - chosen_logits, rejected_logits = logits[:n_logits, :, :], logits[n_logits:, :, :] # [B, S, E], [B, S, E] - # ^ [B, S, E], [B, S, E] - chosen_ref_logits, rejected_ref_logits = ref_logits[:n_logits, :, :], ref_logits[n_logits:, :, :] - - # common subsequence and padding mask - common_prefix_mask = jnp.cumsum(chosen_ids != rejected_ids, axis=-1) == 0 # [B, S] - valid_seq_mask = (chosen_segmentation != 0) & (rejected_segmentation != 0) & ~common_prefix_mask # [B, S] - - # compute logratios from the sequence-reduced observed token log-probability - chosen_logps_seq = jnp.take_along_axis( # [B, S] - jax.nn.log_softmax(chosen_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 - )[..., 0] - chosen_logps = jnp.sum(chosen_logps_seq * valid_seq_mask, axis=-1) # [B] - chosen_ref_logps_seq = jnp.take_along_axis( # [B, S] - jax.nn.log_softmax(chosen_ref_logits[..., :-1, :], axis=-1), chosen_ids[..., None], axis=-1 - )[..., 0] - chosen_ref_logps = jnp.sum(chosen_ref_logps_seq * valid_seq_mask, axis=-1) # [B] - chosen_logratios = chosen_logps - chosen_ref_logps # [B] - - rejected_logps_seq = jnp.take_along_axis( # [B, S] - jax.nn.log_softmax(rejected_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 - )[..., 0] - rejected_logps = jnp.sum(rejected_logps_seq * valid_seq_mask, axis=-1) # [B] - rejected_ref_logps_seq = jnp.take_along_axis( # [B, S] - jax.nn.log_softmax(rejected_ref_logits[..., :-1, :], axis=-1), rejected_ids[..., None], axis=-1 - )[..., 0] - rejected_ref_logps = jnp.sum(rejected_ref_logps_seq * valid_seq_mask, axis=-1) # [B] - rejected_logratios = rejected_logps - rejected_ref_logps # [B] - - # DPO loss from chosen and rejected logratios - LABEL_SMOOTHING, BETA = config.dpo_label_smoothing, config.dpo_beta - logratios_delta = BETA * (chosen_logratios - rejected_logratios) # [B] - losses = ( # [B] - -jax.nn.log_sigmoid(BETA * logratios_delta) * (1 - LABEL_SMOOTHING) - - jax.nn.log_sigmoid(-BETA * logratios_delta) * LABEL_SMOOTHING - ) - total_loss, total_weights = jnp.mean(losses), losses.shape[0] - loss = total_loss - - moe_lb_loss = 0.0 - if config.num_experts > 1: - nested_key = ("intermediates", "decoder", "layers", "moe_lb_loss") - total_moe_lb_loss = maxtext_utils.get_nested_value(intermediate_outputs, nested_key, 0.0) - moe_lb_loss = jnp.mean(jnp.array(total_moe_lb_loss)) - loss += moe_lb_loss - reward_accuracy = jnp.mean(chosen_logratios > rejected_logratios) - aux = { - "intermediate_outputs": intermediate_outputs, - "total_loss": total_loss, - "total_weights": total_weights, - "moe_lb_loss": moe_lb_loss, - "reward_accuracy": reward_accuracy, - } - return loss, aux - - -def _merge_dpo_state(state, reference_params): - """Merge reference parameters back into DPO state.""" - return state.replace(params=dict(state.params, reference_params=reference_params)) diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index c4bb32aae0..17bbddc317 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -805,15 +805,7 @@ def calculate_tflops_training_per_device(config, log=True): learnable_weight_tflops = learnable_weight_tflops * config.gradient_accumulation_steps attention_tflops = attention_tflops * config.gradient_accumulation_steps - # DPO includes one additional forward pass per gradient accumulation step - if config.use_dpo: - reference_model_tflops = learnable_weight_tflops / 3 # additional forward pass - reference_model_attention_tflops = attention_tflops / 3 - attention_tflops = attention_tflops + reference_model_attention_tflops - else: - reference_model_tflops = 0 - - total_tflops = learnable_weight_tflops + attention_tflops + reference_model_tflops + total_tflops = learnable_weight_tflops + attention_tflops if config.use_multimodal: # Add vision layers TFLOPs for multimodal models diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index ecc66aa9c2..af34ea5b96 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -15,7 +15,6 @@ # pylint: disable=bare-except, consider-using-generator """ Utils that are only interesting for training in MaxText. """ -import os import jax import functools from flax.linen import partitioning as nn_partitioning @@ -25,7 +24,6 @@ from maxtext.common import checkpointing from maxtext.common.data_loader import create_dataloader from maxtext.common.goodput import GoodputEvent, maybe_record_goodput -from maxtext.trainers.post_train.dpo.dpo_utils import _merge_dpo_state from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils @@ -250,38 +248,6 @@ def setup_train_loop(config, recorder, devices=None): state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params ) - if config.use_dpo: - abstract_state, _, _ = maxtext_utils.get_abstract_state(model, tx, config, init_rng, mesh, is_training=True) - max_logging.log( - "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" - ) - try: - step0_restored, _ = checkpointing.load_state_if_possible( - checkpoint_manager, - data_iterator, - load_parameters_from_path="", - load_full_state_from_path="", - checkpoint_storage_concurrent_gb=config.checkpoint_storage_concurrent_gb, - abstract_unboxed_pre_state=abstract_state, - enable_single_replica_ckpt_restoring=False, - dataset_type=config.dataset_type, - step=0, - use_ocdbt=config.checkpoint_storage_use_ocdbt, - use_zarr3=config.checkpoint_storage_use_zarr3, - enable_orbax_v1=config.enable_orbax_v1, - checkpoint_conversion_fn=config.checkpoint_conversion_fn, - source_checkpoint_layout=config.source_checkpoint_layout, - ) - except FileNotFoundError: - step0_restored = None - if step0_restored is not None: - reference_params = step0_restored["items"].params["params"] - state = _merge_dpo_state(state, reference_params) - else: - max_logging.log( - "Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" - ) - return ( init_rng, checkpoint_manager, diff --git a/tests/end_to_end/tpu/test_dpo.sh b/tests/end_to_end/tpu/test_dpo.sh deleted file mode 100644 index 9e965778c3..0000000000 --- a/tests/end_to_end/tpu/test_dpo.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/bin/bash - -set -xe - -RUN_NAME=dpo_$(date +%Y-%m-%d-%H-%M-%S) - -# get latest converted Gemma2 2B checkpoint from internal GCS bucket -export GEMMA_2B_CKPT_PATH=$(gcloud storage ls gs://maxtext-gemma/gemma2/2b | sort -r | head -1) -LOGS="gs://maxtext-external/logs" - -# tfds pipeline -python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/dpo.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma \ - run_name="$RUN_NAME-tfds" model_name=gemma2-2b base_output_directory=${LOGS} \ - load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \ - per_device_batch_size=0.5 allow_split_physical_axes=True \ - ici_data_parallelism=2 ici_tensor_parallelism=2 ici_fsdp_parallelism=1 - -# grain pipeline -mkdir -p /tmp/anthropic_rlhf || true -gcloud storage cp -r gs://maxtext-dataset/dpo/anthropic_rlhf/array_record /tmp/anthropic_rlhf -python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/dpo.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma \ - run_name="$RUN_NAME-grain" model_name=gemma2-2b base_output_directory=${LOGS} \ - load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \ - dataset_type=grain grain_worker_count=16 \ - grain_train_files='/tmp/anthropic_rlhf/array_record/anthropic_rlhf_tfds-train.array_record*' \ - grain_eval_files='/tmp/anthropic_rlhf/array_record/anthropic_rlhf_tfds-test.array_record*' \ - per_device_batch_size=0.5 allow_split_physical_axes=True \ - ici_data_parallelism=2 ici_tensor_parallelism=2 ici_fsdp_parallelism=1 - -# hf pipeline -python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/dpo.yml tokenizer_path='google/gemma-2-2b-it' \ - run_name="$RUN_NAME-grain" model_name=gemma2-2b base_output_directory=${LOGS} \ - load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \ - dataset_type=hf hf_access_token=$HF_TOKEN hf_path='Anthropic/hh-rlhf' \ - per_device_batch_size=0.5 allow_split_physical_axes=True ici_tensor_parallelism=2 \ - ici_data_parallelism=2 ici_tensor_parallelism=2 ici_fsdp_parallelism=1 diff --git a/tests/unit/configs_test.py b/tests/unit/configs_test.py index 09ceb03f44..46266fd9d0 100644 --- a/tests/unit/configs_test.py +++ b/tests/unit/configs_test.py @@ -113,7 +113,6 @@ def run_config_validation(config_file_path: str): BASE_CONFIGS = [ os.path.join(CONFIGS_DIR, "base.yml"), - os.path.join(CONFIGS_DIR, "post_train", "dpo.yml"), os.path.join(CONFIGS_DIR, "gpu/gpu_smoke_test.yml"), os.path.join(CONFIGS_DIR, "post_train", "rl.yml"), os.path.join(CONFIGS_DIR, "post_train", "rl_mt_jt.yml"), diff --git a/tests/unit/sft_data_processing_test.py b/tests/unit/sft_data_processing_test.py index 00c43cbc35..8e06ad6c5a 100644 --- a/tests/unit/sft_data_processing_test.py +++ b/tests/unit/sft_data_processing_test.py @@ -364,7 +364,6 @@ def get_data_iterator(self, train_ds, data_columns): add_eos=self.config.add_eos, packing=self.config.packing, generate_padding_batch=False, - use_dpo=self.config.use_dpo, use_sft=self.config.use_sft, sft_train_on_completion_only=self.config.sft_train_on_completion_only, grain_worker_count=0,