From d49bb938030e90a13a7edfc32d155a5c265243b9 Mon Sep 17 00:00:00 2001 From: Tom Date: Thu, 29 Jan 2026 19:18:58 +0100 Subject: [PATCH] Make generated kernels thread-safe using std::atomic --- utils/codegen_tl1.py | 15 +++++++++++---- utils/codegen_tl2.py | 16 ++++++++++++---- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/utils/codegen_tl1.py b/utils/codegen_tl1.py index 4c2e7dd3f..a8876fffd 100644 --- a/utils/codegen_tl1.py +++ b/utils/codegen_tl1.py @@ -5,10 +5,11 @@ def gen_ctor_code(): kernel_code = "\n\ #include \"ggml-bitnet.h\"\n\ +#include \n\ #define GGML_BITNET_MAX_NODES 8192\n\ static bool initialized = false;\n\ static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\ -static size_t bitnet_tensor_extras_index = 0;\n\ +static std::atomic bitnet_tensor_extras_index{0};\n\ static void * aligned_malloc(size_t size) {{\n\ #if defined(_WIN32)\n\ return _aligned_malloc(size, 64);\n\ @@ -355,14 +356,20 @@ def gen_transform_code(kernel_shape): float * i2_scales = (float * )(qweights + k * m / 4);\n\ scales[0] = (bitnet_float_type) i2_scales[0];\n\ \n\ - tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\ - bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\ + size_t current_index = bitnet_tensor_extras_index.fetch_add(1);\n\ + if (current_index >= GGML_BITNET_MAX_NODES) {{\n\ + fprintf(stderr, \"ggml_bitnet_transform_tensor: GGML_BITNET_MAX_NODES reached (%d)\\n\", GGML_BITNET_MAX_NODES);\n\ + return;\n\ + }}\n\ +\n\ + tensor->extra = bitnet_tensor_extras + current_index;\n\ + bitnet_tensor_extras[current_index] = {{\n\ /* .lut_scales_size = */ lut_scales_size,\n\ /* .BK = */ BK,\n\ /* .n_tile_num = */ n_tile_num,\n\ /* .qweights = */ qweights,\n\ /* .scales = */ scales\n\ - };\n\ + }};\n\ }\n"]) return kernel_code diff --git a/utils/codegen_tl2.py b/utils/codegen_tl2.py index 4d9408123..2b9e3fa75 100644 --- a/utils/codegen_tl2.py +++ b/utils/codegen_tl2.py @@ -6,11 +6,13 @@ def gen_ctor_code(): kernel_code = "\n\ #include \"ggml-bitnet.h\"\n\ #include \n\ +#include \n\ +#include \n\ #include \n\ #define GGML_BITNET_MAX_NODES 8192\n\ static bool initialized = false;\n\ static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;\n\ -static size_t bitnet_tensor_extras_index = 0;\n\ +static std::atomic bitnet_tensor_extras_index{{0}};\n\ static void * aligned_malloc(size_t size) {\n\ #if defined(_WIN32)\n\ return _aligned_malloc(size, 64);\n\ @@ -661,14 +663,20 @@ def gen_transform_code(kernel_shapes): float * i2_scales = (float * )(qweights + nbytes);\n\ scales[0] = (bitnet_float_type) i2_scales[0];\n\ \n\ - tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;\n\ - bitnet_tensor_extras[bitnet_tensor_extras_index++] = {\n\ + size_t current_index = bitnet_tensor_extras_index.fetch_add(1);\n\ + if (current_index >= GGML_BITNET_MAX_NODES) {{\n\ + fprintf(stderr, \"ggml_bitnet_transform_tensor: GGML_BITNET_MAX_NODES reached (%d)\\n\", GGML_BITNET_MAX_NODES);\n\ + return;\n\ + }}\n\ +\n\ + tensor->extra = bitnet_tensor_extras + current_index;\n\ + bitnet_tensor_extras[current_index] = {{\n\ /* .lut_scales_size = */ lut_scales_size,\n\ /* .BK = */ BK,\n\ /* .n_tile_num = */ n_tile_num,\n\ /* .qweights = */ qweights,\n\ /* .scales = */ scales\n\ - };\n\ + }};\n\ }\n"]) return kernel_code