diff --git a/kernel_tuner/core.py b/kernel_tuner/core.py index 04c141be..3086983e 100644 --- a/kernel_tuner/core.py +++ b/kernel_tuner/core.py @@ -815,9 +815,9 @@ def _default_verify_function(instance, answer, result_host, atol, verbose): if isinstance(answer[i], (np.ndarray, cp.ndarray)) and isinstance( arg, (np.ndarray, cp.ndarray) ): - if answer[i].dtype != arg.dtype: + if not np.can_cast(arg.dtype, answer[i].dtype): raise TypeError( - f"Element {i} of the expected results list is not of the same dtype as the kernel output: " + f"Element {i} of the expected results list has a dtype that is not compatible with the dtype of the kernel output: " + str(answer[i].dtype) + " != " + str(arg.dtype)