Skip to content

Commit 36f2dfd

Browse files
fix ce loss calculation when some tokens are ignored (#2476)
* fix ce loss with ignore idx Signed-off-by: ykarnati <ykarnati@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: ykarnati <ykarnati@nvidia.com> * remove fix comments Signed-off-by: ykarnati <ykarnati@nvidia.com> * fallback divisor to 1 Signed-off-by: ykarnati <ykarnati@nvidia.com> * have arg for n_rows and n_non_ignore Signed-off-by: ykarnati <ykarnati@nvidia.com> * fuse n_non_ignore to softmax kernel Signed-off-by: ykarnati <ykarnati@nvidia.com> * fix incorrect arg Signed-off-by: ykarnati <ykarnati@nvidia.com> --------- Signed-off-by: ykarnati <ykarnati@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8c9f7c2 commit 36f2dfd

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

tests/pytorch/test_parallel_cross_entropy.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def one_iteration_test(
8989
# Check that loss and grad input match
9090
tols = dtype_tols(dtype)
9191
test_loss = test_loss.to(dtype=torch.float64, device="cpu")
92-
ref_loss = test_loss.to(dtype=torch.float64, device="cpu")
92+
ref_loss = ref_loss.to(dtype=torch.float64, device="cpu")
9393
ref_loss = ref_loss.reshape(test_loss.size())
9494
test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu")
9595
ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu")
@@ -154,3 +154,16 @@ def test_ignore_idx(self):
154154
reduce_loss=False,
155155
ignore_idx=True,
156156
)
157+
158+
def test_ignore_idx_reduced_loss(self):
159+
"""Test ignore_idx with reduce_loss=True"""
160+
self.generate_iters(5)
161+
self.generate_infra(True, 0) # reduce_loss=True
162+
for i in range(self.iters):
163+
self.one_iteration_test(
164+
dtype=torch.float32,
165+
swap_dim=random.choice([True, False]),
166+
label_smoothing=0,
167+
reduce_loss=True,
168+
ignore_idx=True,
169+
)

transformer_engine/common/triton/cross_entropy.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def online_softmax_kernel(
1818
m_d_X_y_stride,
1919
rank,
2020
n_cols,
21+
ignore_idx,
22+
n_non_ignore,
2123
BLOCK_SIZE: tl.constexpr,
2224
):
2325
"""
@@ -32,6 +34,8 @@ def online_softmax_kernel(
3234
m_d_X_y_stride (int): The stride of the m/d/X_y tensor.
3335
rank (int): The rank of this device in the TP group.
3436
n_cols (int): The number of columns in the input tensor.
37+
ignore_idx (int): The index to ignore for loss calculation.
38+
n_non_ignore: The number of non-ignored elements in the batch.
3539
BLOCK_SIZE (int): The block size for Triton operations.
3640
"""
3741

@@ -44,6 +48,9 @@ def online_softmax_kernel(
4448
Y_ptr += program_id * Y_stride
4549
y = tl.load(Y_ptr)
4650

51+
if y != ignore_idx:
52+
tl.atomic_add(n_non_ignore, 1)
53+
4754
vocab_start_idx = rank * n_cols
4855
vocab_end_idx = (rank + 1) * n_cols
4956
if y >= vocab_start_idx:
@@ -89,6 +96,7 @@ def cross_entropy_kernel(
8996
world_size,
9097
ignore_idx,
9198
n_cols,
99+
n_rows,
92100
n_non_ignore,
93101
reduce_loss: tl.constexpr,
94102
label_smoothing: tl.constexpr,
@@ -110,12 +118,14 @@ def cross_entropy_kernel(
110118
world_size (int): The size of world involved in this distributed loss calculation.
111119
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
112120
n_cols (int): The number of columns in the input tensor.
113-
n_non_ignore (int): The number of non-ignored elements in the batch.
121+
n_rows (int): The number of rows in the batch (B * SQ), used for buffer indexing.
122+
n_non_ignore: The number of non-ignored elements in the batch.
114123
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
115124
BLOCK_SIZE (int): The block size for Triton operations.
116125
"""
117126

118127
program_id = tl.program_id(0).to(tl.int64)
128+
n_non_ignore = tl.load(n_non_ignore)
119129

120130
# locate the start index
121131
X_ptr += program_id * X_stride
@@ -140,7 +150,7 @@ def cross_entropy_kernel(
140150
ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride))
141151

142152
for i in range(1, world_size):
143-
offset = i * 3 * n_non_ignore * m_d_X_y_stride
153+
offset = i * 3 * n_rows * m_d_X_y_stride
144154
access_ptr = m_d_X_y_ptr + offset
145155
m_new = tl.load(access_ptr)
146156
d_new = tl.load(access_ptr + m_d_X_y_stride)

transformer_engine/pytorch/triton/cross_entropy.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def cross_entropy_forward(
4646
# tensor to hold this rank's m/d/X_y values
4747
m_d_X_y = torch.zeros(n_rows * 3, dtype=torch.float32, device=_input.device)
4848

49+
n_non_ignore = torch.zeros(1, dtype=torch.int64, device=_input.device)
50+
4951
# ensure _input and target are contiguous in the last dimension
5052
if _input.stride(-1) != 1:
5153
_input = _input.contiguous()
@@ -63,10 +65,14 @@ def cross_entropy_forward(
6365
m_d_X_y_stride=m_d_X_y.stride(-1),
6466
rank=rank,
6567
n_cols=V,
68+
ignore_idx=ignore_idx,
69+
n_non_ignore=n_non_ignore,
6670
BLOCK_SIZE=BLOCK_SIZE,
6771
num_warps=32,
6872
)
6973

74+
n_non_ignore = torch.clamp(n_non_ignore, min=1)
75+
7076
world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group)
7177

7278
if world_size > 1:
@@ -90,14 +96,17 @@ def cross_entropy_forward(
9096
world_size=world_size,
9197
ignore_idx=ignore_idx,
9298
n_cols=V,
93-
n_non_ignore=n_rows,
99+
n_rows=n_rows,
100+
n_non_ignore=n_non_ignore,
94101
reduce_loss=reduce_loss,
95102
label_smoothing=label_smoothing,
96103
BLOCK_SIZE=BLOCK_SIZE,
97104
num_warps=32,
98105
)
99106

100-
loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows)
107+
loss = (
108+
torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_non_ignore)
109+
)
101110

102111
return loss, _input
103112

0 commit comments

Comments
 (0)