diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py index 927dfd5a89b6..04cb6ef24f95 100755 --- a/applications/ColossalChat/coati/models/loss.py +++ b/applications/ColossalChat/coati/models/loss.py @@ -207,9 +207,14 @@ def forward( ) -> torch.Tensor: chosen_logp = chosen_logp.to(dtype=torch.float32) reject_logp = reject_logp.to(dtype=torch.float32) - chosen_odds = chosen_logp - torch.log(-torch.exp(chosen_logp) + 1.0001) + # Clamp log-probabilities to avoid numerical instability in log-odds computation. + # log_odds = log(p) - log(1 - p) = logp - log1p(-exp(logp)) + eps = 1e-6 + chosen_logp = torch.clamp(chosen_logp, max=-eps) + reject_logp = torch.clamp(reject_logp, max=-eps) + chosen_odds = chosen_logp - torch.log1p(-torch.exp(chosen_logp)) chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask) - reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001) + reject_odds = reject_logp - torch.log1p(-torch.exp(reject_logp)) reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask) log_odds_ratio = chosen_odds_masked - reject_odds_masked ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio))