Skip to content

Commit

Permalink
Merge 23f7d49 into fdee031
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing authored Sep 19, 2024
2 parents fdee031 + 23f7d49 commit 6b997fc
Show file tree
Hide file tree
Showing 29 changed files with 1,576 additions and 1,637 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KeywordDispatch = "5888135b-5456-5c80-a1b6-c91ef8180460"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
Expand All @@ -25,7 +27,6 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7"
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Expand Down
10 changes: 8 additions & 2 deletions ext/TenetAdaptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ Adapt.adapt_structure(to, x::Tensor) = Tensor(adapt(to, parent(x)), inds(x))
Adapt.adapt_structure(to, x::TensorNetwork) = TensorNetwork(adapt.(Ref(to), tensors(x)))

Adapt.adapt_structure(to, x::Quantum) = Quantum(adapt(to, TensorNetwork(x)), x.sites)
Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Quantum(x)))
Adapt.adapt_structure(to, x::Chain) = Chain(adapt(to, Quantum(x)), boundary(x))
Adapt.adapt_structure(to, x::Ansatz) = Ansatz(adapt(to, Quantum(x)), lattice(x))

Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Ansatz(x)))
Adapt.adapt_structure(to, x::Dense) = Dense(adapt(to, Ansatz(x)))
Adapt.adapt_structure(to, x::MPS) = MPS(adapt(to, Ansatz(x)), form(x))
Adapt.adapt_structure(to, x::MPO) = MPO(adapt(to, Ansatz(x)), form(x))
Adapt.adapt_structure(to, x::PEPS) = PEPS(adapt(to, Ansatz(x)), form(x))
Adapt.adapt_structure(to, x::PEPO) = PEPO(adapt(to, Ansatz(x)), form(x))

end
39 changes: 25 additions & 14 deletions ext/TenetChainRulesCoreExt/frules.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,39 @@
using Tenet: AbstractTensorNetwork, AbstractQuantum

# `Tensor` constructor
ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds) = T(data, inds), T(Δ, inds)

# `TensorNetwork` constructor
ChainRulesCore.frule((_, Δ), ::Type{TensorNetwork}, tensors) = TensorNetwork(tensors), TensorNetworkTangent(Δ)

# `Quantum` constructor
function ChainRulesCore.frule((_, ẋ, _), ::Type{Quantum}, x::TensorNetwork, sites)
return Quantum(x, sites), Tangent{Quantum}(; tn=ẋ, sites=NoTangent())
end

# `Ansatz` constructor
function ChainRulesCore.frule((_, ẋ), ::Type{Ansatz}, x::Quantum, lattice)
return Ansatz(x, lattice), Tangent{Ansatz}(; tn=ẋ, lattice=NoTangent())
end

# `AbstractAnsatz`-subtype constructors
ChainRulesCore.frule((_, ẋ), ::Type{Product}, x::Ansatz) = Product(x), Tangent{Product}(; tn=ẋ)
ChainRulesCore.frule((_, ẋ), ::Type{Dense}, x::Ansatz) = Dense(x, form), Tangent{Dense}(; tn=ẋ)
ChainRulesCore.frule((_, ẋ), ::Type{MPS}, x::Ansatz, form) = MPS(x, form), Tangent{MPS}(; tn=ẋ, lattice=NoTangent())
ChainRulesCore.frule((_, ẋ), ::Type{MPO}, x::Ansatz, form) = MPO(x, form), Tangent{MPO}(; tn=ẋ, lattice=NoTangent())
function ChainRulesCore.frule((_, ẋ), ::Type{PEPS}, x::Ansatz, form)
return PEPS(x, form), Tangent{PEPS}(; tn=ẋ, lattice=NoTangent())
end

# `Base.conj` methods
ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::Tensor) = conj(tn), conj(Δ)

ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::TensorNetwork) = conj(tn), conj(Δ)
ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::AbstractTensorNetwork) = conj(tn), conj(Δ)

# `Base.merge` methods
ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(Base.merge), a::TensorNetwork, b::TensorNetwork) = merge(a, b), merge(ȧ, ḃ)
function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(Base.merge), a::AbstractTensorNetwork, b::AbstractTensorNetwork)
return merge(a, b), merge(ȧ, ḃ)
end

# `contract` methods
function ChainRulesCore.frule((_, ẋ), ::typeof(contract), x::Tensor; kwargs...)
Expand All @@ -22,15 +45,3 @@ function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(contract), a::Tensor, b::T
= contract(ȧ, b; kwargs...) + contract(a, ḃ; kwargs...)
return c, ċ
end

function ChainRulesCore.frule((_, ẋ, _), ::Type{Quantum}, x::TensorNetwork, sites)
y = Quantum(x, sites)
= Tangent{Quantum}(; tn=ẋ)
return y, ẏ
end

ChainRulesCore.frule((_, ẋ), ::Type{T}, x::Quantum) where {T<:Ansatz} = T(x), Tangent{T}(; super=ẋ)

function ChainRulesCore.frule((_, ẋ, _), ::Type{T}, x::Quantum, boundary) where {T<:Ansatz}
return T(x, boundary), Tangent{T}(; super=ẋ, boundary=NoTangent())
end
7 changes: 2 additions & 5 deletions ext/TenetChainRulesCoreExt/projectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,5 @@ end
ChainRulesCore.ProjectTo(x::Quantum) = ProjectTo{Quantum}(; tn=ProjectTo(TensorNetwork(x)), sites=x.sites)
(projector::ProjectTo{Quantum})(Δ) = Quantum(projector.tn(Δ), projector.sites)

ChainRulesCore.ProjectTo(x::T) where {T<:Ansatz} = ProjectTo{T}(; super=ProjectTo(Quantum(x)))
(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Ansatz} = T(projector.super.super), Δ.boundary)

# NOTE edge case: `Product` has no `boundary`. should it?
(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Product} = T(projector.super.super))
ChainRulesCore.ProjectTo(x::Ansatz) = ProjectTo{Ansatz}(; tn=ProjectTo(Quantum(x)), lattice=x.lattice)
(projector::ProjectTo{Ansatz})(Δ) = Ansatz(projector.tn(Δ), Δ.lattice)
59 changes: 32 additions & 27 deletions ext/TenetChainRulesCoreExt/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,38 @@ TensorNetwork_pullback(Δ::TensorNetworkTangent) = (NoTangent(), tensors(Δ))
TensorNetwork_pullback::AbstractThunk) = TensorNetwork_pullback(unthunk(Δ))
ChainRulesCore.rrule(::Type{TensorNetwork}, tensors) = TensorNetwork(tensors), TensorNetwork_pullback

# `Quantum` constructor
Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
Quantum_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ, NoTangent())
Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback

# `Ansatz` constructor
Ansatz_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Ansatz}, x::Quantum, lattice) = Ansatz(x, lattice), Ansatz_pullback

# `AbstractAnsatz`-subtype constructors
Product_pullback(ȳ) = (NoTangent(), ȳ.tn)
Product_pullback(ȳ::AbstractThunk) = Product_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Product}, x::Ansatz) = Product(x), Product_pullback

Dense_pullback(ȳ) = (NoTangent(), ȳ.tn)
Dense_pullback(ȳ::AbstractThunk) = Dense_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Dense}, x::Ansatz) = Dense(x), Dense_pullback

MPS_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
MPS_pullback(ȳ::AbstractThunk) = MPS_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{MPS}, x::Ansatz, form) = MPS(x, form), MPS_pullback

MPO_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
MPO_pullback(ȳ::AbstractThunk) = MPO_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{MPO}, x::Ansatz, form) = MPO(x, form), MPO_pullback

PEPS_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
PEPS_pullback(ȳ::AbstractThunk) = PEPS_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{PEPS}, x::Ansatz, form) = PEPS(x, form), PEPS_pullback

# `Base.conj` methods
conj_pullback::Tensor) = (NoTangent(), conj(Δ))
conj_pullback::Tangent{Tensor}) = (NoTangent(), conj(Δ))
Expand Down Expand Up @@ -93,33 +125,6 @@ function ChainRulesCore.rrule(::typeof(contract), a::Tensor, b::Tensor; kwargs..
return c, contract_pullback
end

Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
Quantum_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ, NoTangent())
Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback

Ansatz_pullback(ȳ) = (NoTangent(), ȳ.super)
Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ))
function ChainRulesCore.rrule(::Type{T}, x::Quantum) where {T<:Ansatz}
y = T(x)
return y, Ansatz_pullback
end

Ansatz_boundary_pullback(ȳ) = (NoTangent(), ȳ.super, NoTangent())
Ansatz_boundary_pullback(ȳ::AbstractThunk) = Ansatz_boundary_pullback(unthunk(ȳ))
function ChainRulesCore.rrule(::Type{T}, x::Quantum, boundary) where {T<:Ansatz}
return T(x, boundary), Ansatz_boundary_pullback
end

Ansatz_from_arrays_pullback(ȳ) = (NoTangent(), NoTangent(), NoTangent(), parent.(tensors(ȳ.super.tn)))
Ansatz_from_arrays_pullback(ȳ::AbstractThunk) = Ansatz_from_arrays_pullback(unthunk(ȳ))
function ChainRulesCore.rrule(
::Type{T}, socket::Tenet.Socket, boundary::Tenet.Boundary, arrays; kwargs...
) where {T<:Ansatz}
y = T(socket, boundary, arrays; kwargs...)
return y, Ansatz_from_arrays_pullback
end

copy_pullback(ȳ) = (NoTangent(), ȳ)
copy_pullback(ȳ::AbstractThunk) = unthunk(ȳ)
function ChainRulesCore.rrule(::typeof(copy), x::Quantum)
Expand Down
31 changes: 23 additions & 8 deletions ext/TenetChainRulesTestUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,42 @@ using Tenet
using ChainRulesCore
using ChainRulesTestUtils
using Random
using Graphs
using MetaGraphsNext

