From b80ff89bbcb4875bb21c27f0b0271341a4545342 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 26 Jan 2026 14:24:31 +0100 Subject: [PATCH 1/5] Try enabling the complex tests for Mooncake --- test/mooncake.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/mooncake.jl b/test/mooncake.jl index 729a74d..1790ba4 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -12,8 +12,8 @@ is_primitive = false ( (Float64, Float64), (Float32, Float64), - #(ComplexF64, ComplexF64), - #(Float64, ComplexF64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), ) T = promote_type(T₁, T₂) atol = max(precision(T₁), precision(T₂)) @@ -37,8 +37,8 @@ end ( (Float64, Float64), (Float32, Float64), - #(ComplexF64, ComplexF64), - #(Float64, ComplexF64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), ) T = promote_type(T₁, T₂) atol = max(precision(T₁), precision(T₂)) @@ -60,8 +60,8 @@ end ( (Float64, Float64), (Float32, Float64), - #(ComplexF64, ComplexF64), - #(Float64, ComplexF64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), ) T = promote_type(T₁, T₂) atol = max(precision(T₁), precision(T₂)) From 95132cb87a878c18de22d01323df2366e0ac94dc Mon Sep 17 00:00:00 2001 From: Katharine Hyatt Date: Mon, 26 Jan 2026 15:09:22 +0100 Subject: [PATCH 2/5] StridedNative and StridedBlas don't work with complex tangents --- test/mooncake.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/test/mooncake.jl b/test/mooncake.jl index 1790ba4..e64753b 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -28,8 +28,10 @@ is_primitive = false Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β; atol, rtol, mode, is_primitive) Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β; atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β, StridedBLAS(); atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β, StridedNative(); atol, rtol, mode, is_primitive) + if T <: Real + Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β, StridedBLAS(); atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β, StridedNative(); atol, rtol, mode, is_primitive) + end # tangents don't work nicely here end end @@ -51,8 +53,10 @@ end Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β; atol, rtol, mode, is_primitive) Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β; atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β, StridedBLAS(); atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β, StridedNative(); atol, rtol, mode, is_primitive) + if T <: Real + Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β, StridedBLAS(); atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β, StridedNative(); atol, rtol, mode, is_primitive) + end # tangents don't work nicely here end end @@ -81,16 +85,18 @@ end Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, false, B, pB, true, pAB, α, β; atol, rtol, mode, is_primitive) Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, true, B, pB, true, pAB, α, β; atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule( - rng, - tensorcontract!, C, A, pA, false, B, pB, false, pAB, α, β, StridedBLAS(); - atol, rtol, mode, is_primitive - ) - Mooncake.TestUtils.test_rule( - rng, - tensorcontract!, C, A, pA, true, B, pB, false, pAB, α, β, StridedNative(); - atol, rtol, mode, is_primitive - ) + if T <: Real + Mooncake.TestUtils.test_rule( + rng, + tensorcontract!, C, A, pA, false, B, pB, false, pAB, α, β, StridedBLAS(); + atol, rtol, mode, is_primitive + ) + Mooncake.TestUtils.test_rule( + rng, + tensorcontract!, C, A, pA, true, B, pB, false, pAB, α, β, StridedNative(); + atol, rtol, mode, is_primitive + ) + end # tangents don't work nicely here end end From 5ea75c38632c3d18545326f5b0634507493cb36d Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 26 Jan 2026 16:37:44 -0500 Subject: [PATCH 3/5] bump minimal Mooncake version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index deccc00..8ff7ca9 100644 --- a/Project.toml +++ b/Project.toml @@ -41,7 +41,7 @@ DynamicPolynomials = "0.5, 0.6" LRUCache = "1" LinearAlgebra = "1.6" Logging = "1.6" -Mooncake = "0.4.195" +Mooncake = "0.5" PackageExtensionCompat = "1" PrecompileTools = "1.1" Preferences = "1.4" From 740743bb224e7de7cb15a6e9e65d87475393b007 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 26 Jan 2026 17:04:17 -0500 Subject: [PATCH 4/5] remove use of `Mooncake._rdata` --- .../TensorOperationsMooncakeExt.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl index 8bff2f8..8e302ca 100644 --- a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl +++ b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl @@ -60,8 +60,8 @@ function Mooncake.rrule!!( function contract_pb(::NoRData) scale!(C, C_cache, One()) dC, dA, dB, Δα, Δβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) - dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) - dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) + dα = isnothing(Δα) ? NoRData() : Δα + dβ = isnothing(Δβ) ? NoRData() : Δβ return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, contract_pb @@ -90,8 +90,8 @@ function Mooncake.rrule!!( function add_pb(::NoRData) scale!(C, C_cache, One()) dC, dA, Δα, Δβ = TensorOperations.tensoradd_pullback!(dC, dA, C, A, pA, conjA, α, β, ba...) - dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) - dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) + dα = isnothing(Δα) ? NoRData() : Δα + dβ = isnothing(Δβ) ? NoRData() : Δβ return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, add_pb @@ -122,8 +122,8 @@ function Mooncake.rrule!!( function trace_pb(::NoRData) scale!(C, C_cache, One()) dC, dA, Δα, Δβ = TensorOperations.tensortrace_pullback!(dC, dA, C, A, p, q, conjA, α, β, ba...) - dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) - dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) + dα = isnothing(Δα) ? NoRData() : Δα + dβ = isnothing(Δβ) ? NoRData() : Δβ return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, trace_pb From a9c60021bd25c8ac651d02ff0a93a088a55ad52b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 26 Jan 2026 17:04:46 -0500 Subject: [PATCH 5/5] enable tests --- test/mooncake.jl | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/test/mooncake.jl b/test/mooncake.jl index e64753b..1790ba4 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -28,10 +28,8 @@ is_primitive = false Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β; atol, rtol, mode, is_primitive) Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β; atol, rtol, mode, is_primitive) - if T <: Real - Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β, StridedBLAS(); atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β, StridedNative(); atol, rtol, mode, is_primitive) - end # tangents don't work nicely here + Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, true, α, β, StridedBLAS(); atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensortrace!, C, A, p, q, false, α, β, StridedNative(); atol, rtol, mode, is_primitive) end end @@ -53,10 +51,8 @@ end Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β; atol, rtol, mode, is_primitive) Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β; atol, rtol, mode, is_primitive) - if T <: Real - Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β, StridedBLAS(); atol, rtol, mode, is_primitive) - Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β, StridedNative(); atol, rtol, mode, is_primitive) - end # tangents don't work nicely here + Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, false, α, β, StridedBLAS(); atol, rtol, mode, is_primitive) + Mooncake.TestUtils.test_rule(rng, tensoradd!, C, A, pA, true, α, β, StridedNative(); atol, rtol, mode, is_primitive) end end @@ -85,18 +81,16 @@ end Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, false, B, pB, true, pAB, α, β; atol, rtol, mode, is_primitive) Mooncake.TestUtils.test_rule(rng, tensorcontract!, C, A, pA, true, B, pB, true, pAB, α, β; atol, rtol, mode, is_primitive) - if T <: Real - Mooncake.TestUtils.test_rule( - rng, - tensorcontract!, C, A, pA, false, B, pB, false, pAB, α, β, StridedBLAS(); - atol, rtol, mode, is_primitive - ) - Mooncake.TestUtils.test_rule( - rng, - tensorcontract!, C, A, pA, true, B, pB, false, pAB, α, β, StridedNative(); - atol, rtol, mode, is_primitive - ) - end # tangents don't work nicely here + Mooncake.TestUtils.test_rule( + rng, + tensorcontract!, C, A, pA, false, B, pB, false, pAB, α, β, StridedBLAS(); + atol, rtol, mode, is_primitive + ) + Mooncake.TestUtils.test_rule( + rng, + tensorcontract!, C, A, pA, true, B, pB, false, pAB, α, β, StridedNative(); + atol, rtol, mode, is_primitive + ) end end