Skip to content

Commit 697b52c

Browse files
adamantboyksivaman
andauthored
Fix overflow of padding/unpadding kernel (#2548)
Signed-off-by: fuyue.lj <fuyue.lj@antgroup.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent 26c82db commit 697b52c

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

transformer_engine/common/util/padding.cu

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,15 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
9494
#pragma unroll
9595
for (int i2 = 0; i2 < nvec; ++i2) {
9696
const int row = tile_row + i1 * nvec + i2;
97+
size_t row_offset = static_cast<size_t>(row) * row_length;
9798
const int col = tile_col + j1 * nvec;
9899
Vec local_input;
99100
Vec local_output;
100101
local_input.clear();
101102
if (row < num_rows) {
102103
for (int j2 = 0; j2 < nvec; ++j2) {
103104
if (col + j2 < row_length) {
104-
local_input.data.elt[j2] = input[row * row_length + col + j2];
105+
local_input.data.elt[j2] = input[row_offset + col + j2];
105106
}
106107
}
107108
}
@@ -112,14 +113,14 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
112113
if (row < num_rows) {
113114
for (int j2 = 0; j2 < nvec; ++j2) {
114115
if (col + j2 < row_length) {
115-
output[row * row_length + col + j2] = local_output.data.elt[j2];
116+
output[row_offset + col + j2] = local_output.data.elt[j2];
116117
}
117118
}
118119
} else if (row < padded_num_rows) {
119120
// padding
120121
for (int j2 = 0; j2 < nvec; ++j2) {
121122
if (col + j2 < row_length) {
122-
output[row * row_length + col + j2] = local_zero;
123+
output[row_offset + col + j2] = local_zero;
123124
}
124125
}
125126
}
@@ -178,14 +179,15 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
178179
#pragma unroll
179180
for (int i2 = 0; i2 < nvec; ++i2) {
180181
const int row = tile_row + i1 * nvec + i2;
182+
size_t row_offset = static_cast<size_t>(row) * row_length;
181183
const int col = tile_col + j1 * nvec;
182184
Vec local_input;
183185
Vec local_output;
184186
local_input.clear();
185187
if (row < num_rows) {
186188
for (int j2 = 0; j2 < nvec; ++j2) {
187189
if (col + j2 < row_length) {
188-
local_input.data.elt[j2] = input[row * row_length + col + j2];
190+
local_input.data.elt[j2] = input[row_offset + col + j2];
189191
}
190192
}
191193
}
@@ -196,7 +198,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
196198
if (row < num_rows) {
197199
for (int j2 = 0; j2 < nvec; ++j2) {
198200
if (col + j2 < row_length) {
199-
output[row * row_length + col + j2] = local_output.data.elt[j2];
201+
output[row_offset + col + j2] = local_output.data.elt[j2];
200202
}
201203
}
202204
}

0 commit comments

Comments
 (0)