From 03fdb9fde0fe93514b56ce09483f6cfe13e36517 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 5 Jan 2026 11:42:53 -0800 Subject: [PATCH 1/2] Fix long compile time in padding.cu Signed-off-by: Jeremy Berchtold --- transformer_engine/common/util/padding.cu | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index 1859d8a5cb..ebf44c80ae 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -94,7 +94,6 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP #pragma unroll for (int i2 = 0; i2 < nvec; ++i2) { const int row = tile_row + i1 * nvec + i2; - size_t row_offset = static_cast(row) * row_length; const int col = tile_col + j1 * nvec; Vec local_input; Vec local_output; @@ -102,7 +101,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP if (row < num_rows) { for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - local_input.data.elt[j2] = input[row_offset + col + j2]; + local_input.data.elt[j2] = input[static_cast(row)*row_length + col + j2]; } } } @@ -113,14 +112,14 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP if (row < num_rows) { for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - output[row_offset + col + j2] = local_output.data.elt[j2]; + output[static_cast(row)*row_length + col + j2] = local_output.data.elt[j2]; } } } else if (row < padded_num_rows) { // padding for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - output[row_offset + col + j2] = local_zero; + output[static_cast(row)*row_length + col + j2] = local_zero; } } } @@ -179,7 +178,6 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult #pragma unroll for (int i2 = 0; i2 < nvec; ++i2) { const int row = tile_row + i1 * nvec + i2; - size_t row_offset = static_cast(row) * row_length; const int col = tile_col + j1 * nvec; Vec local_input; Vec local_output; @@ -187,7 +185,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult if (row < num_rows) { for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - local_input.data.elt[j2] = input[row_offset + col + j2]; + local_input.data.elt[j2] = input[static_cast(row)*row_length + col + j2]; } } } @@ -198,7 +196,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult if (row < num_rows) { for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - output[row_offset + col + j2] = local_output.data.elt[j2]; + output[static_cast(row)*row_length + col + j2] = local_output.data.elt[j2]; } } } From dc6452d3f41b0b283ef619910a33d1e228035947 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Jan 2026 19:53:38 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/util/padding.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index ebf44c80ae..8359238289 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -101,7 +101,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP if (row < num_rows) { for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - local_input.data.elt[j2] = input[static_cast(row)*row_length + col + j2]; + local_input.data.elt[j2] = input[static_cast(row) * row_length + col + j2]; } } } @@ -112,14 +112,14 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP if (row < num_rows) { for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - output[static_cast(row)*row_length + col + j2] = local_output.data.elt[j2]; + output[static_cast(row) * row_length + col + j2] = local_output.data.elt[j2]; } } } else if (row < padded_num_rows) { // padding for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - output[static_cast(row)*row_length + col + j2] = local_zero; + output[static_cast(row) * row_length + col + j2] = local_zero; } } } @@ -185,7 +185,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult if (row < num_rows) { for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - local_input.data.elt[j2] = input[static_cast(row)*row_length + col + j2]; + local_input.data.elt[j2] = input[static_cast(row) * row_length + col + j2]; } } } @@ -196,7 +196,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult if (row < num_rows) { for (int j2 = 0; j2 < nvec; ++j2) { if (col + j2 < row_length) { - output[static_cast(row)*row_length + col + j2] = local_output.data.elt[j2]; + output[static_cast(row) * row_length + col + j2] = local_output.data.elt[j2]; } } }