Skip to content

Commit 5806286

Browse files
authored
ggml : use WARP_SIZE/2 for argmax reduction offset (#18092)
1 parent 2973a65 commit 5806286

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ggml/src/ggml-cuda/argmax.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
2121
}
2222

2323
#pragma unroll
24-
for (int offset = 16; offset > 0; offset >>= 1) {
24+
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
2525
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
2626
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
2727
if (val > maxval) {
@@ -50,7 +50,7 @@ static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __rest
5050
argmax = shared_argmax[lane_id];
5151
}
5252
#pragma unroll
53-
for (int offset = 16; offset > 0; offset >>= 1) {
53+
for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
5454
const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
5555
const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
5656
if (val > maxval) {

0 commit comments

Comments
 (0)