diff --git a/Project.toml b/Project.toml index deccc00..8ff7ca9 100644 --- a/Project.toml +++ b/Project.toml @@ -41,7 +41,7 @@ DynamicPolynomials = "0.5, 0.6" LRUCache = "1" LinearAlgebra = "1.6" Logging = "1.6" -Mooncake = "0.4.195" +Mooncake = "0.5" PackageExtensionCompat = "1" PrecompileTools = "1.1" Preferences = "1.4" diff --git a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl index 8bff2f8..8e302ca 100644 --- a/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl +++ b/ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl @@ -60,8 +60,8 @@ function Mooncake.rrule!!( function contract_pb(::NoRData) scale!(C, C_cache, One()) dC, dA, dB, Δα, Δβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) - dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) - dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) + dα = isnothing(Δα) ? NoRData() : Δα + dβ = isnothing(Δβ) ? NoRData() : Δβ return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, contract_pb @@ -90,8 +90,8 @@ function Mooncake.rrule!!( function add_pb(::NoRData) scale!(C, C_cache, One()) dC, dA, Δα, Δβ = TensorOperations.tensoradd_pullback!(dC, dA, C, A, pA, conjA, α, β, ba...) - dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) - dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) + dα = isnothing(Δα) ? NoRData() : Δα + dβ = isnothing(Δβ) ? NoRData() : Δβ return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, add_pb @@ -122,8 +122,8 @@ function Mooncake.rrule!!( function trace_pb(::NoRData) scale!(C, C_cache, One()) dC, dA, Δα, Δβ = TensorOperations.tensortrace_pullback!(dC, dA, C, A, p, q, conjA, α, β, ba...) - dα = isnothing(Δα) ? NoRData() : Mooncake._rdata(Δα) - dβ = isnothing(Δβ) ? NoRData() : Mooncake._rdata(Δβ) + dα = isnothing(Δα) ? NoRData() : Δα + dβ = isnothing(Δβ) ? NoRData() : Δβ return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)... end return C_dC, trace_pb diff --git a/test/mooncake.jl b/test/mooncake.jl index 729a74d..1790ba4 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -12,8 +12,8 @@ is_primitive = false ( (Float64, Float64), (Float32, Float64), - #(ComplexF64, ComplexF64), - #(Float64, ComplexF64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), ) T = promote_type(T₁, T₂) atol = max(precision(T₁), precision(T₂)) @@ -37,8 +37,8 @@ end ( (Float64, Float64), (Float32, Float64), - #(ComplexF64, ComplexF64), - #(Float64, ComplexF64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), ) T = promote_type(T₁, T₂) atol = max(precision(T₁), precision(T₂)) @@ -60,8 +60,8 @@ end ( (Float64, Float64), (Float32, Float64), - #(ComplexF64, ComplexF64), - #(Float64, ComplexF64), + (ComplexF64, ComplexF64), + (Float64, ComplexF64), ) T = promote_type(T₁, T₂) atol = max(precision(T₁), precision(T₂))