From 74c7355611fdc7917d6dc83ed4178c70b3f7d53a Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Wed, 11 Feb 2026 23:50:01 +0530 Subject: [PATCH] fix: improve numerical stability of log-odds in ORPO loss --- applications/ColossalChat/coati/models/loss.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/applications/ColossalChat/coati/models/loss.py b/applications/ColossalChat/coati/models/loss.py index 927dfd5a89b6..04cb6ef24f95 100755 --- a/applications/ColossalChat/coati/models/loss.py +++ b/applications/ColossalChat/coati/models/loss.py @@ -207,9 +207,14 @@ def forward( ) -> torch.Tensor: chosen_logp = chosen_logp.to(dtype=torch.float32) reject_logp = reject_logp.to(dtype=torch.float32) - chosen_odds = chosen_logp - torch.log(-torch.exp(chosen_logp) + 1.0001) + # Clamp log-probabilities to avoid numerical instability in log-odds computation. + # log_odds = log(p) - log(1 - p) = logp - log1p(-exp(logp)) + eps = 1e-6 + chosen_logp = torch.clamp(chosen_logp, max=-eps) + reject_logp = torch.clamp(reject_logp, max=-eps) + chosen_odds = chosen_logp - torch.log1p(-torch.exp(chosen_logp)) chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask) - reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001) + reject_odds = reject_logp - torch.log1p(-torch.exp(reject_logp)) reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask) log_odds_ratio = chosen_odds_masked - reject_odds_masked ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio))