[oneMKL] Interface variants of trsm! and trmm!#479
[oneMKL] Interface variants of trsm! and trmm!#479amontoison wants to merge 1 commit intoJuliaGPU:masterfrom
Conversation
|
@kballeda It seems that the new variants of T = Float32
alpha = rand(T)
beta = rand(T)
@testset "trmm!" begin
A = triu(rand(T, m, m))
B = rand(T, m, n)
dA = oneArray(A)
dB = oneArray(B)
# Test without beta
C = alpha * A * B
oneMKL.trmm!('L', 'U', 'N', 'N', alpha, dA, dB)
# Move to host and compare
h_C = Array(dB)
@test C ≈ h_C
# Test with beta
C = rand(T, m, n)
dC = oneArray(C)
oneMKL.trmm!('L', 'U', 'N', 'N', alpha, beta, dA, dB, dC) # <-- fail
h_C = Array(dC)
D = alpha * A * B + beta * C
@test D ≈ h_C
end
@testset "left trsm!" begin
A = triu(rand(T, m, m))
B = rand(T, m, n)
dA = oneArray(A)
dB = oneArray(B)
# Test without beta
C = alpha * (A \ B)
dC = copy(dB)
oneMKL.trsm!('L', 'U', 'N', 'N', alpha, dA, dC)
@test C ≈ Array(dC)
# Test with beta
C = rand(T, m, n)
dC = oneArray(C)
oneMKL.trsm!('L', 'U', 'N', 'N', alpha, beta, dA, dB, dC) # <-- fail
h_C = Array(dC)
D = alpha * (A \ B) + beta * C
@test D ≈ h_C
end
@testset "right trsm!" begin
A = rand(T, m, m)
B = triu(rand(T, m, m))
dA = oneArray(A)
dB = oneArray(B)
# Test without beta
C = alpha * (A / B)
dC = copy(dA)
oneMKL.trsm!('R', 'U', 'N', 'N', alpha, dB, dC)
@test C ≈ Array(dC)
# Test with beta
C = rand(T, m, m)
dC = oneArray(C)
oneMKL.trsm!('R', 'U', 'N', 'N', alpha, beta, dA, dB, dC) # <-- fail
h_C = Array(dC)
D = alpha * (A / B) + beta * C
@test D ≈ h_C
end |
Thanks for reporting! Let me check this at my end with C reproducer. |
No, the build over at JuliaPackaging/Yggdrasil#9552 fails probably because of bugs in the Intel libraries. I was going to wait for a new version of the base toolkit before investigating. I see 2025.0.0 has been released now, so we should probably try again. |
|
I will check how to update the wrappers for oneMKL. [1/4] Building CXX object CMakeFiles/oneapi_support.dir/src/sycl.cpp.o
FAILED: CMakeFiles/oneapi_support.dir/src/sycl.cpp.o
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/bin/icpx -Doneapi_support_EXPORTS -fsycl -isystem /home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/include -isystem /home/alexis/.julia/artifacts/4acaedf5204fc60d0f11bb5d32020fa91c5b3d10/include -std=gnu++17 -fPIC -MD -MT CMakeFiles/oneapi_support.dir/src/sycl.cpp.o -MF CMakeFiles/oneapi_support.dir/src/sycl.cpp.o.d -o CMakeFiles/oneapi_support.dir/src/sycl.cpp.o -c /home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:9:72: error: unknown type name 'pi_native_handle'; did you mean 'ur_native_handle_t'?
9 | auto sycl_platform = sycl::ext::oneapi::level_zero::make_platform((pi_native_handle) driver);
| ^~~~~~~~~~~~~~~~
| ur_native_handle_t
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/bin/compiler/../../include/sycl/ur_api.h:400:19: note: 'ur_native_handle_t' declared here
400 | typedef uintptr_t ur_native_handle_t;
| ^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:9:57: error: no member named 'make_platform' in namespace 'sycl::ext::oneapi::level_zero'
9 | auto sycl_platform = sycl::ext::oneapi::level_zero::make_platform((pi_native_handle) driver);
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:22:68: error: unknown type name 'pi_native_handle'; did you mean 'ur_native_handle_t'?
22 | sycl::ext::oneapi::level_zero::make_device(platform->val, (pi_native_handle) device);
| ^~~~~~~~~~~~~~~~
| ur_native_handle_t
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/bin/compiler/../../include/sycl/ur_api.h:400:19: note: 'ur_native_handle_t' declared here
400 | typedef uintptr_t ur_native_handle_t;
| ^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:22:9: error: no member named 'make_device' in namespace 'sycl::ext::oneapi::level_zero'; did you mean 'sycl::ext::oneapi::level_zero::detail::make_device'?
22 | sycl::ext::oneapi::level_zero::make_device(platform->val, (pi_native_handle) device);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
| sycl::ext::oneapi::level_zero::detail::make_device
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/include/sycl/ext/oneapi/backend/level_zero.hpp:44:22: note: 'sycl::ext::oneapi::level_zero::detail::make_device' declared here
44 | __SYCL_EXPORT device make_device(const platform &Platform,
| ^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:40:68: error: unknown type name 'pi_native_handle'; did you mean 'ur_native_handle_t'?
40 | sycl::ext::oneapi::level_zero::make_context(sycl_devices, (pi_native_handle) context, keep_ownership);
| ^~~~~~~~~~~~~~~~
| ur_native_handle_t
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/bin/compiler/../../include/sycl/ur_api.h:400:19: note: 'ur_native_handle_t' declared here
400 | typedef uintptr_t ur_native_handle_t;
| ^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:40:40: error: no member named 'make_context' in namespace 'sycl::ext::oneapi::level_zero'
40 | sycl::ext::oneapi::level_zero::make_context(sycl_devices, (pi_native_handle) context, keep_ownership);
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:54:93: error: unknown type name 'pi_native_handle'; did you mean 'ur_native_handle_t'?
54 | auto sycl_queue = sycl::ext::oneapi::level_zero::make_queue(context->val, device->val, (pi_native_handle) queue, false, keep_ownership, {});
| ^~~~~~~~~~~~~~~~
| ur_native_handle_t
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/bin/compiler/../../include/sycl/ur_api.h:400:19: note: 'ur_native_handle_t' declared here
400 | typedef uintptr_t ur_native_handle_t;
| ^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:54:54: error: no member named 'make_queue' in namespace 'sycl::ext::oneapi::level_zero'
54 | auto sycl_queue = sycl::ext::oneapi::level_zero::make_queue(context->val, device->val, (pi_native_handle) queue, false, keep_ownership, {});
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:66:79: error: unknown type name 'pi_native_handle'; did you mean 'ur_native_handle_t'?
66 | auto sycl_event = sycl::ext::oneapi::level_zero::make_event(context->val, (pi_native_handle) event, keep_ownership);
| ^~~~~~~~~~~~~~~~
| ur_native_handle_t
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/bin/compiler/../../include/sycl/ur_api.h:400:19: note: 'ur_native_handle_t' declared here
400 | typedef uintptr_t ur_native_handle_t;
| ^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:66:53: error: no member named 'make_event' in namespace 'sycl::ext::oneapi::level_zero'
66 | auto sycl_event = sycl::ext::oneapi::level_zero::make_event(context->val, (pi_native_handle) event, keep_ownership);
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
10 errors generated.
[2/4] Building CXX object CMakeFiles/oneapi_support.dir/src/onemkl.cpp.o
ninja: build stopped: subcommand failed.
ERROR: LoadError: failed process: Process(`/home/alexis/.julia/artifacts/7e62c00e1f15f21da3a56196bac84e23e6d629c3/bin/ninja -C /tmp/jl_NRO6kE install`, ProcessExited(1)) [1] |
dce5866 to
08e7dfa
Compare
|
Doesn't seem fixed on v2025.0.0 |
1a8482f to
653050a
Compare
653050a to
211e6f1
Compare
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/mkl/linalg.jl b/lib/mkl/linalg.jl
index db16da6..22f2ba0 100644
--- a/lib/mkl/linalg.jl
+++ b/lib/mkl/linalg.jl
@@ -5,7 +5,7 @@ using LinearAlgebra: Transpose, Adjoint, AdjOrTrans,
Hermitian, Symmetric,
LowerTriangular, UnitLowerTriangular,
UpperTriangular, UnitUpperTriangular,
- UpperOrLowerTriangular, MulAddMul, wrap
+ UpperOrLowerTriangular, MulAddMul, wrap
#
# BLAS 1
@@ -163,13 +163,13 @@ function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStr
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
end
-const AdjOrTransOroneMatrix{T} = Union{oneStridedMatrix{T}, AdjOrTrans{<:T,<:oneStridedMatrix}}
+const AdjOrTransOroneMatrix{T} = Union{oneStridedMatrix{T}, AdjOrTrans{<:T, <:oneStridedMatrix}}
function LinearAlgebra.generic_trimatmul!(
- C::oneStridedMatrix{T}, uplocA, isunitcA,
- tfunA::Function, A::oneStridedMatrix{T},
- triB::UpperOrLowerTriangular{T, <: AdjOrTransOroneMatrix{T}},
-) where {T<:onemklFloat}
+ C::oneStridedMatrix{T}, uplocA, isunitcA,
+ tfunA::Function, A::oneStridedMatrix{T},
+ triB::UpperOrLowerTriangular{T, <:AdjOrTransOroneMatrix{T}},
+ ) where {T <: onemklFloat}
uplocB = LinearAlgebra.uplo_char(triB)
isunitcB = LinearAlgebra.isunit_char(triB)
B = parent(triB)
@@ -206,7 +206,7 @@ LinearAlgebra.generic_trimatmul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::F
trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
LinearAlgebra.generic_mattrimul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C)
-LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::AbstractMatrix{T}) where {T<:onemklFloat} =
+LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::AbstractMatrix{T}) where {T <: onemklFloat} =
trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
-LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
+LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::oneStridedMatrix{T}) where {T <: onemklFloat} =
trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C)
diff --git a/lib/mkl/wrappers_blas.jl b/lib/mkl/wrappers_blas.jl
index e01ffd2..a986e0b 100644
--- a/lib/mkl/wrappers_blas.jl
+++ b/lib/mkl/wrappers_blas.jl
@@ -1140,73 +1140,91 @@ function trsm(side::Char,
end
for (mmname_variant, smname_variant, elty) in
- ((:onemklDtrmm_variant, :onemklDtrsm_variant, :Float64),
- (:onemklStrmm_variant, :onemklStrsm_variant, :Float32),
- (:onemklZtrmm_variant, :onemklZtrsm_variant, :ComplexF64),
- (:onemklCtrmm_variant, :onemklCtrsm_variant, :ComplexF32))
+ (
+ (:onemklDtrmm_variant, :onemklDtrsm_variant, :Float64),
+ (:onemklStrmm_variant, :onemklStrsm_variant, :Float32),
+ (:onemklZtrmm_variant, :onemklZtrsm_variant, :ComplexF64),
+ (:onemklCtrmm_variant, :onemklCtrsm_variant, :ComplexF32),
+ )
@eval begin
- function trmm!(side::Char,
- uplo::Char,
- transa::Char,
- diag::Char,
- alpha::Number,
- beta::Number,
- A::oneStridedMatrix{$elty},
- B::oneStridedMatrix{$elty},
- C::oneStridedMatrix{$elty})
+ function trmm!(
+ side::Char,
+ uplo::Char,
+ transa::Char,
+ diag::Char,
+ alpha::Number,
+ beta::Number,
+ A::oneStridedMatrix{$elty},
+ B::oneStridedMatrix{$elty},
+ C::oneStridedMatrix{$elty}
+ )
m, n = size(B)
mA, nA = size(A)
- if mA != nA throw(DimensionMismatch("A must be square")) end
- if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trmm!")) end
- lda = max(1,stride(A,2))
- ldb = max(1,stride(B,2))
- ldc = max(1,stride(C,2))
+ if mA != nA
+ throw(DimensionMismatch("A must be square"))
+ end
+ if nA != (side == 'L' ? m : n)
+ throw(DimensionMismatch("trmm!"))
+ end
+ lda = max(1, stride(A, 2))
+ ldb = max(1, stride(B, 2))
+ ldc = max(1, stride(C, 2))
queue = global_queue(context(A), device())
$mmname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, C, ldc)
- B
+ return B
end
- function trsm!(side::Char,
- uplo::Char,
- transa::Char,
- diag::Char,
- alpha::Number,
- beta::Number,
- A::oneStridedMatrix{$elty},
- B::oneStridedMatrix{$elty},
- C::oneStridedMatrix{$elty})
+ function trsm!(
+ side::Char,
+ uplo::Char,
+ transa::Char,
+ diag::Char,
+ alpha::Number,
+ beta::Number,
+ A::oneStridedMatrix{$elty},
+ B::oneStridedMatrix{$elty},
+ C::oneStridedMatrix{$elty}
+ )
m, n = size(B)
mA, nA = size(A)
- if mA != nA throw(DimensionMismatch("A must be square")) end
- if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trsm!")) end
- lda = max(1,stride(A,2))
- ldb = max(1,stride(B,2))
- ldc = max(1,stride(C,2))
+ if mA != nA
+ throw(DimensionMismatch("A must be square"))
+ end
+ if nA != (side == 'L' ? m : n)
+ throw(DimensionMismatch("trsm!"))
+ end
+ lda = max(1, stride(A, 2))
+ ldb = max(1, stride(B, 2))
+ ldc = max(1, stride(C, 2))
queue = global_queue(context(A), device())
$smname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, C, ldc)
- B
+ return B
end
end
end
-function trmm!(side::Char,
- uplo::Char,
- transa::Char,
- diag::Char,
- alpha::Number,
- A::oneStridedMatrix{T},
- B::oneStridedMatrix{T},
- C::oneStridedMatrix{T}) where T
- trmm!(side, uplo, transa, diag, alpha, zero(T), A, B, C)
-end
-function trsm!(side::Char,
- uplo::Char,
- transa::Char,
- diag::Char,
- alpha::Number,
- A::oneStridedMatrix{T},
- B::oneStridedMatrix{T},
- C::oneStridedMatrix{T}) where T
- trsm!(side, uplo, transa, diag, alpha, zero(T), A, B, C)
+function trmm!(
+ side::Char,
+ uplo::Char,
+ transa::Char,
+ diag::Char,
+ alpha::Number,
+ A::oneStridedMatrix{T},
+ B::oneStridedMatrix{T},
+ C::oneStridedMatrix{T}
+ ) where {T}
+ return trmm!(side, uplo, transa, diag, alpha, zero(T), A, B, C)
+end
+function trsm!(
+ side::Char,
+ uplo::Char,
+ transa::Char,
+ diag::Char,
+ alpha::Number,
+ A::oneStridedMatrix{T},
+ B::oneStridedMatrix{T},
+ C::oneStridedMatrix{T}
+ ) where {T}
+ return trsm!(side, uplo, transa, diag, alpha, zero(T), A, B, C)
end
## hemm
diff --git a/test/onemkl.jl b/test/onemkl.jl
index bbafaed..33f1ba7 100644
--- a/test/onemkl.jl
+++ b/test/onemkl.jl
@@ -662,13 +662,13 @@ end
h_C = Array(dB)
@test C ≈ h_C
- C = rand(T,m,n)
- dC = oneArray(C)
- beta = zero(T) # rand(T)
- oneMKL.trmm!('L','U','N','N',alpha,beta,dA,dB,dC)
- h_C = Array(dC)
- D = alpha*A*B + beta*C
- @test D ≈ h_C
+ C = rand(T, m, n)
+ dC = oneArray(C)
+ beta = zero(T) # rand(T)
+ oneMKL.trmm!('L', 'U', 'N', 'N', alpha, beta, dA, dB, dC)
+ h_C = Array(dC)
+ D = alpha * A * B + beta * C
+ @test D ≈ h_C
end
@testset "trmm" begin
@@ -693,13 +693,13 @@ end
oneMKL.trsm!('L','U','N','N',alpha,dA,dC)
@test C ≈ Array(dC)
- C = rand(T,m,n)
- dC = oneArray(C)
- beta = rand(T)
- oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC)
- h_C = Array(dC)
- D = alpha*(A\B) + beta*C
- @test D ≈ h_C
+ C = rand(T, m, n)
+ dC = oneArray(C)
+ beta = rand(T)
+ oneMKL.trsm!('L', 'U', 'N', 'N', alpha, beta, dA, dB, dC)
+ h_C = Array(dC)
+ D = alpha * (A \ B) + beta * C
+ @test D ≈ h_C
end
@testset "left trsm" begin
@@ -742,13 +742,13 @@ end
oneMKL.trsm!('R','U','N','N',alpha,dB,dC)
@test C ≈ Array(dC)
- C = rand(T,m,m)
- dC = oneArray(C)
- beta = rand(T)
- oneMKL.trsm!('R','U','N','N',alpha,beta,dA,dB,dC)
- h_C = Array(dC)
- D = alpha*(A/B) + beta*C
- @test D ≈ h_C
+ C = rand(T, m, m)
+ dC = oneArray(C)
+ beta = rand(T)
+ oneMKL.trsm!('R', 'U', 'N', 'N', alpha, beta, dA, dB, dC)
+ h_C = Array(dC)
+ D = alpha * (A / B) + beta * C
+ @test D ≈ h_C
end
@testset "right trsm" begin |
|
CI failures still seem related. |
Interface variants of
trsm!andtrmm!with additional arguments.