From c13ceca883b8e450cddb7374a7969228cd89ac6c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 8 Jan 2026 01:23:13 -0500 Subject: [PATCH 1/6] Try to do SVD truncation on GPU with _ind_intersect --- .../MatrixAlgebraKitAMDGPUExt.jl | 11 ++++++++--- .../MatrixAlgebraKitCUDAExt.jl | 11 ++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index befb4e0b..595312eb 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -167,9 +167,14 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix return C end -# TODO: intersect on GPU arrays is not working -MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::AbstractVector) = MatrixAlgebraKit._ind_intersect(collect(A), B) -MatrixAlgebraKit._ind_intersect(A::AbstractVector, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(A, collect(B)) +MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B) = MatrixAlgebraKit._ind_intersect(B, A) +function MatrixAlgebraKit._ind_intersect(A::UnitRange, B::ROCVector{Int}) + sortedB = sort(B) + firstB = findfirst(≥(first(A)), B) + lastB = findlast(≤(last(A)), B) + # ONLY works if the indices in B are contiguous!!! + return B[firstB:lastB] +end MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index 432f176a..a0225a04 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -191,9 +191,14 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T return C end -# TODO: intersect on GPU arrays is not working -MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::AbstractVector) = MatrixAlgebraKit._ind_intersect(collect(A), B) -MatrixAlgebraKit._ind_intersect(A::AbstractVector, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(A, collect(B)) +MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B) = MatrixAlgebraKit._ind_intersect(B, A) +function MatrixAlgebraKit._ind_intersect(A::UnitRange, B::CuVector{Int}) + sortedB = sort(B) + firstB = findfirst(≥(first(A)), B) + lastB = findlast(≤(last(A)), B) + # ONLY works if the indices in B are contiguous!!! + return B[firstB:lastB] +end MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) end From fa9fa80c9ef1ee42bb5ad5c945fa176af3d1617e Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 8 Jan 2026 06:32:36 -0500 Subject: [PATCH 2/6] Use fixer GPUArrays branch --- Project.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index cc694527..0a932246 100644 --- a/Project.toml +++ b/Project.toml @@ -30,6 +30,7 @@ ChainRulesTestUtils = "1" CUDA = "5" GenericLinearAlgebra = "0.3.19" GenericSchur = "0.5.6" +GPUArrays = "11.3.3" JET = "0.9, 0.10" LinearAlgebra = "1" Mooncake = "0.4.183" @@ -46,6 +47,7 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -56,5 +58,8 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "GPUArrays", "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"] + +[sources] +GPUArrays = {url="https://github.com/JuliaGPU/GPUArrays.jl", rev="ksh/fix12"} From aa8e07dbe5042ed33e6d5eb9402b41a7ac3561b7 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 8 Jan 2026 12:16:13 +0100 Subject: [PATCH 3/6] ind_intersect via filter --- .../MatrixAlgebraKitAMDGPUExt.jl | 12 +++--------- .../MatrixAlgebraKitCUDAExt.jl | 12 +++--------- src/implementations/truncation.jl | 8 ++++++++ 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 595312eb..0ca43183 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -167,14 +167,8 @@ function MatrixAlgebraKit._mul_herm!(C::StridedROCMatrix{T}, A::StridedROCMatrix return C end -MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B) = MatrixAlgebraKit._ind_intersect(B, A) -function MatrixAlgebraKit._ind_intersect(A::UnitRange, B::ROCVector{Int}) - sortedB = sort(B) - firstB = findfirst(≥(first(A)), B) - lastB = findlast(≤(last(A)), B) - # ONLY works if the indices in B are contiguous!!! - return B[firstB:lastB] -end -MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) +# TODO: intersect doesn't work on GPU +MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) = + MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) end diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index a0225a04..8bb09db1 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -191,14 +191,8 @@ function MatrixAlgebraKit._mul_herm!(C::StridedCuMatrix{T}, A::StridedCuMatrix{T return C end -MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B) = MatrixAlgebraKit._ind_intersect(B, A) -function MatrixAlgebraKit._ind_intersect(A::UnitRange, B::CuVector{Int}) - sortedB = sort(B) - firstB = findfirst(≥(first(A)), B) - lastB = findlast(≤(last(A)), B) - # ONLY works if the indices in B are contiguous!!! - return B[firstB:lastB] -end -MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) +# TODO: intersect doesn't work on GPU +MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = + MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) end diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index be730bed..d18e0cba 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -126,6 +126,14 @@ function _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector) end _ind_intersect(A::AbstractVector, B::AbstractVector{Bool}) = _ind_intersect(B, A) _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B + +# when one of the ind selections is a unitrange, filter is more efficient than intersect +# since we know both selections only contain unique entries +# (This is also more GPU-friendly!) +_ind_intersect(A::AbstractVector{Int}, B::AbstractUnitRange{Int}) = filter(in(B), A) +_ind_intersect(A::AbstractUnitRange{Int}, B::AbstractVector{Int}) = _ind_intersect(B, A) + +# when all else fails, call intersect _ind_intersect(A, B) = intersect(A, B) # Truncation error From a138765b4ebcbfdbe9cf73533d713e4f1042251c Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 8 Jan 2026 14:19:41 +0100 Subject: [PATCH 4/6] Fix ambiguity --- src/implementations/truncation.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index d18e0cba..768f7674 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -130,6 +130,7 @@ _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B # when one of the ind selections is a unitrange, filter is more efficient than intersect # since we know both selections only contain unique entries # (This is also more GPU-friendly!) +_ind_intersect(A::AbstractUnitRange{Int}, B::AbstractUnitRange{Int}) = filter(in(B), A) _ind_intersect(A::AbstractVector{Int}, B::AbstractUnitRange{Int}) = filter(in(B), A) _ind_intersect(A::AbstractUnitRange{Int}, B::AbstractVector{Int}) = _ind_intersect(B, A) From 076e943ae67bf4eb54fb7de48aa1a6ac62b04ab9 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Thu, 8 Jan 2026 14:36:37 +0100 Subject: [PATCH 5/6] Update src/implementations/truncation.jl Co-authored-by: Lukas Devos --- src/implementations/truncation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/implementations/truncation.jl b/src/implementations/truncation.jl index 768f7674..945e0772 100644 --- a/src/implementations/truncation.jl +++ b/src/implementations/truncation.jl @@ -130,7 +130,7 @@ _ind_intersect(A::AbstractVector{Bool}, B::AbstractVector{Bool}) = A .& B # when one of the ind selections is a unitrange, filter is more efficient than intersect # since we know both selections only contain unique entries # (This is also more GPU-friendly!) -_ind_intersect(A::AbstractUnitRange{Int}, B::AbstractUnitRange{Int}) = filter(in(B), A) +_ind_intersect(A::AbstractUnitRange{Int}, B::AbstractUnitRange{Int}) = intersect(A, B) _ind_intersect(A::AbstractVector{Int}, B::AbstractUnitRange{Int}) = filter(in(B), A) _ind_intersect(A::AbstractUnitRange{Int}, B::AbstractVector{Int}) = _ind_intersect(B, A) From e561986eecd7204817345c10a95663381532a68e Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 8 Jan 2026 17:00:24 +0100 Subject: [PATCH 6/6] Revert "Use fixer GPUArrays branch" This reverts commit fa9fa80c9ef1ee42bb5ad5c945fa176af3d1617e. --- Project.toml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 0a932246..cc694527 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,6 @@ ChainRulesTestUtils = "1" CUDA = "5" GenericLinearAlgebra = "0.3.19" GenericSchur = "0.5.6" -GPUArrays = "11.3.3" JET = "0.9, 0.10" LinearAlgebra = "1" Mooncake = "0.4.183" @@ -47,7 +46,6 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -58,8 +56,5 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "GPUArrays", +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "Random", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "GenericLinearAlgebra", "GenericSchur", "Mooncake"] - -[sources] -GPUArrays = {url="https://github.com/JuliaGPU/GPUArrays.jl", rev="ksh/fix12"}