diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index ac2ee2f9e40..1af54661722 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -214,6 +214,17 @@ class GRPOConfig(TrainingArguments): default=0.04, metadata={"help": "KL coefficient."}, ) + + epsilon_low: float = field( + default=0.2, + metadata={"help": "Value to clip at 1-epsilon."}, + ) + epsilon_high: float = field( + default=0.28, + metadata={ + "help": "Value to clip at 1+epsilon. Generally set this higher than epsilon_low based on DAPO paper results." + }, + ) reward_weights: Optional[list[float]] = field( default=None, metadata={ @@ -256,4 +267,4 @@ class GRPOConfig(TrainingArguments): limit_video_per_prompt: int = field( default=0, metadata={"help": "Limit the number of videos per prompt for vllm generation."}, - ) \ No newline at end of file + ) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index d0e2efca84f..8950e9a2eac 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import random import textwrap import warnings from collections import defaultdict @@ -19,6 +20,7 @@ from typing import Any, Callable, Optional, Sized, Union from unittest.mock import patch +import numpy as np import torch import torch.utils.data import transformers @@ -65,6 +67,7 @@ if is_vllm_available(): from vllm import LLM, SamplingParams + from vllm.sampling_params import GuidedDecodingParams if is_wandb_available(): import wandb @@ -130,6 +133,109 @@ def __len__(self): return self.num_samples * self.repeat_count +class SSRBuffer: + """ + Selective Sample Replay manager. Maintains a buffer of high entropy samples for training. + """ + + def __init__(self, alpha: float = 2.0, total_buffer_size: int = 1000, persist_steps: int = 1000): + """ + Args: + alpha: float, handles prioritization intensity>=0. + alpha = 0 means no prioritization, + alpha = 1 means prioritization linearly proportional to advantage, + alpha > 1 means more prioritization for high entropy samples. + total_buffer_size: int, maximum size of the buffer. After the buffer is full, the oldest samples will be discarded. + persist_steps: int, number of steps an example lives in the buffer. After this many steps, the example will be discarded. + """ + + if alpha <= 0: + raise ValueError("alpha must be greater than 0") + self.alpha = alpha + + if total_buffer_size <= 0: + raise ValueError("total_buffer_size must be greater than 0") + self.total_buffer_size = total_buffer_size + + if persist_steps <= 0: + raise ValueError("persist_steps must be greater than 0") + self.persist_steps = persist_steps + + self.buffer = [] + + # element of buffer format: + # { + # "example": dict, - the example to be replayed + # "advantage": float, - the advantage observed last training step + # "ttl": int, - time to live, number of steps before the example will be discarded from the buffer + # } + + def add_example(self, example: dict, advantage: float) -> None: + """ + Add an example to the buffer. + """ + # NOTE: We don't check if the buffer is full here. We'll do it at the end of each training step. + buffer_element = {"example": example, "advantage": advantage, "ttl": self.persist_steps} + self.buffer.append(buffer_element) + + @property + def buffer_size(self) -> int: + """ + Number of examples in the buffer. + """ + return len(self.buffer) + + def draw_example(self) -> dict: + """ + Returns an example from the buffer. The probabilty of drawing an example j is: + abs(advantage_j)**(self.alpha) / sum(abs(advantage_i)**(self.alpha) for i in range(len(self.buffer))) + + Raises a ValueError if the buffer is empty, otherwise, pops an example from the buffer and returns it. + """ + + if self.buffer_size == 0: + raise ValueError("Buffer is empty") + + values = [] + for buffer_element in self.buffer: + values.append(abs(buffer_element["advantage"]) ** self.alpha) + + total = sum(values) + probabilities = [value / total for value in values] + + # check that the probabilities sum to 1, with some tolerance + if not np.isclose(sum(probabilities), 1.0, atol=1e-6): + raise ValueError(f"Probabilities do not sum to 1, but instead sum to {sum(probabilities)}") + + # choose the index of the example to draw + index = np.random.choice(range(len(self.buffer)), p=probabilities) + + # pop the example from the buffer + buffer_element = self.buffer.pop(index) + + return buffer_element["example"] + + def step(self) -> None: + """ + Handles reducing ttl's on objects in the buffer and removes objects that have expired. + + It is to be called once at the end of training step. + """ + # decrement the ttl of each buffer element + for buffer_element in self.buffer: + buffer_element["ttl"] -= 1 + + # remove buffer elements that have expired + self.buffer = [b for b in self.buffer if b["ttl"] > 0] + + # if the buffer is too big, discard the oldest examples + if len(self.buffer) > self.buffer_size: + # Sort by absolute advantage (priority), ascending + self.buffer.sort(key=lambda x: abs(x["advantage"])) + # Keep only the top 'buffer_size' elements (highest priority) + self.buffer = self.buffer[-self.buffer_size :] + + class QwenGRPOTrainer(Trainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -230,6 +336,7 @@ def __init__( shuffle_dataset: bool = True, image_pad_id: int = 151655, inputs_to_log: list[str] = [], + guided_regex: Optional[str] = None, ): # Args if args is None: @@ -299,12 +406,17 @@ def __init__( if args.reward_weights is not None: if len(args.reward_weights) != len(reward_funcs): raise ValueError( - f"Number of reward weights ({len(len(args.reward_weights))}) must match number of reward " + f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " f"functions ({len(reward_funcs)})" ) - self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + # Validate types (float or callable) + for weight in args.reward_weights: + if not isinstance(weight, float) and not callable(weight): + raise TypeError(f"Reward weights must be floats or callables, but found {type(weight)}") + self.reward_weights_config = args.reward_weights # Store the original list/config else: - self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + # Default to list of 1.0 floats + self.reward_weights_config = torch.ones(len(reward_funcs), dtype=torch.float32) # Reward processing class if reward_processing_classes is None: @@ -352,6 +464,34 @@ def data_collator(features): # No data collation is needed in GRPO self._metrics = defaultdict(list) self.log_completions = args.log_completions + # intialize epsilon + self.epsilon_low = args.epsilon_low + self.epsilon_high = args.epsilon_high + + # TODO: make these configurable args + self.use_ssr_buffer = True + self.ssr_alpha = 2.0 + self.ssr_total_buffer_size = 10000 + self.ssr_persist_steps = 10000000 + # size at which the probability ramp reaches max_ssr_use_prob + self.ssr_ramp_up_size = 1000 + # if the buffer is smaller than this, we don't use it. Instead, draw from the dataset. This helps ensure we only select the best quality examples from the buffer on average. + self.min_ssr_buffer_size = 50 + # the maximum probability of using the SSR buffer on each step + self.max_ssr_use_prob = 0.65 + + if not 0 <= self.max_ssr_use_prob <= 1: + raise ValueError("max_ssr_use_prob must be between 0 and 1") + + if self.use_ssr_buffer: + self.ssr_buffer = SSRBuffer( + alpha=self.ssr_alpha, + total_buffer_size=self.ssr_total_buffer_size, + persist_steps=self.ssr_persist_steps, + ) + else: + self.ssr_buffer = None + super().__init__( model=model, args=args, @@ -363,6 +503,9 @@ def data_collator(features): # No data collation is needed in GRPO optimizers=optimizers, ) + # we need this line to avoid serializing the reward weights on checkpoint save. The reward schedules are not serializable. + self.args.reward_weights = None + # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations num_processes = self.accelerator.num_processes global_batch_size = args.per_device_train_batch_size * num_processes @@ -441,11 +584,22 @@ def data_collator(features): # No data collation is needed in GRPO enable_prefix_caching=True, max_model_len=self.args.vllm_max_model_len, # Setting this to 1 as we only have one image per prompt for now. Setting it longer requires more resources, which is wasteful until we need it. - limit_mm_per_prompt={"image": self.args.limit_image_per_prompt, "video": self.args.limit_video_per_prompt}, + limit_mm_per_prompt={ + "image": self.args.limit_image_per_prompt, + "video": self.args.limit_video_per_prompt, + }, ) + + if guided_regex is not None: + guided_decoding_params = GuidedDecodingParams(regex=guided_regex) + + else: + guided_decoding_params = None + self.sampling_params = SamplingParams( temperature=args.temperature, max_tokens=self.max_completion_length, + guided_decoding=guided_decoding_params, ) self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation @@ -576,6 +730,65 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if not self.env: raise ValueError("No environment provided. Only supporting envs now. ") + if self.use_ssr_buffer: + # each process needs to know if we are using the buffer. Process 0 decides and then broadcasts the decision to all processes + if self.accelerator.process_index == 0: + print(f"buffer size: {self.ssr_buffer.buffer_size}") + # step the buffer, needs to happen at each training step + self.ssr_buffer.step() + + buffer_size = self.ssr_buffer.buffer_size + if buffer_size >= self.min_ssr_buffer_size: + # Calculate dynamic probability based on buffer size relative to the ramp-up size + size_range = self.ssr_ramp_up_size - self.min_ssr_buffer_size + if size_range > 0: # Avoid division by zero + # Calculate the ramp progress, ensuring it doesn't exceed 1 + prob_ramp = min(1.0, (buffer_size - self.min_ssr_buffer_size) / size_range) + current_ssr_use_prob = prob_ramp * self.max_ssr_use_prob + else: # If min and ramp-up size are the same, use max probability if buffer is at least min size + current_ssr_use_prob = ( + self.max_ssr_use_prob if buffer_size >= self.min_ssr_buffer_size else 0.0 + ) + + else: + # If buffer is smaller than min size, probability is 0 + current_ssr_use_prob = 0.0 + + print(f"Current SSR use probability: {current_ssr_use_prob:.4f}") + should_use_buffer = random.random() < current_ssr_use_prob and buffer_size >= self.min_ssr_buffer_size + + should_use_buffer_list = [should_use_buffer for _ in range(self.accelerator.num_processes)] + else: + should_use_buffer_list = [None for _ in range(self.accelerator.num_processes)] + # Non-zero processes also need the flag, although they don't decide it + buffer_size = None # Placeholder for non-zero processes + + broadcast_object_list(should_use_buffer_list, from_process=0) + should_use_buffer = should_use_buffer_list[0] # All processes now know if buffer is used + + if should_use_buffer: + # process 0 will draw from the buffer, the other processes will hang out + if self.accelerator.process_index == 0: + # put the current example in the buffer with a small advantage so we avoid "throwing it away" + self.ssr_buffer.add_example(inputs[0], 0.01) + + print("Drawing from buffer") + example_from_buffer = self.ssr_buffer.draw_example() + local_inputs = [deepcopy(example_from_buffer) for _ in range(len(inputs))] + print(f"local_inputs after drawing from buffer length: {len(local_inputs)=}") + else: + local_inputs = [None for _ in range(len(inputs))] + + # broadcast the inputs (from the buffer) to all processes + broadcast_object_list(local_inputs, from_process=0) + inputs = local_inputs + else: + # if we are not using the SSR buffer, we just use the inputs passed into the function and signal that this example is not from the buffer + should_use_buffer = False + buffer_size = 0 # Ensure buffer_size is defined even if not using SSR buffer + + print(f"should_use_buffer: {should_use_buffer}") + # TODO: This is a hack that we should probably fix. # without this, each gpu receives different inputs, screwing up the advantage computation. # Simple synchronization of inputs across processes @@ -633,10 +846,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s sampling_params=self.sampling_params, ) - completion_ids = generated_output['ids'] - completion_messages = generated_output.get('messages', None) - completion_mask = generated_output.get('mask', None) - + completion_ids = generated_output["ids"] + completion_messages = generated_output.get("messages", None) + completion_mask = generated_output.get("mask", None) else: completion_ids = [None] * len(all_env_inputs) @@ -676,7 +888,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s print("No completion mask provided. Computing mask based on EOS positions.") # Fallback: compute mask based on EOS positions if not provided eos_idx = torch.tensor([len(ids) - 1 for ids in completion_ids], device=device) - sequence_indices = torch.arange(completion_ids.size(1), device=device).expand(completion_ids.size(0), -1) + sequence_indices = torch.arange(completion_ids.size(1), device=device).expand( + completion_ids.size(0), -1 + ) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) @@ -693,9 +907,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if len(new_images) > 0: # use the processor to get pixel_values and image_grid_thw for the new images new_images_info = self.processing_class( - text='', + text="", images=new_images, - return_tensors='pt', + return_tensors="pt", padding=True, ) new_pixel_values = new_images_info["pixel_values"] @@ -706,10 +920,12 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s new_pixel_values = new_pixel_values.to(device) new_image_grid_thw = new_image_grid_thw.to(device) pixel_values = torch.cat([pixel_values, new_pixel_values], dim=0) - image_grid_thw = torch.cat([image_grid_thw, new_image_grid_thw], dim=0) + image_grid_thw = torch.cat([image_grid_thw, new_image_grid_thw], dim=0) else: raise ValueError("Attempted to generate with HF. Only supporting vllm now.") + print("Finished with generation") + # Concatenate prompt_mask with completion_mask for logit computation attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C) @@ -736,6 +952,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s logits_to_keep, ) + print("Finished with ref logits") + # Decode the generated completions completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): @@ -781,9 +999,14 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s output_reward_func = reward_func(prompts=conversations, completions=completions, **reward_kwargs) rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + print("Finished with rewards per function") # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the # completions may be distributed across processes rewards_per_func = gather(rewards_per_func) + print(f"Finished with gathering rewards per function, {rewards_per_func=}, {self.accelerator.process_index=}") + + self.accelerator.wait_for_everyone() + print(f"Finished with waiting for everyone, {self.accelerator.process_index=}") # # DEBUG: Verify prompt consistency across completions in each group # TODO: remove this probably? @@ -801,18 +1024,64 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # # Add synchronization point to prevent processes from getting out of sync # self.accelerator.wait_for_everyone() - # Apply weights to each reward function's output and sum - rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) + # Calculate current weights based on schedule/config + current_step_weights = [] + current_global_step = self.state.global_step + for weight_config in self.reward_weights_config: + if callable(weight_config): + # Call the schedule function with the current step + current_weight = weight_config(current_global_step) + current_step_weights.append(current_weight) + else: + # Use the fixed float weight + current_step_weights.append(weight_config) + + current_step_weights_tensor = torch.tensor(current_step_weights, dtype=torch.float32, device=device) + + # Log the calculated weights for this step + for i, weight in enumerate(current_step_weights): + reward_func = self.reward_funcs[i] + reward_func_name = reward_func.__name__ + self._metrics[f"reward_weights/{reward_func_name}"].append(weight) + # Apply calculated weights to each reward function's output and sum + rewards = (rewards_per_func * current_step_weights_tensor.to(device).unsqueeze(0)).sum(dim=1) + print(f"Finished with reward weighting, {rewards=} {self.accelerator.process_index=}") # Compute grouped-wise rewards mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) + print(f"Finished with grouped rewards, {mean_grouped_rewards=} {self.accelerator.process_index=}") + print(f"Finished with grouped rewards std, {std_grouped_rewards=} {self.accelerator.process_index=}") + # Normalize the rewards to compute the advantages mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) + print("Finished with advantages") + + # if we are using the SSR buffer, we need to populate it with the current batch of examples + # we DO allow an example from the buffer to be re-added after it is popped + if self.use_ssr_buffer and self.accelerator.process_index == 0: + # if the average absolute advantage is greater than 0, we add that example to the buffer with the average advantage + average_abs_advantage = torch.abs(advantages).mean().item() + self._metrics["avg_abs_advantage"].append(average_abs_advantage) + + # record if the given example provided training signal. This is true if the average absolute advantage is greater than 0 + provided_training_signal = average_abs_advantage > 0 + self._metrics["provided_training_signal"].append(provided_training_signal) + + if provided_training_signal: + print(f"Adding {inputs[0]} to the SSR buffer with advantage {average_abs_advantage}") + + # add the example to the buffer with the average advantage + self.ssr_buffer.add_example(inputs[0], average_abs_advantage) + + print("Finished with repopulating SSR buffer") + + self.accelerator.wait_for_everyone() + # Slice to keep only the local part of the data process_slice = slice( self.accelerator.process_index * len(conversations), @@ -841,9 +1110,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s # For logging inputs_data_to_log = { - key: gather_object( - [i[key] for i in inputs if key in i] - ) for key in self.inputs_to_log + key: gather_object([i[key] for i in inputs if key in i]) for key in self.inputs_to_log } # if the value is torch.Tensor, convert it to a list for key, value in inputs_data_to_log.items(): @@ -870,6 +1137,9 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s if wandb.run is not None and self.accelerator.is_main_process: wandb.log({"completions": wandb.Table(dataframe=df)}) + self._metrics["buffer_size"].append(self.ssr_buffer.buffer_size if self.use_ssr_buffer else 0) + self._metrics["buffer_usage"].append(float(should_use_buffer)) + return { "prompt_ids": prompt_ids, "prompt_mask": prompt_mask, @@ -912,13 +1182,22 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N # x - x.detach() allows for preserving gradients from x advantages = inputs["advantages"] - per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) + coef_1 = torch.exp(per_token_logps - per_token_logps.detach()) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = torch.min(per_token_loss1, per_token_loss2) + per_token_loss = -(per_token_loss - self.beta * per_token_kl) loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() # Log the metrics - completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() + completion_lengths_tensor = self.accelerator.gather_for_metrics(completion_mask.sum(1)) + completion_length = completion_lengths_tensor.float().mean().item() self._metrics["completion_length"].append(completion_length) + max_completion_length = completion_lengths_tensor.float().max().item() + self._metrics["max_completion_length"].append(max_completion_length) mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())