From 180a0cb300a47ac94dafb37633f415a6cbb4813c Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Thu, 17 Apr 2025 18:28:04 +0000 Subject: [PATCH 01/17] add clipping - separate low and high values as suggested in DAPO. --- trl/trainer/grpo_config.py | 11 ++++++++++- trl/trainer/qwen_grpo_trainer.py | 14 ++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index ac2ee2f9e40..a92aa1ceea6 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -214,6 +214,15 @@ class GRPOConfig(TrainingArguments): default=0.04, metadata={"help": "KL coefficient."}, ) + + epsilon_low: float = field( + default=0.2, + metadata={"help": "Value to clip at 1-epsilon."}, + ) + epsilon_high: float = field( + default=0.28, + metadata={"help": "Value to clip at 1+epsilon. Generally set this higher than epsilon_low based on DAPO paper results."}, + ) reward_weights: Optional[list[float]] = field( default=None, metadata={ @@ -256,4 +265,4 @@ class GRPOConfig(TrainingArguments): limit_video_per_prompt: int = field( default=0, metadata={"help": "Limit the number of videos per prompt for vllm generation."}, - ) \ No newline at end of file + ) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index d0e2efca84f..33925e72be6 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -352,6 +352,10 @@ def data_collator(features): # No data collation is needed in GRPO self._metrics = defaultdict(list) self.log_completions = args.log_completions + # intialize epsilon + self.epsilon_low = args.epsilon_low + self.epsilon_high = args.epsilon_high + super().__init__( model=model, args=args, @@ -706,7 +710,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s new_pixel_values = new_pixel_values.to(device) new_image_grid_thw = new_image_grid_thw.to(device) pixel_values = torch.cat([pixel_values, new_pixel_values], dim=0) - image_grid_thw = torch.cat([image_grid_thw, new_image_grid_thw], dim=0) + image_grid_thw = torch.cat([image_grid_thw, new_image_grid_thw], dim=0) else: raise ValueError("Attempted to generate with HF. Only supporting vllm now.") @@ -912,7 +916,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # x - x.detach() allows for preserving gradients from x advantages = inputs["advantages"] - per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + coef_1 = torch.exp(per_token_logps - per_token_logps.detach()) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = torch.min(per_token_loss1, per_token_loss2) + per_token_loss = -(per_token_loss - self.beta * per_token_kl) loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() From c3e61477cb39465a25f119a8008f1360b9e4184e Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Thu, 17 Apr 2025 18:23:57 -0700 Subject: [PATCH 02/17] buffer class written. working on init --- trl/trainer/qwen_grpo_trainer.py | 113 +++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 33925e72be6..10ec0d17b03 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -19,6 +19,7 @@ from typing import Any, Callable, Optional, Sized, Union from unittest.mock import patch +import numpy as np import torch import torch.utils.data import transformers @@ -129,6 +130,112 @@ def __iter__(self): def __len__(self): return self.num_samples * self.repeat_count +class SSRBuffer: + ''' + Selective Sample Replay manager. Maintains a buffer of high entropy samples for training. + ''' + + def __init__(self, alpha: float = 2.0, total_buffer_size: int = 1000, persist_steps: int = 1000): + ''' + Args: + alpha: float, handles prioritization intensity>=0. + alpha = 0 means no prioritization, + alpha = 1 means prioritization linearly proportional to advantage, + alpha > 1 means more prioritization for high entropy samples. + total_buffer_size: int, maximum size of the buffer. After the buffer is full, the oldest samples will be discarded. + persist_steps: int, number of steps an example lives in the buffer. After this many steps, the example will be discarded. + ''' + + if alpha <= 0: + raise ValueError("alpha must be greater than 0") + self.alpha = alpha + + if total_buffer_size <= 0: + raise ValueError("total_buffer_size must be greater than 0") + self.total_buffer_size = total_buffer_size + + if persist_steps <= 0: + raise ValueError("persist_steps must be greater than 0") + self.persist_steps = persist_steps + + self.buffer = [] + + # element of buffer format: + # { + # "example": dict, - the example to be replayed + # "advantage": float, - the advantage observed last training step + # "ttl": int, - time to live, number of steps before the example will be discarded from the buffer + # } + + def add_example(self, example: dict, advantage: float) -> None: + ''' + Add an example to the buffer. + ''' + # NOTE: We don't check if the buffer is full here. We'll do it at the end of each training step. + buffer_element = { + "example": example, + "advantage": advantage, + "ttl": self.persist_steps + } + self.buffer.append(buffer_element) + + @property + def buffer_size(self) -> int: + ''' + Number of examples in the buffer. + ''' + return len(self.buffer) + + def draw_example(self) -> dict: + ''' + Returns an example from the buffer. The probabilty of drawing an example j is: + abs(advantage_j)**(self.alpha) / sum(abs(advantage_i)**(self.alpha) for i in range(len(self.buffer))) + + Raises a ValueError if the buffer is empty, otherwise, pops an example from the buffer and returns it. + ''' + + if self.buffer_size == 0: + raise ValueError("Buffer is empty") + + values = [] + for buffer_element in self.buffer: + values.append(abs(buffer_element["advantage"])**self.alpha) + + total = sum(values) + probabilities = [value / total for value in values] + + # check that the probabilities sum to 1, with some tolerance + if not np.isclose(sum(probabilities), 1.0, atol=1e-6): + raise ValueError(f"Probabilities do not sum to 1, but instead sum to {sum(probabilities)}") + + # choose the index of the example to draw + index = np.random.choice(range(len(self.buffer)), p=probabilities) + + # pop the example from the buffer + buffer_element = self.buffer.pop(index) + + return buffer_element['example'] + + def step(self) -> None: + ''' + Handles reducing ttl's on objects in the buffer and removes objects that have expired. + + It is to be called once at the end of training step. + ''' + # decrement the ttl of each buffer element + for buffer_element in self.buffer: + buffer_element['ttl'] -= 1 + + # remove buffer elements that have expired + self.buffer = [b for b in self.buffer if b['ttl'] > 0] + + # if the buffer is too big, discard the oldest examples + if len(self.buffer) > self.buffer_size: + # Sort by absolute advantage (priority), ascending + self.buffer.sort(key=lambda x: abs(x['advantage'])) + # Keep only the top 'buffer_size' elements (highest priority) + self.buffer = self.buffer[-self.buffer_size:] + class QwenGRPOTrainer(Trainer): """ @@ -355,6 +462,12 @@ def data_collator(features): # No data collation is needed in GRPO # intialize epsilon self.epsilon_low = args.epsilon_low self.epsilon_high = args.epsilon_high + + # TODO: make these config args + self.use_ssr_buffer = True + self. + + self.ssr_buffer = super().__init__( model=model, From 2d0837a527b398d1543e719d092a42746bd7038a Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Thu, 17 Apr 2025 20:22:14 -0700 Subject: [PATCH 03/17] ssr POC ready to test --- trl/trainer/qwen_grpo_trainer.py | 102 +++++++++++++++++++++++++------ 1 file changed, 84 insertions(+), 18 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 10ec0d17b03..021037d01bf 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import random import textwrap import warnings from collections import defaultdict @@ -462,12 +463,23 @@ def data_collator(features): # No data collation is needed in GRPO # intialize epsilon self.epsilon_low = args.epsilon_low self.epsilon_high = args.epsilon_high - - # TODO: make these config args + + + # TODO: make these configurable args self.use_ssr_buffer = True - self. - - self.ssr_buffer = + self.ssr_alpha = 2.0 + self.ssr_total_buffer_size = 1000 + self.ssr_persist_steps = 1000 + # the probability of using the SSR buffer on each step + self.ssr_use_prob = 0.5 + + if not 0 <= self.ssr_use_prob <= 1: + raise ValueError("ssr_use_prob must be between 0 and 1") + + if self.use_ssr_buffer: + self.ssr_buffer = SSRBuffer(alpha=self.ssr_alpha, total_buffer_size=self.ssr_total_buffer_size, persist_steps=self.ssr_persist_steps) + else: + self.ssr_buffer = None super().__init__( model=model, @@ -687,7 +699,12 @@ def _move_model_to_vllm(self): vlm_model = self.vlm.llm_engine.model_executor.driver_worker.model_runner.model vlm_model.load_weights(state_dict.items()) - def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: + def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]], should_use_buffer: bool = False) -> dict[str, Union[torch.Tensor, Any]]: + ''' + should_use_buffer - signals if the input comes from the SSR buffer + ''' + + device = self.accelerator.device if not self.env: @@ -904,19 +921,19 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # # DEBUG: Verify prompt consistency across completions in each group # TODO: remove this probably? - # if self.accelerator.is_main_process: - # all_prompts = gather_object(prompts_text) + if self.accelerator.is_main_process: + all_prompts = gather_object(prompts_text) - # if not len(all_prompts) == self.num_generations: - # raise ValueError( - # f"We should have one prompt per generation, but we have {len(all_prompts)} prompts and {self.num_generations} generations" - # ) - # if not len(set(all_prompts)) == 1: - # raise ValueError(f"All prompts should be the same. {all_prompts=}") - # print("PASSED PROMPT CONSISTENCY CHECK") + if not len(all_prompts) == self.num_generations: + raise ValueError( + f"We should have one prompt per generation, but we have {len(all_prompts)} prompts and {self.num_generations} generations" + ) + if not len(set(all_prompts)) == 1: + raise ValueError(f"All prompts should be the same. {all_prompts=}") + print("PASSED PROMPT CONSISTENCY CHECK") - # # Add synchronization point to prevent processes from getting out of sync - # self.accelerator.wait_for_everyone() + # Add synchronization point to prevent processes from getting out of sync + self.accelerator.wait_for_everyone() # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) @@ -930,6 +947,20 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + # if we are using the SSR buffer, we need to populate it with the current batch of examples + # we only add examples to the buffer if they are not coming from the buffer + if self.use_ssr_buffer and not should_use_buffer and self.accelerator.process_index == 0: + # if the average absolute advantage is positive, we add that example to the buffer with the average advantage + average_abs_advantage = torch.abs(advantages).mean().item() + if average_abs_advantage > 0: + print(f"Adding {inputs[0]} to the SSR buffer with advantage {average_abs_advantage}") + + # add the example to the buffer with the average advantage + self.ssr_buffer.add_example(inputs[0], average_abs_advantage) + + self.accelerator.wait_for_everyone() + + # Slice to keep only the local part of the data process_slice = slice( self.accelerator.process_index * len(conversations), @@ -987,6 +1018,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if wandb.run is not None and self.accelerator.is_main_process: wandb.log({"completions": wandb.Table(dataframe=df)}) + self._metrics["buffer_size"].append(self.ssr_buffer.buffer_size if self.use_ssr_buffer else 0) + self._metrics["buffer_usage"].append(float(should_use_buffer)) + return { "prompt_ids": prompt_ids, "prompt_mask": prompt_mask, @@ -1055,7 +1089,39 @@ def prediction_step( prediction_loss_only, ignore_keys: Optional[list[str]] = None, ): - inputs = self._prepare_inputs(inputs) + if self.use_ssr_buffer: + # each process needs to know if we are using the buffer. Process 0 decides and then broadcasts the decision to all processes + if self.accelerator.process_index == 0: + # step the buffer, needs to happen at each training step + self.ssr_buffer.step() + + should_use_buffer = random.random() < self.ssr_use_prob and self.ssr_buffer.buffer_size > 0 + should_use_buffer_list = [should_use_buffer for _ in range(self.accelerator.num_processes)] + else: + should_use_buffer_list = [None for _ in range(self.accelerator.num_processes)] + + broadcast_object_list(should_use_buffer_list, from_process=0) + should_use_buffer = should_use_buffer_list[0] + + if should_use_buffer: + # process 0 will draw from the buffer, the other processes will hang out + if self.accelerator.process_index == 0: + print("Drawing from buffer") + example_from_buffer = self.ssr_buffer.draw_example() + local_inputs = [deepcopy(example_from_buffer) for _ in range(len(inputs))] + print(f"local_inputs after drawing from buffer length: {len(local_inputs)=}") + else: + local_inputs = [None for _ in range(len(inputs))] + + # broadcast the inputs (from the buffer) to all processes + broadcast_object_list(local_inputs, from_process=0) + inputs = local_inputs + else: + # if we are not using the SSR buffer, we just use the inputs passed into the function and signal that this example is not from the buffer + should_use_buffer = False + + + inputs = self._prepare_inputs(inputs, should_use_buffer=should_use_buffer) with torch.no_grad(): with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) From a77f062c4239e629c286fb1a7dbb3dd6856f054d Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Thu, 17 Apr 2025 21:34:50 -0700 Subject: [PATCH 04/17] draft of ssr good enough to run --- trl/trainer/qwen_grpo_trainer.py | 121 ++++++++++++++++++------------- 1 file changed, 71 insertions(+), 50 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 021037d01bf..b296a7455a4 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -469,7 +469,9 @@ def data_collator(features): # No data collation is needed in GRPO self.use_ssr_buffer = True self.ssr_alpha = 2.0 self.ssr_total_buffer_size = 1000 - self.ssr_persist_steps = 1000 + self.ssr_persist_steps = 10000 + # if the buffer is smaller than this, we don't use it. Instead, draw from the dataset. This helps ensure we only select the best quality examples from the buffer on average. + self.min_ssr_buffer_size = 100 # the probability of using the SSR buffer on each step self.ssr_use_prob = 0.5 @@ -699,10 +701,7 @@ def _move_model_to_vllm(self): vlm_model = self.vlm.llm_engine.model_executor.driver_worker.model_runner.model vlm_model.load_weights(state_dict.items()) - def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]], should_use_buffer: bool = False) -> dict[str, Union[torch.Tensor, Any]]: - ''' - should_use_buffer - signals if the input comes from the SSR buffer - ''' + def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: device = self.accelerator.device @@ -710,6 +709,42 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]], should_us if not self.env: raise ValueError("No environment provided. Only supporting envs now. ") + if self.use_ssr_buffer: + # each process needs to know if we are using the buffer. Process 0 decides and then broadcasts the decision to all processes + if self.accelerator.process_index == 0: + print(f"buffer size: {self.ssr_buffer.buffer_size}") + # step the buffer, needs to happen at each training step + self.ssr_buffer.step() + + should_use_buffer = random.random() < self.ssr_use_prob and self.ssr_buffer.buffer_size > self.min_ssr_buffer_size + should_use_buffer_list = [should_use_buffer for _ in range(self.accelerator.num_processes)] + else: + should_use_buffer_list = [None for _ in range(self.accelerator.num_processes)] + + broadcast_object_list(should_use_buffer_list, from_process=0) + should_use_buffer = should_use_buffer_list[0] + + if should_use_buffer: + # process 0 will draw from the buffer, the other processes will hang out + if self.accelerator.process_index == 0: + print("Drawing from buffer") + example_from_buffer = self.ssr_buffer.draw_example() + local_inputs = [deepcopy(example_from_buffer) for _ in range(len(inputs))] + print(f"local_inputs after drawing from buffer length: {len(local_inputs)=}") + else: + local_inputs = [None for _ in range(len(inputs))] + + # broadcast the inputs (from the buffer) to all processes + broadcast_object_list(local_inputs, from_process=0) + inputs = local_inputs + else: + # if we are not using the SSR buffer, we just use the inputs passed into the function and signal that this example is not from the buffer + should_use_buffer = False + + print(f"should_use_buffer: {should_use_buffer}") + + + # TODO: This is a hack that we should probably fix. # without this, each gpu receives different inputs, screwing up the advantage computation. # Simple synchronization of inputs across processes @@ -844,6 +879,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]], should_us else: raise ValueError("Attempted to generate with HF. Only supporting vllm now.") + print("Finished with generation") + # Concatenate prompt_mask with completion_mask for logit computation attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) @@ -870,6 +907,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]], should_us logits_to_keep, ) + print("Finished with ref logits") + # Decode the generated completions completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): @@ -915,38 +954,50 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]], should_us output_reward_func = reward_func(prompts=conversations, completions=completions, **reward_kwargs) rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + print("Finished with rewards per function") # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the # completions may be distributed across processes rewards_per_func = gather(rewards_per_func) + print(f"Finished with gathering rewards per function, {rewards_per_func=}, {self.accelerator.process_index=}") + + self.accelerator.wait_for_everyone() + print(f"Finished with waiting for everyone, {self.accelerator.process_index=}") # # DEBUG: Verify prompt consistency across completions in each group # TODO: remove this probably? - if self.accelerator.is_main_process: - all_prompts = gather_object(prompts_text) + # if self.accelerator.is_main_process: + # all_prompts = gather_object(prompts_text) - if not len(all_prompts) == self.num_generations: - raise ValueError( - f"We should have one prompt per generation, but we have {len(all_prompts)} prompts and {self.num_generations} generations" - ) - if not len(set(all_prompts)) == 1: - raise ValueError(f"All prompts should be the same. {all_prompts=}") - print("PASSED PROMPT CONSISTENCY CHECK") + # if not len(all_prompts) == self.num_generations: + # raise ValueError( + # f"We should have one prompt per generation, but we have {len(all_prompts)} prompts and {self.num_generations} generations" + # ) + # if not len(set(all_prompts)) == 1: + # raise ValueError(f"All prompts should be the same. {all_prompts=}") + # print("PASSED PROMPT CONSISTENCY CHECK") + + # # Add synchronization point to prevent processes from getting out of sync + # self.accelerator.wait_for_everyone() - # Add synchronization point to prevent processes from getting out of sync - self.accelerator.wait_for_everyone() # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) - + print(f"Finished with reward weighting, {rewards=} {self.accelerator.process_index=}") # Compute grouped-wise rewards mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) + print(f"Finished with grouped rewards, {mean_grouped_rewards=} {self.accelerator.process_index=}") + print(f"Finished with grouped rewards std, {std_grouped_rewards=} {self.accelerator.process_index=}") + # Normalize the rewards to compute the advantages mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + + print("Finished with advantages") + # if we are using the SSR buffer, we need to populate it with the current batch of examples # we only add examples to the buffer if they are not coming from the buffer if self.use_ssr_buffer and not should_use_buffer and self.accelerator.process_index == 0: @@ -958,6 +1009,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]], should_us # add the example to the buffer with the average advantage self.ssr_buffer.add_example(inputs[0], average_abs_advantage) + print("Finished with repopulating SSR buffer") + self.accelerator.wait_for_everyone() @@ -1089,39 +1142,7 @@ def prediction_step( prediction_loss_only, ignore_keys: Optional[list[str]] = None, ): - if self.use_ssr_buffer: - # each process needs to know if we are using the buffer. Process 0 decides and then broadcasts the decision to all processes - if self.accelerator.process_index == 0: - # step the buffer, needs to happen at each training step - self.ssr_buffer.step() - - should_use_buffer = random.random() < self.ssr_use_prob and self.ssr_buffer.buffer_size > 0 - should_use_buffer_list = [should_use_buffer for _ in range(self.accelerator.num_processes)] - else: - should_use_buffer_list = [None for _ in range(self.accelerator.num_processes)] - - broadcast_object_list(should_use_buffer_list, from_process=0) - should_use_buffer = should_use_buffer_list[0] - - if should_use_buffer: - # process 0 will draw from the buffer, the other processes will hang out - if self.accelerator.process_index == 0: - print("Drawing from buffer") - example_from_buffer = self.ssr_buffer.draw_example() - local_inputs = [deepcopy(example_from_buffer) for _ in range(len(inputs))] - print(f"local_inputs after drawing from buffer length: {len(local_inputs)=}") - else: - local_inputs = [None for _ in range(len(inputs))] - - # broadcast the inputs (from the buffer) to all processes - broadcast_object_list(local_inputs, from_process=0) - inputs = local_inputs - else: - # if we are not using the SSR buffer, we just use the inputs passed into the function and signal that this example is not from the buffer - should_use_buffer = False - - - inputs = self._prepare_inputs(inputs, should_use_buffer=should_use_buffer) + inputs = self._prepare_inputs(inputs) with torch.no_grad(): with self.compute_loss_context_manager(): loss = self.compute_loss(model, inputs) From 9d081810d58f54c8a0d6707deacbf6b95cc37cc2 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Fri, 18 Apr 2025 13:21:55 -0700 Subject: [PATCH 05/17] allow examples from the buffer to be added back --- trl/trainer/qwen_grpo_trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index b296a7455a4..49f139ba8a9 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -468,8 +468,8 @@ def data_collator(features): # No data collation is needed in GRPO # TODO: make these configurable args self.use_ssr_buffer = True self.ssr_alpha = 2.0 - self.ssr_total_buffer_size = 1000 - self.ssr_persist_steps = 10000 + self.ssr_total_buffer_size = 10000 + self.ssr_persist_steps = 1000 # if the buffer is smaller than this, we don't use it. Instead, draw from the dataset. This helps ensure we only select the best quality examples from the buffer on average. self.min_ssr_buffer_size = 100 # the probability of using the SSR buffer on each step @@ -999,9 +999,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s print("Finished with advantages") # if we are using the SSR buffer, we need to populate it with the current batch of examples - # we only add examples to the buffer if they are not coming from the buffer - if self.use_ssr_buffer and not should_use_buffer and self.accelerator.process_index == 0: - # if the average absolute advantage is positive, we add that example to the buffer with the average advantage + # we DO allow an example from the buffer to be re-added after it is popped + if self.use_ssr_buffer and self.accelerator.process_index == 0: + # if the average absolute advantage is greater than 0, we add that example to the buffer with the average advantage average_abs_advantage = torch.abs(advantages).mean().item() if average_abs_advantage > 0: print(f"Adding {inputs[0]} to the SSR buffer with advantage {average_abs_advantage}") From 482bec30195bf0bb378425d56b49935887aec269 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Fri, 18 Apr 2025 17:59:09 -0700 Subject: [PATCH 06/17] make buffer ttl longer and decrease min buffer to get results faster --- trl/trainer/qwen_grpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 49f139ba8a9..c0cdaa943b9 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -469,9 +469,9 @@ def data_collator(features): # No data collation is needed in GRPO self.use_ssr_buffer = True self.ssr_alpha = 2.0 self.ssr_total_buffer_size = 10000 - self.ssr_persist_steps = 1000 + self.ssr_persist_steps = 100000 # if the buffer is smaller than this, we don't use it. Instead, draw from the dataset. This helps ensure we only select the best quality examples from the buffer on average. - self.min_ssr_buffer_size = 100 + self.min_ssr_buffer_size = 50 # the probability of using the SSR buffer on each step self.ssr_use_prob = 0.5 From 6e33e43d4b5e8940f4c96b9b709af749c9c1f6eb Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Sun, 20 Apr 2025 16:38:44 -0700 Subject: [PATCH 07/17] when using example from the buffer, add the current example to the buffer to avoid throwing it away --- trl/trainer/qwen_grpo_trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index c0cdaa943b9..017f53c1365 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -727,6 +727,10 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if should_use_buffer: # process 0 will draw from the buffer, the other processes will hang out if self.accelerator.process_index == 0: + + # put the current example in the buffer with a small advantage so we avoid "throwing it away" + self.ssr_buffer.add_example(inputs[0], 0.01) + print("Drawing from buffer") example_from_buffer = self.ssr_buffer.draw_example() local_inputs = [deepcopy(example_from_buffer) for _ in range(len(inputs))] From e8ef1428ab592d26d4ed8a162abf62a518479e26 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 22 Apr 2025 19:54:00 +0000 Subject: [PATCH 08/17] prototype reward schedule --- trl/trainer/qwen_grpo_trainer.py | 38 +++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 017f53c1365..af5d2e134fc 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -407,12 +407,17 @@ def __init__( if args.reward_weights is not None: if len(args.reward_weights) != len(reward_funcs): raise ValueError( - f"Number of reward weights ({len(len(args.reward_weights))}) must match number of reward " + f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " f"functions ({len(reward_funcs)})" ) - self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + # Validate types (float or callable) + for weight in args.reward_weights: + if not isinstance(weight, float) and not callable(weight): + raise TypeError(f"Reward weights must be floats or callables, but found {type(weight)}") + self.reward_weights_config = args.reward_weights # Store the original list/config else: - self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + # Default to list of 1.0 floats + self.reward_weights_config = torch.ones(len(reward_funcs), dtype=torch.float32) # Reward processing class if reward_processing_classes is None: @@ -984,8 +989,31 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # self.accelerator.wait_for_everyone() - # Apply weights to each reward function's output and sum - rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) + # Calculate current weights based on schedule/config + current_step_weights = [] + current_global_step = self.state.global_step + for weight_config in self.reward_weights_config: + if callable(weight_config): + # Call the schedule function with the current step + current_weight = weight_config(current_global_step) + if not 0.0 <= current_weight <= 1.0: + warnings.warn(f"Reward weight schedule returned {current_weight} at step {current_global_step}. Clamping to [0, 1].") + current_weight = max(0.0, min(1.0, current_weight)) + current_step_weights.append(current_weight) + else: + # Use the fixed float weight + current_step_weights.append(weight_config) + + current_step_weights_tensor = torch.tensor(current_step_weights, dtype=torch.float32, device=device) + + # Log the calculated weights for this step + for i, weight in enumerate(current_step_weights): + reward_func = self.reward_funcs[i] + reward_func_name = reward_func.__name__ + self._metrics[f"reward_weights/{reward_func_name}"].append(weight) + + # Apply calculated weights to each reward function's output and sum + rewards = (rewards_per_func * current_step_weights_tensor.to(device).unsqueeze(0)).sum(dim=1) print(f"Finished with reward weighting, {rewards=} {self.accelerator.process_index=}") # Compute grouped-wise rewards mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) From cce5f8d34c3cd43ba3463fbf04ff846f1341f3ae Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 22 Apr 2025 20:20:35 +0000 Subject: [PATCH 09/17] Refactor SSR buffer usage probability logic to use a dynamic approach based on buffer size. Rename `ssr_use_prob` to `max_ssr_use_prob` for clarity and ensure proper validation. Update related calculations and print statements for better debugging. --- trl/trainer/qwen_grpo_trainer.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index af5d2e134fc..e4f87b4a9f9 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -478,10 +478,10 @@ def data_collator(features): # No data collation is needed in GRPO # if the buffer is smaller than this, we don't use it. Instead, draw from the dataset. This helps ensure we only select the best quality examples from the buffer on average. self.min_ssr_buffer_size = 50 # the probability of using the SSR buffer on each step - self.ssr_use_prob = 0.5 + self.max_ssr_use_prob = 0.9 - if not 0 <= self.ssr_use_prob <= 1: - raise ValueError("ssr_use_prob must be between 0 and 1") + if not 0 <= self.max_ssr_use_prob <= 1: + raise ValueError("max_ssr_use_prob must be between 0 and 1") if self.use_ssr_buffer: self.ssr_buffer = SSRBuffer(alpha=self.ssr_alpha, total_buffer_size=self.ssr_total_buffer_size, persist_steps=self.ssr_persist_steps) @@ -721,13 +721,30 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # step the buffer, needs to happen at each training step self.ssr_buffer.step() - should_use_buffer = random.random() < self.ssr_use_prob and self.ssr_buffer.buffer_size > self.min_ssr_buffer_size + buffer_size = self.ssr_buffer.buffer_size + if buffer_size >= self.min_ssr_buffer_size: + # Calculate dynamic probability based on buffer size + size_range = self.ssr_total_buffer_size - self.min_ssr_buffer_size + if size_range > 0: # Avoid division by zero if min and total size are the same + prob_ramp = (buffer_size - self.min_ssr_buffer_size) / size_range + current_ssr_use_prob = min(self.max_ssr_use_prob, max(0.0, prob_ramp)) + else: # If min and total size are the same, use max probability if buffer is at least min size + current_ssr_use_prob = self.max_ssr_use_prob + else: + # If buffer is smaller than min size, probability is 0 + current_ssr_use_prob = 0.0 + + print(f"Current SSR use probability: {current_ssr_use_prob:.4f}") + should_use_buffer = random.random() < current_ssr_use_prob and buffer_size >= self.min_ssr_buffer_size + should_use_buffer_list = [should_use_buffer for _ in range(self.accelerator.num_processes)] else: should_use_buffer_list = [None for _ in range(self.accelerator.num_processes)] + # Non-zero processes also need the flag, although they don't decide it + buffer_size = None # Placeholder for non-zero processes broadcast_object_list(should_use_buffer_list, from_process=0) - should_use_buffer = should_use_buffer_list[0] + should_use_buffer = should_use_buffer_list[0] # All processes now know if buffer is used if should_use_buffer: # process 0 will draw from the buffer, the other processes will hang out @@ -749,6 +766,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s else: # if we are not using the SSR buffer, we just use the inputs passed into the function and signal that this example is not from the buffer should_use_buffer = False + buffer_size = 0 # Ensure buffer_size is defined even if not using SSR buffer print(f"should_use_buffer: {should_use_buffer}") From 16420a4db51b9c55b39e2751f6c2aadf97b174dc Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 22 Apr 2025 15:04:04 -0700 Subject: [PATCH 10/17] Add ssr_ramp_up_size parameter and update SSR buffer usage probability calculation to use ramp-up size for dynamic probability determination. --- trl/trainer/qwen_grpo_trainer.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index e4f87b4a9f9..787f3393985 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -475,6 +475,8 @@ def data_collator(features): # No data collation is needed in GRPO self.ssr_alpha = 2.0 self.ssr_total_buffer_size = 10000 self.ssr_persist_steps = 100000 + # size at which the probability ramp reaches max_ssr_use_prob + self.ssr_ramp_up_size = 500 # if the buffer is smaller than this, we don't use it. Instead, draw from the dataset. This helps ensure we only select the best quality examples from the buffer on average. self.min_ssr_buffer_size = 50 # the probability of using the SSR buffer on each step @@ -723,13 +725,15 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s buffer_size = self.ssr_buffer.buffer_size if buffer_size >= self.min_ssr_buffer_size: - # Calculate dynamic probability based on buffer size - size_range = self.ssr_total_buffer_size - self.min_ssr_buffer_size - if size_range > 0: # Avoid division by zero if min and total size are the same - prob_ramp = (buffer_size - self.min_ssr_buffer_size) / size_range - current_ssr_use_prob = min(self.max_ssr_use_prob, max(0.0, prob_ramp)) - else: # If min and total size are the same, use max probability if buffer is at least min size - current_ssr_use_prob = self.max_ssr_use_prob + # Calculate dynamic probability based on buffer size relative to the ramp-up size + size_range = self.ssr_ramp_up_size - self.min_ssr_buffer_size + if size_range > 0: # Avoid division by zero + # Calculate the ramp progress, ensuring it doesn't exceed 1 + prob_ramp = min(1.0, (buffer_size - self.min_ssr_buffer_size) / size_range) + current_ssr_use_prob = prob_ramp * self.max_ssr_use_prob + else: # If min and ramp-up size are the same, use max probability if buffer is at least min size + current_ssr_use_prob = self.max_ssr_use_prob if buffer_size >= self.min_ssr_buffer_size else 0.0 + else: # If buffer is smaller than min size, probability is 0 current_ssr_use_prob = 0.0 From 8a7bc33b4a0237eac4ba113347d1597bf928b861 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Tue, 22 Apr 2025 16:54:27 -0700 Subject: [PATCH 11/17] fix save issue --- trl/trainer/qwen_grpo_trainer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 787f3393985..f39768fd687 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -419,6 +419,7 @@ def __init__( # Default to list of 1.0 floats self.reward_weights_config = torch.ones(len(reward_funcs), dtype=torch.float32) + # Reward processing class if reward_processing_classes is None: reward_processing_classes = [None] * len(reward_funcs) @@ -501,6 +502,9 @@ def data_collator(features): # No data collation is needed in GRPO optimizers=optimizers, ) + # we need this line to avoid serializing the reward weights on checkpoint save. The reward schedules are not serializable. + self.args.reward_weights = None + # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations num_processes = self.accelerator.num_processes global_batch_size = args.per_device_train_batch_size * num_processes From 07cc194b4251e676042f08d8e5dcaa094e75dece Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 23 Apr 2025 09:48:02 -0700 Subject: [PATCH 12/17] allow weights > 1 and <0 --- trl/trainer/qwen_grpo_trainer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index f39768fd687..5f36d291942 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -1022,9 +1022,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if callable(weight_config): # Call the schedule function with the current step current_weight = weight_config(current_global_step) - if not 0.0 <= current_weight <= 1.0: - warnings.warn(f"Reward weight schedule returned {current_weight} at step {current_global_step}. Clamping to [0, 1].") - current_weight = max(0.0, min(1.0, current_weight)) current_step_weights.append(current_weight) else: # Use the fixed float weight From d344fb971c2e89f79d9466dfccec45a9f6ae5089 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 23 Apr 2025 10:30:48 -0700 Subject: [PATCH 13/17] reduce max buffer usage prob --- trl/trainer/qwen_grpo_trainer.py | 113 +++++++++++++++---------------- 1 file changed, 54 insertions(+), 59 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 5f36d291942..4d91c5332de 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -131,13 +131,14 @@ def __iter__(self): def __len__(self): return self.num_samples * self.repeat_count + class SSRBuffer: - ''' + """ Selective Sample Replay manager. Maintains a buffer of high entropy samples for training. - ''' + """ def __init__(self, alpha: float = 2.0, total_buffer_size: int = 1000, persist_steps: int = 1000): - ''' + """ Args: alpha: float, handles prioritization intensity>=0. alpha = 0 means no prioritization, @@ -145,7 +146,7 @@ def __init__(self, alpha: float = 2.0, total_buffer_size: int = 1000, persist_st alpha > 1 means more prioritization for high entropy samples. total_buffer_size: int, maximum size of the buffer. After the buffer is full, the oldest samples will be discarded. persist_steps: int, number of steps an example lives in the buffer. After this many steps, the example will be discarded. - ''' + """ if alpha <= 0: raise ValueError("alpha must be greater than 0") @@ -169,38 +170,34 @@ def __init__(self, alpha: float = 2.0, total_buffer_size: int = 1000, persist_st # } def add_example(self, example: dict, advantage: float) -> None: - ''' + """ Add an example to the buffer. - ''' + """ # NOTE: We don't check if the buffer is full here. We'll do it at the end of each training step. - buffer_element = { - "example": example, - "advantage": advantage, - "ttl": self.persist_steps - } + buffer_element = {"example": example, "advantage": advantage, "ttl": self.persist_steps} self.buffer.append(buffer_element) @property def buffer_size(self) -> int: - ''' + """ Number of examples in the buffer. - ''' + """ return len(self.buffer) def draw_example(self) -> dict: - ''' + """ Returns an example from the buffer. The probabilty of drawing an example j is: abs(advantage_j)**(self.alpha) / sum(abs(advantage_i)**(self.alpha) for i in range(len(self.buffer))) - + Raises a ValueError if the buffer is empty, otherwise, pops an example from the buffer and returns it. - ''' + """ if self.buffer_size == 0: raise ValueError("Buffer is empty") values = [] for buffer_element in self.buffer: - values.append(abs(buffer_element["advantage"])**self.alpha) + values.append(abs(buffer_element["advantage"]) ** self.alpha) total = sum(values) probabilities = [value / total for value in values] @@ -215,27 +212,27 @@ def draw_example(self) -> dict: # pop the example from the buffer buffer_element = self.buffer.pop(index) - return buffer_element['example'] + return buffer_element["example"] def step(self) -> None: - ''' + """ Handles reducing ttl's on objects in the buffer and removes objects that have expired. - + It is to be called once at the end of training step. - ''' + """ # decrement the ttl of each buffer element for buffer_element in self.buffer: - buffer_element['ttl'] -= 1 + buffer_element["ttl"] -= 1 # remove buffer elements that have expired - self.buffer = [b for b in self.buffer if b['ttl'] > 0] + self.buffer = [b for b in self.buffer if b["ttl"] > 0] # if the buffer is too big, discard the oldest examples if len(self.buffer) > self.buffer_size: # Sort by absolute advantage (priority), ascending - self.buffer.sort(key=lambda x: abs(x['advantage'])) + self.buffer.sort(key=lambda x: abs(x["advantage"])) # Keep only the top 'buffer_size' elements (highest priority) - self.buffer = self.buffer[-self.buffer_size:] + self.buffer = self.buffer[-self.buffer_size :] class QwenGRPOTrainer(Trainer): @@ -414,12 +411,11 @@ def __init__( for weight in args.reward_weights: if not isinstance(weight, float) and not callable(weight): raise TypeError(f"Reward weights must be floats or callables, but found {type(weight)}") - self.reward_weights_config = args.reward_weights # Store the original list/config + self.reward_weights_config = args.reward_weights # Store the original list/config else: # Default to list of 1.0 floats self.reward_weights_config = torch.ones(len(reward_funcs), dtype=torch.float32) - # Reward processing class if reward_processing_classes is None: reward_processing_classes = [None] * len(reward_funcs) @@ -470,24 +466,27 @@ def data_collator(features): # No data collation is needed in GRPO self.epsilon_low = args.epsilon_low self.epsilon_high = args.epsilon_high - # TODO: make these configurable args self.use_ssr_buffer = True self.ssr_alpha = 2.0 self.ssr_total_buffer_size = 10000 - self.ssr_persist_steps = 100000 + self.ssr_persist_steps = 10000000 # size at which the probability ramp reaches max_ssr_use_prob - self.ssr_ramp_up_size = 500 + self.ssr_ramp_up_size = 1000 # if the buffer is smaller than this, we don't use it. Instead, draw from the dataset. This helps ensure we only select the best quality examples from the buffer on average. self.min_ssr_buffer_size = 50 - # the probability of using the SSR buffer on each step - self.max_ssr_use_prob = 0.9 + # the maximum probability of using the SSR buffer on each step + self.max_ssr_use_prob = 0.75 if not 0 <= self.max_ssr_use_prob <= 1: raise ValueError("max_ssr_use_prob must be between 0 and 1") if self.use_ssr_buffer: - self.ssr_buffer = SSRBuffer(alpha=self.ssr_alpha, total_buffer_size=self.ssr_total_buffer_size, persist_steps=self.ssr_persist_steps) + self.ssr_buffer = SSRBuffer( + alpha=self.ssr_alpha, + total_buffer_size=self.ssr_total_buffer_size, + persist_steps=self.ssr_persist_steps, + ) else: self.ssr_buffer = None @@ -583,7 +582,10 @@ def data_collator(features): # No data collation is needed in GRPO enable_prefix_caching=True, max_model_len=self.args.vllm_max_model_len, # Setting this to 1 as we only have one image per prompt for now. Setting it longer requires more resources, which is wasteful until we need it. - limit_mm_per_prompt={"image": self.args.limit_image_per_prompt, "video": self.args.limit_video_per_prompt}, + limit_mm_per_prompt={ + "image": self.args.limit_image_per_prompt, + "video": self.args.limit_video_per_prompt, + }, ) self.sampling_params = SamplingParams( temperature=args.temperature, @@ -713,8 +715,6 @@ def _move_model_to_vllm(self): vlm_model.load_weights(state_dict.items()) def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: - - device = self.accelerator.device if not self.env: @@ -731,15 +731,17 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if buffer_size >= self.min_ssr_buffer_size: # Calculate dynamic probability based on buffer size relative to the ramp-up size size_range = self.ssr_ramp_up_size - self.min_ssr_buffer_size - if size_range > 0: # Avoid division by zero + if size_range > 0: # Avoid division by zero # Calculate the ramp progress, ensuring it doesn't exceed 1 prob_ramp = min(1.0, (buffer_size - self.min_ssr_buffer_size) / size_range) current_ssr_use_prob = prob_ramp * self.max_ssr_use_prob - else: # If min and ramp-up size are the same, use max probability if buffer is at least min size - current_ssr_use_prob = self.max_ssr_use_prob if buffer_size >= self.min_ssr_buffer_size else 0.0 + else: # If min and ramp-up size are the same, use max probability if buffer is at least min size + current_ssr_use_prob = ( + self.max_ssr_use_prob if buffer_size >= self.min_ssr_buffer_size else 0.0 + ) else: - # If buffer is smaller than min size, probability is 0 + # If buffer is smaller than min size, probability is 0 current_ssr_use_prob = 0.0 print(f"Current SSR use probability: {current_ssr_use_prob:.4f}") @@ -749,15 +751,14 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s else: should_use_buffer_list = [None for _ in range(self.accelerator.num_processes)] # Non-zero processes also need the flag, although they don't decide it - buffer_size = None # Placeholder for non-zero processes + buffer_size = None # Placeholder for non-zero processes broadcast_object_list(should_use_buffer_list, from_process=0) - should_use_buffer = should_use_buffer_list[0] # All processes now know if buffer is used + should_use_buffer = should_use_buffer_list[0] # All processes now know if buffer is used if should_use_buffer: # process 0 will draw from the buffer, the other processes will hang out if self.accelerator.process_index == 0: - # put the current example in the buffer with a small advantage so we avoid "throwing it away" self.ssr_buffer.add_example(inputs[0], 0.01) @@ -774,12 +775,10 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s else: # if we are not using the SSR buffer, we just use the inputs passed into the function and signal that this example is not from the buffer should_use_buffer = False - buffer_size = 0 # Ensure buffer_size is defined even if not using SSR buffer + buffer_size = 0 # Ensure buffer_size is defined even if not using SSR buffer print(f"should_use_buffer: {should_use_buffer}") - - # TODO: This is a hack that we should probably fix. # without this, each gpu receives different inputs, screwing up the advantage computation. # Simple synchronization of inputs across processes @@ -837,10 +836,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s sampling_params=self.sampling_params, ) - completion_ids = generated_output['ids'] - completion_messages = generated_output.get('messages', None) - completion_mask = generated_output.get('mask', None) - + completion_ids = generated_output["ids"] + completion_messages = generated_output.get("messages", None) + completion_mask = generated_output.get("mask", None) else: completion_ids = [None] * len(all_env_inputs) @@ -880,7 +878,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s print("No completion mask provided. Computing mask based on EOS positions.") # Fallback: compute mask based on EOS positions if not provided eos_idx = torch.tensor([len(ids) - 1 for ids in completion_ids], device=device) - sequence_indices = torch.arange(completion_ids.size(1), device=device).expand(completion_ids.size(0), -1) + sequence_indices = torch.arange(completion_ids.size(1), device=device).expand( + completion_ids.size(0), -1 + ) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) @@ -897,9 +897,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if len(new_images) > 0: # use the processor to get pixel_values and image_grid_thw for the new images new_images_info = self.processing_class( - text='', + text="", images=new_images, - return_tensors='pt', + return_tensors="pt", padding=True, ) new_pixel_values = new_images_info["pixel_values"] @@ -1014,7 +1014,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # # Add synchronization point to prevent processes from getting out of sync # self.accelerator.wait_for_everyone() - # Calculate current weights based on schedule/config current_step_weights = [] current_global_step = self.state.global_step @@ -1050,7 +1049,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) - print("Finished with advantages") # if we are using the SSR buffer, we need to populate it with the current batch of examples @@ -1068,7 +1066,6 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s self.accelerator.wait_for_everyone() - # Slice to keep only the local part of the data process_slice = slice( self.accelerator.process_index * len(conversations), @@ -1097,9 +1094,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # For logging inputs_data_to_log = { - key: gather_object( - [i[key] for i in inputs if key in i] - ) for key in self.inputs_to_log + key: gather_object([i[key] for i in inputs if key in i]) for key in self.inputs_to_log } # if the value is torch.Tensor, convert it to a list for key, value in inputs_data_to_log.items(): From d3960848ff4280edf9c53754ea7255fe80aca5a9 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Wed, 23 Apr 2025 10:34:12 -0700 Subject: [PATCH 14/17] reduce it a bit more to 65 --- trl/trainer/qwen_grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 4d91c5332de..3716a80f0ab 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -476,7 +476,7 @@ def data_collator(features): # No data collation is needed in GRPO # if the buffer is smaller than this, we don't use it. Instead, draw from the dataset. This helps ensure we only select the best quality examples from the buffer on average. self.min_ssr_buffer_size = 50 # the maximum probability of using the SSR buffer on each step - self.max_ssr_use_prob = 0.75 + self.max_ssr_use_prob = 0.65 if not 0 <= self.max_ssr_use_prob <= 1: raise ValueError("max_ssr_use_prob must be between 0 and 1") From 483c9a71cd9aec245d83af2125ab2ed5a20419c6 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Fri, 25 Apr 2025 21:11:54 -0700 Subject: [PATCH 15/17] support guided decoding, but its too slow for our work --- trl/trainer/grpo_config.py | 4 +++- trl/trainer/qwen_grpo_trainer.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index a92aa1ceea6..1af54661722 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -221,7 +221,9 @@ class GRPOConfig(TrainingArguments): ) epsilon_high: float = field( default=0.28, - metadata={"help": "Value to clip at 1+epsilon. Generally set this higher than epsilon_low based on DAPO paper results."}, + metadata={ + "help": "Value to clip at 1+epsilon. Generally set this higher than epsilon_low based on DAPO paper results." + }, ) reward_weights: Optional[list[float]] = field( default=None, diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 3716a80f0ab..4cad3aad782 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -67,6 +67,7 @@ if is_vllm_available(): from vllm import LLM, SamplingParams + from vllm.sampling_params import GuidedDecodingParams if is_wandb_available(): import wandb @@ -335,6 +336,7 @@ def __init__( shuffle_dataset: bool = True, image_pad_id: int = 151655, inputs_to_log: list[str] = [], + guided_regex: Optional[str] = None, ): # Args if args is None: @@ -587,9 +589,17 @@ def data_collator(features): # No data collation is needed in GRPO "video": self.args.limit_video_per_prompt, }, ) + + if guided_regex is not None: + guided_decoding_params = GuidedDecodingParams(regex=guided_regex) + + else: + guided_decoding_params = None + self.sampling_params = SamplingParams( temperature=args.temperature, max_tokens=self.max_completion_length, + guided_decoding=guided_decoding_params, ) self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation From 0cba6766da1201714c0205c58094aef61a42afa3 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Sat, 26 Apr 2025 13:36:31 -0700 Subject: [PATCH 16/17] log observed advantages --- trl/trainer/qwen_grpo_trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 4cad3aad782..12c5bff23cc 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -1066,6 +1066,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if self.use_ssr_buffer and self.accelerator.process_index == 0: # if the average absolute advantage is greater than 0, we add that example to the buffer with the average advantage average_abs_advantage = torch.abs(advantages).mean().item() + self._metrics["avg_abs_advantage"].append(average_abs_advantage) if average_abs_advantage > 0: print(f"Adding {inputs[0]} to the SSR buffer with advantage {average_abs_advantage}") @@ -1187,8 +1188,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # Log the metrics - completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() + completion_lengths_tensor = self.accelerator.gather_for_metrics(completion_mask.sum(1)) + completion_length = completion_lengths_tensor.float().mean().item() self._metrics["completion_length"].append(completion_length) + max_completion_length = completion_lengths_tensor.float().max().item() + self._metrics["max_completion_length"].append(max_completion_length) mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) From 388865fc1c61502709a271afb0dbe29d2722b4b2 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Sat, 26 Apr 2025 22:38:44 -0700 Subject: [PATCH 17/17] new metric --- trl/trainer/qwen_grpo_trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 12c5bff23cc..8950e9a2eac 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -1067,7 +1067,12 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # if the average absolute advantage is greater than 0, we add that example to the buffer with the average advantage average_abs_advantage = torch.abs(advantages).mean().item() self._metrics["avg_abs_advantage"].append(average_abs_advantage) - if average_abs_advantage > 0: + + # record if the given example provided training signal. This is true if the average absolute advantage is greater than 0 + provided_training_signal = average_abs_advantage > 0 + self._metrics["provided_training_signal"].append(provided_training_signal) + + if provided_training_signal: print(f"Adding {inputs[0]} to the SSR buffer with advantage {average_abs_advantage}") # add the example to the buffer with the average advantage