Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,20 @@ def new(self, **kwargs):
args.update(kwargs)
return GenerationHyperparameters(**args)


@dataclass
class PRMRewardHyperparameters:
reward_shaping_alpha: float = field(
default=0.02,
metadata={"help": "reward shaping alpha"},
)
use_clip: bool = field(
default=True,
metadata={"help": "Whether to use clip mechanism."},
)
use_delta: bool = field(
default=True,
metadata={"help": "Whether to use delta mechanism."},
)
# Train Engine Configs


Expand Down Expand Up @@ -577,7 +590,7 @@ class SGLangConfig:
max_lora_rank: int | None = None
lora_target_modules: List[str] | None = None
lora_paths: List[str] | None = None
max_loaded_loras: int = 1
# max_loaded_loras: int = 1
max_loras_per_batch: int = 1
lora_backend: str = "triton"
# logging
Expand Down Expand Up @@ -1118,6 +1131,19 @@ class GRPOConfig(BaseExperimentConfig):
actor: PPOActorConfig = field(default_factory=PPOActorConfig)
ref: PPOActorConfig = field(default_factory=PPOActorConfig)

@dataclass
class PRMConfig(BaseExperimentConfig):
async_training: bool = field(default=True)
prm_path: str = field(default="")
gconfig: GenerationHyperparameters = field(
default_factory=GenerationHyperparameters
)
prmconfig: PRMRewardHyperparameters = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like that we can just inheirt GRPOConfig and add two new fields prm_path and reward_shaping_alpha? BTW if you refer to reward scaling, you can use actor.reward_scaling rather than creating a new field.

default_factory=PRMRewardHyperparameters
)
rollout: InferenceEngineConfig = field(default_factory=InferenceEngineConfig)
actor: PPOActorConfig = field(default_factory=PPOActorConfig)
ref: PPOActorConfig = field(default_factory=PPOActorConfig)

@dataclass
class PPOConfig(GRPOConfig):
Expand Down
2 changes: 1 addition & 1 deletion areal/api/reward_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _recreate_executor(cls, executor_key, max_workers):
return cls._executors[executor_key]
return None

async def __call__(self, *args, **kwargs) -> float:
async def __call__(self, *args, **kwargs):
last_exception = None

for attempt in range(self.max_retries + 1):
Expand Down
281 changes: 280 additions & 1 deletion areal/engine/ppo/actor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
import warnings

import torch

Expand Down Expand Up @@ -287,6 +288,224 @@ def compute_advantages(self, *args, **kwargs) -> None:
def ppo_update(self, *args, **kwargs) -> List[Dict[str, float]]:
return self.actor.ppo_update(*args, **kwargs)

class FSDPPPOActorDense(FSDPPPOActor):

def __init__(self, config: PPOActorConfig):
super().__init__(config)
self.actor = DensePPOActor(config, self)

class DensePPOActor(PPOActor):
def __init__(self, config: PPOActorConfig, engine: TrainEngine):
super().__init__(config, engine)
def compute_advantages(self, data: Dict[str, Any]) -> None:
bs = data["input_ids"].shape[0]
max_seqlen = data["input_ids"].shape[1]
batch_indices = torch.arange(
bs, device=data["input_ids"].device, dtype=torch.long
)

# Reward Penalty on length
if self.config.overlong_reward_penalty:

overlong_tokens = self.config.overlong_tokens
overlong_penalty_factor = self.config.overlong_penalty_factor

data = reward_overlong_penalty(
data,
overlong_tokens=overlong_tokens,
overlong_penalty_factor=overlong_penalty_factor,
max_response_length=self.config.max_new_tokens,
)

# Reward Scaling
reward_score = data["rewards"]
reward_score = (reward_score + self.reward_bias) * self.reward_scaling
reward_score = torch.clip(
reward_score, max=self.reward_clip, min=-self.reward_clip
)
if self.reward_norm:
reward_score = self.reward_norm(reward_score)

loss_mask = data["loss_mask"].float()
loss_mask = torch.roll(loss_mask, shifts=-1, dims=-1)
# Apply the mask to log probabilities.
if not self.config.use_decoupled_loss and self.config.recompute_logprob:
# Overwrite logprobs produced by the inference engine
old_logp = data["logprobs"] = data["prox_logp"]
else:
old_logp = torch.roll(data["logprobs"], shifts=-1, dims=-1)
if not self.config.use_decoupled_loss:
# prox logp not available, use inferenced logp
data["prox_logp"] = old_logp
ref_logp = data.get("ref_logp", torch.zeros_like(old_logp))
ref_logp *= loss_mask
old_logp *= loss_mask