const TensorNetworkTangent = Base.get_extension(Tenet, :TenetChainRulesCoreExt).TensorNetworkTangent

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::Vector{T}) where {T<:Tensor}
if isempty(x)
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Vector{T}) where {T<:Tensor}
if isempty(tn)
return Vector{T}()
else
@invoke rand_tangent(rng::AbstractRNG, x::AbstractArray)
@invoke rand_tangent(rng::AbstractRNG, tn::AbstractArray)
end
end

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::TensorNetwork)
return TensorNetworkTangent(Tensor[ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)])
end

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::Quantum)
return Tangent{Quantum}(; tn=rand_tangent(rng, TensorNetwork(x)), sites=NoTangent())
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Quantum)
return Tangent{Quantum}(; tn=rand_tangent(rng, TensorNetwork(tn)), sites=NoTangent())
end

# WARN type-piracy
# NOTE used in `Quantum` constructor
ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::Dict{<:Site,Symbol}) = NoTangent()
# WARN type-piracy, used in `Quantum` constructor
ChainRulesTestUtils.rand_tangent(::AbstractRNG, tn::Dict{<:Site,Symbol}) = NoTangent()

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Ansatz)
return Tangent{Ansatz}(; tn=rand_tangent(rng, Quantum(tn)), lattice=NoTangent())
end

# WARN not really type-piracy but almost, used in `Ansatz` constructor
ChainRulesTestUtils.rand_tangent(::AbstractRNG, tn::T) where {V,T<:MetaGraph{V,SimpleGraph{V},<:Site}} = NoTangent()

# WARN not really type-piracy but almost, used when testing `Ansatz`
function ChainRulesTestUtils.test_approx(
actual::G, expected::G, msg; kwargs...
) where {G<:MetaGraph{Int64,SimpleGraph{Int64},<:Site}}
return actual == expected
end

end
14 changes: 14 additions & 0 deletions ext/TenetFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,18 @@ function FiniteDifferences.to_vec(x::Dict{Vector{Symbol},Tensor})
return x_vec, Dict_from_vec
end

function FiniteDifferences.to_vec(x::Quantum)
x_vec, back = to_vec(TensorNetwork(x))
Quantum_from_vec(v) = Quantum(back(v), copy(x.sites))

return x_vec, Quantum_from_vec
end

function FiniteDifferences.to_vec(x::Ansatz)
x_vec, back = to_vec(Quantum(x))
Ansatz_from_vec(v) = Ansatz(back(v), copy(x.lattice))

return x_vec, Ansatz_from_vec
end

end
4 changes: 2 additions & 2 deletions ext/TenetGraphMakieExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module TenetGraphMakieExt

using Tenet
using GraphMakie
using Graphs
using Makie
const Graphs = GraphMakie.Graphs
using Tenet
using Combinatorics: combinations

"""
Expand Down
38 changes: 28 additions & 10 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,23 @@ function Reactant.make_tracer(seen::IdDict, prev::Quantum, path::Tuple, mode::Re
return Quantum(tracetn, copy(prev.sites))
end

function Reactant.make_tracer(seen::IdDict, prev::Ansatz, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracetn = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return Ansatz(tracetn, copy(Tenet.lattice(prev)))
end

# TODO try rely on generic fallback for ansatzes
function Reactant.make_tracer(seen::IdDict, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracequantum = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :super), mode; kwargs...)
tracequantum = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return Tenet.Product(tracequantum)
end

# TODO try rely on generic fallback for ansatzes -> do it when refactoring to MPS/MPO
function Reactant.make_tracer(seen::IdDict, prev::Tenet.Chain, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracequantum = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :super), mode; kwargs...)
return Tenet.Chain(tracequantum, boundary(prev))
for A in (MPS, MPO)
@eval function Reactant.make_tracer(seen::IdDict, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracequantum = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return $A(tracequantum, form(prev))
end
end

function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores)
data = Reactant.create_result(parent(tocopy), Reactant.append_path(path, :data), result_stores)
return :($Tensor($data, $(inds(tocopy))))
Expand All @@ -59,10 +65,22 @@ function Reactant.create_result(tocopy::Quantum, @nospecialize(path), result_sto
return :($Quantum($tn, $(copy(tocopy.sites))))
end

# TODO try rely on generic fallback for ansatzes -> do it when refactoring to MPS/MPO
function Reactant.create_result(tocopy::Tenet.Chain, @nospecialize(path), result_stores)
qtn = Reactant.create_result(Quantum(tocopy), Reactant.append_path(path, :super), result_stores)
return :($(Tenet.Chain)($qtn, $(boundary(tocopy))))
function Reactant.create_result(tocopy::Ansatz, @nospecialize(path), result_stores)
tn = Reactant.create_result(Quantum(tocopy), Reactant.append_path(path, :tn), result_stores)
return :($Ansatz($tn, $(copy(Tenet.lattice(tocopy)))))
end

# TODO try rely on generic fallback for ansatzes
function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), result_stores)
tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
return :($(Tenet.Product)($tn))
end

for A in (MPS, MPO)
@eval function Reactant.create_result(tocopy::$A, @nospecialize(path), result_stores)
tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
return :($A($tn, form(tocopy)))
end
end

function Reactant.push_val!(ad_inputs, x::TensorNetwork, path)
Expand Down
Loading

0 comments on commit 6b997fc

Please sign in to comment.