Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fast_llm/functional/entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _fused_reverse_kl_base(
# Compute loss terms: student_probs * log_ratio, then sum over vocab
# This is equivalent to kl_div(..., log_target=True) but more memory efficient
log_ratio = predicted_log_probability - target_log_probability
per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1)
per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1, keepdim=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how this changes anything, keepdim is the same as unsqueeze?

Copy link
Contributor Author

@oleksost oleksost Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is with the reported loss, that grads are ok.

When we later multiply the loss by by the mask here, per_sample_loss = per_sample_loss * loss_mask.unsqueeze(-1), per_sample_loss is [N] and mask is [N,1], so when multiplying those torch actually seem to left-pad the missing dims with 1, so it would be the same as if we did per_sample_loss = per_sample_loss.unsqueeze(0) * loss_mask.unsqueeze(-1) which results in N x N tensor.

Then taking the average of that is much larger than taking the average of per_sample_loss * loss_mask.

if group is not None:
all_reduce(per_sample_loss, op=ReduceOp.SUM, group=group)

Expand All @@ -130,7 +130,7 @@ def _fused_reverse_kl_base(
else:
# Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)])
# where E_q[log(q/p)] is the expected log ratio under the student distribution
grad = (log_ratio - per_sample_loss.unsqueeze(-1)) * predicted_probability * grad_output
grad = (log_ratio - per_sample_loss) * predicted_probability * grad_output

return per_sample_loss, grad

Expand Down
3 changes: 1 addition & 2 deletions tests/functional/test_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,13 @@ def test_entropy_loss(num_columns, grad_output, logits_scale_factor, loss_maskin
out_torch, grad_torch = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.torch)
out_fused, grad_fused = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.fused)

# TODO: Why is the error so high with loss masking for reverse KL?
_compare_entropy_loss_outputs(
out_fused,
out_torch,
grad_output is not None,
grad_fused,
grad_torch,
loss_min_threshold=2e-4 if entropy_loss_type == EntropyLossType.reverse_kl and loss_masking else 5e-6,
loss_min_threshold=5e-6,
)

if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available():
Expand Down
Loading