# Compute KL-regularized rewards.
attn_mask = data["attention_mask"]
seqlens = attn_mask.sum(-1).long()
seq_no_eos_mask = seqlens == attn_mask.shape[1]
rewards = -self.kl_ctl * (old_logp - ref_logp)
kl_rewards = rewards.clone()
# KL rewards at the next token after eos is zero.
rewards[batch_indices, seqlens - 1] = 0
indices = torch.clip(seqlens - 2, min=0)
# print(f"reward_score: {reward_score.shape}, {reward_score}")
# print(f"rewards before: {rewards.shape}, {rewards}")
if self.mask_no_eos_with_zero:
rewards[batch_indices, :] += torch.where(
seq_no_eos_mask, 0, reward_score
)
else:
rewards[batch_indices, :] += reward_score
# print(f"rewards after: {rewards}")
# Compute GAE.
if "values" not in data:
values = torch.zeros_like(rewards)
else:
values = data["values"]
advantages_reversed = [
torch.zeros(bs, dtype=torch.float32, device=values.device)
]
lastgaelam = 0
nextvalues = values[:, max_seqlen - 1] * seq_no_eos_mask
for t in reversed(range(max_seqlen - 1)):
delta = rewards[:, t] + self.discount * nextvalues - values[:, t]
newgaelam = delta + self.discount * self.gae_lambda * lastgaelam

# Skip tokens that do not contribute to the loss
mask = loss_mask[:, t]
nextvalues = nextvalues * (1 - mask) + values[:, t] * mask
lastgaelam = lastgaelam * (1 - mask) + newgaelam * mask
advantages_reversed.append(lastgaelam)

advantages = torch.stack(advantages_reversed[::-1], dim=1)
data["returns"] = advantages + values

# Optionally perform advantage normalization.
if self.adv_norm is not None:
advantages = self.adv_norm(advantages, loss_mask)

# Store data in the dict.
data["advantages"] = advantages
data["kl_rewards"] = kl_rewards
data["tot_rewards"] = rewards
data["loss_mask"] = loss_mask
# because we have rolled old_logp by -1
data["logprobs"] = old_logp

def ppo_update(self, data: Dict[str, Any]) -> List[Dict[str, float]]:

if self.dynamic_sampling and len(data["rewards"]) % self.group_size == 0:
data, sampling_stat = dynamic_sampling_dense_reward(data, self.group_size)

attn_mask = data["attention_mask"]
loss_mask = data["loss_mask"]
reward_score = data["rewards"]
seqlens = attn_mask.sum(-1)

all_stats = []
########## Logging code starts ##########
result_denominators = {
"correct_n_seqs": (reward_score[:, -1] > 0).bool(),
"incorrect_n_seqs": (reward_score[:, -1] <= 0).bool(),
}
if self.config.log_agent_stats:
assert (
"begin_of_trajectory" in data
), "'begin_of_trajectory' is expected to log agent statistics"
assert (
len(self.config.log_agent_stats_keys) > 0
), "`log_agent_stats_keys` should not be empty when log_agent_stats=True"
agent_denominator = (data["begin_of_trajectory"] > 0).bool()
result_denominators["agent"] = agent_denominator
global_denominators = dict(
n_seqs=torch.ones_like(reward_score[:, 0], dtype=torch.bool),
n_tokens=torch.ones_like(loss_mask, dtype=torch.bool),
n_valid_tokens=loss_mask.bool(),
**result_denominators,
)
stats_tracker.denominator(**global_denominators)
stats_tracker.stat(
correct_seq_len=seqlens.float(), denominator="correct_n_seqs"
)
stats_tracker.stat(
incorrect_seq_len=seqlens.float(), denominator="incorrect_n_seqs"
)

stats = dict(
advantages=data["advantages"],
kl_rewards=data["kl_rewards"],
final_reward=data["tot_rewards"],
)
stats_tracker.stat(**stats, denominator="n_valid_tokens")

prompt_lens = []
prompt_lens = data["attention_mask"].sum(-1) - data["loss_mask"].sum(-1)
seq_stats = dict(
no_eos_ratios=(seqlens == attn_mask.shape[-1]).float(),
task_reward=reward_score[:, -2].float(),
prompt_len=prompt_lens.float(),
seq_len=seqlens.float(),
)
stats_tracker.stat(**seq_stats, denominator="n_seqs")
scalars = dict(
mask_no_eos_with_zero=self.config.mask_no_eos_with_zero,
eps_clip=self.config.eps_clip,
)
if self.config.c_clip is not None:
scalars["c_clip"] = self.config.c_clip
scalars["use_dual_clip"] = 1
else:
scalars["use_dual_clip"] = 0
if self.config.behav_imp_weight_cap is not None:
scalars["behav_imp_weight_cap"] = self.config.behav_imp_weight_cap
stats_tracker.scalar(**scalars)

if self.config.log_agent_stats:
stats_tracker.stat(
**{k: data[k].float() for k in self.config.log_agent_stats_keys},
denominator="agent",
)

global_stats = stats_tracker.export(
reduce_group=self.engine.data_parallel_group
)
for k in global_denominators:
keys = list(global_stats.keys())
for k2 in keys:
if k2.endswith(k):
global_stats.pop(k2)
########## Logging code ends ##########

for key in ["rewards", "tot_rewards", "kl_rewards", "versions"]:
data.pop(key, None)
# NOTE: calling engine.train() is critical to enabling gradient checkpointing
self.engine.train()
mb_inputs = split_padded_tensor_dict_into_mb_list(
data,
mb_spec=MicroBatchSpec(n_mbs=self.config.ppo_n_minibatches),
)
for mb in mb_inputs.mbs:
train_stat = self.engine.train_batch(
mb,
loss_fn=functools.partial(
grpo_loss_fn,
temperature=self.temperature,
eps_clip=self.config.eps_clip,
eps_clip_higher=self.config.eps_clip_higher,
c_clip=self.config.c_clip,
behav_imp_weight_cap=self.config.behav_imp_weight_cap,
),
loss_weight_fn=lambda x: x["loss_mask"].count_nonzero(),
)
stats_tracker.scalar(**train_stat)
all_stats.append(
stats_tracker.export(reduce_group=self.engine.data_parallel_group)
)
all_stats[0].update(global_stats)
return all_stats


def grpo_loss_fn(
logits: torch.Tensor,
Expand Down Expand Up @@ -364,3 +583,63 @@ def grpo_loss_fn(
denominator="clipped_tokens",
)
return loss

def dynamic_sampling_dense_reward(
data: Dict[str, Any], group_size: int
) -> Tuple[Dict[str, Any], Dict[str, int]]:
"""Filter samples by group when all rewards in a group are equal.

Assumes samples of the same group are adjacent in the batch.

Returns a new dict containing only kept samples (mask applied on batch dim
for all tensor values whose first dimension equals batch size), and a small
stats dict.
"""
rewards = data["rewards"]
if not torch.is_tensor(rewards):
raise TypeError("data['rewards'] must be a torch.Tensor")
batch_size = rewards.shape[0]

if group_size <= 0:
warnings.warn("group_size <= 0; returning original data")
return data, dict(n_group_kept=0, n_group_filtered=0)

if batch_size % group_size != 0:
warnings.warn(
"The group size is not divisible by the batch size. Return the original data"
)
return data, dict(
n_group_kept=batch_size // max(group_size, 1), n_group_filtered=0
)

# Calculate number of groups (must be divisible)
num_groups = batch_size // group_size

# Reshape rewards to (num_groups, group_size) for group-wise operations
rewards_reshaped = rewards.view(num_groups, group_size * rewards.shape[1])

# Check if all elements in each group are equal to the first element
all_equal = (rewards_reshaped == rewards_reshaped[:, 0:1]).all(dim=1)

# Create mask for groups to keep (where not all rewards are equal)
valid_groups = ~all_equal

# Expand the group mask to individual samples
mask = valid_groups.repeat_interleave(group_size)

# In case all group is filtered out, return the original data (although not gradient in this case)
if not mask.any():
return data, dict(n_group_kept=0, n_group_filtered=num_groups)

n_group_kept = int(valid_groups.sum().item())
n_group_filtered = int(num_groups - n_group_kept)

# Apply mask row-wise across tensors that share the same batch dimension
filtered: Dict[str, Any] = {}
for k, v in data.items():
if torch.is_tensor(v) and v.shape[:1] == (batch_size,):
filtered[k] = v[mask]
else:
# keep untouched (e.g., scalars, metadata); caller should ensure consistency
filtered[k] = v
return filtered, dict(n_group_kept=n_group_kept, n_group_filtered=n_group_filtered)
2 changes: 1 addition & 1 deletion areal/utils/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

logger = logging.getLogger("Launcher Utils")

LOCAL_CACHE_DIR = "/tmp/areal"
LOCAL_CACHE_DIR = "/data/yl/AReaL/tmp/areal"
PYTORCH_KERNEL_CACHE_PATH = (
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels/"
)
Expand Down
10 changes: 5 additions & 5 deletions areal/utils/stats_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
class StatsLogger:

def __init__(self, config: BaseExperimentConfig, ft_spec: FinetuneSpec):
if not isinstance(config, StatsLoggerConfig):
raise ValueError(
"Passing config.stats_logger as the config is deprecated. "
"Please pass the full config instead."
)
# if not isinstance(config, StatsLoggerConfig):
# raise ValueError(
# "Passing config.stats_logger as the config is deprecated. "
# "Please pass the full config instead."
# )
self.exp_config = config
self.config = config.stats_logger
self.ft_spec = ft_spec
Expand Down
Loading