From 2291420685bc461af6d3ae90e4f67d0eea5d0b0b Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Tue, 20 Jan 2026 20:07:46 +0600 Subject: [PATCH 01/10] feat: multi gpu support for matmul and reduce; --- starter_kit.cu | 313 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 309 insertions(+), 4 deletions(-) diff --git a/starter_kit.cu b/starter_kit.cu index 95d22b3..b2be6e5 100644 --- a/starter_kit.cu +++ b/starter_kit.cu @@ -10,8 +10,10 @@ #include // `std::printf` #include // `std::rand` #include // `std::memset` +#include // `std::function` #include // `std::runtime_error` #include // `std::thread::hardware_concurrency()` +#include // `std::vector` /* * Include the SIMD intrinsics for the target architecture. @@ -53,28 +55,31 @@ #include #include #include +#include #if defined(STARTER_KIT_VOLTA) #include #endif +namespace cg = cooperative_groups; #endif /* * If we are only testing the raw kernels, we don't need to link to PyBind. * That accelerates the build process and simplifies the configs. */ -#if !defined(STARTER_KIT_TEST) +#if !defined(STARTER_KIT_TEST) && !defined(NVCC_DEVICE_COMPILE) #include // `array_t` #include #include namespace py = pybind11; -#endif // !defined(STARTER_KIT_TEST) +#endif // !defined(STARTER_KIT_TEST) && !defined(NVCC_DEVICE_COMPILE) using cell_idx_t = std::uint32_t; enum class backend_t { openmp_k, cuda_k, + cuda_multigpu_k, }; /** @@ -309,12 +314,276 @@ __global__ void cuda_matmul_kernel( matrix_c[row * stride_c + col] = cell_c; } +/** + * @brief Multi-GPU reduction using cooperative groups. + * + * This function performs reduction across multiple GPUs by: + * 1. Detecting available CUDA devices + * 2. Partitioning the input array across devices + * 3. Computing partial reductions on each device + * 4. Aggregating results on the host + * + * @tparam scalar_type The data type of the array elements (e.g., float, double). + * + * @param data A pointer to the input array of elements of type `scalar_type`. + * @param length The number of elements in the input array. + * + * @return reduce_type The result of the reduction operation across all GPUs. + */ +template +reduce_type cuda_reduce_multigpu(scalar_type const* data, std::size_t length) noexcept(false) { + int num_devices = 0; + cudaError_t error = cudaGetDeviceCount(&num_devices); + if (error != cudaSuccess || num_devices == 0) + throw std::runtime_error("No CUDA devices available"); + + // If only one GPU, fall back to single GPU implementation + if (num_devices == 1) + return cuda_reduce(data, length); + + // Calculate work partition for each GPU + std::size_t chunk_size = (length + num_devices - 1) / num_devices; + std::vector> partial_results(num_devices); + std::vector streams(num_devices); + std::vector device_ptrs(num_devices); + std::vector*> result_ptrs(num_devices); + + // Launch reduction on each GPU + for (int dev = 0; dev < num_devices; ++dev) { + cudaSetDevice(dev); + cudaStreamCreate(&streams[dev]); + + std::size_t offset = dev * chunk_size; + std::size_t current_length = std::min(chunk_size, length - offset); + + if (current_length == 0) { + partial_results[dev] = 0; + continue; + } + + // Allocate device memory + cudaMalloc(&device_ptrs[dev], current_length * sizeof(scalar_type)); + cudaMalloc(&result_ptrs[dev], sizeof(reduce_type)); + + // Copy data to device asynchronously + cudaMemcpyAsync(device_ptrs[dev], data + offset, current_length * sizeof(scalar_type), cudaMemcpyHostToDevice, + streams[dev]); + + // Perform reduction using Thrust on each device + cudaStreamSynchronize(streams[dev]); + thrust::device_ptr dev_ptr = thrust::device_pointer_cast(device_ptrs[dev]); + reduce_type result = + thrust::reduce(thrust::cuda::par.on(streams[dev]), dev_ptr, dev_ptr + current_length, + reduce_type(0)); + + // Copy result back to host + partial_results[dev] = result; + } + + // Synchronize all devices + for (int dev = 0; dev < num_devices; ++dev) { + cudaSetDevice(dev); + cudaStreamSynchronize(streams[dev]); + } + + // Cleanup + for (int dev = 0; dev < num_devices; ++dev) { + cudaSetDevice(dev); + if (device_ptrs[dev]) + cudaFree(device_ptrs[dev]); + if (result_ptrs[dev]) + cudaFree(result_ptrs[dev]); + cudaStreamDestroy(streams[dev]); + } + + // Aggregate results on host + reduce_type final_result = 0; + for (int dev = 0; dev < num_devices; ++dev) + final_result += partial_results[dev]; + + return final_result; +} + +/** + * @brief Multi-GPU matrix multiplication kernel. + * + * This function performs matrix multiplication across multiple GPUs by: + * 1. Detecting available CUDA devices + * 2. Partitioning matrix A's rows across devices + * 3. Broadcasting matrix B to all devices + * 4. Computing partial results on each device + * 5. Gathering results back to host + * + * Uses peer-to-peer access when available for efficient data transfer. + * + * @tparam scalar_type The data type of the matrix elements (e.g., float, double). + * @tparam tile_size The size of the tiles used for shared memory. + * + * @param matrix_a Pointer to the input matrix A, stored in row-major order. + * @param matrix_b Pointer to the input matrix B, stored in row-major order. + * @param matrix_c Pointer to the output matrix C, stored in row-major order. + * @param num_rows_a The number of rows in matrix A. + * @param num_cols_b The number of columns in matrix B. + * @param num_cols_a The number of columns in matrix A, and the number of rows in matrix B. + * @param stride_a The stride (leading dimension) of matrix A. + * @param stride_b The stride (leading dimension) of matrix B. + * @param stride_c The stride (leading dimension) of matrix C. + */ +template +void cuda_matmul_multigpu(scalar_type const* matrix_a, scalar_type const* matrix_b, + matmul_type* matrix_c, cell_idx_t num_rows_a, cell_idx_t num_cols_b, + cell_idx_t num_cols_a, cell_idx_t stride_a, cell_idx_t stride_b, + cell_idx_t stride_c) noexcept(false) { + + int num_devices = 0; + cudaError_t error = cudaGetDeviceCount(&num_devices); + if (error != cudaSuccess || num_devices == 0) + throw std::runtime_error("No CUDA devices available"); + + // If only one GPU, we could fall back to single GPU, but let's use the multi-GPU path anyway + if (num_devices == 1) { + cudaSetDevice(0); + + // Allocate memory on device + size_t pitch_a, pitch_b; + scalar_type *dev_a = nullptr, *dev_b = nullptr; + matmul_type* dev_c = nullptr; + + cudaMallocPitch(&dev_a, &pitch_a, num_cols_a * sizeof(scalar_type), num_rows_a); + cudaMallocPitch(&dev_b, &pitch_b, num_cols_b * sizeof(scalar_type), num_cols_a); + cudaMalloc(&dev_c, num_rows_a * num_cols_b * sizeof(matmul_type)); + + cudaMemcpy2D(dev_a, pitch_a, matrix_a, stride_a * sizeof(scalar_type), num_cols_a * sizeof(scalar_type), + num_rows_a, cudaMemcpyHostToDevice); + cudaMemcpy2D(dev_b, pitch_b, matrix_b, stride_b * sizeof(scalar_type), num_cols_b * sizeof(scalar_type), + num_cols_a, cudaMemcpyHostToDevice); + + dim3 block(tile_size, tile_size); + dim3 grid((num_cols_b + tile_size - 1) / tile_size, (num_rows_a + tile_size - 1) / tile_size); + + cuda_matmul_kernel<<>>( + dev_a, dev_b, dev_c, num_rows_a, num_cols_b, num_cols_a, pitch_a / sizeof(scalar_type), + pitch_b / sizeof(scalar_type), num_cols_b); + + cudaMemcpy(matrix_c, dev_c, num_rows_a * num_cols_b * sizeof(matmul_type), + cudaMemcpyDeviceToHost); + + cudaFree(dev_a); + cudaFree(dev_b); + cudaFree(dev_c); + return; + } + + // Enable peer access between GPUs + for (int i = 0; i < num_devices; ++i) { + cudaSetDevice(i); + for (int j = 0; j < num_devices; ++j) { + if (i != j) { + int can_access = 0; + cudaDeviceCanAccessPeer(&can_access, i, j); + if (can_access) { + cudaDeviceEnablePeerAccess(j, 0); + } + } + } + } + + // Partition rows of matrix A across GPUs + cell_idx_t rows_per_device = (num_rows_a + num_devices - 1) / num_devices; + + std::vector streams(num_devices); + std::vector dev_a_ptrs(num_devices); + std::vector dev_b_ptrs(num_devices); + std::vector*> dev_c_ptrs(num_devices); + std::vector pitches_a(num_devices); + std::vector pitches_b(num_devices); + std::vector actual_rows(num_devices); + + // Allocate memory and copy data to each GPU + for (int dev = 0; dev < num_devices; ++dev) { + cudaSetDevice(dev); + cudaStreamCreate(&streams[dev]); + + cell_idx_t row_start = dev * rows_per_device; + cell_idx_t row_end = std::min(row_start + rows_per_device, num_rows_a); + actual_rows[dev] = row_end - row_start; + + if (actual_rows[dev] == 0) + continue; + + // Allocate pitched memory for matrix A partition + cudaMallocPitch(&dev_a_ptrs[dev], &pitches_a[dev], num_cols_a * sizeof(scalar_type), actual_rows[dev]); + + // Allocate pitched memory for matrix B (full matrix on each GPU) + cudaMallocPitch(&dev_b_ptrs[dev], &pitches_b[dev], num_cols_b * sizeof(scalar_type), num_cols_a); + + // Allocate memory for result matrix partition + cudaMalloc(&dev_c_ptrs[dev], actual_rows[dev] * num_cols_b * sizeof(matmul_type)); + + // Copy matrix A partition asynchronously + cudaMemcpy2DAsync(dev_a_ptrs[dev], pitches_a[dev], matrix_a + row_start * stride_a, + stride_a * sizeof(scalar_type), num_cols_a * sizeof(scalar_type), actual_rows[dev], + cudaMemcpyHostToDevice, streams[dev]); + + // Copy full matrix B asynchronously + cudaMemcpy2DAsync(dev_b_ptrs[dev], pitches_b[dev], matrix_b, stride_b * sizeof(scalar_type), + num_cols_b * sizeof(scalar_type), num_cols_a, cudaMemcpyHostToDevice, streams[dev]); + } + + // Launch kernels on each GPU + for (int dev = 0; dev < num_devices; ++dev) { + if (actual_rows[dev] == 0) + continue; + + cudaSetDevice(dev); + cudaStreamSynchronize(streams[dev]); + + dim3 block(tile_size, tile_size); + dim3 grid((num_cols_b + tile_size - 1) / tile_size, (actual_rows[dev] + tile_size - 1) / tile_size); + + cuda_matmul_kernel<<>>( + dev_a_ptrs[dev], dev_b_ptrs[dev], dev_c_ptrs[dev], actual_rows[dev], num_cols_b, num_cols_a, + pitches_a[dev] / sizeof(scalar_type), pitches_b[dev] / sizeof(scalar_type), num_cols_b); + } + + // Copy results back to host + for (int dev = 0; dev < num_devices; ++dev) { + if (actual_rows[dev] == 0) + continue; + + cudaSetDevice(dev); + cudaStreamSynchronize(streams[dev]); + + cell_idx_t row_start = dev * rows_per_device; + cudaMemcpyAsync(matrix_c + row_start * stride_c, dev_c_ptrs[dev], + actual_rows[dev] * num_cols_b * sizeof(matmul_type), cudaMemcpyDeviceToHost, + streams[dev]); + } + + // Cleanup + for (int dev = 0; dev < num_devices; ++dev) { + cudaSetDevice(dev); + cudaStreamSynchronize(streams[dev]); + + if (dev_a_ptrs[dev]) + cudaFree(dev_a_ptrs[dev]); + if (dev_b_ptrs[dev]) + cudaFree(dev_b_ptrs[dev]); + if (dev_c_ptrs[dev]) + cudaFree(dev_c_ptrs[dev]); + cudaStreamDestroy(streams[dev]); + } + + // Reset device + cudaSetDevice(0); +} + #endif // defined(__NVCC__) #pragma endregion CUDA #pragma region Python bindings -#if !defined(STARTER_KIT_TEST) +#if !defined(STARTER_KIT_TEST) && !defined(NVCC_DEVICE_COMPILE) /** * @brief Router function, that unpacks Python buffers into C++ pointers and calls the appropriate @@ -335,6 +604,12 @@ static py::object python_reduce_typed(py::buffer_info const& buf) noexcept(false result = cuda_reduce(ptr, buf.size); #else throw std::runtime_error("CUDA backend not available"); +#endif + } else if constexpr (backend_kind == backend_t::cuda_multigpu_k) { +#if defined(__NVCC__) + result = cuda_reduce_multigpu(ptr, buf.size); +#else + throw std::runtime_error("CUDA backend not available"); #endif } else { throw std::runtime_error("Unsupported backend"); @@ -534,6 +809,20 @@ static py::array python_matmul_typed(py::buffer_info const& buffer_a, py::buffer cudaFree(ptr_b_cuda); cudaFree(ptr_c_cuda); +#else + throw std::runtime_error("CUDA backend not available"); +#endif + } else if constexpr (backend_kind == backend_t::cuda_multigpu_k) { +#if defined(__NVCC__) + // Call multi-GPU matmul implementation + switch (tile_size) { + case 4: cuda_matmul_multigpu(ptr_a, ptr_b, ptr_c, num_rows_a, num_cols_b, num_cols_a, stride_a, stride_b, stride_c); break; + case 8: cuda_matmul_multigpu(ptr_a, ptr_b, ptr_c, num_rows_a, num_cols_b, num_cols_a, stride_a, stride_b, stride_c); break; + case 16: cuda_matmul_multigpu(ptr_a, ptr_b, ptr_c, num_rows_a, num_cols_b, num_cols_a, stride_a, stride_b, stride_c); break; + case 32: cuda_matmul_multigpu(ptr_a, ptr_b, ptr_c, num_rows_a, num_cols_b, num_cols_a, stride_a, stride_b, stride_c); break; + case 64: cuda_matmul_multigpu(ptr_a, ptr_b, ptr_c, num_rows_a, num_cols_b, num_cols_a, stride_a, stride_b, stride_c); break; + default: throw std::runtime_error("Unsupported tile size - choose from 4, 8, 16, 32, and 64"); + } #else throw std::runtime_error("CUDA backend not available"); #endif @@ -588,6 +877,18 @@ PYBIND11_MODULE(starter_kit, m) { #endif }); + m.def("get_cuda_device_count", []() -> int { +#if defined(__NVCC__) + int device_count = 0; + cudaError_t error = cudaGetDeviceCount(&device_count); + if (error != cudaSuccess) + return 0; + return device_count; +#else + return 0; +#endif + }); + m.def("log_cuda_devices", []() { #if defined(__NVCC__) int device_count; @@ -618,9 +919,13 @@ PYBIND11_MODULE(starter_kit, m) { m.def("reduce_cuda", &python_reduce); m.def("matmul_cuda", &python_matmul, py::arg("a"), py::arg("b"), py::kw_only(), py::arg("tile_size") = 16); + + m.def("reduce_cuda_multigpu", &python_reduce); + m.def("matmul_cuda_multigpu", &python_matmul, py::arg("a"), py::arg("b"), py::kw_only(), + py::arg("tile_size") = 16); } -#endif // !defined(STARTER_KIT_TEST) +#endif // !defined(STARTER_KIT_TEST) && !defined(NVCC_DEVICE_COMPILE) #pragma endregion Python bindings #if defined(STARTER_KIT_TEST) From e60c74929e463066122ac2127f905982675e57fd Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Tue, 20 Jan 2026 20:08:27 +0600 Subject: [PATCH 02/10] deps: fixed deps for building; - as of 2026.01.20 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7f3ad8f..63c494f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=42", "wheel", "pybind11", "numpy"] +requires = ["setuptools>=42", "wheel", "pybind11>=2.10,<2.13", "numpy"] build-backend = "setuptools.build_meta" [project] @@ -13,7 +13,7 @@ authors = [ { name = "Ash Vardanian", email = "1983160+ashvardanian@users.noreply.github.com" }, ] urls = { Homepage = "https://github.com/ashvardanian/PyBindToGPUs" } -dependencies = ["pybind11", "numpy", "numba"] +dependencies = ["pybind11>=2.10,<2.13", "numpy", "numba"] [project.optional-dependencies] cpu = ["pytest", "pytest-repeat", "pytest-benchmark"] From 8daafc4b91b618e4cd85bd27c88932808d714d0e Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Tue, 20 Jan 2026 20:09:09 +0600 Subject: [PATCH 03/10] feat: benchmark for multi gpu for matmul and reduce; --- bench.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/bench.py b/bench.py index 9203501..40d162a 100644 --- a/bench.py +++ b/bench.py @@ -15,7 +15,16 @@ import pytest from starter_kit_baseline import reduce as reduce_baseline, matmul as matmul_baseline -from starter_kit import reduce_openmp, reduce_cuda, matmul_openmp, matmul_cuda, supports_cuda +from starter_kit import ( + reduce_openmp, + reduce_cuda, + reduce_cuda_multigpu, + matmul_openmp, + matmul_cuda, + matmul_cuda_multigpu, + supports_cuda, + get_cuda_device_count, +) # Build lists of (name, kernel_function) for reduction and matrix multiplication. REDUCTION_KERNELS = [ @@ -24,6 +33,8 @@ ] if supports_cuda(): REDUCTION_KERNELS.append(("cuda", reduce_cuda)) +if get_cuda_device_count() > 1: + REDUCTION_KERNELS.append(("cuda_multigpu", reduce_cuda_multigpu)) MATMUL_KERNELS = [ ("baseline", matmul_baseline), @@ -31,6 +42,8 @@ ] if supports_cuda(): MATMUL_KERNELS.append(("cuda", matmul_cuda)) +if get_cuda_device_count() > 1: + MATMUL_KERNELS.append(("cuda_multigpu", matmul_cuda_multigpu)) @pytest.mark.parametrize("dtype", [np.float32, np.int32]) From dacc0834c2d50c5fb617da7bfa5576d55c825fb4 Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Tue, 20 Jan 2026 20:09:15 +0600 Subject: [PATCH 04/10] feat: test for multi gpu for matmul and reduce; --- test.py | 86 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 1 deletion(-) diff --git a/test.py b/test.py index 2a27e75..63690f2 100644 --- a/test.py +++ b/test.py @@ -19,9 +19,20 @@ import numpy as np from starter_kit_baseline import matmul as matmul_baseline, reduce as reduce_baseline -from starter_kit import supports_cuda, reduce_openmp, reduce_cuda, matmul_openmp, matmul_cuda +from starter_kit import ( + supports_cuda, + get_cuda_device_count, + reduce_openmp, + reduce_cuda, + reduce_cuda_multigpu, + matmul_openmp, + matmul_cuda, + matmul_cuda_multigpu, +) backends = ["openmp", "cuda"] if supports_cuda() else ["openmp"] +multigpu_backends = ["cuda_multigpu"] if get_cuda_device_count() > 1 else [] +all_backends = backends + multigpu_backends @pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int64, np.uint64]) @pytest.mark.parametrize("backend", backends) @@ -56,6 +67,41 @@ def test_reduce(dtype, backend): np.testing.assert_allclose(result, expected_result, rtol=1e-2) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int64, np.uint64]) +@pytest.mark.parametrize("size", [1024, 8192, 65536]) +@pytest.mark.parametrize("backend", multigpu_backends) +def test_reduce_multigpu(dtype, size, backend): + """ + Test the multi-GPU reduction operation for different data types and sizes. + + This test verifies that multi-GPU reduction produces the same results + as the baseline implementation for various data sizes. + + Parameters: + dtype (np.dtype): The data type for the array elements. + size (int): The size of the array to reduce. + backend (str): The backend to test ('cuda_multigpu'). + + Raises: + AssertionError: If the results differ by more than the acceptable tolerance. + """ + if not multigpu_backends: + pytest.skip("Multi-GPU not available (requires 2+ GPUs)") + + # Generate random data + data = (np.random.rand(size) * 100).astype(dtype) + + # Get the expected result from the baseline implementation + expected_result = reduce_baseline(data) + + # Get the result from the multi-GPU implementation + if backend == "cuda_multigpu": + result = reduce_cuda_multigpu(data) + + # Compare the results + np.testing.assert_allclose(result, expected_result, rtol=1e-2) + + @pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int64, np.uint64]) @pytest.mark.parametrize("tile_size", [4, 8, 16, 32]) @pytest.mark.parametrize("backend", backends) @@ -93,5 +139,43 @@ def test_matmul(dtype, tile_size, backend): np.testing.assert_allclose(result, expected_result, rtol=1e-2) +@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.int64, np.uint64]) +@pytest.mark.parametrize("size", [128, 256, 512]) +@pytest.mark.parametrize("tile_size", [8, 16, 32]) +@pytest.mark.parametrize("backend", multigpu_backends) +def test_matmul_multigpu(dtype, size, tile_size, backend): + """ + Test the multi-GPU matrix multiplication operation for different sizes and tile sizes. + + This test verifies that multi-GPU matrix multiplication produces the same results + as the baseline implementation for various matrix sizes. + + Parameters: + dtype (np.dtype): The data type for the matrix elements. + size (int): The dimension of the square matrices. + tile_size (int): The tile size to be used for the multiplication kernel. + backend (str): The backend to test ('cuda_multigpu'). + + Raises: + AssertionError: If the output matrices differ by more than the acceptable tolerance. + """ + if not multigpu_backends: + pytest.skip("Multi-GPU not available (requires 2+ GPUs)") + + # Generate random matrices + a = (np.random.rand(size, size) * 100).astype(dtype) + b = (np.random.rand(size, size) * 100).astype(dtype) + + # Get the expected result from the baseline implementation + expected_result = matmul_baseline(a, b) + + # Get the result from the multi-GPU implementation + if backend == "cuda_multigpu": + result = matmul_cuda_multigpu(a, b, tile_size=tile_size) + + # Compare the results + np.testing.assert_allclose(result, expected_result, rtol=1e-2) + + if __name__ == "__main__": pytest.main() From a0a33a2236033349f8e80b1fbafc0ad3a8269f1a Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Tue, 20 Jan 2026 20:23:00 +0600 Subject: [PATCH 05/10] setup: updated for dual t4 gpu; --- setup.py | 56 ++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 14 deletions(-) diff --git a/setup.py b/setup.py index 821a07a..a29913e 100644 --- a/setup.py +++ b/setup.py @@ -29,29 +29,44 @@ def build_extensions(self): super().build_extension(ext) def build_cuda_extension(self, ext): - # Compile CUDA source files + # Step 1: Compile CUDA kernels with NVCC (device code only) for source in ext.sources: if source.endswith(".cu"): self.compile_cuda(source) - # Compile non-CUDA source files - objects = [] + # Step 2: Compile host code (including PyBind11 bindings) with GCC + # Treat .cu file as C++ for host compilation + host_objects = [] for source in ext.sources: - if not source.endswith(".cu"): + if source.endswith(".cu"): obj = self.compiler.compile( [source], output_dir=self.build_temp, + extra_preargs=["-x", "c++"], extra_postargs=[ "-fPIC", "-std=c++17", "-fdiagnostics-color=always", + "-D__CUDACC__", # Tell the code that CUDA is available ], ) - objects.extend(obj) + host_objects.extend(obj) + else: + obj = self.compiler.compile( + [source], + output_dir=self.build_temp, + extra_postargs=[ + "-fPIC", + "-std=c++17", + "-fdiagnostics-color=always", + ], + ) + host_objects.extend(obj) - # Link all object files + # Link all object files (host + device) + all_objects = host_objects + [os.path.join(self.build_temp, "starter_kit.o")] self.compiler.link_shared_object( - objects + [os.path.join(self.build_temp, "starter_kit.o")], + all_objects, self.get_ext_fullpath(ext.name), libraries=ext.libraries, library_dirs=ext.library_dirs, @@ -103,16 +118,26 @@ def build_gcc_extension(self, ext): ) def compile_cuda(self, source): - # Compile CUDA source file using NVCC + # Compile CUDA device code only using NVCC ext = self.extensions[0] output_dir = self.build_temp os.makedirs(output_dir, exist_ok=True) - include_dirs = self.compiler.include_dirs + ext.include_dirs - include_dirs = " ".join(f"-I{dir}" for dir in include_dirs) + + # Only include CUDA-related headers for device compilation + cuda_include_dirs = [ + "/usr/local/cuda/include/", + "/usr/include/cuda/", + "cccl/cub/", + "cccl/libcudacxx/include", + "cccl/thrust/", + ] + # Filter to only existing directories + cuda_include_dirs = [d for d in cuda_include_dirs if os.path.exists(d)] + cuda_include_dirs_str = " ".join(f"-I{dir}" for dir in cuda_include_dirs) output_file = os.path.join(output_dir, "starter_kit.o") # Let's try inferring the compute capability from the GPU - arch_code = "90" + arch_code = "75" # Default to Turing (T4 GPU) try: import pycuda.driver as cuda import pycuda.autoinit @@ -120,13 +145,16 @@ def compile_cuda(self, source): device = cuda.Device(0) # Get the default device major, minor = device.compute_capability() arch_code = f"{major}{minor}" - except ImportError: + except (ImportError, Exception): pass + # Compile device code only with nvcc - add define to skip host-only code cmd = ( - f"nvcc -c {source} -o {output_file} -std=c++17 " + f"nvcc -dc {source} -o {output_file} -std=c++17 " f"-gencode=arch=compute_{arch_code},code=sm_{arch_code} " - f"-Xcompiler -fPIC {include_dirs} -O3 -g" + f"--expt-relaxed-constexpr --expt-extended-lambda " + f"-D__CUDACC_RELAXED_CONSTEXPR__ -DNVCC_DEVICE_COMPILE " + f"-Xcompiler -fPIC,-Wno-psabi {cuda_include_dirs_str} -O3 -g" ) if os.system(cmd) != 0: raise RuntimeError(f"nvcc compilation of {source} failed") From 90ce8e854596062642ac5a818de24cf59c1d7e5c Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Tue, 20 Jan 2026 20:52:50 +0600 Subject: [PATCH 06/10] debug: added more logging for subprocess installation; --- setup.py | 40 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index a29913e..6056682 100644 --- a/setup.py +++ b/setup.py @@ -119,6 +119,7 @@ def build_gcc_extension(self, ext): def compile_cuda(self, source): # Compile CUDA device code only using NVCC + import subprocess ext = self.extensions[0] output_dir = self.build_temp os.makedirs(output_dir, exist_ok=True) @@ -133,6 +134,14 @@ def compile_cuda(self, source): ] # Filter to only existing directories cuda_include_dirs = [d for d in cuda_include_dirs if os.path.exists(d)] + print(f"\n{'='*70}") + print(f"CUDA Include Directories Found:") + for d in cuda_include_dirs: + print(f" - {d}") + if not cuda_include_dirs: + print(" ⚠ WARNING: No CUDA include directories found!") + print(f"{'='*70}\n") + cuda_include_dirs_str = " ".join(f"-I{dir}" for dir in cuda_include_dirs) output_file = os.path.join(output_dir, "starter_kit.o") @@ -145,8 +154,9 @@ def compile_cuda(self, source): device = cuda.Device(0) # Get the default device major, minor = device.compute_capability() arch_code = f"{major}{minor}" - except (ImportError, Exception): - pass + print(f"Detected GPU Compute Capability: {arch_code}") + except (ImportError, Exception) as e: + print(f"Could not detect GPU, using default arch {arch_code}: {e}") # Compile device code only with nvcc - add define to skip host-only code cmd = ( @@ -156,8 +166,30 @@ def compile_cuda(self, source): f"-D__CUDACC_RELAXED_CONSTEXPR__ -DNVCC_DEVICE_COMPILE " f"-Xcompiler -fPIC,-Wno-psabi {cuda_include_dirs_str} -O3 -g" ) - if os.system(cmd) != 0: - raise RuntimeError(f"nvcc compilation of {source} failed") + + print(f"\n{'='*70}") + print(f"NVCC Command:") + print(f"{cmd}") + print(f"{'='*70}\n") + + # Use subprocess to capture output + result = subprocess.run(cmd, shell=True, capture_output=True, text=True) + + if result.returncode != 0: + print(f"\n{'='*70}") + print(f"NVCC COMPILATION FAILED!") + print(f"{'='*70}") + print(f"STDOUT:\n{result.stdout}") + print(f"{'='*70}") + print(f"STDERR:\n{result.stderr}") + print(f"{'='*70}\n") + raise RuntimeError(f"nvcc compilation of {source} failed with exit code {result.returncode}") + else: + print(f"- NVCC compilation successful") + if result.stdout: + print(f"STDOUT: {result.stdout}") + if result.stderr: + print(f"STDERR: {result.stderr}") __version__ = open("VERSION", "r").read().strip() From ce59af7ce3f5563509931e17981f7b6cd715b11b Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Tue, 20 Jan 2026 21:01:28 +0600 Subject: [PATCH 07/10] fix: include directories in compiler options for BuildExt --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 6056682..6b79e52 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ def build_cuda_extension(self, ext): obj = self.compiler.compile( [source], output_dir=self.build_temp, + include_dirs=ext.include_dirs, extra_preargs=["-x", "c++"], extra_postargs=[ "-fPIC", @@ -55,6 +56,7 @@ def build_cuda_extension(self, ext): obj = self.compiler.compile( [source], output_dir=self.build_temp, + include_dirs=ext.include_dirs, extra_postargs=[ "-fPIC", "-std=c++17", From 96be14577be583393cdfacb8088a9f81b1e20a95 Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Tue, 20 Jan 2026 21:07:56 +0600 Subject: [PATCH 08/10] fix: compile both device and host code with NVCC in BuildExt; --- setup.py | 41 ++++++++++++++--------------------------- 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/setup.py b/setup.py index 6b79e52..0b78055 100644 --- a/setup.py +++ b/setup.py @@ -29,30 +29,17 @@ def build_extensions(self): super().build_extension(ext) def build_cuda_extension(self, ext): - # Step 1: Compile CUDA kernels with NVCC (device code only) + # Compile everything with NVCC (both device and host code) + cuda_objects = [] + other_objects = [] + for source in ext.sources: if source.endswith(".cu"): + # Use NVCC to compile everything (device + host code) self.compile_cuda(source) - - # Step 2: Compile host code (including PyBind11 bindings) with GCC - # Treat .cu file as C++ for host compilation - host_objects = [] - for source in ext.sources: - if source.endswith(".cu"): - obj = self.compiler.compile( - [source], - output_dir=self.build_temp, - include_dirs=ext.include_dirs, - extra_preargs=["-x", "c++"], - extra_postargs=[ - "-fPIC", - "-std=c++17", - "-fdiagnostics-color=always", - "-D__CUDACC__", # Tell the code that CUDA is available - ], - ) - host_objects.extend(obj) + cuda_objects.append(os.path.join(self.build_temp, "starter_kit.o")) else: + # Compile non-CUDA files with GCC obj = self.compiler.compile( [source], output_dir=self.build_temp, @@ -63,10 +50,10 @@ def build_cuda_extension(self, ext): "-fdiagnostics-color=always", ], ) - host_objects.extend(obj) + other_objects.extend(obj) - # Link all object files (host + device) - all_objects = host_objects + [os.path.join(self.build_temp, "starter_kit.o")] + # Link all object files + all_objects = cuda_objects + other_objects self.compiler.link_shared_object( all_objects, self.get_ext_fullpath(ext.name), @@ -160,12 +147,12 @@ def compile_cuda(self, source): except (ImportError, Exception) as e: print(f"Could not detect GPU, using default arch {arch_code}: {e}") - # Compile device code only with nvcc - add define to skip host-only code + # Compile both device and host code with nvcc (no -dc flag) cmd = ( - f"nvcc -dc {source} -o {output_file} -std=c++17 " + f"nvcc -c {source} -o {output_file} -std=c++17 " f"-gencode=arch=compute_{arch_code},code=sm_{arch_code} " f"--expt-relaxed-constexpr --expt-extended-lambda " - f"-D__CUDACC_RELAXED_CONSTEXPR__ -DNVCC_DEVICE_COMPILE " + f"-D__CUDACC_RELAXED_CONSTEXPR__ " f"-Xcompiler -fPIC,-Wno-psabi {cuda_include_dirs_str} -O3 -g" ) @@ -227,7 +214,7 @@ def compile_cuda(self, source): ], # libraries=[python_lib_name.replace(".a", "")] - + (["cudart", "cuda", "cublas"] if enable_cuda else []) + + (["cudart", "cublas"] if enable_cuda else []) + (["gomp"] if enable_openmp else []), # extra_link_args=[f"-Wl,-rpath,{python_lib_dir}"] From 388228eadb383d4474246f1915f8763953991b43 Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Tue, 20 Jan 2026 21:14:12 +0600 Subject: [PATCH 09/10] fix: include all relevant directories for NVCC compilation in BuildExt; --- setup.py | 21 +++++++-------------- 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/setup.py b/setup.py index 0b78055..d808a36 100644 --- a/setup.py +++ b/setup.py @@ -113,25 +113,18 @@ def compile_cuda(self, source): output_dir = self.build_temp os.makedirs(output_dir, exist_ok=True) - # Only include CUDA-related headers for device compilation - cuda_include_dirs = [ - "/usr/local/cuda/include/", - "/usr/include/cuda/", - "cccl/cub/", - "cccl/libcudacxx/include", - "cccl/thrust/", - ] + # Include all directories: CUDA headers, PyBind11, NumPy, Python, CCCL # Filter to only existing directories - cuda_include_dirs = [d for d in cuda_include_dirs if os.path.exists(d)] + include_dirs = [d for d in ext.include_dirs if os.path.exists(d)] print(f"\n{'='*70}") - print(f"CUDA Include Directories Found:") - for d in cuda_include_dirs: + print(f"Include Directories for NVCC:") + for d in include_dirs: print(f" - {d}") - if not cuda_include_dirs: - print(" ⚠ WARNING: No CUDA include directories found!") + if not include_dirs: + print(" * WARNING: No include directories found!") print(f"{'='*70}\n") - cuda_include_dirs_str = " ".join(f"-I{dir}" for dir in cuda_include_dirs) + cuda_include_dirs_str = " ".join(f"-I{dir}" for dir in include_dirs) output_file = os.path.join(output_dir, "starter_kit.o") # Let's try inferring the compute capability from the GPU From e25199985d16b0fdc16943ff63605612f65cbb4d Mon Sep 17 00:00:00 2001 From: Maifee Ul Asad Date: Tue, 20 Jan 2026 21:20:57 +0600 Subject: [PATCH 10/10] docs: update for multi gpu kernel; --- README.md | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5cfc741..f8442f7 100644 --- a/README.md +++ b/README.md @@ -11,11 +11,13 @@ This project provides a pre-configured environment for such workflows...: 3. including [CCCL](https://github.com/NVIDIA/cccl) libraries, like Thrust, and CUB, to simplify the code. As an example, the repository implements, tests, and benchmarks only 2 operations - array accumulation and matrix multiplication. -The baseline Python + Numba implementations are placed in `starter_kit_baseline.py`, and the optimized CUDA nd OpenMP implementations are placed in `starter_kit.cu`. +The baseline Python + Numba implementations are placed in `starter_kit_baseline.py`, and the optimized CUDA and OpenMP implementations are placed in `starter_kit.cu`. If no CUDA-capable device is found, the file will be treated as a CPU-only C++ implementation. If VSCode is used, the `tasks.json` file is configured with debuggers for both CPU and GPU code, both in Python and C++. The `.clang-format` is configured with LLVM base style, adjusted for wider screens, allowing 120 characters per line. +**Multi-GPU Support**: The repository now includes multi-GPU implementations for both reduction and matrix multiplication operations, utilizing CUDA cooperative groups and efficient device partitioning strategies. + ## Installation I'd recommend forking the repository for your own projects, but you can also clone it directly: @@ -51,6 +53,35 @@ The project is designed to be as simple as possible, with the following workflow 2. Implement your baseline algorithm in `starter_kit_baseline.py`. 3. Implement your optimized algorithm in `starter_kit.cu`. +## Multi-GPU Features + +The starter kit now includes multi-GPU implementations: + +- **Multi-GPU Reduction**: Partitions data across available GPUs, performs parallel reductions, and aggregates results +- **Multi-GPU Matrix Multiplication**: Distributes matrix rows across GPUs using row-wise partitioning with peer-to-peer access when available +- **Automatic Detection**: Falls back to single-GPU or CPU when multiple GPUs are not available +- **Cooperative Groups**: Uses CUDA cooperative groups for efficient inter-block synchronization +- **Tested & Benchmarked**: Comprehensive test suite and performance benchmarks included + +Usage: +```python +import numpy as np +from starter_kit import reduce_cuda_multigpu, matmul_cuda_multigpu, get_cuda_device_count + +# Check available GPUs +num_gpus = get_cuda_device_count() +print(f"Available GPUs: {num_gpus}") + +# Multi-GPU reduction +data = np.random.rand(1_000_000).astype(np.float32) +result = reduce_cuda_multigpu(data) + +# Multi-GPU matrix multiplication +a = np.random.rand(1024, 1024).astype(np.float32) +b = np.random.rand(1024, 1024).astype(np.float32) +c = matmul_cuda_multigpu(a, b, tile_size=16) +``` + ## Reading Materials Beginner GPGPU: