diff --git a/kernel_tuner/backends/pycuda.py b/kernel_tuner/backends/pycuda.py index 81d8de67..c8f3e689 100644 --- a/kernel_tuner/backends/pycuda.py +++ b/kernel_tuner/backends/pycuda.py @@ -180,6 +180,9 @@ def ready_argument_list(self, arguments): # pycuda does not support bool, convert to uint8 instead elif isinstance(arg, np.bool_): gpu_args.append(arg.astype(np.uint8)) + # pycuda does not support 16-bit formats, view them as uint16 + elif isinstance(arg, np.generic) and str(arg.dtype) in ("float16", "bfloat16"): + gpu_args.append(arg.view(np.uint16)) # if not an array, just pass argument along else: gpu_args.append(arg)