@@ -94,14 +94,15 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
9494#pragma unroll
9595 for (int i2 = 0 ; i2 < nvec; ++i2) {
9696 const int row = tile_row + i1 * nvec + i2;
97+ size_t row_offset = static_cast <size_t >(row) * row_length;
9798 const int col = tile_col + j1 * nvec;
9899 Vec local_input;
99100 Vec local_output;
100101 local_input.clear ();
101102 if (row < num_rows) {
102103 for (int j2 = 0 ; j2 < nvec; ++j2) {
103104 if (col + j2 < row_length) {
104- local_input.data .elt [j2] = input[row * row_length + col + j2];
105+ local_input.data .elt [j2] = input[row_offset + col + j2];
105106 }
106107 }
107108 }
@@ -112,14 +113,14 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
112113 if (row < num_rows) {
113114 for (int j2 = 0 ; j2 < nvec; ++j2) {
114115 if (col + j2 < row_length) {
115- output[row * row_length + col + j2] = local_output.data .elt [j2];
116+ output[row_offset + col + j2] = local_output.data .elt [j2];
116117 }
117118 }
118119 } else if (row < padded_num_rows) {
119120 // padding
120121 for (int j2 = 0 ; j2 < nvec; ++j2) {
121122 if (col + j2 < row_length) {
122- output[row * row_length + col + j2] = local_zero;
123+ output[row_offset + col + j2] = local_zero;
123124 }
124125 }
125126 }
@@ -178,14 +179,15 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
178179#pragma unroll
179180 for (int i2 = 0 ; i2 < nvec; ++i2) {
180181 const int row = tile_row + i1 * nvec + i2;
182+ size_t row_offset = static_cast <size_t >(row) * row_length;
181183 const int col = tile_col + j1 * nvec;
182184 Vec local_input;
183185 Vec local_output;
184186 local_input.clear ();
185187 if (row < num_rows) {
186188 for (int j2 = 0 ; j2 < nvec; ++j2) {
187189 if (col + j2 < row_length) {
188- local_input.data .elt [j2] = input[row * row_length + col + j2];
190+ local_input.data .elt [j2] = input[row_offset + col + j2];
189191 }
190192 }
191193 }
@@ -196,7 +198,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
196198 if (row < num_rows) {
197199 for (int j2 = 0 ; j2 < nvec; ++j2) {
198200 if (col + j2 < row_length) {
199- output[row * row_length + col + j2] = local_output.data .elt [j2];
201+ output[row_offset + col + j2] = local_output.data .elt [j2];
200202 }
201203 }
202204 }
0 commit comments