Skip to content
93 changes: 41 additions & 52 deletions kernel_tuner/backends/nvcuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,19 @@

from kernel_tuner.backends.backend import GPUBackend
from kernel_tuner.observers.nvcuda import CudaRuntimeObserver
from kernel_tuner.util import SkippableFailure, cuda_error_check, to_valid_nvrtc_gpu_arch_cc
from kernel_tuner.util import SkippableFailure
from kernel_tuner.utils.nvcuda import cuda_error_check, to_valid_nvrtc_gpu_arch_cc

# embedded in try block to be able to generate documentation
# and run tests without cuda-python installed
try:
from cuda import cuda, cudart, nvrtc
from cuda.bindings import driver, runtime, nvrtc
except ImportError:
cuda = None
driver = None


class CudaFunctions(GPUBackend):
"""Class that groups the Cuda functions on maintains state about the device."""
"""Class that groups the Cuda functions and it maintains state about the device."""

def __init__(self, device=0, iterations=7, compiler_options=None, observers=None):
"""Instantiate CudaFunctions object used for interacting with the CUDA device.
Expand All @@ -38,34 +39,30 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
"""
self.allocations = []
self.texrefs = []
if not cuda:
if not driver:
raise ImportError(
"cuda-python not installed, install using 'pip install cuda-python', or check https://kerneltuner.github.io/kernel_tuner/stable/install.html#cuda-and-pycuda."
)

# initialize and select device
err = cuda.cuInit(0)
err = driver.cuInit(0)
cuda_error_check(err)
err, self.device = cuda.cuDeviceGet(device)
err, self.device = driver.cuDeviceGet(device)
cuda_error_check(err)
err, self.context = cuda.cuDevicePrimaryCtxRetain(device)
err, self.context = driver.cuDevicePrimaryCtxRetain(device)
cuda_error_check(err)
if CudaFunctions.last_selected_device != device:
err = cuda.cuCtxSetCurrent(self.context)
err = driver.cuCtxSetCurrent(self.context)
cuda_error_check(err)
CudaFunctions.last_selected_device = device

