diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e30f6e26..57aaf27b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ All notable changes to this project will be documented in this file. This project adheres to [Semantic Versioning](http://semver.org/). ## Unreleased +- changed HIP python bindings from pyhip-interface to the official hip-python ## [1.0.0] - 2024-04-04 - HIP backend to support tuning HIP kernels on AMD GPUs diff --git a/INSTALL.rst b/INSTALL.rst index 985b48c07..4575f15fa 100644 --- a/INSTALL.rst +++ b/INSTALL.rst @@ -124,31 +124,26 @@ Or you could install Kernel Tuner and PyOpenCL together if you haven't done so a If this fails, please see the PyOpenCL installation guide (https://wiki.tiker.net/PyOpenCL/Installation) -HIP and PyHIP +HIP and HIP Python ------------- -Before we can install PyHIP, you'll need to have the HIP runtime and compiler installed on your system. +Before we can install HIP Python, you'll need to have the HIP runtime and compiler installed on your system. The HIP compiler is included as part of the ROCm software stack. Here is AMD's installation guide: * `ROCm Documentation: HIP Installation Guide `__ -After you've installed HIP, you will need to install PyHIP. Run the following command in your terminal to install: +After you've installed HIP, you will need to install HIP Python. Run the following command in your terminal to install: -.. code-block:: bash - - pip install pyhip-interface +First identify the first three digits of the version number of your ROCmâ„¢ installation. +Then install the HIP Python package(s) as follows: -Alternatively, you can install PyHIP from the source code. First, clone the repository from GitHub: - -.. code-block:: bash +.. code-block:: shell - git clone https://github.com/jatinx/PyHIP - -Then, navigate to the repository directory and run the following command to install: - -.. code-block:: bash + python3 -m pip install -i https://test.pypi.org/simple hip-python~=$rocm_version + # if you want to install the CUDA Python interoperability package too, run: + python3 -m pip install -i https://test.pypi.org/simple hip-python-as-cuda~=$rocm_version - python setup.py install +For other installation options check `hip-python on GitHub `_ Installing the git version -------------------------- @@ -171,7 +166,7 @@ The runtime dependencies are: - `cuda`: install pycuda along with kernel_tuner - `opencl`: install pycuda along with kernel_tuner -- `hip`: install pyhip along with kernel_tuner +- `hip`: install HIP Python along with kernel_tuner - `tutorial`: install packages required to run the guides These can be installed by appending e.g. ``-E cuda -E opencl -E hip``. diff --git a/README.md b/README.md index 2c5f8f02b..557bae395 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,7 @@ What Kernel Tuner does: ## Installation -- First, make sure you have your [CUDA](https://kerneltuner.github.io/kernel_tuner/stable/install.html#cuda-and-pycuda), [OpenCL](https://kerneltuner.github.io/kernel_tuner/stable/install.html#opencl-and-pyopencl), or [HIP](https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-pyhipl) compiler installed +- First, make sure you have your [CUDA](https://kerneltuner.github.io/kernel_tuner/stable/install.html#cuda-and-pycuda), [OpenCL](https://kerneltuner.github.io/kernel_tuner/stable/install.html#opencl-and-pyopencl), or [HIP](https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-hip-python) compiler installed - Then type: `pip install kernel_tuner[cuda]`, `pip install kernel_tuner[opencl]`, or `pip install kernel_tuner[hip]` - or why not all of them: `pip install kernel_tuner[cuda,opencl,hip]` diff --git a/doc/source/backends.rst b/doc/source/backends.rst index 551da1fed..d132a9014 100644 --- a/doc/source/backends.rst +++ b/doc/source/backends.rst @@ -58,7 +58,7 @@ used to compile the kernels. :header: Feature, PyCUDA, CuPy, CUDA-Python, HIP :widths: auto - Python package, "pycuda", "cupy", "cuda-python", "pyhip-interface" + Python package, "pycuda", "cupy", "cuda-python", "hip-python" Selected with lang=, "CUDA", "CUPY", "NVCUDA", "HIP" Compiler used, "nvcc", "nvrtc", "nvrtc", "hiprtc" diff --git a/doc/source/design.rst b/doc/source/design.rst index a4078a2e1..aaa640467 100644 --- a/doc/source/design.rst +++ b/doc/source/design.rst @@ -49,7 +49,7 @@ building blocks for implementing runners. The observers are explained in :ref:`observers`. At the bottom, the backends are shown. -PyCUDA, CuPy, cuda-python, PyOpenCL and PyHIP are for tuning either CUDA, OpenCL, or HIP kernels. +PyCUDA, CuPy, cuda-python, PyOpenCL and HIP Python are for tuning either CUDA, OpenCL, or HIP kernels. The CompilerFunctions implementation can call any compiler, typically NVCC or GCC is used. There is limited support for tuning Fortran kernels. This backend was created not just to be able to tune C diff --git a/examples/hip/test_vector_add.py b/examples/hip/test_vector_add.py index 6c342632e..2988b1ae9 100644 --- a/examples/hip/test_vector_add.py +++ b/examples/hip/test_vector_add.py @@ -5,11 +5,11 @@ from kernel_tuner import run_kernel import pytest -#Check pyhip is installed and if a HIP capable device is present, if not skip the test +#Check hip is installed and if a HIP capable device is present, if not skip the test try: - from pyhip import hip, hiprtc + from hip import hip, hiprtc except ImportError: - pytest.skip("PyHIP not installed or PYTHONPATH does not includes PyHIP") + pytest.skip("HIP Python not installed or PYTHONPATH does not includes HIP Python") hip = None hiprtc = None diff --git a/examples/hip/vector_add.py b/examples/hip/vector_add.py index 7e2810711..19fde6b6d 100644 --- a/examples/hip/vector_add.py +++ b/examples/hip/vector_add.py @@ -30,8 +30,8 @@ def tune(): tune_params = OrderedDict() tune_params["block_size_x"] = [128+64*i for i in range(15)] - results, env = tune_kernel("vector_add", kernel_string, size, args, tune_params, lang="HIP", - cache="vector_add_cache.json", log=logging.DEBUG) + results, env = tune_kernel("vector_add", kernel_string, size, args, tune_params, lang="HIP", + log=logging.DEBUG) # Store the metadata of this run store_metadata_file("vector_add-metadata.json") @@ -40,4 +40,4 @@ def tune(): if __name__ == "__main__": - tune() \ No newline at end of file + tune() diff --git a/kernel_tuner/backends/compiler.py b/kernel_tuner/backends/compiler.py index 154f501ba..730710489 100644 --- a/kernel_tuner/backends/compiler.py +++ b/kernel_tuner/backends/compiler.py @@ -26,6 +26,17 @@ except ImportError: cp = None +try: + from hip import hip +except ImportError: + hip = None + +try: + from hip._util.types import DeviceArray +except ImportError: + Pointer = Exception # using Exception here as a type that will never be among kernel arguments + DeviceArray = Exception + def is_cupy_array(array): """Check if something is a cupy array. @@ -145,9 +156,9 @@ def ready_argument_list(self, arguments): ctype_args = [None for _ in arguments] for i, arg in enumerate(arguments): - if not (isinstance(arg, (np.ndarray, np.number)) or is_cupy_array(arg)): - raise TypeError(f"Argument is not numpy or cupy ndarray or numpy scalar but a {type(arg)}") - dtype_str = str(arg.dtype) + if not (isinstance(arg, (np.ndarray, np.number, DeviceArray)) or is_cupy_array(arg)): + raise TypeError(f"Argument is not numpy or cupy ndarray or numpy scalar or HIP Python DeviceArray but a {type(arg)}") + dtype_str = arg.typestr if isinstance(arg, DeviceArray) else str(arg.dtype) if isinstance(arg, np.ndarray): if dtype_str in dtype_map.keys(): # In numpy <= 1.15, ndarray.ctypes.data_as does not itself keep a reference @@ -156,13 +167,20 @@ def ready_argument_list(self, arguments): # (This changed in numpy > 1.15.) # data_ctypes = data.ctypes.data_as(C.POINTER(dtype_map[dtype_str])) data_ctypes = arg.ctypes.data_as(C.POINTER(dtype_map[dtype_str])) + numpy_arg = arg else: raise TypeError("unknown dtype for ndarray") elif isinstance(arg, np.generic): data_ctypes = dtype_map[dtype_str](arg) + numpy_arg = arg elif is_cupy_array(arg): data_ctypes = C.c_void_p(arg.data.ptr) - ctype_args[i] = Argument(numpy=arg, ctypes=data_ctypes) + numpy_arg = arg + elif isinstance(arg, DeviceArray): + data_ctypes = arg.as_c_void_p() + numpy_arg = None + + ctype_args[i] = Argument(numpy=numpy_arg, ctypes=data_ctypes) return ctype_args def compile(self, kernel_instance): @@ -380,6 +398,12 @@ def memcpy_dtoh(self, dest, src): :param src: An Argument for some memory allocation :type src: Argument """ + # If src.numpy is None, it means we're dealing with a HIP Python DeviceArray + if src.numpy is None: + # Skip memory copies for HIP Python DeviceArray + # This is because DeviceArray manages its own memory and donesn't need + # explicit copies like numpy arrays do + return if isinstance(dest, np.ndarray) and is_cupy_array(src.numpy): # Implicit conversion to a NumPy array is not allowed. value = src.numpy.get() @@ -397,6 +421,12 @@ def memcpy_htod(self, dest, src): :param src: A numpy or cupy array containing the source data :type src: np.ndarray or cupy.ndarray """ + # If src.numpy is None, it means we're dealing with a HIP Python DeviceArray + if dest.numpy is None: + # Skip memory copies for HIP Python DeviceArray + # This is because DeviceArray manages its own memory and donesn't need + # explicit copies like numpy arrays do + return if isinstance(dest.numpy, np.ndarray) and is_cupy_array(src): # Implicit conversion to a NumPy array is not allowed. value = src.get() diff --git a/kernel_tuner/backends/hip.py b/kernel_tuner/backends/hip.py index 7d3adb90a..1a0b7ce73 100644 --- a/kernel_tuner/backends/hip.py +++ b/kernel_tuner/backends/hip.py @@ -10,7 +10,7 @@ from kernel_tuner.observers.hip import HipRuntimeObserver try: - from pyhip import hip, hiprtc + from hip import hip, hiprtc except (ImportError, RuntimeError): hip = None hiprtc = None @@ -32,6 +32,18 @@ hipSuccess = 0 + +def hip_check(call_result): + """helper function to check return values of hip calls""" + err = call_result[0] + result = call_result[1:] + if len(result) == 1: + result = result[0] + if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess: + raise RuntimeError(str(err)) + return result + + class HipFunctions(GPUBackend): """Class that groups the HIP functions on maintains state about the device.""" @@ -50,16 +62,20 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None :type iterations: int """ if not hip or not hiprtc: - raise ImportError("Unable to import PyHIP, make sure PYTHONPATH includes PyHIP, or check https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-pyhip.") + raise ImportError( + "Unable to import HIP Python, check https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-hip-python." + ) # embedded in try block to be able to generate documentation - # and run tests without pyhip installed + # and run tests without HIP Python installed logging.debug("HipFunction instantiated") - self.hipProps = hip.hipGetDeviceProperties(device) + # Get device properties + props = hip.hipDeviceProp_t() + hip_check(hip.hipGetDeviceProperties(props, device)) - self.name = self.hipProps._name.decode('utf-8') - self.max_threads = self.hipProps.maxThreadsPerBlock + self.name = props.name.decode("utf-8") + self.max_threads = props.maxThreadsPerBlock self.device = device self.compiler_options = compiler_options or [] self.iterations = iterations @@ -70,59 +86,60 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None env["compiler_options"] = compiler_options self.env = env - # create a stream and events - self.stream = hip.hipStreamCreate() - self.start = hip.hipEventCreate() - self.end = hip.hipEventCreate() + # Create stream and events + self.stream = hip_check(hip.hipStreamCreate()) + self.start = hip_check(hip.hipEventCreate()) + self.end = hip_check(hip.hipEventCreate()) - # default dynamically allocated shared memory size, can be overwritten using smem_args + # Default dynamically allocated shared memory size self.smem_size = 0 self.current_module = None - # setup observers + # Setup observers self.observers = observers or [] self.observers.append(HipRuntimeObserver(self)) for obs in self.observers: obs.register_device(self) - def ready_argument_list(self, arguments): """Ready argument list to be passed to the HIP function. - :param arguments: List of arguments to be passed to the HIP function. The order should match the argument list on the HIP function. Allowed values are np.ndarray, and/or np.int32, np.float32, and so on. :type arguments: list(numpy objects) - - :returns: Ctypes structure of arguments to be passed to the HIP function. - :rtype: ctypes structure + :returns: List of arguments to be passed to the HIP function. + :rtype: list """ logging.debug("HipFunction ready_argument_list called") + prepared_args = [] - ctype_args = [] - data_ctypes = None for arg in arguments: dtype_str = str(arg.dtype) - # Allocate space on device for array and convert to ctypes + + # Handle numpy arrays if isinstance(arg, np.ndarray): if dtype_str in dtype_map.keys(): - device_ptr = hip.hipMalloc(arg.nbytes) - data_ctypes = arg.ctypes.data_as(ctypes.POINTER(dtype_map[dtype_str])) - hip.hipMemcpy_htod(device_ptr, data_ctypes, arg.nbytes) - # may be part of run_kernel, return allocations here instead - ctype_args.append(device_ptr) + # Allocate device memory + device_ptr = hip_check(hip.hipMalloc(arg.nbytes)) + + # Copy data to device using hipMemcpy + hip_check(hip.hipMemcpy(device_ptr, arg, arg.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) + + prepared_args.append(device_ptr) else: - raise TypeError("unknown dtype for ndarray") - # Convert valid non-array arguments to ctypes + raise TypeError(f"Unknown dtype {dtype_str} for ndarray") + + # Handle numpy scalar types elif isinstance(arg, np.generic): - data_ctypes = dtype_map[dtype_str](arg) - ctype_args.append(data_ctypes) + # Convert numpy scalar to corresponding ctypes + ctype_arg = dtype_map[dtype_str](arg) + prepared_args.append(ctype_arg) + else: raise ValueError(f"Invalid argument type {type(arg)}, {arg}") - return ctype_args - + return prepared_args def compile(self, kernel_instance): """Call the HIP compiler to compile the kernel, return the function. @@ -131,39 +148,52 @@ def compile(self, kernel_instance): in the parameter space. :type kernel_instance: kernel_tuner.core.KernelInstance - :returns: An ctypes function that can be called directly. - :rtype: ctypes._FuncPtr + :returns: A HIP kernel function that can be called. + :rtype: hipFunction_t """ logging.debug("HipFunction compile called") - #Format and create program + # Format kernel string kernel_string = kernel_instance.kernel_string kernel_name = kernel_instance.name if 'extern "C"' not in kernel_string: kernel_string = 'extern "C" {\n' + kernel_string + "\n}" - kernel_ptr = hiprtc.hiprtcCreateProgram(kernel_string, kernel_name, [], []) - try: - #Compile based on device (Not yet tested for non-AMD devices) - plat = hip.hipGetPlatformName() - if plat == "amd": - options_list = [f'--offload-arch={self.hipProps.gcnArchName}'] - options_list.extend(self.compiler_options) - hiprtc.hiprtcCompileProgram(kernel_ptr, options_list) - else: - options_list = [] - options_list.extend(self.compiler_options) - hiprtc.hiprtcCompileProgram(kernel_ptr, options_list) + # Create program + prog = hip_check(hiprtc.hiprtcCreateProgram(kernel_string.encode(), kernel_name.encode(), 0, [], [])) - #Get module and kernel from compiled kernel string - code = hiprtc.hiprtcGetCode(kernel_ptr) - module = hip.hipModuleLoadData(code) + try: + # Get device properties + props = hip.hipDeviceProp_t() + hip_check(hip.hipGetDeviceProperties(props, 0)) + + # Setup compilation options + arch = props.gcnArchName + cflags = [b"--offload-arch=" + arch] + cflags.extend([opt.encode() if isinstance(opt, str) else opt for opt in self.compiler_options]) + + # Compile program + (err,) = hiprtc.hiprtcCompileProgram(prog, len(cflags), cflags) + if err != hiprtc.hiprtcResult.HIPRTC_SUCCESS: + # Get compilation log if there's an error + log_size = hip_check(hiprtc.hiprtcGetProgramLogSize(prog)) + log = bytearray(log_size) + hip_check(hiprtc.hiprtcGetProgramLog(prog, log)) + raise RuntimeError(log.decode()) + + # Get compiled code + code_size = hip_check(hiprtc.hiprtcGetCodeSize(prog)) + code = bytearray(code_size) + hip_check(hiprtc.hiprtcGetCode(prog, code)) + + # Load module and get function + module = hip_check(hip.hipModuleLoadData(code)) self.current_module = module - kernel = hip.hipModuleGetFunction(module, kernel_name) + kernel = hip_check(hip.hipModuleGetFunction(module, kernel_name.encode())) except Exception as e: - log = hiprtc.hiprtcGetProgramLog(kernel_ptr) - print(log) + # Cleanup + hip_check(hiprtc.hiprtcDestroyProgram(prog.createRef())) raise e return kernel @@ -172,37 +202,42 @@ def start_event(self): """Records the event that marks the start of a measurement.""" logging.debug("HipFunction start_event called") - hip.hipEventRecord(self.start, self.stream) + hip_check(hip.hipEventRecord(self.start, self.stream)) def stop_event(self): """Records the event that marks the end of a measurement.""" logging.debug("HipFunction stop_event called") - hip.hipEventRecord(self.end, self.stream) + hip_check(hip.hipEventRecord(self.end, self.stream)) def kernel_finished(self): """Returns True if the kernel has finished, False otherwise.""" logging.debug("HipFunction kernel_finished called") - # Query the status of the event - return hip.hipEventQuery(self.end) + # ROCm HIP returns (hipError_t, bool) for hipEventQuery + status = hip.hipEventQuery(self.end) + if status[0] == hip.hipError_t.hipSuccess: + return True + elif status[0] == hip.hipError_t.hipErrorNotReady: + return False + else: + hip_check(status) def synchronize(self): """Halts execution until device has finished its tasks.""" logging.debug("HipFunction synchronize called") - hip.hipDeviceSynchronize() + hip_check(hip.hipDeviceSynchronize()) def run_kernel(self, func, gpu_args, threads, grid, stream=None): """Runs the HIP kernel passed as 'func'. :param func: A HIP kernel compiled for this specific kernel configuration - :type func: ctypes pionter + :type func: hipFunction_t - :param gpu_args: A ctypes structure of arguments to the kernel, order should match the - order in the code. Allowed values are either variables in global memory - or single values passed by value. - :type gpu_args: ctypes structure + :param gpu_args: List of arguments to pass to the kernel. Can be DeviceArray + objects or ctypes values + :type gpu_args: list :param threads: A tuple listing the number of threads in each dimension of the thread block @@ -217,40 +252,38 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None): if stream is None: stream = self.stream - # Determine the types of the fields in the structure - field_types = [type(x) for x in gpu_args] - - # Define a new ctypes structure with the inferred layout - class ArgListStructure(ctypes.Structure): - _fields_ = [(f'field{i}', t) for i, t in enumerate(field_types)] - def __getitem__(self, key): - return getattr(self, self._fields_[key][0]) - - ctype_args = ArgListStructure(*gpu_args) - - hip.hipModuleLaunchKernel(func, - grid[0], grid[1], grid[2], - threads[0], threads[1], threads[2], - self.smem_size, - stream, - ctype_args) + # Create dim3 objects for grid and block dimensions + grid_dim = hip.dim3(x=grid[0], y=grid[1], z=grid[2]) + block_dim = hip.dim3(x=threads[0], y=threads[1], z=threads[2]) + + # Launch kernel with the arguments + hip_check( + hip.hipModuleLaunchKernel( + func, + *grid_dim, + *block_dim, + sharedMemBytes=self.smem_size, + stream=stream, + kernelParams=None, + extra=tuple(gpu_args), + ) + ) def memset(self, allocation, value, size): """Set the memory in allocation to the value in value. - :param allocation: A GPU memory allocation unit - :type allocation: ctypes ptr + :param allocation: A GPU memory allocation (DeviceArray) + :type allocation: DeviceArray or int :param value: The value to set the memory to - :type value: a single 8-bit unsigned int + :type value: int (8-bit unsigned) :param size: The size of to the allocation unit in bytes :type size: int - """ logging.debug("HipFunction memset called") - hip.hipMemset(allocation, value, size) + hip_check(hip.hipMemset(allocation, value, size)) def memcpy_dtoh(self, dest, src): """Perform a device to host memory copy. @@ -259,30 +292,24 @@ def memcpy_dtoh(self, dest, src): :type dest: numpy.ndarray :param src: A GPU memory allocation unit - :type src: ctypes ptr + :type src: DeviceArray or int """ logging.debug("HipFunction memcpy_dtoh called") - # Format arguments to correct type and perform memory copy - dtype_str = str(dest.dtype) - dest_c = dest.ctypes.data_as(ctypes.POINTER(dtype_map[dtype_str])) - hip.hipMemcpy_dtoh(dest_c, src, dest.nbytes) + hip_check(hip.hipMemcpy(dest, src, dest.nbytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost)) def memcpy_htod(self, dest, src): """Perform a host to device memory copy. :param dest: A GPU memory allocation unit - :type dest: ctypes ptr + :type dest: DeviceArray or int - :param src: A numpy array in host memory to store the data + :param src: A numpy array in host memory to copy from :type src: numpy.ndarray """ logging.debug("HipFunction memcpy_htod called") - # Format arguments to correct type and perform memory copy - dtype_str = str(src.dtype) - src_c = src.ctypes.data_as(ctypes.POINTER(dtype_map[dtype_str])) - hip.hipMemcpy_htod(dest, src_c, src.nbytes) + hip_check(hip.hipMemcpy(dest, src, src.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) def copy_constant_memory_args(self, cmem_args): """Adds constant memory arguments to the most recently compiled module. @@ -292,19 +319,17 @@ def copy_constant_memory_args(self, cmem_args): string key is used to name the constant memory symbol to which the value needs to be copied. Similar to regular arguments, these need to be numpy objects, such as numpy.ndarray or numpy.int32, and so on. - :type cmem_args: dict( string: numpy.ndarray, ... ) + :type cmem_args: dict(string: numpy.ndarray, ...) """ logging.debug("HipFunction copy_constant_memory_args called") # Iterate over dictionary - for k, v in cmem_args.items(): - #Get symbol pointer - symbol_ptr, _ = hip.hipModuleGetGlobal(self.current_module, k) - - #Format arguments and perform memory copy - dtype_str = str(v.dtype) - v_c = v.ctypes.data_as(ctypes.POINTER(dtype_map[dtype_str])) - hip.hipMemcpy_htod(symbol_ptr, v_c, v.nbytes) + for symbol_name, data in cmem_args.items(): + # Get symbol pointer and size using hipModuleGetGlobal + dptr, _ = hip_check(hip.hipModuleGetGlobal(self.current_module, symbol_name.encode())) + + # Copy data to the global memory location + hip_check(hip.hipMemcpy(dptr, data, data.nbytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)) def copy_shared_memory_args(self, smem_args): """Add shared memory arguments to the kernel.""" diff --git a/kernel_tuner/core.py b/kernel_tuner/core.py index 4323c411c..655779337 100644 --- a/kernel_tuner/core.py +++ b/kernel_tuner/core.py @@ -29,6 +29,11 @@ except ImportError: torch = util.TorchPlaceHolder() +try: + from hip._util.types import DeviceArray +except ImportError: + DeviceArray = Exception # using Exception here as a type that will never be among kernel arguments + _KernelInstance = namedtuple( "_KernelInstance", [ @@ -495,7 +500,7 @@ def check_kernel_output( should_sync = [answer[i] is not None for i, arg in enumerate(instance.arguments)] else: - should_sync = [isinstance(arg, (np.ndarray, cp.ndarray, torch.Tensor)) for arg in instance.arguments] + should_sync = [isinstance(arg, (np.ndarray, cp.ndarray, torch.Tensor, DeviceArray)) for arg in instance.arguments] # re-copy original contents of output arguments to GPU memory, to overwrite any changes # by earlier kernel runs @@ -659,7 +664,7 @@ def compile_kernel(self, instance, verbose): f"skipping config {util.get_instance_string(instance.params)} reason: too much shared memory used" ) else: - logging.debug("compile_kernel failed due to error: " + str(e)) + print("compile_kernel failed due to error: " + error_message) print("Error while compiling:", instance.name) raise e return func diff --git a/kernel_tuner/observers/hip.py b/kernel_tuner/observers/hip.py index 9ee339209..c536cf965 100644 --- a/kernel_tuner/observers/hip.py +++ b/kernel_tuner/observers/hip.py @@ -3,7 +3,7 @@ from kernel_tuner.observers.observer import BenchmarkObserver try: - from pyhip import hip, hiprtc + from hip import hip, hiprtc except (ImportError, RuntimeError): hip = None hiprtc = None @@ -14,7 +14,7 @@ class HipRuntimeObserver(BenchmarkObserver): def __init__(self, dev): if not hip or not hiprtc: - raise ImportError("Unable to import PyHIP, make sure PYTHONPATH includes PyHIP, or check https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-pyhip.") + raise ImportError("Unable to import HIP Python, or check https://kerneltuner.github.io/kernel_tuner/stable/install.html#hip-and-hip-python.") self.dev = dev self.stream = dev.stream diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 0d2cef696..710b59e0d 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -1,4 +1,5 @@ """Module for kernel tuner utility functions.""" + import errno import json import logging @@ -102,6 +103,11 @@ class StopCriterionReached(Exception): "block_size_z", ] +try: + from hip._util.types import DeviceArray +except ImportError: + DeviceArray = Exception # using Exception here as a type that will never be among kernel arguments + def check_argument_type(dtype, kernel_argument): """Check if the numpy.dtype matches the type used in the code.""" @@ -127,57 +133,58 @@ def check_argument_type(dtype, kernel_argument): def check_argument_list(kernel_name, kernel_string, args): - """Raise an exception if a kernel arguments do not match host arguments.""" + """Raise an exception if kernel arguments do not match host arguments.""" kernel_arguments = list() collected_errors = list() + for iterator in re.finditer(kernel_name + "[ \n\t]*" + r"\(", kernel_string): kernel_start = iterator.end() kernel_end = kernel_string.find(")", kernel_start) if kernel_start != 0: kernel_arguments.append(kernel_string[kernel_start:kernel_end].split(",")) + for arguments_set, arguments in enumerate(kernel_arguments): collected_errors.append(list()) if len(arguments) != len(args): - collected_errors[arguments_set].append( - "Kernel and host argument lists do not match in size." - ) + collected_errors[arguments_set].append("Kernel and host argument lists do not match in size.") continue + for i, arg in enumerate(args): kernel_argument = arguments[i] - # Fix to deal with tunable arguments + # Handle tunable arguments if isinstance(arg, Tunable): continue - if not isinstance(arg, (np.ndarray, np.generic, cp.ndarray, torch.Tensor)): + # Handle numpy arrays and other array types + if not isinstance(arg, (np.ndarray, np.generic, cp.ndarray, torch.Tensor, DeviceArray)): raise TypeError( - "Argument at position " - + str(i) - + " of type: " - + str(type(arg)) - + " should be of type np.ndarray or numpy scalar" + f"Argument at position {i} of type: {type(arg)} should be of type " + "np.ndarray, numpy scalar, or HIP Python DeviceArray type" ) correct = True - if isinstance(arg, np.ndarray) and "*" not in kernel_argument: - correct = False # array is passed to non-pointer kernel argument + if isinstance(arg, np.ndarray): + if "*" not in kernel_argument: + correct = False + + if isinstance(arg, DeviceArray): + str_dtype = str(np.dtype(arg.typestr)) + else: + str_dtype = str(arg.dtype) - if correct and check_argument_type(str(arg.dtype), kernel_argument): + if correct and check_argument_type(str_dtype, kernel_argument): continue collected_errors[arguments_set].append( - "Argument at position " - + str(i) - + " of dtype: " - + str(arg.dtype) - + " does not match " - + kernel_argument - + "." + f"Argument at position {i} of dtype: {str_dtype} does not match {kernel_argument}." ) + if not collected_errors[arguments_set]: # We assume that if there is a possible list of arguments that matches with the provided one # it is the right one return + for errors in collected_errors: warnings.warn(errors[0], UserWarning) @@ -186,10 +193,7 @@ def check_stop_criterion(to): """Checks if max_fevals is reached or time limit is exceeded.""" if "max_fevals" in to and len(to.unique_results) >= to.max_fevals: raise StopCriterionReached("max_fevals reached") - if "time_limit" in to and ( - ((time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3)) - > to.time_limit - ): + if "time_limit" in to and (((time.perf_counter() - to.start_time) + (to.simulated_time * 1e-3)) > to.time_limit): raise StopCriterionReached("time limit exceeded") @@ -198,13 +202,7 @@ def check_tune_params_list(tune_params, observers, simulation_mode=False): forbidden_names = ("grid_size_x", "grid_size_y", "grid_size_z", "time") for name, param in tune_params.items(): if name in forbidden_names: - raise ValueError( - "Tune parameter " - + name - + " with value " - + str(param) - + " has a forbidden name!" - ) + raise ValueError("Tune parameter " + name + " with value " + str(param) + " has a forbidden name!") if any("nvml_" in param for param in tune_params): if not simulation_mode and (not observers or not any(isinstance(obs, NVMLObserver) for obs in observers)): raise ValueError("Tune parameters starting with nvml_ require an NVMLObserver!") @@ -243,6 +241,7 @@ def check_block_size_params_names_list(block_size_names, tune_params): UserWarning, ) + def check_restriction(restrict, params: dict) -> bool: """Check whether a configuration meets a search space restriction.""" # if it's a python-constraint, convert to function and execute @@ -256,8 +255,12 @@ def check_restriction(restrict, params: dict) -> bool: elif callable(restrict): return restrict(**params) # if it's a tuple, use only the parameters in the second argument to call the restriction - elif (isinstance(restrict, tuple) and len(restrict) == 2 - and callable(restrict[0]) and isinstance(restrict[1], (list, tuple))): + elif ( + isinstance(restrict, tuple) + and len(restrict) == 2 + and callable(restrict[0]) + and isinstance(restrict[1], (list, tuple)) + ): # unpack the tuple restrict, selected_params = restrict # look up the selected parameters and their value @@ -272,6 +275,7 @@ def check_restriction(restrict, params: dict) -> bool: else: raise ValueError(f"Unkown restriction type {type(restrict)} ({restrict})") + def check_restrictions(restrictions, params: dict, verbose: bool) -> bool: """Check whether a configuration meets the search space restrictions.""" if callable(restrictions): @@ -296,29 +300,45 @@ def check_restrictions(restrictions, params: dict, verbose: bool) -> bool: def convert_constraint_restriction(restrict: Constraint): """Convert the python-constraint to a function for backwards compatibility.""" if isinstance(restrict, FunctionConstraint): + def f_restrict(p): return restrict._func(*p) + elif isinstance(restrict, AllDifferentConstraint): + def f_restrict(p): return len(set(p)) == len(p) + elif isinstance(restrict, AllEqualConstraint): + def f_restrict(p): return all(x == p[0] for x in p) + elif isinstance(restrict, MaxProdConstraint): + def f_restrict(p): return np.prod(p) <= restrict._maxprod + elif isinstance(restrict, MinProdConstraint): + def f_restrict(p): return np.prod(p) >= restrict._minprod + elif isinstance(restrict, MaxSumConstraint): + def f_restrict(p): return sum(p) <= restrict._maxsum + elif isinstance(restrict, ExactSumConstraint): + def f_restrict(p): return sum(p) == restrict._exactsum + elif isinstance(restrict, MinSumConstraint): + def f_restrict(p): return sum(p) >= restrict._minsum + elif isinstance(restrict, (InSetConstraint, NotInSetConstraint, SomeInSetConstraint, SomeNotInSetConstraint)): raise NotImplementedError( f"Restriction of the type {type(restrict)} is explicitely not supported in backwards compatibility mode, because the behaviour is too complex. Please rewrite this constraint to a function to use it with this algorithm." @@ -343,9 +363,7 @@ def config_valid(config, tuning_options, max_threads): if not legal: return False block_size_names = tuning_options.get("block_size_names", None) - valid_thread_block_dimensions = check_thread_block_dimensions( - params, max_threads, block_size_names - ) + valid_thread_block_dimensions = check_thread_block_dimensions(params, max_threads, block_size_names) return valid_thread_block_dimensions @@ -372,9 +390,7 @@ def detect_language(kernel_string): def get_best_config(results, objective, objective_higher_is_better=False): """Returns the best configuration from a list of results according to some objective.""" func = max if objective_higher_is_better else min - ignore_val = ( - sys.float_info.max if not objective_higher_is_better else -sys.float_info.max - ) + ignore_val = sys.float_info.max if not objective_higher_is_better else -sys.float_info.max best_config = func( results, key=lambda x: x[objective] if isinstance(x[objective], float) else ignore_val, @@ -419,18 +435,10 @@ def get_dimension_divisor(divisor_list, default, params): if callable(divisor_list): return divisor_list(params) else: - return np.prod( - [int(eval(replace_param_occurrences(s, params))) for s in divisor_list] - ) + return np.prod([int(eval(replace_param_occurrences(s, params))) for s in divisor_list]) - divisors = [ - get_dimension_divisor(d, block_size_names[i], params) - for i, d in enumerate(grid_div) - ] - return tuple( - int(np.ceil(float(current_problem_size[i]) / float(d))) - for i, d in enumerate(divisors) - ) + divisors = [get_dimension_divisor(d, block_size_names[i], params) for i, d in enumerate(grid_div)] + return tuple(int(np.ceil(float(current_problem_size[i]) / float(d))) for i, d in enumerate(divisors)) def get_instance_string(params): @@ -492,9 +500,7 @@ def get_problem_size(problem_size, params): elif isinstance(s, (int, np.integer)): current_problem_size[i] = s else: - raise TypeError( - "Error: problem_size should only contain strings or integers" - ) + raise TypeError("Error: problem_size should only contain strings or integers") return current_problem_size @@ -569,11 +575,12 @@ def get_total_timings(results, env, overhead_time): return env -NVRTC_VALID_CC = np.array(['50', '52', '53', '60', '61', '62', '70', '72', '75', '80', '87', '89', '90', '90a']) +NVRTC_VALID_CC = np.array(["50", "52", "53", "60", "61", "62", "70", "72", "75", "80", "87", "89", "90", "90a"]) + def to_valid_nvrtc_gpu_arch_cc(compute_capability: str) -> str: """Returns a valid Compute Capability for NVRTC `--gpu-architecture=`, as per https://docs.nvidia.com/cuda/nvrtc/index.html#group__options.""" - return max(NVRTC_VALID_CC[NVRTC_VALID_CC<=compute_capability], default='52') + return max(NVRTC_VALID_CC[NVRTC_VALID_CC <= compute_capability], default="52") def print_config(config, tuning_options, runner): @@ -822,7 +829,9 @@ def has_kw_argument(func, name): return lambda answer, result_host, atol: v(answer, result_host) -def parse_restrictions(restrictions: list[str], tune_params: dict, monolithic = False, try_to_constraint = True) -> list[tuple[Union[Constraint, str], list[str]]]: +def parse_restrictions( + restrictions: list[str], tune_params: dict, monolithic=False, try_to_constraint=True +) -> list[tuple[Union[Constraint, str], list[str]]]: """Parses restrictions from a list of strings into compilable functions and constraints, or a single compilable function (if monolithic is True). Returns a list of tuples of (strings or constraints) and parameters.""" # rewrite the restrictions so variables are singled out regex_match_variable = r"([a-zA-Z_$][a-zA-Z_$0-9]*)" @@ -854,8 +863,8 @@ def to_multiple_restrictions(restrictions: list[str]) -> list[str]: split_restrictions.append(res) continue # find the indices of splittable comparators - comparators = ['<=', '>=', '>', '<'] - comparators_indices = [(m.start(0), m.end(0)) for m in re.finditer('|'.join(comparators), res)] + comparators = ["<=", ">=", ">", "<"] + comparators_indices = [(m.start(0), m.end(0)) for m in re.finditer("|".join(comparators), res)] if len(comparators_indices) <= 1: # this can't be split further split_restrictions.append(res) @@ -863,15 +872,19 @@ def to_multiple_restrictions(restrictions: list[str]) -> list[str]: # split the restrictions from the previous to the next comparator for index in range(len(comparators_indices)): temp_copy = res - prev_stop = comparators_indices[index-1][1] + 1 if index > 0 else 0 - next_stop = comparators_indices[index+1][0] if index < len(comparators_indices) - 1 else len(temp_copy) + prev_stop = comparators_indices[index - 1][1] + 1 if index > 0 else 0 + next_stop = ( + comparators_indices[index + 1][0] if index < len(comparators_indices) - 1 else len(temp_copy) + ) split_restrictions.append(temp_copy[prev_stop:next_stop].strip()) return split_restrictions - def to_numeric_constraint(restriction: str, params: list[str]) -> Optional[Union[MinSumConstraint, ExactSumConstraint, MaxSumConstraint, MaxProdConstraint]]: + def to_numeric_constraint( + restriction: str, params: list[str] + ) -> Optional[Union[MinSumConstraint, ExactSumConstraint, MaxSumConstraint, MaxProdConstraint]]: """Converts a restriction to a built-in numeric constraint if possible.""" - comparators = ['<=', '==', '>=', '>', '<'] - comparators_found = re.findall('|'.join(comparators), restriction) + comparators = ["<=", "==", ">=", ">", "<"] + comparators_found = re.findall("|".join(comparators), restriction) # check if there is exactly one comparator, if not, return None if len(comparators_found) != 1: return None @@ -897,19 +910,21 @@ def is_or_evals_to_number(s: str) -> Optional[Union[int, float]]: if (left_num is None and right_num is None) or (left_num is not None and right_num is not None): # left_num and right_num can't be both None or both a constant return None - number, variables, variables_on_left = (left_num, right.strip(), False) if left_num is not None else (right_num, left.strip(), True) + number, variables, variables_on_left = ( + (left_num, right.strip(), False) if left_num is not None else (right_num, left.strip(), True) + ) # if the number is an integer, we can map '>' to '>=' and '<' to '<=' by changing the number (does not work with floating points!) number_is_int = isinstance(number, int) if number_is_int: - if comparator == '<': + if comparator == "<": if variables_on_left: # (x < 2) == (x <= 2-1) number -= 1 else: # (2 < x) == (2+1 <= x) number += 1 - elif comparator == '>': + elif comparator == ">": if variables_on_left: # (x > 2) == (x >= 2+1) number += 1 @@ -918,8 +933,8 @@ def is_or_evals_to_number(s: str) -> Optional[Union[int, float]]: number -= 1 # check if an operator is applied on the variables, if not return - operators = [r'\*\*', r'\*', r'\+'] - operators_found = re.findall(str('|'.join(operators)), variables) + operators = [r"\*\*", r"\*", r"\+"] + operators_found = re.findall(str("|".join(operators)), variables) if len(operators_found) == 0: # no operators found, return only based on comparator if len(params) != 1 or variables not in params: @@ -927,12 +942,12 @@ def is_or_evals_to_number(s: str) -> Optional[Union[int, float]]: return None # map to a Constraint # if there are restrictions with a single variable, it will be used to prune the domain at the start - elif comparator == '==': + elif comparator == "==": return ExactSumConstraint(number) - elif comparator == '<=' or (comparator == '<' and number_is_int): - return MaxSumConstraint(number) if variables_on_left else MinSumConstraint(number) - elif comparator == '>=' or (comparator == '>' and number_is_int): - return MinSumConstraint(number) if variables_on_left else MaxSumConstraint(number) + elif comparator == "<=" or (comparator == "<" and number_is_int): + return MaxSumConstraint(number) if variables_on_left else MinSumConstraint(number) + elif comparator == ">=" or (comparator == ">" and number_is_int): + return MinSumConstraint(number) if variables_on_left else MaxSumConstraint(number) raise ValueError(f"Invalid comparator {comparator}") # check which operator is applied on the variables @@ -946,34 +961,36 @@ def is_or_evals_to_number(s: str) -> Optional[Union[int, float]]: # check if there are only pure, non-recurring variables (no operations or constants) in the restriction if len(splitted) == len(params) and all(s.strip() in params for s in splitted): # map to a Constraint - if operator == '**': + if operator == "**": # power operations are not (yet) supported, added to avoid matching the double asterisk return None - elif operator == '*': - if comparator == '<=' or (comparator == '<' and number_is_int): + elif operator == "*": + if comparator == "<=" or (comparator == "<" and number_is_int): return MaxProdConstraint(number) if variables_on_left else MinProdConstraint(number) - elif comparator == '>=' or (comparator == '>' and number_is_int): + elif comparator == ">=" or (comparator == ">" and number_is_int): return MinProdConstraint(number) if variables_on_left else MaxProdConstraint(number) - elif operator == '+': - if comparator == '==': + elif operator == "+": + if comparator == "==": return ExactSumConstraint(number) - elif comparator == '<=' or (comparator == '<' and number_is_int): + elif comparator == "<=" or (comparator == "<" and number_is_int): return MaxSumConstraint(number) if variables_on_left else MinSumConstraint(number) - elif comparator == '>=' or (comparator == '>' and number_is_int): + elif comparator == ">=" or (comparator == ">" and number_is_int): return MinSumConstraint(number) if variables_on_left else MaxSumConstraint(number) else: raise ValueError(f"Invalid operator {operator}") return None - def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Union[AllEqualConstraint, AllDifferentConstraint]]: + def to_equality_constraint( + restriction: str, params: list[str] + ) -> Optional[Union[AllEqualConstraint, AllDifferentConstraint]]: """Converts a restriction to either an equality or inequality constraint on all the parameters if possible.""" # check if all parameters are involved if len(params) != len(tune_params): return None # find whether (in)equalities appear in this restriction - equalities_found = re.findall('==', restriction) - inequalities_found = re.findall('!=', restriction) + equalities_found = re.findall("==", restriction) + inequalities_found = re.findall("!=", restriction) # check if one of the two have been found, if none or both have been found, return None if not (len(equalities_found) > 0 ^ len(inequalities_found) > 0): return None @@ -984,9 +1001,9 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio # check if there are only pure, non-recurring variables (no operations or constants) in the restriction if len(splitted) == len(params) and all(s.strip() in params for s in splitted): # map to a Constraint - if comparator == '==': + if comparator == "==": return AllEqualConstraint() - elif comparator == '!=': + elif comparator == "!=": return AllDifferentConstraint() return ValueError(f"Not possible: comparator should be '==' or '!=', is {comparator}") return None @@ -1005,7 +1022,12 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio finalized_constraint = None if try_to_constraint and " or " not in res and " and " not in res: # if applicable, strip the outermost round brackets - while parsed_restriction[0] == '(' and parsed_restriction[-1] == ')' and '(' not in parsed_restriction[1:] and ')' not in parsed_restriction[:1]: + while ( + parsed_restriction[0] == "(" + and parsed_restriction[-1] == ")" + and "(" not in parsed_restriction[1:] + and ")" not in parsed_restriction[:1] + ): parsed_restriction = parsed_restriction[1:-1] # check if we can turn this into the built-in numeric comparison constraint finalized_constraint = to_numeric_constraint(parsed_restriction, params_used) @@ -1018,7 +1040,9 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio parsed_restrictions.append((finalized_constraint, params_used)) else: # create one monolithic function - parsed_restrictions = ") and (".join([re.sub(regex_match_variable, replace_params, res) for res in restrictions]) + parsed_restrictions = ") and (".join( + [re.sub(regex_match_variable, replace_params, res) for res in restrictions] + ) # tidy up the code by removing the last suffix and unnecessary spaces parsed_restrictions = "(" + parsed_restrictions.strip() + ")" @@ -1027,12 +1051,19 @@ def to_equality_constraint(restriction: str, params: list[str]) -> Optional[Unio # provide a mapping of the parameter names to the index in the tuple received params_index = dict(zip(tune_params.keys(), range(len(tune_params.keys())))) - parsed_restrictions = [(f"def restrictions(*params): params_index = {params_index}; return {parsed_restrictions} \n", list(tune_params.keys()))] + parsed_restrictions = [ + ( + f"def restrictions(*params): params_index = {params_index}; return {parsed_restrictions} \n", + list(tune_params.keys()), + ) + ] return parsed_restrictions -def compile_restrictions(restrictions: list, tune_params: dict, monolithic = False, try_to_constraint = True) -> list[tuple[Union[str, Constraint, FunctionType], list[str]]]: +def compile_restrictions( + restrictions: list, tune_params: dict, monolithic=False, try_to_constraint=True +) -> list[tuple[Union[str, Constraint, FunctionType], list[str]]]: """Parses restrictions from a list of strings into a list of strings, Functions, or Constraints (if `try_to_constraint`) and parameters used, or a single Function if monolithic is true.""" # filter the restrictions to get only the strings restrictions_str, restrictions_ignore = [], [] @@ -1042,7 +1073,9 @@ def compile_restrictions(restrictions: list, tune_params: dict, monolithic = Fal return restrictions_ignore # parse the strings - parsed_restrictions = parse_restrictions(restrictions_str, tune_params, monolithic=monolithic, try_to_constraint=try_to_constraint) + parsed_restrictions = parse_restrictions( + restrictions_str, tune_params, monolithic=monolithic, try_to_constraint=try_to_constraint + ) # compile the parsed restrictions into a function compiled_restrictions: list[tuple] = list() @@ -1103,18 +1136,12 @@ def process_cache(cache, kernel_options, tuning_options, runner): # if file does not exist, create new cache if not os.path.isfile(cache): if tuning_options.simulation_mode: - raise ValueError( - f"Simulation mode requires an existing cachefile: file {cache} does not exist" - ) + raise ValueError(f"Simulation mode requires an existing cachefile: file {cache} does not exist") c = dict() c["device_name"] = runner.dev.name c["kernel_name"] = kernel_options.kernel_name - c["problem_size"] = ( - kernel_options.problem_size - if not callable(kernel_options.problem_size) - else "callable" - ) + c["problem_size"] = kernel_options.problem_size if not callable(kernel_options.problem_size) else "callable" c["tune_params_keys"] = list(tuning_options.tune_params.keys()) c["tune_params"] = tuning_options.tune_params c["objective"] = tuning_options.objective @@ -1139,34 +1166,19 @@ def process_cache(cache, kernel_options, tuning_options, runner): # check if it is safe to continue tuning from this cache if cached_data["device_name"] != runner.dev.name: - raise ValueError( - "Cannot load cache which contains results for different device" - ) + raise ValueError("Cannot load cache which contains results for different device") if cached_data["kernel_name"] != kernel_options.kernel_name: - raise ValueError( - "Cannot load cache which contains results for different kernel" - ) + raise ValueError("Cannot load cache which contains results for different kernel") if "problem_size" in cached_data and not callable(kernel_options.problem_size): # if problem_size is not iterable, compare directly if not hasattr(kernel_options.problem_size, "__iter__"): if cached_data["problem_size"] != kernel_options.problem_size: - raise ValueError( - "Cannot load cache which contains results for different problem_size" - ) + raise ValueError("Cannot load cache which contains results for different problem_size") # else (problem_size is iterable) # cache returns list, problem_size is likely a tuple. Therefore, the next check # checks the equality of all items in the list/tuples individually - elif not all( - [ - i == j - for i, j in zip( - cached_data["problem_size"], kernel_options.problem_size - ) - ] - ): - raise ValueError( - "Cannot load cache which contains results for different problem_size" - ) + elif not all([i == j for i, j in zip(cached_data["problem_size"], kernel_options.problem_size)]): + raise ValueError("Cannot load cache which contains results for different problem_size") if cached_data["tune_params_keys"] != list(tuning_options.tune_params.keys()): if all(key in tuning_options.tune_params for key in cached_data["tune_params_keys"]): raise ValueError( @@ -1203,6 +1215,7 @@ def correct_open_cache(cache, open_cache=True): return filestr + def read_cache(cache, open_cache=True): """Read the cachefile into a dictionary, if open_cache=True prepare the cachefile for appending.""" filestr = correct_open_cache(cache, open_cache) diff --git a/pyproject.toml b/pyproject.toml index b09198bec..48034bf15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,7 +80,7 @@ pynvml = { version = "^11.4.1", optional = true } # OpenCL pyopencl = { version = "*", optional = true } # Attention: if pyopencl is changed here, also change `session.install("pyopencl")` in the Noxfile # HIP -pyhip-interface = { version = "*", optional = true } +hip-python = { version = "*", optional = true } # Tutorial (for the notebooks used in the examples) jupyter = { version = "^1.0.0", optional = true } matplotlib = { version = "^3.5.0", optional = true } @@ -89,7 +89,7 @@ matplotlib = { version = "^3.5.0", optional = true } cuda = ["pycuda", "nvidia-ml-py", "pynvml"] opencl = ["pyopencl"] cuda_opencl = ["pycuda", "pyopencl"] -hip = ["pyhip-interface"] +hip = ["hip-python"] tutorial = ["jupyter", "matplotlib", "nvidia-ml-py"] # ATTENTION: if anything is changed here, run `poetry update` and `poetry export --with docs --without-hashes --format=requirements.txt --output doc/requirements.txt` diff --git a/test/context.py b/test/context.py index cccc3332a..d1cbcf3c3 100644 --- a/test/context.py +++ b/test/context.py @@ -53,15 +53,11 @@ except Exception: cuda_present = False -PYHIP_PATH = os.environ.get("PYHIP_PATH") # get the PYHIP_PATH environment variable try: - if PYHIP_PATH is not None: - sys.path.insert(0, PYHIP_PATH) - from pyhip import hip, hiprtc - - pyhip_present = True + from hip import hip + hip_present = True except ImportError: - pyhip_present = False + hip_present = False skip_if_no_pycuda = pytest.mark.skipif( not pycuda_present, reason="PyCuda not installed or no CUDA device detected" @@ -82,7 +78,7 @@ ) skip_if_no_openmp = pytest.mark.skipif(not openmp_present, reason="No OpenMP found") skip_if_no_openacc = pytest.mark.skipif(not openacc_present, reason="No nvc++ on PATH") -skip_if_no_pyhip = pytest.mark.skipif(not pyhip_present, reason="No PyHIP found") +skip_if_no_hip = pytest.mark.skipif(not hip_present, reason="No HIP Python found") def skip_backend(backend: str): @@ -100,3 +96,5 @@ def skip_backend(backend: str): pytest.skip("No gfortran on PATH") elif backend.upper() == "OPENACC" and not openacc_present: pytest.skip("No nvc++ on PATH") + elif backend.upper() == "HIP" and not hip_present: + pytest.skip("HIP Python not installed") diff --git a/test/test_file_utils.py b/test/test_file_utils.py index e84e00da4..622e06b44 100644 --- a/test/test_file_utils.py +++ b/test/test_file_utils.py @@ -1,10 +1,18 @@ import json import pytest +import ctypes from jsonschema import validate +import numpy as np +import warnings +try: + from hip import hip +except: + hip = None from kernel_tuner.file_utils import output_file_schema, store_metadata_file, store_output_file -from kernel_tuner.util import delete_temp_file +from kernel_tuner.util import delete_temp_file, check_argument_list +from .context import skip_if_no_hip from .test_runners import cache_filename, env, tune_kernel # noqa: F401 @@ -55,3 +63,33 @@ def test_store_metadata_file(): finally: # clean up delete_temp_file(filename) + +def hip_check(call_result): + err = call_result[0] + result = call_result[1:] + if len(result) == 1: + result = result[0] + if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess: + raise RuntimeError(str(err)) + return result + +@skip_if_no_hip +def test_check_argument_list_device_array(): + """Test check_argument_list with DeviceArray""" + float_kernel = """ + __global__ void simple_kernel(float* input) { + // kernel code + } + """ + host_array = np.ones((100,), dtype=np.float32) + num_bytes = host_array.size * host_array.itemsize + device_array = hip_check(hip.hipMalloc(num_bytes)) + device_array.configure( + typestr="float32", + shape=host_array.shape, + itemsize=host_array.itemsize + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + check_argument_list("simple_kernel", float_kernel, [device_array]) diff --git a/test/test_hip_functions.py b/test/test_hip_functions.py index 3a665c254..e192223ed 100644 --- a/test/test_hip_functions.py +++ b/test/test_hip_functions.py @@ -1,5 +1,4 @@ import ctypes - import numpy as np import pytest @@ -7,14 +6,25 @@ from kernel_tuner.backends import hip as kt_hip from kernel_tuner.core import KernelInstance, KernelSource -from .context import skip_if_no_pyhip +from .context import skip_if_no_hip try: - from pyhip import hip, hiprtc + from hip import hip, hiprtc hip_present = True except ImportError: pass +def hip_check(call_result): + err = call_result[0] + result = call_result[1:] + if len(result) == 1: + result = result[0] + if isinstance(err, hip.hipError_t) and err != hip.hipError_t.hipSuccess: + raise RuntimeError(str(err)) + elif isinstance(err, hiprtc.hiprtcResult) and err != hiprtc.hiprtcResult.HIPRTC_SUCCESS: + raise RuntimeError(str(err)) + return result + @pytest.fixture def env(): kernel_string = """ @@ -38,9 +48,8 @@ def env(): return ["vector_add", kernel_string, size, args, tune_params] -@skip_if_no_pyhip +@skip_if_no_hip def test_ready_argument_list(): - size = 1000 a = np.int32(75) b = np.random.randn(size).astype(np.float32) @@ -53,14 +62,13 @@ def test_ready_argument_list(): gpu_args = dev.ready_argument_list(arguments) # ctypes have no equality defined, so indirect comparison for type and value - assert(isinstance(gpu_args[1], ctypes.c_int)) - assert(isinstance(gpu_args[3], ctypes.c_bool)) - assert(gpu_args[1] == a) - assert(gpu_args[3] == c) + assert isinstance(gpu_args[1], ctypes.c_int) + assert isinstance(gpu_args[3], ctypes.c_bool) + assert gpu_args[1].value == a + assert gpu_args[3].value == c -@skip_if_no_pyhip +@skip_if_no_hip def test_compile(): - kernel_string = """ __global__ void vector_add(float *c, float *a, float *b, int n) { int i = blockIdx.x * blockDim.x + threadIdx.x; @@ -79,12 +87,11 @@ def test_compile(): except Exception as e: pytest.fail("Did not expect any exception:" + str(e)) - -@skip_if_no_pyhip +@skip_if_no_hip def test_memset_and_memcpy_dtoh(): a = [1, 2, 3, 4] x = np.array(a).astype(np.int8) - x_d = hip.hipMalloc(x.nbytes) + x_d = hip_check(hip.hipMalloc(x.nbytes)) dev = kt_hip.HipFunctions() dev.memset(x_d, 4, x.nbytes) @@ -94,11 +101,11 @@ def test_memset_and_memcpy_dtoh(): assert all(output == np.full(4, 4)) -@skip_if_no_pyhip +@skip_if_no_hip def test_memcpy_htod(): a = [1, 2, 3, 4] x = np.array(a).astype(np.float32) - x_d = hip.hipMalloc(x.nbytes) + x_d = hip_check(hip.hipMalloc(x.nbytes)) output = np.empty(4, dtype=np.float32) dev = kt_hip.HipFunctions() @@ -107,7 +114,7 @@ def test_memcpy_htod(): assert all(output == x) -@skip_if_no_pyhip +@skip_if_no_hip def test_copy_constant_memory_args(): kernel_string = """ __constant__ float my_constant_data[100]; @@ -138,13 +145,13 @@ def test_copy_constant_memory_args(): dev.memcpy_dtoh(output, gpu_args[0]) - assert(my_constant_data == output).all() + assert (my_constant_data == output).all() -@skip_if_no_pyhip +@skip_if_no_hip def test_smem_args(env): result, _ = tune_kernel(*env, - smem_args=dict(size="block_size_x*4"), - verbose=True, lang="HIP") + smem_args=dict(size="block_size_x*4"), + verbose=True, lang="HIP") tune_params = env[-1] assert len(result) == len(tune_params["block_size_x"]) result, _ = tune_kernel( @@ -152,6 +159,4 @@ def test_smem_args(env): smem_args=dict(size=lambda p: p['block_size_x'] * 4), verbose=True, lang="HIP") tune_params = env[-1] - assert len(result) == len(tune_params["block_size_x"]) - - + assert len(result) == len(tune_params["block_size_x"]) \ No newline at end of file diff --git a/test/test_observers.py b/test/test_observers.py index 5f2242657..97928b477 100644 --- a/test/test_observers.py +++ b/test/test_observers.py @@ -10,7 +10,7 @@ skip_if_no_cupy, skip_if_no_opencl, skip_if_no_pycuda, - skip_if_no_pyhip, + skip_if_no_hip, skip_if_no_pynvml, ) from .test_hip_functions import env as env_hip # noqa: F401 @@ -68,7 +68,7 @@ def test_register_observer_opencl(env_opencl): assert err.errisinstance(NotImplementedError) assert "OpenCL" in str(err.value) -@skip_if_no_pyhip +@skip_if_no_hip def test_register_observer_hip(env_hip): with raises(NotImplementedError) as err: kernel_tuner.tune_kernel(*env_hip, observers=[RegisterObserver()], lang='HIP')