File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
2121 }
2222
2323#pragma unroll
24- for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
24+ for (int offset = WARP_SIZE/ 2 ; offset > 0 ; offset >>= 1 ) {
2525 const float val = __shfl_xor_sync (0xFFFFFFFF , maxval, offset, WARP_SIZE);
2626 const int col = __shfl_xor_sync (0xFFFFFFFF , argmax, offset, WARP_SIZE);
2727 if (val > maxval) {
@@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
5050 argmax = shared_argmax[lane_id];
5151 }
5252#pragma unroll
53- for (int offset = 16 ; offset > 0 ; offset >>= 1 ) {
53+ for (int offset = WARP_SIZE/ 2 ; offset > 0 ; offset >>= 1 ) {
5454 const float val = __shfl_xor_sync (0xFFFFFFFF , maxval, offset, WARP_SIZE);
5555 const int col = __shfl_xor_sync (0xFFFFFFFF , argmax, offset, WARP_SIZE);
5656 if (val > maxval) {
You can’t perform that action at this time.
0 commit comments