# compute capabilities and device properties
err, major = cudart.cudaDeviceGetAttribute(
cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor, device
)
err, major = runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrComputeCapabilityMajor, device)
cuda_error_check(err)
err, minor = cudart.cudaDeviceGetAttribute(
cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, device
)
err, minor = runtime.cudaDeviceGetAttribute(runtime.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, device)
cuda_error_check(err)
err, self.max_threads = cudart.cudaDeviceGetAttribute(
cudart.cudaDeviceAttr.cudaDevAttrMaxThreadsPerBlock, device
err, self.max_threads = runtime.cudaDeviceGetAttribute(
runtime.cudaDeviceAttr.cudaDevAttrMaxThreadsPerBlock, device
)
cuda_error_check(err)
self.cc = f"{major}{minor}"
Expand All @@ -78,11 +75,11 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
self.compiler_options_bytes.append(str(option).encode("UTF-8"))

# create a stream and events
err, self.stream = cuda.cuStreamCreate(0)
err, self.stream = driver.cuStreamCreate(0)
cuda_error_check(err)
err, self.start = cuda.cuEventCreate(0)
err, self.start = driver.cuEventCreate(0)
cuda_error_check(err)
err, self.end = cuda.cuEventCreate(0)
err, self.end = driver.cuEventCreate(0)
cuda_error_check(err)

# default dynamically allocated shared memory size, can be overwritten using smem_args
Expand All @@ -95,11 +92,11 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
observer.register_device(self)

# collect environment information
err, device_properties = cudart.cudaGetDeviceProperties(device)
err, device_properties = runtime.cudaGetDeviceProperties(device)
cuda_error_check(err)
env = dict()
env["device_name"] = device_properties.name.decode()
env["cuda_version"] = cuda.CUDA_VERSION
env["cuda_version"] = driver.CUDA_VERSION
env["compute_capability"] = self.cc
env["iterations"] = self.iterations
env["compiler_options"] = self.compiler_options
Expand All @@ -109,8 +106,8 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None

def __del__(self):
for device_memory in self.allocations:
if isinstance(device_memory, cuda.CUdeviceptr):
err = cuda.cuMemFree(device_memory)
if isinstance(device_memory, driver.CUdeviceptr):
err = driver.cuMemFree(device_memory)
cuda_error_check(err)

def ready_argument_list(self, arguments):
Expand All @@ -128,7 +125,7 @@ def ready_argument_list(self, arguments):
for arg in arguments:
# if arg is a numpy array copy it to device
if isinstance(arg, np.ndarray):
err, device_memory = cuda.cuMemAlloc(arg.nbytes)
err, device_memory = driver.cuMemAlloc(arg.nbytes)
cuda_error_check(err)
self.allocations.append(device_memory)
gpu_args.append(device_memory)
Expand Down Expand Up @@ -164,38 +161,30 @@ def compile(self, kernel_instance):
if not any(["--std=" in opt for opt in self.compiler_options]):
self.compiler_options.append("--std=c++11")
if not any([b"--gpu-architecture=" in opt or b"-arch" in opt for opt in compiler_options]):
compiler_options.append(
f"--gpu-architecture=compute_{to_valid_nvrtc_gpu_arch_cc(self.cc)}".encode("UTF-8")
)
compiler_options.append(f"--gpu-architecture=compute_{to_valid_nvrtc_gpu_arch_cc(self.cc)}".encode("UTF-8"))
if not any(["--gpu-architecture=" in opt or "-arch" in opt for opt in self.compiler_options]):
self.compiler_options.append(f"--gpu-architecture=compute_{to_valid_nvrtc_gpu_arch_cc(self.cc)}")

err, program = nvrtc.nvrtcCreateProgram(
str.encode(kernel_string), b"CUDAProgram", 0, [], []
)
err, program = nvrtc.nvrtcCreateProgram(str.encode(kernel_string), b"CUDAProgram", 0, [], [])
try:
cuda_error_check(err)
err = nvrtc.nvrtcCompileProgram(
program, len(compiler_options), compiler_options
)
err = nvrtc.nvrtcCompileProgram(program, len(compiler_options), compiler_options)
cuda_error_check(err)
err, size = nvrtc.nvrtcGetPTXSize(program)
cuda_error_check(err)
buff = b" " * size
err = nvrtc.nvrtcGetPTX(program, buff)
cuda_error_check(err)
err, self.current_module = cuda.cuModuleLoadData(np.char.array(buff))
if err == cuda.CUresult.CUDA_ERROR_INVALID_PTX:
err, self.current_module = driver.cuModuleLoadData(np.char.array(buff))
if err == driver.CUresult.CUDA_ERROR_INVALID_PTX:
raise SkippableFailure("uses too much shared data")
else:
cuda_error_check(err)
err, self.func = cuda.cuModuleGetFunction(
self.current_module, str.encode(kernel_name)
)
err, self.func = driver.cuModuleGetFunction(self.current_module, str.encode(kernel_name))
cuda_error_check(err)

# get the number of registers per thread used in this kernel
num_regs = cuda.cuFuncGetAttribute(cuda.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NUM_REGS, self.func)
num_regs = driver.cuFuncGetAttribute(driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NUM_REGS, self.func)
assert num_regs[0] == 0, f"Retrieving number of registers per thread unsuccesful: code {num_regs[0]}"
self.num_regs = num_regs[1]

Expand All @@ -210,26 +199,26 @@ def compile(self, kernel_instance):

def start_event(self):
"""Records the event that marks the start of a measurement."""
err = cudart.cudaEventRecord(self.start, self.stream)
err = runtime.cudaEventRecord(self.start, self.stream)
cuda_error_check(err)

def stop_event(self):
"""Records the event that marks the end of a measurement."""
err = cudart.cudaEventRecord(self.end, self.stream)
err = runtime.cudaEventRecord(self.end, self.stream)
cuda_error_check(err)

def kernel_finished(self):
"""Returns True if the kernel has finished, False otherwise."""
err = cudart.cudaEventQuery(self.end)
if err[0] == cudart.cudaError_t.cudaSuccess:
err = runtime.cudaEventQuery(self.end)
if err[0] == runtime.cudaError_t.cudaSuccess:
return True
else:
return False

@staticmethod
def synchronize():
"""Halts execution until device has finished its tasks."""
err = cudart.cudaDeviceSynchronize()
err = runtime.cudaDeviceSynchronize()
cuda_error_check(err)

def copy_constant_memory_args(self, cmem_args):
Expand All @@ -243,9 +232,9 @@ def copy_constant_memory_args(self, cmem_args):
:type cmem_args: dict( string: numpy.ndarray, ... )
"""
for k, v in cmem_args.items():
err, symbol, _ = cuda.cuModuleGetGlobal(self.current_module, str.encode(k))
err, symbol, _ = driver.cuModuleGetGlobal(self.current_module, str.encode(k))
cuda_error_check(err)
err = cuda.cuMemcpyHtoD(symbol, v, v.nbytes)
err = driver.cuMemcpyHtoD(symbol, v, v.nbytes)
cuda_error_check(err)

def copy_shared_memory_args(self, smem_args):
Expand Down Expand Up @@ -284,12 +273,12 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None):
stream = self.stream
arg_types = list()
for arg in gpu_args:
if isinstance(arg, cuda.CUdeviceptr):
if isinstance(arg, driver.CUdeviceptr):
arg_types.append(None)
else:
arg_types.append(np.ctypeslib.as_ctypes_type(arg.dtype))
kernel_args = (tuple(gpu_args), tuple(arg_types))
err = cuda.cuLaunchKernel(
err = driver.cuLaunchKernel(
func,
grid[0],
grid[1],
Expand Down Expand Up @@ -318,7 +307,7 @@ def memset(allocation, value, size):
:type size: int

"""
err = cudart.cudaMemset(allocation, value, size)
err = runtime.cudaMemset(allocation, value, size)
cuda_error_check(err)

@staticmethod
Expand All @@ -331,7 +320,7 @@ def memcpy_dtoh(dest, src):
:param src: A GPU memory allocation unit
:type src: cuda.CUdeviceptr
"""
err = cuda.cuMemcpyDtoH(dest, src, dest.nbytes)
err = driver.cuMemcpyDtoH(dest, src, dest.nbytes)
cuda_error_check(err)

@staticmethod
Expand All @@ -344,7 +333,7 @@ def memcpy_htod(dest, src):
:param src: A numpy array in host memory to store the data
:type src: numpy.ndarray
"""
err = cuda.cuMemcpyHtoD(dest, src, src.nbytes)
err = driver.cuMemcpyHtoD(dest, src, src.nbytes)
cuda_error_check(err)

units = {"time": "ms"}
Expand Down
6 changes: 3 additions & 3 deletions kernel_tuner/observers/nvcuda.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numpy as np

try:
from cuda import cudart
from cuda.bindings import runtime
except ImportError:
cuda = None

from kernel_tuner.observers.observer import BenchmarkObserver
from kernel_tuner.util import cuda_error_check
from kernel_tuner.utils.nvcuda import cuda_error_check


class CudaRuntimeObserver(BenchmarkObserver):
Expand All @@ -21,7 +21,7 @@ def __init__(self, dev):

def after_finish(self):
# Time is measured in milliseconds
err, time = cudart.cudaEventElapsedTime(self.start, self.end)
err, time = runtime.cudaEventElapsedTime(self.start, self.end)
cuda_error_check(err)
self.times.append(time)

Expand Down
28 changes: 0 additions & 28 deletions kernel_tuner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@
import cupy as cp
except ImportError:
cp = np
try:
from cuda import cuda, cudart, nvrtc
except ImportError:
cuda = None

from kernel_tuner.observers.nvml import NVMLObserver

Expand Down Expand Up @@ -642,14 +638,6 @@ def get_total_timings(results, env, overhead_time):
return env


NVRTC_VALID_CC = np.array(["50", "52", "53", "60", "61", "62", "70", "72", "75", "80", "87", "89", "90", "90a"])


def to_valid_nvrtc_gpu_arch_cc(compute_capability: str) -> str:
"""Returns a valid Compute Capability for NVRTC `--gpu-architecture=`, as per https://docs.nvidia.com/cuda/nvrtc/index.html#group__options."""
return max(NVRTC_VALID_CC[NVRTC_VALID_CC <= compute_capability], default="52")


def print_config(config, tuning_options, runner):
"""Print the configuration string with tunable parameters and benchmark results."""
print_config_output(tuning_options.tune_params, config, runner.quiet, tuning_options.metrics, runner.units)
Expand Down Expand Up @@ -1315,19 +1303,3 @@ def dump_cache(obj: str, tuning_options):
if isinstance(tuning_options.cache, dict) and tuning_options.cachefile:
with open(tuning_options.cachefile, "a") as cachefile:
cachefile.write(obj)


def cuda_error_check(error):
"""Checking the status of CUDA calls using the NVIDIA cuda-python backend."""
if isinstance(error, cuda.CUresult):
if error != cuda.CUresult.CUDA_SUCCESS:
_, name = cuda.cuGetErrorName(error)
raise RuntimeError(f"CUDA error: {name.decode()}")
elif isinstance(error, cudart.cudaError_t):
if error != cudart.cudaError_t.cudaSuccess:
_, name = cudart.getErrorName(error)
raise RuntimeError(f"CUDART error: {name.decode()}")
elif isinstance(error, nvrtc.nvrtcResult):
if error != nvrtc.nvrtcResult.NVRTC_SUCCESS:
_, desc = nvrtc.nvrtcGetErrorString(error)
raise RuntimeError(f"NVRTC error: {desc.decode()}")
63 changes: 63 additions & 0 deletions kernel_tuner/utils/nvcuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Module for kernel tuner cuda-python utility functions."""

import numpy as np

try:
from cuda.bindings import driver, runtime, nvrtc
except ImportError:
cuda = None

NVRTC_VALID_CC = np.array(
[
"50",
"52",
"53",
"60",
"61",
"62",
"70",
"72",
"75",
"80",
"87",
"89",
"90",
"90a",
"100",
"100f",
"100a",
"101",
"101f",
"101a",
"103",
"103f",
"103a",
"120",
"120f",
"120a",
"121",
"121f",
"121a",
]
)


def cuda_error_check(error):
"""Checking the status of CUDA calls using the NVIDIA cuda-python backend."""
if isinstance(error, driver.CUresult):
if error != driver.CUresult.CUDA_SUCCESS:
_, name = driver.cuGetErrorName(error)
raise RuntimeError(f"CUDA Driver error: {name.decode()}")
elif isinstance(error, runtime.cudaError_t):
if error != runtime.cudaError_t.cudaSuccess:
_, name = runtime.cudaGetErrorName(error)
raise RuntimeError(f"CUDA Runtime error: {name.decode()}")
elif isinstance(error, nvrtc.nvrtcResult):
if error != nvrtc.nvrtcResult.NVRTC_SUCCESS:
_, desc = nvrtc.nvrtcGetErrorString(error)
raise RuntimeError(f"NVRTC error: {desc.decode()}")


def to_valid_nvrtc_gpu_arch_cc(compute_capability: str) -> str:
"""Returns a valid Compute Capability for NVRTC `--gpu-architecture=`, as per https://docs.nvidia.com/cuda/nvrtc/index.html#group__options."""
return max(NVRTC_VALID_CC[NVRTC_VALID_CC <= compute_capability], default="75")
Loading