Skip to content

Commit eb8498c

Browse files
Update transformer_engine/pytorch/triton/permutation.py
Co-authored-by: Teddy Do <tdophung@nvidia.com> Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
1 parent a1f0662 commit eb8498c

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

transformer_engine/pytorch/triton/permutation.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,16 @@ def permute_with_mask_map(
454454
(num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda"
455455
)
456456
else:
457-
permuted_scale = torch.empty(
457+
alloc = torch.zeros if pad_offsets is not None else torch.empty
458+
output = alloc((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda")
459+
permuted_probs = (
460+
alloc((num_out_tokens,), dtype=probs.dtype, device="cuda") if probs is not None else None
461+
)
462+
permuted_scale = (
463+
alloc((num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda")
464+
if scale is not None
465+
else None
466+
)
458467
(num_out_tokens, scale_hidden_dim), dtype=scale.dtype, device="cuda"
459468
)
460469
else:

0 commit comments

Comments
 (0)