diff --git a/Project.toml b/Project.toml index ff890157a..934bdb6ed 100644 --- a/Project.toml +++ b/Project.toml @@ -23,12 +23,14 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] TensorKitAdaptExt = "Adapt" TensorKitCUDAExt = ["CUDA", "cuTENSOR"] TensorKitChainRulesCoreExt = "ChainRulesCore" TensorKitFiniteDifferencesExt = "FiniteDifferences" +TensorKitMooncakeExt = "Mooncake" [compat] Adapt = "4" @@ -43,6 +45,7 @@ GPUArrays = "11.3.1" LRUCache = "1.0.2" LinearAlgebra = "1" MatrixAlgebraKit = "0.6.2" +Mooncake = "0.4.183" OhMyThreads = "0.8.0" Printf = "1" Random = "1" @@ -70,6 +73,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -78,4 +82,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1" [targets] -test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"] +test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"] diff --git a/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl new file mode 100644 index 000000000..b35c73f4c --- /dev/null +++ b/ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl @@ -0,0 +1,17 @@ +module TensorKitMooncakeExt + +using Mooncake +using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoRData, CoDual, arrayify, primal +using TensorKit +using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize +import TensorOperations as TO +using VectorInterface: One, Zero +using TupleTools + + +include("utility.jl") +include("tangent.jl") +include("linalg.jl") +include("tensoroperations.jl") + +end diff --git a/ext/TensorKitMooncakeExt/linalg.jl b/ext/TensorKitMooncakeExt/linalg.jl new file mode 100644 index 000000000..56533d227 --- /dev/null +++ b/ext/TensorKitMooncakeExt/linalg.jl @@ -0,0 +1,14 @@ +Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(norm), AbstractTensorMap, Real} + +function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorMap}, pdp::CoDual{<:Real}) + t, Δt = arrayify(tΔt) + p = primal(pdp) + p == 2 || error("currently only implemented for p = 2") + n = norm(t, p) + function norm_pullback(Δn) + x = (Δn' + Δn) / 2 / hypot(n, eps(one(n))) + add!(Δt, t, x) + return NoRData(), NoRData(), NoRData() + end + return CoDual(n, Mooncake.NoFData()), norm_pullback +end diff --git a/ext/TensorKitMooncakeExt/tangent.jl b/ext/TensorKitMooncakeExt/tangent.jl new file mode 100644 index 000000000..761e626f0 --- /dev/null +++ b/ext/TensorKitMooncakeExt/tangent.jl @@ -0,0 +1,7 @@ +function Mooncake.arrayify(A_dA::CoDual{<:TensorMap}) + A = Mooncake.primal(A_dA) + dA_fw = Mooncake.tangent(A_dA) + data = dA_fw.data.data + dA = typeof(A)(data, A.space) + return A, dA +end diff --git a/ext/TensorKitMooncakeExt/tensoroperations.jl b/ext/TensorKitMooncakeExt/tensoroperations.jl new file mode 100644 index 000000000..d663a3281 --- /dev/null +++ b/ext/TensorKitMooncakeExt/tensoroperations.jl @@ -0,0 +1,137 @@ +Mooncake.@is_primitive( + DefaultCtx, + ReverseMode, + Tuple{ + typeof(TO.tensorcontract!), + AbstractTensorMap, + AbstractTensorMap, Index2Tuple, Bool, + AbstractTensorMap, Index2Tuple, Bool, + Index2Tuple, + Number, Number, + Vararg{Any}, + } +) + +function Mooncake.rrule!!( + ::CoDual{typeof(TO.tensorcontract!)}, + C_ΔC::CoDual{<:AbstractTensorMap}, + A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool}, + B_ΔB::CoDual{<:AbstractTensorMap}, pB_ΔpB::CoDual{<:Index2Tuple}, conjB_ΔconjB::CoDual{Bool}, + pAB_ΔpAB::CoDual{<:Index2Tuple}, + α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number}, + ba_Δba::CoDual..., + ) + # prepare arguments + (C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB)) + pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB)) + conjA, conjB = primal.((conjA_ΔconjA, conjB_ΔconjB)) + α, β = primal.((α_Δα, β_Δβ)) + ba = primal.(ba_Δba) + + # primal call + C_cache = copy(C) + TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...) + + function tensorcontract_pullback(::NoRData) + copy!(C, C_cache) + + ΔCr = tensorcontract_pullback_ΔC!(ΔC, β) + ΔAr = tensorcontract_pullback_ΔA!( + ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ) + ΔBr = tensorcontract_pullback_ΔB!( + ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ) + Δαr = tensorcontract_pullback_Δα( + ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ) + Δβr = tensorcontract_pullback_Δβ(ΔC, C, β) + + return NoRData(), ΔCr, + ΔAr, NoRData(), NoRData(), + ΔBr, NoRData(), NoRData(), + NoRData(), + Δαr, Δβr, + map(ba_ -> NoRData(), ba)... + end + + return C_ΔC, tensorcontract_pullback +end + +tensorcontract_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData()) + +function tensorcontract_pullback_ΔA!( + ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ) + ipAB = invperm(linearize(pAB)) + pΔC = _repartition(ipAB, TO.numout(pA)) + ipA = _repartition(invperm(linearize(pA)), A) + conjΔC = conjA + conjB′ = conjA ? conjB : !conjB + + tB = twist( + B, + TupleTools.vcat( + filter(x -> !isdual(space(B, x)), pB[1]), + filter(x -> isdual(space(B, x)), pB[2]) + ); copy = false + ) + + TO.tensorcontract!( + ΔA, + ΔC, pΔC, conjΔC, + tB, reverse(pB), conjB′, + ipA, + conjA ? α : conj(α), Zero(), + ba... + ) + + return NoRData() +end + +function tensorcontract_pullback_ΔB!( + ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ) + ipAB = invperm(linearize(pAB)) + pΔC = _repartition(ipAB, TO.numout(pA)) + ipB = _repartition(invperm(linearize(pB)), B) + conjΔC = conjB + conjA′ = conjB ? conjA : !conjA + + tA = twist( + A, + TupleTools.vcat( + filter(x -> isdual(space(A, x)), pA[1]), + filter(x -> !isdual(space(A, x)), pA[2]) + ); copy = false + ) + + TO.tensorcontract!( + ΔB, + tA, reverse(pA), conjA′, + ΔC, pΔC, conjΔC, + ipB, + conjB ? α : conj(α), Zero(), ba... + ) + + return NoRData() +end + +function tensorcontract_pullback_Δα( + ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba... + ) + Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α))) + Tdα === NoRData && return NoRData() + + AB = TO.tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...) + Δα = inner(AB, ΔC) + return Mooncake._rdata(Δα) +end + +function tensorcontract_pullback_Δβ(ΔC, C, β) + Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β))) + Tdβ === NoRData && return NoRData() + + Δβ = inner(C, ΔC) + return Mooncake._rdata(Δβ) +end diff --git a/ext/TensorKitMooncakeExt/utility.jl b/ext/TensorKitMooncakeExt/utility.jl new file mode 100644 index 000000000..ca2c79b54 --- /dev/null +++ b/ext/TensorKitMooncakeExt/utility.jl @@ -0,0 +1,28 @@ +_needs_tangent(x) = _needs_tangent(typeof(x)) +_needs_tangent(::Type{<:Number}) = true +_needs_tangent(::Type{<:Integer}) = false +_needs_tangent(::Type{<:Union{One, Zero}}) = false + +# IndexTuple utility +# ------------------ +trivtuple(N) = ntuple(identity, N) + +Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return TupleTools.getindices(p, trivtuple(N₁)), + TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁) +end +Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int) + return _repartition(linearize(p), N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap) + return _repartition(p, TensorKit.numout(t)) +end + +# Ignore derivatives +# ------------------ +@zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any} diff --git a/test/autodiff/ad.jl b/test/autodiff/chainrules.jl similarity index 100% rename from test/autodiff/ad.jl rename to test/autodiff/chainrules.jl diff --git a/test/autodiff/mooncake.jl b/test/autodiff/mooncake.jl new file mode 100644 index 000000000..1cd74fa27 --- /dev/null +++ b/test/autodiff/mooncake.jl @@ -0,0 +1,117 @@ +using Test, TestExtras +using TensorKit +using TensorOperations +using Mooncake +using Random + +mode = Mooncake.ReverseMode +rng = Random.default_rng() +is_primitive = false + +function randindextuple(N::Int, k::Int = rand(0:N)) + @assert 0 ≤ k ≤ N + _p = randperm(N) + return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...)) +end + +const _repartition = @static if isdefined(Base, :get_extension) + Base.get_extension(TensorKit, :TensorKitMooncakeExt)._repartition +else + TensorKit.TensorKitMooncakeExt._repartition +end + +spacelist = ( + (ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), + ( + Vect[Z2Irrep](0 => 1, 1 => 1), + Vect[Z2Irrep](0 => 1, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 2)', + Vect[Z2Irrep](0 => 2, 1 => 3), + Vect[Z2Irrep](0 => 2, 1 => 2), + ), + ( + Vect[FermionParity](0 => 1, 1 => 1), + Vect[FermionParity](0 => 1, 1 => 2)', + Vect[FermionParity](0 => 2, 1 => 1)', + Vect[FermionParity](0 => 2, 1 => 3), + Vect[FermionParity](0 => 2, 1 => 2), + ), + ( + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1), + Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)', + Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2), + Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)', + ), + ( + Vect[SU2Irrep](0 => 2, 1 // 2 => 1), + Vect[SU2Irrep](0 => 1, 1 => 1), + Vect[SU2Irrep](1 // 2 => 1, 1 => 1)', + Vect[SU2Irrep](1 // 2 => 2), + Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)', + ), + ( + Vect[FibonacciAnyon](:I => 2, :τ => 1), + Vect[FibonacciAnyon](:I => 1, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 2)', + Vect[FibonacciAnyon](:I => 2, :τ => 3), + Vect[FibonacciAnyon](:I => 2, :τ => 2), + ), +) + +for V in spacelist + I = sectortype(eltype(V)) + Istr = TensorKit.type_repr(I) + + symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding + println("---------------------------------------") + println("Mooncake with symmetry: $Istr") + println("---------------------------------------") + eltypes = (Float64,) # no complex support yet + symmetricbraiding && @timedtestset "TensorOperations with scalartype $T" for T in eltypes + atol = precision(T) + rtol = precision(T) + + @timedtestset "tensorcontract!" begin + for _ in 1:5 + d = 0 + local V1, V2, V3 + # retry a couple times to make sure there are at least some nonzero elements + for _ in 1:10 + k1 = rand(0:3) + k2 = rand(0:2) + k3 = rand(0:2) + V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1])) + V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1])) + V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1])) + d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3)) + d > 0 && break + end + ipA = randindextuple(length(V1) + length(V2)) + pA = _repartition(invperm(linearize(ipA)), length(V1)) + ipB = randindextuple(length(V2) + length(V3)) + pB = _repartition(invperm(linearize(ipB)), length(V2)) + pAB = randindextuple(length(V1) + length(V3)) + + α = randn(T) + β = randn(T) + V2_conj = prod(conj, V2; init = one(V[1])) + + for conjA in (false, true), conjB in (false, true) + A = randn(T, permute(V1 ← (conjA ? V2_conj : V2), ipA)) + B = randn(T, permute((conjB ? V2_conj : V2) ← V3, ipB)) + C = randn!( + TensorOperations.tensoralloc_contract( + T, A, pA, conjA, B, pB, conjB, pAB, Val(false) + ) + ) + Mooncake.TestUtils.test_rule( + rng, tensorcontract!, C, A, pA, conjA, B, pB, conjB, pAB, α, β; + atol, rtol, mode, is_primitive + ) + + end + end + end + end +end