From 5a09e3ddc131d08a4a53a4f733049bb1739736db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Fri, 5 Jul 2024 16:19:31 +0200 Subject: [PATCH] Merge Qrochet.jl into Tenet.jl (#156) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Transfer code from `Qrochet.jl` Co-authored-by: Jofre Vallès Muns <61060572+jofrevalles@users.noreply.github.com> Co-authored-by: Todorbsc * Update feature list * Refactor directory organization --------- Co-authored-by: Jofre Vallès Muns <61060572+jofrevalles@users.noreply.github.com> Co-authored-by: Todorbsc --- Project.toml | 6 + README.md | 12 +- docs/make.jl | 4 + docs/src/ansatz/chain.md | 58 ++ docs/src/ansatz/product.md | 1 + docs/src/quantum.md | 39 ++ examples/Project.toml | 9 + examples/dagger.jl | 70 +++ examples/distributed.jl | 102 ++++ ext/TenetAdaptExt.jl | 5 +- ext/TenetChainRulesCoreExt/frules.jl | 12 + ext/TenetChainRulesCoreExt/non_diff.jl | 11 + ext/TenetChainRulesCoreExt/projectors.jl | 9 + ext/TenetChainRulesCoreExt/rrules.jl | 34 ++ ext/TenetChainRulesTestUtilsExt.jl | 8 + ext/TenetQuacExt.jl | 50 ++ ext/TenetYaoExt.jl | 51 ++ src/Ansatz/Ansatz.jl | 108 ++++ src/Ansatz/Chain.jl | 743 +++++++++++++++++++++++ src/Ansatz/Dense.jl | 38 ++ src/Ansatz/Grid.jl | 176 ++++++ src/Ansatz/Product.jl | 83 +++ src/Helpers.jl | 6 + src/Quantum.jl | 308 ++++++++++ src/Site.jl | 43 ++ src/Tenet.jl | 30 + test/Chain_test.jl | 390 ++++++++++++ test/Product_test.jl | 29 + test/Project.toml | 1 + test/Quantum_test.jl | 45 ++ test/Site_test.jl | 63 ++ test/integration/ChainRules_test.jl | 33 + test/integration/Quac_test.jl | 23 + test/runtests.jl | 5 + 34 files changed, 2594 insertions(+), 11 deletions(-) create mode 100644 docs/src/ansatz/chain.md create mode 100644 docs/src/ansatz/product.md create mode 100644 docs/src/quantum.md create mode 100644 examples/dagger.jl create mode 100644 examples/distributed.jl create mode 100644 ext/TenetQuacExt.jl create mode 100644 ext/TenetYaoExt.jl create mode 100644 src/Ansatz/Ansatz.jl create mode 100644 src/Ansatz/Chain.jl create mode 100644 src/Ansatz/Dense.jl create mode 100644 src/Ansatz/Grid.jl create mode 100644 src/Ansatz/Product.jl create mode 100644 src/Quantum.jl create mode 100644 src/Site.jl create mode 100644 test/Chain_test.jl create mode 100644 test/Product_test.jl create mode 100644 test/Quantum_test.jl create mode 100644 test/Site_test.jl create mode 100644 test/integration/Quac_test.jl diff --git a/Project.toml b/Project.toml index abdfaefc..81a0bce3 100644 --- a/Project.toml +++ b/Project.toml @@ -27,6 +27,8 @@ GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143" +Yao = "5872b779-8223-5990-8dd0-5abbb0748c8c" [extensions] TenetAdaptExt = "Adapt" @@ -37,6 +39,8 @@ TenetDaggerExt = "Dagger" TenetFiniteDifferencesExt = "FiniteDifferences" TenetGraphMakieExt = ["GraphMakie", "Makie"] TenetReactantExt = "Reactant" +TenetQuacExt = "Quac" +TenetYaoExt = "Yao" [compat] AbstractTrees = "0.4" @@ -55,9 +59,11 @@ LinearAlgebra = "1.9" Makie = "0.18,0.19,0.20, 0.21" Muscle = "0.2" OMEinsum = "0.7, 0.8" +Quac = "0.3" Random = "1.9" Reactant = "0.1" ScopedValues = "1" SparseArrays = "1.9" UUIDs = "1.9" +Yao = "0.8, 0.9" julia = "1.9" diff --git a/README.md b/README.md index d5eace1b..0c28492c 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,6 @@ [![Documentation: stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://bsc-quantic.github.io/Tenet.jl/) [![Documentation: dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://bsc-quantic.github.io/Tenet.jl/dev/) -> [!IMPORTANT] -> The code for quantum tensor networks has been moved to the new [`Qrochet`](https://github.com/bsc-quantic/Qrochet.jl) library. - A Julia library for **Ten**sor **Net**works. `Tenet` can be executed both at local environments and on large supercomputers. Its goals are, - **Expressiveness** _Simple to use._ 👶 @@ -22,14 +19,9 @@ A Julia library for **Ten**sor **Net**works. `Tenet` can be executed both at loc - Tensor Network slicing/cuttings - Automatic Differentiation of TN contraction - Distributed contraction -- Local Tensor Network transformations - - Hyperindex converter - - Rank simplification - - Diagonal reduction - - Anti-diagonal gauging - - Column reduction - - Split simplification +- Local Tensor Network transformations/simplifications - 2D & 3D visualization of large networks, powered by [`Makie`](https://github.com/MakieOrg/Makie.jl) +- Quantum Tensor Networks ## Preview diff --git a/docs/make.jl b/docs/make.jl index d9eef607..8bbe5ad7 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -24,6 +24,10 @@ makedocs(; "Tensor Networks" => "tensor-network.md", "Contraction" => "contraction.md", "Transformations" => "transformations.md", + "Quantum" => [ + "Introduction" => "quantum.md", + "Ansatzes" => ["`Product` ansatz" => "ansatz/product.md", "`Chain` ansatz" => "ansatz/chain.md"], + ], "Visualization" => "visualization.md", "Alternatives" => "alternatives.md", "References" => "references.md", diff --git a/docs/src/ansatz/chain.md b/docs/src/ansatz/chain.md new file mode 100644 index 00000000..37d32cfd --- /dev/null +++ b/docs/src/ansatz/chain.md @@ -0,0 +1,58 @@ +# Matrix Product States (MPS) + +Matrix Product States (MPS) are a Quantum Tensor Network ansatz whose tensors are laid out in a 1D chain. +Due to this, these networks are also known as _Tensor Trains_ in other mathematical fields. +Depending on the boundary conditions, the chains can be open or closed (i.e. periodic boundary conditions). + +```@setup viz +using Makie +Makie.inline!(true) +set_theme!(resolution=(800,200)) + +using CairoMakie + +using Tenet +using NetworkLayout +``` + +```@example viz +fig = Figure() # hide + +tn_open = rand(MatrixProduct{State,Open}, n=10, χ=4) # hide +tn_periodic = rand(MatrixProduct{State,Periodic}, n=10, χ=4) # hide + +plot!(fig[1,1], tn_open, layout=Spring(iterations=1000, C=0.5, seed=100)) # hide +plot!(fig[1,2], tn_periodic, layout=Spring(iterations=1000, C=0.5, seed=100)) # hide + +Label(fig[1,1, Bottom()], "Open") # hide +Label(fig[1,2, Bottom()], "Periodic") # hide + +fig # hide +``` + +## Matrix Product Operators (MPO) + +Matrix Product Operators (MPO) are the operator version of [Matrix Product State (MPS)](#matrix-product-states-mps). +The major difference between them is that MPOs have 2 indices per site (1 input and 1 output) while MPSs only have 1 index per site (i.e. an output). + +```@example viz +fig = Figure() # hide + +tn_open = rand(MatrixProduct{Operator,Open}, n=10, χ=4) # hide +tn_periodic = rand(MatrixProduct{Operator,Periodic}, n=10, χ=4) # hide + +plot!(fig[1,1], tn_open, layout=Spring(iterations=1000, C=0.5, seed=100)) # hide +plot!(fig[1,2], tn_periodic, layout=Spring(iterations=1000, C=0.5, seed=100)) # hide + +Label(fig[1,1, Bottom()], "Open") # hide +Label(fig[1,2, Bottom()], "Periodic") # hide + +fig # hide +``` + +In `Tenet`, the generic `MatrixProduct` ansatz implements this topology. Type variables are used to address their functionality (`State` or `Operator`) and their boundary conditions (`Open` or `Periodic`). + +```@docs +MatrixProduct +MatrixProduct(::Any) +``` diff --git a/docs/src/ansatz/product.md b/docs/src/ansatz/product.md new file mode 100644 index 00000000..28ba05e8 --- /dev/null +++ b/docs/src/ansatz/product.md @@ -0,0 +1 @@ +# `Product` ansatz diff --git a/docs/src/quantum.md b/docs/src/quantum.md new file mode 100644 index 00000000..0c8dd030 --- /dev/null +++ b/docs/src/quantum.md @@ -0,0 +1,39 @@ +# `Quantum` Tensor Networks + +```@docs +Quantum +Tenet.TensorNetwork(::Quantum) +Base.adjoint(::Quantum) +sites +nsites +``` + +## Queries + +```@docs +Tenet.inds(::Quantum; kwargs...) +Tenet.tensors(::Quantum; kwargs...) +``` + +## Connecting `Quantum` Tensor Networks + +```@docs +inputs +outputs +lanes +ninputs +noutputs +nlanes +``` + +```@docs +Socket +socket(::Quantum) +Scalar +State +Operator +``` + +```@docs +Base.merge(::Quantum, ::Quantum...) +``` diff --git a/examples/Project.toml b/examples/Project.toml index 415c0ec1..af994de6 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -4,8 +4,17 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" +ClusterManagers = "34f1f09b-3a8b-5176-ab39-66d58a4d544e" +Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" +KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec" +TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63" +Yao = "5872b779-8223-5990-8dd0-5abbb0748c8c" diff --git a/examples/dagger.jl b/examples/dagger.jl new file mode 100644 index 00000000..32920ac6 --- /dev/null +++ b/examples/dagger.jl @@ -0,0 +1,70 @@ +using Tenet +using Yao: Yao +using EinExprs +using AbstractTrees +using Distributed +using Dagger +using TimespanLogging +using KaHyPar + +m = 10 +circuit = Yao.EasyBuild.rand_google53(m); +H = Quantum(circuit) +ψ = Product(fill([1, 0], Yao.nqubits(circuit))) +qtn = merge(Quantum(ψ), H, Quantum(ψ)') +tn = Tenet.TensorNetwork(qtn) + +contract_smaller_dims = 20 +target_size = 24 + +Tenet.transform!(tn, Tenet.ContractSimplification()) +path = einexpr( + tn; + optimizer=HyPar(; + parts=2, + imbalance=0.41, + edge_scaler=(ind_size) -> 10 * Int(round(log2(ind_size))), + vertex_scaler=(prod_size) -> 100 * Int(round(exp2(prod_size))), + ), +); + +max_dims_path = @show maximum(ndims, Branches(path)) +flops_path = @show mapreduce(flops, +, Branches(path)) +@show log10(flops_path) + +grouppath = deepcopy(path); +function recursiveforeach!(f, expr) + f(expr) + return foreach(arg -> recursiveforeach!(f, arg), args(expr)) +end +sizedict = merge(Iterators.map(i -> i.size, Leaves(path))...); +recursiveforeach!(grouppath) do expr + merge!(expr.size, sizedict) + if all(<(contract_smaller_dims) ∘ ndims, expr.args) + empty!(expr.args) + end +end + +max_dims_grouppath = maximum(ndims, Branches(grouppath)) +flops_grouppath = mapreduce(flops, +, Branches(grouppath)) +targetinds = findslices(SizeScorer(), grouppath; size=2^(target_size)); + +subexprs = map(Leaves(grouppath)) do expr + only(EinExprs.select(path, tuple(head(expr)...))) +end + +addprocs(3) +@everywhere using Dagger, Tenet + +disttn = Tenet.TensorNetwork( + map(subexprs) do subexpr + Tensor( + distribute( # data + parent(Tenet.contract(tn; path=subexpr)), + Blocks([i ∈ targetinds ? 1 : 2 for i in head(subexpr)]...), + ), + head(subexpr), # inds + ) + end, +) +@show Tenet.contract(disttn; path=grouppath) diff --git a/examples/distributed.jl b/examples/distributed.jl new file mode 100644 index 00000000..5f184170 --- /dev/null +++ b/examples/distributed.jl @@ -0,0 +1,102 @@ +using Yao: Yao +using Tenet +using EinExprs +using KaHyPar +using Random +using Distributed +using ClusterManagers +using AbstractTrees + +n = 64 +depth = 6 + +circuit = Yao.chain(n) + +for _ in 1:depth + perm = randperm(n) + + for (i, j) in Iterators.partition(perm, 2) + push!(circuit, Yao.put((i, j) => Yao.EasyBuild.FSimGate(2π * rand(), 2π * rand()))) + # push!(circuit, Yao.control(n, i, j => Yao.phase(2π * rand()))) + end +end + +H = Quantum(circuit) +ψ = zeros(Product, n) + +tn = TensorNetwork(merge(Quantum(ψ), H, Quantum(ψ)')) +transform!(tn, Tenet.ContractSimplification()) + +path = einexpr( + tn; + optimizer=HyPar(; + parts=2, + imbalance=0.41, + edge_scaler=(ind_size) -> 10 * Int(round(log2(ind_size))), + vertex_scaler=(prod_size) -> 100 * Int(round(exp2(prod_size))), + ), +) + +@show maximum(ndims, Branches(path)) +@show maximum(length, Branches(path)) * sizeof(eltype(tn)) / 1024^3 + +@show log10(mapreduce(flops, +, Branches(path))) + +cutinds = findslices(SizeScorer(), path; size=2^24) +cuttings = [[i => dim for dim in 1:size(tn, i)] for i in cutinds] + +# mock sliced path - valid for all slices +proj_inds = first.(cuttings) +slice_path = view(path.path, proj_inds...) + +expr = Tenet.codegen(Val(:outplace), slice_path) + +manager = SlurmManager(2 * 112 - 1) +addprocs(manager; cpus_per_task=1, exeflags="--project=$(Base.active_project())") +# @everywhere using LinearAlgebra +# @everywhere LinearAlgebra.BLAS.set_num_threads(2) + +@everywhere using Tenet, EinExprs, IterTools, LinearAlgebra, Reactant, AbstractTrees +@everywhere tn = $tn +@everywhere slice_path = $slice_path +@everywhere cuttings = $cuttings +@everywhere expr = $expr + +partial_results = map(enumerate(workers())) do (i, worker) + Distributed.@spawnat worker begin + # interleaved chunking without instantiation + it = takenth(Iterators.drop(Iterators.product(cuttings...), i - 1), nworkers()) + + f = @eval $expr + mock_slice = view(tn, first(it)...) + tensors′ = [ + Tensor(Reactant.ConcreteRArray(copy(parent(mock_slice[head(leaf)...]))), inds(mock_slice[head(leaf)...])) for leaf in Leaves(slice_path) + ] + g = Reactant.compile(f, Tuple(tensors′)) + + # local reduction of chunk + accumulator = zero(eltype(tn)) + + for proj_inds in it + slice = view(tn, proj_inds...) + tensors′ = [ + Tensor( + Reactant.ConcreteRArray(copy(parent(mock_slice[head(leaf)...]))), + inds(mock_slice[head(leaf)...]), + ) for leaf in Leaves(slice_path) + ] + res = only(g(tensors′...)) + + # avoid OOM due to garbage accumulation + GC.gc() + + accumulator += res + end + + return accumulator + end +end + +@show result = sum(Distributed.fetch.(partial_results)) + +rmprocs(workers()) diff --git a/ext/TenetAdaptExt.jl b/ext/TenetAdaptExt.jl index 5d6710a4..6facff44 100644 --- a/ext/TenetAdaptExt.jl +++ b/ext/TenetAdaptExt.jl @@ -4,7 +4,10 @@ using Tenet using Adapt Adapt.adapt_structure(to, x::Tensor) = Tensor(adapt(to, parent(x)), inds(x)) - Adapt.adapt_structure(to, x::TensorNetwork) = TensorNetwork(adapt.(Ref(to), tensors(x))) +Adapt.adapt_structure(to, x::Quantum) = Quantum(adapt(to, TensorNetwork(x)), x.sites) +Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Quantum(x))) +Adapt.adapt_structure(to, x::Chain) = Chain(adapt(to, Quantum(x)), boundary(x)) + end diff --git a/ext/TenetChainRulesCoreExt/frules.jl b/ext/TenetChainRulesCoreExt/frules.jl index c3a2dc9b..941a723b 100644 --- a/ext/TenetChainRulesCoreExt/frules.jl +++ b/ext/TenetChainRulesCoreExt/frules.jl @@ -22,3 +22,15 @@ function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(contract), a::Tensor, b::T ċ = contract(ȧ, b; kwargs...) + contract(a, ḃ; kwargs...) return c, ċ end + +function ChainRulesCore.frule((_, ẋ, _), ::Type{Quantum}, x::TensorNetwork, sites) + y = Quantum(x, sites) + ẏ = Tangent{Quantum}(; tn=ẋ) + return y, ẏ +end + +ChainRulesCore.frule((_, ẋ), ::Type{T}, x::Quantum) where {T<:Ansatz} = T(x), Tangent{T}(; super=ẋ) + +function ChainRulesCore.frule((_, ẋ, _), ::Type{T}, x::Quantum, boundary) where {T<:Ansatz} + return T(x, boundary), Tangent{T}(; super=ẋ, boundary=NoTangent()) +end diff --git a/ext/TenetChainRulesCoreExt/non_diff.jl b/ext/TenetChainRulesCoreExt/non_diff.jl index f2d72861..2d1129c6 100644 --- a/ext/TenetChainRulesCoreExt/non_diff.jl +++ b/ext/TenetChainRulesCoreExt/non_diff.jl @@ -9,3 +9,14 @@ # TODO maybe we need to convert this into a frule/rrule? such that the tangents change their indices too @non_differentiable Base.replace!(::TensorNetwork, ::Pair{Symbol,Symbol}...) + +@non_differentiable Tenet.currindex() +@non_differentiable Tenet.nextindex() + +# WARN type-piracy +@non_differentiable Base.setdiff(::Vector{Symbol}, ::Base.ValueIterator) + +@non_differentiable Tenet.inputs(::Quantum) +@non_differentiable Tenet.ninputs(::Quantum) +@non_differentiable Tenet.outputs(::Quantum) +@non_differentiable Tenet.noutputs(::Quantum) diff --git a/ext/TenetChainRulesCoreExt/projectors.jl b/ext/TenetChainRulesCoreExt/projectors.jl index ad6cbe04..acd488fa 100644 --- a/ext/TenetChainRulesCoreExt/projectors.jl +++ b/ext/TenetChainRulesCoreExt/projectors.jl @@ -32,3 +32,12 @@ function (projector::ProjectTo{TensorNetwork})(dx) ) end (projector::ProjectTo{TensorNetwork})(dx::Vector{<:Tensor}) = projector(TensorNetwork(dx)) + +ChainRulesCore.ProjectTo(x::Quantum) = ProjectTo{Quantum}(; tn=ProjectTo(TensorNetwork(x)), sites=x.sites) +(projector::ProjectTo{Quantum})(Δ) = Quantum(projector.tn(Δ), projector.sites) + +ChainRulesCore.ProjectTo(x::T) where {T<:Ansatz} = ProjectTo{T}(; super=ProjectTo(Quantum(x))) +(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Ansatz} = T(projector.super(Δ.super), Δ.boundary) + +# NOTE edge case: `Product` has no `boundary`. should it? +(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Product} = T(projector.super(Δ.super)) diff --git a/ext/TenetChainRulesCoreExt/rrules.jl b/ext/TenetChainRulesCoreExt/rrules.jl index 92d35477..9992c877 100644 --- a/ext/TenetChainRulesCoreExt/rrules.jl +++ b/ext/TenetChainRulesCoreExt/rrules.jl @@ -92,3 +92,37 @@ function ChainRulesCore.rrule(::typeof(contract), a::Tensor, b::Tensor; kwargs.. return c, contract_pullback end + +Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +Quantum_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ, NoTangent()) +Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback + +Ansatz_pullback(ȳ) = (NoTangent(), ȳ.super) +Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ)) +function ChainRulesCore.rrule(::Type{T}, x::Quantum) where {T<:Ansatz} + y = T(x) + return y, Ansatz_pullback +end + +Ansatz_boundary_pullback(ȳ) = (NoTangent(), ȳ.super, NoTangent()) +Ansatz_boundary_pullback(ȳ::AbstractThunk) = Ansatz_boundary_pullback(unthunk(ȳ)) +function ChainRulesCore.rrule(::Type{T}, x::Quantum, boundary) where {T<:Ansatz} + return T(x, boundary), Ansatz_boundary_pullback +end + +Ansatz_from_arrays_pullback(ȳ) = (NoTangent(), NoTangent(), NoTangent(), parent.(tensors(ȳ.super.tn))) +Ansatz_from_arrays_pullback(ȳ::AbstractThunk) = Ansatz_from_arrays_pullback(unthunk(ȳ)) +function ChainRulesCore.rrule( + ::Type{T}, socket::Tenet.Socket, boundary::Tenet.Boundary, arrays; kwargs... +) where {T<:Ansatz} + y = T(socket, boundary, arrays; kwargs...) + return y, Ansatz_from_arrays_pullback +end + +copy_pullback(ȳ) = (NoTangent(), ȳ) +copy_pullback(ȳ::AbstractThunk) = unthunk(ȳ) +function ChainRulesCore.rrule(::typeof(copy), x::Quantum) + y = copy(x) + return y, copy_pullback +end diff --git a/ext/TenetChainRulesTestUtilsExt.jl b/ext/TenetChainRulesTestUtilsExt.jl index 4d64ddd8..8fdef1b0 100644 --- a/ext/TenetChainRulesTestUtilsExt.jl +++ b/ext/TenetChainRulesTestUtilsExt.jl @@ -19,4 +19,12 @@ function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::TensorNetwork) return TensorNetworkTangent(Tensor[ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)]) end +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::Quantum) + return Tangent{Quantum}(; tn=rand_tangent(rng, TensorNetwork(x)), sites=NoTangent()) +end + +# WARN type-piracy +# NOTE used in `Quantum` constructor +ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::Dict{<:Site,Symbol}) = NoTangent() + end diff --git a/ext/TenetQuacExt.jl b/ext/TenetQuacExt.jl new file mode 100644 index 00000000..17b973f6 --- /dev/null +++ b/ext/TenetQuacExt.jl @@ -0,0 +1,50 @@ +module TenetQuacExt + +using Tenet +using Tenet +using Quac: Gate, Circuit, lanes, arraytype, Swap + +function Tenet.Dense(gate::Gate) + return Tenet.Dense( + Operator(), arraytype(gate)(gate); sites=Site[Site.(lanes(gate))..., Site.(lanes(gate); dual=true)...] + ) +end + +Tenet.evolve!(qtn::Ansatz, gate::Gate; kwargs...) = evolve!(qtn, Tenet.Dense(gate); kwargs...) + +function Tenet.Quantum(circuit::Circuit) + n = lanes(circuit) + wire = [[Tenet.nextindex()] for _ in 1:n] + tensors = Tensor[] + + for gate in circuit + G = arraytype(gate) + array = G(gate) + + if gate isa Swap + (a, b) = lanes(gate) + wire[a], wire[b] = wire[b], wire[a] + continue + end + + inds = (x -> collect(Iterators.flatten(zip(x...))))( + map(lanes(gate)) do l + from, to = last(wire[l]), Tenet.nextindex() + push!(wire[l], to) + (from, to) + end, + ) + + tensor = Tensor(array, tuple(inds...)) + push!(tensors, tensor) + end + + sites = merge( + Dict([Site(site; dual=true) => first(index) for (site, index) in enumerate(wire)]), + Dict([Site(site; dual=false) => last(index) for (site, index) in enumerate(wire)]), + ) + + return Quantum(Tenet.TensorNetwork(tensors), sites) +end + +end diff --git a/ext/TenetYaoExt.jl b/ext/TenetYaoExt.jl new file mode 100644 index 00000000..4f4bce06 --- /dev/null +++ b/ext/TenetYaoExt.jl @@ -0,0 +1,51 @@ +module TenetYaoExt + +using Tenet +using Yao + +function flatten_circuit(x) + if any(i -> i isa ChainBlock, subblocks(x)) + flatten_circuit(Yao.Optimise.eliminate_nested(x)) + else + x + end +end + +function Tenet.Quantum(circuit::AbstractBlock) + @assert nlevel(circuit) == 2 "Only support 2-level qubits" + + n = nqubits(circuit) + wire = [[Tenet.nextindex()] for _ in 1:n] + tensors = Tensor[] + + for gate in flatten_circuit(circuit) + if gate isa Swap + (a, b) = occupied_locs(gate) + wire[a], wire[b] = wire[b], wire[a] + continue + end + + operator = content(gate) + array = reshape(mat(operator), fill(2, 2 * nqubits(operator))...) + + inds = (x -> collect(Iterators.flatten(zip(x...))))( + map(occupied_locs(gate)) do l + from, to = last(wire[l]), Tenet.nextindex() + push!(wire[l], to) + (from, to) + end, + ) + + tensor = Tensor(array, tuple(inds...)) + push!(tensors, tensor) + end + + sites = merge( + Dict([Site(site; dual=true) => first(index) for (site, index) in enumerate(wire)]), + Dict([Site(site; dual=false) => last(index) for (site, index) in enumerate(wire)]), + ) + + return Quantum(Tenet.TensorNetwork(tensors), sites) +end + +end diff --git a/src/Ansatz/Ansatz.jl b/src/Ansatz/Ansatz.jl new file mode 100644 index 00000000..0c47cd71 --- /dev/null +++ b/src/Ansatz/Ansatz.jl @@ -0,0 +1,108 @@ +using LinearAlgebra + +""" + Ansatz + +[`Quantum`](@ref) Tensor Network with a predefined structure. + +# Notes + + - Any subtype must define `super::Quantum` field or specialize the `Quantum` method. +""" +abstract type Ansatz end + +Quantum(@nospecialize tn::Ansatz) = tn.super + +Base.:(==)(a::Ansatz, b::Ansatz) = Quantum(a) == Quantum(b) +Base.isapprox(a::Ansatz, b::Ansatz; kwargs...) = isapprox(Quantum(a), Quantum(b); kwargs...) + +# TODO forward `Quantum` methods +for f in [ + :(Tenet.TensorNetwork), + :ninputs, + :noutputs, + :inputs, + :outputs, + :nsites, + :nlanes, + :socket, + :(Tenet.arrays), + :(Base.collect), +] + @eval $f(@nospecialize tn::Ansatz) = $f(Quantum(tn)) +end + +abstract type Boundary end +struct Open <: Boundary end +struct Periodic <: Boundary end + +function boundary end + +alias(::A) where {A} = string(A) +function Base.summary(io::IO, tn::A) where {A<:Ansatz} + return print(io, "$(alias(tn)) (inputs=$(ninputs(tn)), outputs=$(noutputs(tn)))") +end +Base.show(io::IO, tn::A) where {A<:Ansatz} = summary(io, tn) + +sites(tn::Ansatz; kwargs...) = sites(Quantum(tn); kwargs...) + +function Tenet.inds(tn::Ansatz; kwargs...) + if keys(kwargs) === (:bond,) + inds(tn, Val(:bond), kwargs[:bond]...) + else + inds(Quantum(tn); kwargs...) + end +end + +function Tenet.inds(tn::Ansatz, ::Val{:bond}, site1::Site, site2::Site) + @assert site1 ∈ sites(tn) "Site $site1 not found" + @assert site2 ∈ sites(tn) "Site $site2 not found" + @assert site1 != site2 "Sites must be different" + + tensor1 = tensors(tn; at=site1) + tensor2 = tensors(tn; at=site2) + + isdisjoint(inds(tensor1), inds(tensor2)) && return nothing + return only(inds(tensor1) ∩ inds(tensor2)) +end + +function Tenet.tensors(tn::Ansatz; kwargs...) + if keys(kwargs) === (:between,) + tensors(tn, Val(:between), kwargs[:between]...) + else + tensors(Quantum(tn); kwargs...) + end +end + +function Tenet.tensors(tn::Ansatz, ::Val{:between}, site1::Site, site2::Site) + @assert site1 ∈ sites(tn) "Site $site1 not found" + @assert site2 ∈ sites(tn) "Site $site2 not found" + @assert site1 != site2 "Sites must be different" + + tensor1 = tensors(tn; at=site1) + tensor2 = tensors(tn; at=site2) + + isdisjoint(inds(tensor1), inds(tensor2)) && return nothing + + return TensorNetwork(tn)[only(inds(tensor1) ∩ inds(tensor2))] +end + +struct MissingSchmidtCoefficientsException <: Base.Exception + bond::NTuple{2,Site} +end + +MissingSchmidtCoefficientsException(bond::Vector{<:Site}) = MissingSchmidtCoefficientsException(tuple(bond...)) + +function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) + return print(io, "Can't access the spectrum on bond $(e.bond)") +end + +function LinearAlgebra.norm(ψ::Ansatz, p::Real=2; kwargs...) + p == 2 || throw(ArgumentError("only L2-norm is implemented yet")) + + return LinearAlgebra.norm2(ψ; kwargs...) +end + +function LinearAlgebra.norm2(ψ::Ansatz; kwargs...) + return abs(sqrt(only(contract(merge(TensorNetwork(ψ), TensorNetwork(ψ')); kwargs...)))) +end diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl new file mode 100644 index 00000000..0b5b73ba --- /dev/null +++ b/src/Ansatz/Chain.jl @@ -0,0 +1,743 @@ +using LinearAlgebra +using Random +using Muscle: gramschmidt! + +struct Chain <: Ansatz + super::Quantum + boundary::Boundary +end + +Base.copy(tn::Chain) = Chain(copy(Quantum(tn)), boundary(tn)) + +Base.similar(tn::Chain) = Chain(similar(Quantum(tn)), boundary(tn)) +Base.zero(tn::Chain) = Chain(zero(Quantum(tn)), boundary(tn)) + +boundary(tn::Chain) = tn.boundary + +MPS(arrays) = Chain(State(), Open(), arrays) +pMPS(arrays) = Chain(State(), Periodic(), arrays) +MPO(arrays) = Chain(Operator(), Open(), arrays) +pMPO(arrays) = Chain(Operator(), Periodic(), arrays) + +alias(tn::Chain) = alias(socket(tn), boundary(tn), tn) +alias(::State, ::Open, ::Chain) = "MPS" +alias(::State, ::Periodic, ::Chain) = "pMPS" +alias(::Operator, ::Open, ::Chain) = "MPO" +alias(::Operator, ::Periodic, ::Chain) = "pMPO" + +function Chain(tn::TensorNetwork, sites, args...; kwargs...) + return Chain(Quantum(tn, sites), args...; kwargs...) +end + +defaultorder(::Type{Chain}, ::State) = (:o, :l, :r) +defaultorder(::Type{Chain}, ::Operator) = (:o, :i, :l, :r) + +function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, State())) + @assert all(==(3) ∘ ndims, arrays) "All arrays must have 3 dimensions" + issetequal(order, defaultorder(Chain, State())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))")) + + n = length(arrays) + symbols = [nextindex() for _ in 1:(2n)] + + _tensors = map(enumerate(arrays)) do (i, array) + inds = map(order) do dir + if dir == :o + symbols[i] + elseif dir == :r + symbols[n + mod1(i, n)] + elseif dir == :l + symbols[n + mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end + + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + + return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) +end + +function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, State())) + @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" + @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" + @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" + issetequal(order, defaultorder(Chain, State())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))")) + + n = length(arrays) + symbols = [nextindex() for _ in 1:(2n)] + + _tensors = map(enumerate(arrays)) do (i, array) + _order = if i == 1 + filter(x -> x != :l, order) + elseif i == n + filter(x -> x != :r, order) + else + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :r + symbols[n + mod1(i, n)] + elseif dir == :l + symbols[n + mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end + + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + + return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) +end + +function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, Operator())) + @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" + issetequal(order, defaultorder(Chain, Operator())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) + + n = length(arrays) + symbols = [nextindex() for _ in 1:(3n)] + + _tensors = map(enumerate(arrays)) do (i, array) + inds = map(order) do dir + if dir == :o + symbols[i] + elseif dir == :i + symbols[i + n] + elseif dir == :l + symbols[2n + mod1(i - 1, n)] + elseif dir == :r + symbols[2n + mod1(i, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end + + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + merge!(sitemap, Dict(Site(i; dual=true) => symbols[i + n] for i in 1:n)) + + return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) +end + +function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; order=defaultorder(Chain, Operator())) + @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" + @assert all(==(4) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 4 dimensions" + @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" + issetequal(order, defaultorder(Chain, Operator())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) + + n = length(arrays) + symbols = [nextindex() for _ in 1:(3n - 1)] + + _tensors = map(enumerate(arrays)) do (i, array) + _order = if i == 1 + filter(x -> x != :l, order) + elseif i == n + filter(x -> x != :r, order) + else + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :i + symbols[i + n] + elseif dir == :l + symbols[2n + mod1(i - 1, n)] + elseif dir == :r + symbols[2n + mod1(i, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end + + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + merge!(sitemap, Dict(Site(i; dual=true) => symbols[i + n] for i in 1:n)) + + return Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) +end + +function Base.convert(::Type{Chain}, qtn::Product) + arrs::Vector{Array} = arrays(TensorNetwork(qtn)) + arrs[1] = reshape(arrs[1], size(arrs[1])..., 1) + arrs[end] = reshape(arrs[end], size(arrs[end])..., 1) + map!(@view(arrs[2:(end - 1)]), @view(arrs[2:(end - 1)])) do arr + reshape(arr, size(arr)..., 1, 1) + end + + return Chain(socket(qtn), Open(), arrs) +end + +leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site) +function leftsite(::Open, tn::Chain, site::Site) + return id(site) ∈ range(2, nlanes(tn)) ? Site(id(site) - 1; dual=isdual(site)) : nothing +end +leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) - 1, nlanes(tn)); dual=isdual(site)) + +rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site) +function rightsite(::Open, tn::Chain, site::Site) + return id(site) ∈ range(1, nlanes(tn) - 1) ? Site(id(site) + 1; dual=isdual(site)) : nothing +end +rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn)); dual=isdual(site)) + +leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site) +leftindex(::Open, tn::Chain, site::Site) = site == site"1" ? nothing : leftindex(Periodic(), tn, site) +leftindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond=(site, leftsite(tn, site))) + +rightindex(tn::Chain, site::Site) = rightindex(boundary(tn), tn, site) +function rightindex(::Open, tn::Chain, site::Site) + return site == Site(nlanes(tn); dual=isdual(site)) ? nothing : rightindex(Periodic(), tn, site) +end +rightindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond=(site, rightsite(tn, site))) + +Base.adjoint(chain::Chain) = Chain(adjoint(Quantum(chain)), boundary(chain)) + +struct ChainSampler{B<:Boundary,S<:Socket,NT<:NamedTuple} <: Random.Sampler{Chain} + parameters::NT + + ChainSampler{B,S}(; kwargs...) where {B,S} = new{B,S,typeof(values(kwargs))}(values(kwargs)) +end + +function Base.rand(A::Type{<:Chain}, B::Type{<:Boundary}, S::Type{<:Socket}; kwargs...) + return rand(Random.default_rng(), A, B, S; kwargs...) +end + +function Base.rand(rng::AbstractRNG, ::Type{A}, ::Type{B}, ::Type{S}; kwargs...) where {A<:Chain,B<:Boundary,S<:Socket} + return rand(rng, ChainSampler{B,S}(; kwargs...), B, S) +end + +function Base.rand(rng::Random.AbstractRNG, sampler::ChainSampler, ::Type{Open}, ::Type{State}) + n = sampler.parameters.n + χ = sampler.parameters.χ + p = get(sampler.parameters, :p, 2) + T = get(sampler.parameters, :eltype, Float64) + + arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i + χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 + χl = min(χ, p^(i - 1)) + χr = min(χ, p^i) + + # swap bond dims after mid and handle midpoint for odd-length MPS + (isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr)) + end + + # fix for first site + i == 1 && ((χl, χr) = (χr, 1)) + + # orthogonalize by Gram-Schmidt algorithm + A = gramschmidt!(rand(rng, T, χl, χr * p)) + + A = reshape(A, χl, χr, p) + permutedims(A, (3, 1, 2)) + end + + # reshape boundary sites + arrays[1] = reshape(arrays[1], p, p) + arrays[n] = reshape(arrays[n], p, p) + + # normalize state + arrays[1] ./= sqrt(p) + + return Chain(State(), Open(), arrays) +end + +# TODO let choose the orthogonality center +# TODO different input/output physical dims +function Base.rand(rng::Random.AbstractRNG, sampler::ChainSampler, ::Type{Open}, ::Type{Operator}) + n = sampler.parameters.n + χ = sampler.parameters.χ + p = get(sampler.parameters, :p, 2) + T = get(sampler.parameters, :eltype, Float64) + + ip = op = p + + arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i + χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 + χl = min(χ, ip^(i - 1) * op^(i - 1)) + χr = min(χ, ip^i * op^i) + + # swap bond dims after mid and handle midpoint for odd-length MPS + (isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr)) + end + + shape = if i == 1 + (χr, ip, op) + elseif i == n + (χl, ip, op) + else + (χl, χr, ip, op) + end + + # orthogonalize by Gram-Schmidt algorithm + A = gramschmidt!(rand(rng, T, shape[1], prod(shape[2:end]))) + A = reshape(A, shape) + + (i == 1 || i == n) ? permutedims(A, (2, 3, 1)) : permutedims(A, (3, 4, 1, 2)) + end + + # normalize + ζ = min(χ, ip * op) + arrays[1] ./= sqrt(ζ) + + return Chain(Operator(), Open(), arrays) +end + +Tenet.contract(tn::Chain, query::Symbol, args...; kwargs...) = contract!(copy(tn), Val(query), args...; kwargs...) +Tenet.contract!(tn::Chain, query::Symbol, args...; kwargs...) = contract!(tn, Val(query), args...; kwargs...) + +""" + Tenet.contract!(tn::Chain, ::Val{:between}, site1::Site, site2::Site; direction::Symbol = :left, delete_Λ = true) + +For a given [`Chain`](@ref) tensor network, contracts the singular values Λ between two sites `site1` and `site2`. +The `direction` keyword argument specifies the direction of the contraction, and the `delete_Λ` keyword argument +specifies whether to delete the singular values tensor after the contraction. +""" +function Tenet.contract!(tn::Chain, ::Val{:between}, site1::Site, site2::Site; direction::Symbol=:left, delete_Λ=true) + Λᵢ = tensors(tn; between=(site1, site2)) + Λᵢ === nothing && return tn + + if direction === :right + Γᵢ₊₁ = tensors(tn; at=site2) + replace!(TensorNetwork(tn), Γᵢ₊₁ => contract(Γᵢ₊₁, Λᵢ; dims=())) + elseif direction === :left + Γᵢ = tensors(tn; at=site1) + replace!(TensorNetwork(tn), Γᵢ => contract(Λᵢ, Γᵢ; dims=())) + else + throw(ArgumentError("Unknown direction=:$direction")) + end + + delete_Λ && delete!(TensorNetwork(tn), Λᵢ) + + return tn +end + +canonize_site(tn::Chain, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...) +canonize_site!(tn::Chain, args...; kwargs...) = canonize_site!(boundary(tn), tn, args...; kwargs...) + +# NOTE: in method == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex! +function canonize_site!(::Open, tn::Chain, site::Site; direction::Symbol, method=:qr) + left_inds = Symbol[] + right_inds = Symbol[] + + virtualind = if direction === :left + site == Site(1) && throw(ArgumentError("Cannot right-canonize left-most tensor")) + push!(right_inds, leftindex(tn, site)) + + site == Site(nsites(tn)) || push!(left_inds, rightindex(tn, site)) + push!(left_inds, Quantum(tn)[site]) + + only(right_inds) + elseif direction === :right + site == Site(nsites(tn)) && throw(ArgumentError("Cannot left-canonize right-most tensor")) + push!(right_inds, rightindex(tn, site)) + + site == Site(1) || push!(left_inds, leftindex(tn, site)) + push!(left_inds, Quantum(tn)[site]) + + only(right_inds) + else + throw(ArgumentError("Unknown direction=:$direction")) + end + + tmpind = gensym(:tmp) + if method === :svd + svd!(TensorNetwork(tn); left_inds, right_inds, virtualind=tmpind) + elseif method === :qr + qr!(TensorNetwork(tn); left_inds, right_inds, virtualind=tmpind) + else + throw(ArgumentError("Unknown factorization method=:$method")) + end + + contract!(TensorNetwork(tn), virtualind) + replace!(TensorNetwork(tn), tmpind => virtualind) + + return tn +end + +truncate(tn::Chain, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...) + +""" + truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing) + +Truncate the dimension of the virtual `bond`` of the [`Chain`](@ref) Tensor Network by keeping only the `maxdim` largest Schmidt coefficients or those larger than`threshold`. + +# Notes + + - Either `threshold` or `maxdim` must be provided. If both are provided, `maxdim` is used. + - The bond must contain the Schmidt coefficients, i.e. a site canonization must be performed before calling `truncate!`. +""" +function truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real}=nothing, maxdim::Union{Nothing,Int}=nothing) + # TODO replace for tensors(; between) + vind = rightindex(qtn, bond[1]) + if vind != leftindex(qtn, bond[2]) + throw(ArgumentError("Invalid bond $bond")) + end + + if vind ∉ inds(qtn; set=:hyper) + throw(MissingSchmidtCoefficientsException(bond)) + end + + tensor = TensorNetwork(qtn)[vind] + spectrum = parent(tensor) + + extent = collect( + if !isnothing(maxdim) + 1:min(size(TensorNetwork(qtn), vind), maxdim) + else + 1:size(TensorNetwork(qtn), vind) + end, + ) + + # remove 0s from spectrum + if isnothing(threshold) + threshold = 1e-16 + end + + filter!(extent) do i + abs(spectrum[i]) > threshold + end + + slice!(TensorNetwork(qtn), vind, extent) + + return qtn +end + +function isleftcanonical(qtn::Chain, site; atol::Real=1e-12) + right_ind = rightindex(qtn, site) + tensor = tensors(qtn; at=site) + + # we are at right-most site, we need to add an extra dummy dimension to the tensor + if isnothing(right_ind) + right_ind = gensym(:dummy) + tensor = Tensor(reshape(parent(tensor), size(tensor)..., 1), (inds(tensor)..., right_ind)) + end + + # TODO is replace(conj(A)...) copying too much? + contracted = contract(tensor, replace(conj(tensor), right_ind => gensym(:new_ind))) + n = size(tensor, right_ind) + identity_matrix = Matrix(I, n, n) + + return isapprox(contracted, identity_matrix; atol) +end + +function isrightcanonical(qtn::Chain, site; atol::Real=1e-12) + left_ind = leftindex(qtn, site) + tensor = tensors(qtn; at=site) + + # we are at left-most site, we need to add an extra dummy dimension to the tensor + if isnothing(left_ind) + left_ind = gensym(:dummy) + tensor = Tensor(reshape(parent(tensor), 1, size(tensor)...), (left_ind, inds(tensor)...)) + end + + #TODO is replace(conj(A)...) copying too much? + contracted = contract(tensor, replace(conj(tensor), left_ind => gensym(:new_ind))) + n = size(tensor, left_ind) + identity_matrix = Matrix(I, n, n) + + return isapprox(contracted, identity_matrix; atol) +end + +canonize(tn::Chain, args...; kwargs...) = canonize!(copy(tn), args...; kwargs...) +canonize!(tn::Chain, args...; kwargs...) = canonize!(boundary(tn), tn, args...; kwargs...) + +""" +canonize(boundary::Boundary, tn::Chain) + +Transform a `Chain` tensor network into the canonical form (Vidal form), that is, +we have the singular values matrix Λᵢ between each tensor Γᵢ₋₁ and Γᵢ. +""" +function canonize!(::Open, tn::Chain) + Λ = Tensor[] + + # right-to-left QR sweep, get right-canonical tensors + for i in nsites(tn):-1:2 + canonize_site!(tn, Site(i); direction=:left, method=:qr) + end + + # left-to-right SVD sweep, get left-canonical tensors and singular values without reversing + for i in 1:(nsites(tn) - 1) + canonize_site!(tn, Site(i); direction=:right, method=:svd) + + # extract the singular values and contract them with the next tensor + Λᵢ = pop!(TensorNetwork(tn), tensors(tn; between=(Site(i), Site(i + 1)))) + Aᵢ₊₁ = tensors(tn; at=Site(i + 1)) + replace!(TensorNetwork(tn), Aᵢ₊₁ => contract(Aᵢ₊₁, Λᵢ; dims=())) + push!(Λ, Λᵢ) + end + + for i in 2:nsites(tn) # tensors at i in "A" form, need to contract (Λᵢ)⁻¹ with A to get Γᵢ + Λᵢ = Λ[i - 1] # singular values start between site 1 and 2 + A = tensors(tn; at=Site(i)) + Γᵢ = contract(A, Tensor(diag(pinv(Diagonal(parent(Λᵢ)); atol=1e-64)), inds(Λᵢ)); dims=()) + replace!(TensorNetwork(tn), A => Γᵢ) + push!(TensorNetwork(tn), Λᵢ) + end + + return tn +end + +mixed_canonize(tn::Chain, args...; kwargs...) = mixed_canonize!(deepcopy(tn), args...; kwargs...) +mixed_canonize!(tn::Chain, args...; kwargs...) = mixed_canonize!(boundary(tn), tn, args...; kwargs...) + +""" + mixed_canonize!(boundary::Boundary, tn::Chain, center::Site) + +Transform a `Chain` tensor network into the mixed-canonical form, that is, +for i < center the tensors are left-canonical and for i >= center the tensors are right-canonical, +and in the center there is a matrix with singular values. +""" +function mixed_canonize!(::Open, tn::Chain, center::Site) # TODO: center could be a range of sites + # left-to-right QR sweep (left-canonical tensors) + for i in 1:(id(center) - 1) + canonize_site!(tn, Site(i); direction=:right, method=:qr) + end + + # right-to-left QR sweep (right-canonical tensors) + for i in nsites(tn):-1:(id(center) + 1) + canonize_site!(tn, Site(i); direction=:left, method=:qr) + end + + # center SVD sweep to get singular values + canonize_site!(tn, center; direction=:left, method=:svd) + + return tn +end + +""" + LinearAlgebra.normalize!(tn::Chain, center::Site) + +Normalizes the input [`Chain`](@ref) tensor network by transforming it +to mixed-canonized form with the given center site. +""" +function LinearAlgebra.normalize!(tn::Chain, root::Site; p::Real=2) + mixed_canonize!(tn, root) + normalize!(tensors(tn; between=(Site(id(root) - 1), root)), p) + return tn +end + +""" + evolve!(qtn::Chain, gate) + +Applies a local operator `gate` to the [`Chain`](@ref) tensor network. +""" +function evolve!(qtn::Chain, gate::Dense; threshold=nothing, maxdim=nothing, iscanonical=false, renormalize=false) + # check gate is a valid operator + if !(socket(gate) isa Operator) + throw(ArgumentError("Gate must be an operator, but got $(socket(gate))")) + end + + # TODO refactor out to `islane`? + if !issetequal(adjoint.(inputs(gate)), outputs(gate)) + throw(ArgumentError("Gate inputs ($(inputs(gate))) and outputs ($(outputs(gate))) must be the same")) + end + + # TODO refactor out to `canconnect`? + if adjoint.(inputs(gate)) ⊈ outputs(qtn) + throw(ArgumentError("Gate inputs ($(inputs(gate))) must be a subset of the TN sites ($(sites(qtn)))")) + end + + if nlanes(gate) == 1 + evolve_1site!(qtn, gate) + elseif nlanes(gate) == 2 + # check gate sites are contiguous + # TODO refactor this out? + gate_inputs = sort!(map(id, inputs(gate))) + range = UnitRange(extrema(gate_inputs)...) + + range != gate_inputs && throw(ArgumentError("Gate lanes must be contiguous")) + + # TODO check correctly for periodic boundary conditions + evolve_2site!(qtn, gate; threshold, maxdim, iscanonical, renormalize) + else + # TODO generalize for more than 2 lanes + throw(ArgumentError("Invalid number of lanes $(nlanes(gate)), maximum is 2")) + end + + return qtn +end + +function evolve_1site!(qtn::Chain, gate::Dense) + # shallow copy to avoid problems if errors in mid execution + gate = copy(gate) + + contracting_index = gensym(:tmp) + targetsite = only(inputs(gate))' + + # reindex contracting index + replace!(TensorNetwork(qtn), inds(qtn; at=targetsite) => contracting_index) + replace!(TensorNetwork(gate), inds(gate; at=targetsite') => contracting_index) + + # reindex output of gate to match TN sitemap + replace!(TensorNetwork(gate), inds(gate; at=only(outputs(gate))) => inds(qtn; at=targetsite)) + + # contract gate with TN + merge!(TensorNetwork(qtn), TensorNetwork(gate)) + return contract!(TensorNetwork(qtn), contracting_index) +end + +# TODO: Maybe rename iscanonical kwarg ? +function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical=false, renormalize=false) + # shallow copy to avoid problems if errors in mid execution + gate = copy(gate) + + bond = sitel, siter = minmax(outputs(gate)...) + left_inds::Vector{Symbol} = !isnothing(leftindex(qtn, sitel)) ? [leftindex(qtn, sitel)] : Symbol[] + right_inds::Vector{Symbol} = !isnothing(rightindex(qtn, siter)) ? [rightindex(qtn, siter)] : Symbol[] + + virtualind::Symbol = inds(qtn; bond=bond) + + iscanonical ? contract_2sitewf!(qtn, bond) : contract!(TensorNetwork(qtn), virtualind) + + # reindex contracting index + contracting_inds = [gensym(:tmp) for _ in inputs(gate)] + replace!( + TensorNetwork(qtn), + map(zip(inputs(gate), contracting_inds)) do (site, contracting_index) + inds(qtn; at=site') => contracting_index + end, + ) + replace!( + TensorNetwork(gate), + map(zip(inputs(gate), contracting_inds)) do (site, contracting_index) + inds(gate; at=site) => contracting_index + end, + ) + + # reindex output of gate to match TN sitemap + for site in outputs(gate) + if inds(qtn; at=site) != inds(gate; at=site) + replace!(TensorNetwork(gate), inds(gate; at=site) => inds(qtn; at=site)) + end + end + + # contract physical inds + merge!(TensorNetwork(qtn), TensorNetwork(gate)) + contract!(TensorNetwork(qtn), contracting_inds) + + # decompose using SVD + push!(left_inds, inds(qtn; at=sitel)) + push!(right_inds, inds(qtn; at=siter)) + + if iscanonical + unpack_2sitewf!(qtn, bond, left_inds, right_inds, virtualind) + else + svd!(TensorNetwork(qtn); left_inds, right_inds, virtualind) + end + # truncate virtual index + if any(!isnothing, [threshold, maxdim]) + truncate!(qtn, bond; threshold, maxdim) + + # renormalize the bond + if renormalize && iscanonical + λ = tensors(qtn; between=bond) + replace!(TensorNetwork(qtn), λ => normalize(λ)) + elseif renormalize && !iscanonical + normalize!(qtn, bond[1]) + end + end + + return qtn +end + +""" + contract_2sitewf!(ψ::Chain, bond) + +For a given [`Chain`](@ref) in the canonical form, creates the two-site wave function θ with Λᵢ₋₁Γᵢ₋₁ΛᵢΓᵢΛᵢ₊₁, +where i is the `bond`, and replaces the Γᵢ₋₁ΛᵢΓᵢ tensors with θ. +""" +function contract_2sitewf!(ψ::Chain, bond) + # TODO Check if ψ is in canonical form + + sitel, siter = bond # TODO Check if bond is valid + (0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) || + throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) + + Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel)) + Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1))) + + !isnothing(Λᵢ₋₁) && contract!(ψ, :between, Site(id(sitel) - 1), sitel; direction=:right, delete_Λ=false) + !isnothing(Λᵢ₊₁) && contract!(ψ, :between, siter, Site(id(siter) + 1); direction=:left, delete_Λ=false) + + contract!(TensorNetwork(ψ), inds(ψ; bond=bond)) + + return ψ +end + +""" + unpack_2sitewf!(ψ::Chain, bond) + +For a given [`Chain`](@ref) that contains a two-site wave function θ in a bond, it decomposes θ into the canonical +form: Γᵢ₋₁ΛᵢΓᵢ, where i is the `bond`. +""" +function unpack_2sitewf!(ψ::Chain, bond, left_inds, right_inds, virtualind) + # TODO Check if ψ is in canonical form + + sitel, siter = bond # TODO Check if bond is valid + (0 < id(sitel) < nsites(ψ) || 0 < id(site_r) < nsites(ψ)) || + throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) + + Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel)) + Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1))) + + # do svd of the θ tensor + θ = tensors(ψ; at=sitel) + U, s, Vt = svd(θ; left_inds, right_inds, virtualind) + + # contract with the inverse of Λᵢ and Λᵢ₊₂ + Γᵢ₋₁ = + isnothing(Λᵢ₋₁) ? U : contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)); atol=1e-32)), inds(Λᵢ₋₁)); dims=()) + Γᵢ = + isnothing(Λᵢ₊₁) ? Vt : contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)); atol=1e-32)), inds(Λᵢ₊₁)), Vt; dims=()) + + delete!(TensorNetwork(ψ), θ) + + push!(TensorNetwork(ψ), Γᵢ₋₁) + push!(TensorNetwork(ψ), s) + push!(TensorNetwork(ψ), Γᵢ) + + return ψ +end + +function expect(ψ::Chain, observables) + # contract observable with TN + ϕ = copy(ψ) + for observable in observables + evolve!(ϕ, observable) + end + + # contract evolved TN with adjoint of original TN + tn = merge!(TensorNetwork(ϕ), TensorNetwork(ψ')) + + return contract(tn) +end + +overlap(a::Chain, b::Chain) = overlap(socket(a), a, socket(b), b) + +# TODO fix optimal path +function overlap(::State, a::Chain, ::State, b::Chain) + @assert issetequal(sites(a), sites(b)) "Ansatzes must have the same sites" + + b = copy(b) + b = @reindex! outputs(a) => outputs(b) + + tn = merge(TensorNetwork(a), TensorNetwork(b')) + return contract(tn) +end + +# TODO optimize +overlap(a::Product, b::Chain) = contract(TensorNetwork(merge(Quantum(a), Quantum(b)'))) +overlap(a::Chain, b::Product) = contract(TensorNetwork(merge(Quantum(a), Quantum(b)'))) diff --git a/src/Ansatz/Dense.jl b/src/Ansatz/Dense.jl new file mode 100644 index 00000000..c92fc969 --- /dev/null +++ b/src/Ansatz/Dense.jl @@ -0,0 +1,38 @@ +struct Dense <: Ansatz + super::Quantum +end + +function Dense(::State, array::AbstractArray; sites=Site.(1:ndims(array))) + @assert ndims(array) > 0 + @assert all(>(1), size(array)) + + symbols = [nextindex() for _ in 1:ndims(array)] + sitemap = Dict{Site,Symbol}( + map(sites, 1:ndims(array)) do site, i + site => symbols[i] + end, + ) + + tensor = Tensor(array, symbols) + + tn = TensorNetwork([tensor]) + qtn = Quantum(tn, sitemap) + return Dense(qtn) +end + +function Dense(::Operator, array::AbstractArray; sites) + @assert ndims(array) > 0 + @assert all(>(1), size(array)) + @assert length(sites) == ndims(array) + + tensor_inds = [nextindex() for _ in 1:ndims(array)] + tensor = Tensor(array, tensor_inds) + tn = TensorNetwork([tensor]) + + sitemap = Dict{Site,Symbol}(map(splat(Pair), zip(sites, tensor_inds))) + qtn = Quantum(tn, sitemap) + + return Dense(qtn) +end + +Base.copy(qtn::Dense) = Dense(copy(Quantum(qtn))) diff --git a/src/Ansatz/Grid.jl b/src/Ansatz/Grid.jl new file mode 100644 index 00000000..46fab3c1 --- /dev/null +++ b/src/Ansatz/Grid.jl @@ -0,0 +1,176 @@ +struct Grid <: Ansatz + super::Quantum + boundary::Boundary +end + +Base.copy(tn::Grid) = Grid(copy(Quantum(tn)), boundary(tn)) + +boundary(tn::Grid) = tn.boundary + +PEPS(arrays) = Grid(State(), Open(), arrays) +pPEPS(arrays) = Grid(State(), Periodic(), arrays) +PEPO(arrays) = Grid(Operator(), Open(), arrays) +pPEPO(arrays) = Grid(Operator(), Periodic(), arrays) + +alias(tn::Grid) = alias(socket(tn), boundary(tn), tn) +alias(::State, ::Open, ::Grid) = "PEPS" +alias(::State, ::Periodic, ::Grid) = "pPEPS" +alias(::Operator, ::Open, ::Grid) = "PEPO" +alias(::Operator, ::Periodic, ::Grid) = "pPEPO" + +function Grid(::State, ::Periodic, arrays::Matrix{<:AbstractArray}) + @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" + + m, n = size(arrays) + pinds = map(_ -> nextindex(), arrays) + hvinds = map(_ -> nextindex(), arrays) + vvinds = map(_ -> nextindex(), arrays) + + _tensors = map(eachindex(IndexCartesian(), arrays)) do I + i, j = Tuple(I) + + array = arrays[i, j] + pind = pinds[i, j] + up, down = hvinds[i, j], hvinds[mod1(i + 1, m), j] + left, right = vvinds[i, j], vvinds[i, mod1(j + 1, n)] + + # TODO customize order + Tensor(array, [pind, up, down, left, right]) + end + + sitemap = Dict(Site(i, j) => pinds[i, j] for i in 1:m, j in 1:n) + + return Grid(Quantum(TensorNetwork(_tensors), sitemap), Periodic()) +end + +function Grid(::State, ::Open, arrays::Matrix{<:AbstractArray}) + m, n = size(arrays) + + predicate = all(eachindex(arrays)) do I + i, j = Tuple(I) + array = arrays[i, j] + + N = ndims(array) - 1 + (i == 1 || i == m) && (N -= 1) + (j == 1 || j == n) && (N -= 1) + + N > 0 + end + + if !predicate + throw(DimensionMismatch()) + end + + pinds = map(_ -> nextindex(), arrays) + vvinds = [nextindex() for _ in 1:(m - 1), _ in 1:n] + hvinds = [nextindex() for _ in 1:m, _ in 1:(n - 1)] + + _tensors = map(eachindex(IndexCartesian(), arrays)) do I + i, j = Tuple(I) + + array = arrays[i, j] + pind = pinds[i, j] + up = i == 1 ? missing : vvinds[i - 1, j] + down = i == m ? missing : vvinds[i, j] + left = j == 1 ? missing : hvinds[i, j - 1] + right = j == n ? missing : hvinds[i, j] + + # TODO customize order + Tensor(array, collect(skipmissing([pind, up, down, left, right]))) + end + + sitemap = Dict(Site(i, j) => pinds[i, j] for i in 1:m, j in 1:n) + + return Grid(Quantum(TensorNetwork(_tensors), sitemap), Open()) +end + +function Grid(::Operator, ::Periodic, arrays::Matrix{<:AbstractArray}) + @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" + + m, n = size(arrays) + ipinds = map(_ -> nextindex(), arrays) + opinds = map(_ -> nextindex(), arrays) + hvinds = map(_ -> nextindex(), arrays) + vvinds = map(_ -> nextindex(), arrays) + + _tensors = map(eachindex(IndexCartesian(), arrays)) do I + i, j = Tuple(I) + + array = arrays[i, j] + ipind, opind = ipinds[i, j], opinds[i, j] + up, down = hvinds[i, j], hvinds[mod1(i + 1, m), j] + left, right = vvinds[i, j], vvinds[i, mod1(j + 1, n)] + + # TODO customize order + Tensor(array, [ipind, opind, up, down, left, right]) + end + + sitemap = Dict( + flatten([ + (Site(i, j; dual=true) => ipinds[i, j] for i in 1:m, j in 1:n), + (Site(i, j) => opinds[i, j] for i in 1:m, j in 1:n), + ]), + ) + + return Grid(Quantum(TensorNetwork(_tensors), sitemap), Periodic()) +end + +function Grid(::Operator, ::Open, arrays::Matrix{<:AbstractArray}) + m, n = size(arrays) + + predicate = all(eachindex(IndexCartesian(), arrays)) do I + i, j = Tuple(I) + array = arrays[i, j] + + N = ndims(array) - 2 + (i == 1 || i == m) && (N -= 1) + (j == 1 || j == n) && (N -= 1) + + N > 0 + end + + if !predicate + throw(DimensionMismatch()) + end + + ipinds = map(_ -> nextindex(), arrays) + opinds = map(_ -> nextindex(), arrays) + vvinds = [nextindex() for _ in 1:(m - 1), _ in 1:n] + hvinds = [nextindex() for _ in 1:m, _ in 1:(n - 1)] + + _tensors = map(eachindex(IndexCartesian(), arrays)) do I + i, j = Tuple(I) + + array = arrays[i, j] + ipind = ipinds[i, j] + opind = opinds[i, j] + up = i == 1 ? missing : vvinds[i - 1, j] + down = i == m ? missing : vvinds[i, j] + left = j == 1 ? missing : hvinds[i, j - 1] + right = j == n ? missing : hvinds[i, j] + + # TODO customize order + Tensor(array, collect(skipmissing([ipind, opind, up, down, left, right]))) + end + + sitemap = Dict( + flatten([ + (Site(i, j; dual=true) => ipinds[i, j] for i in 1:m, j in 1:n), + (Site(i, j) => opinds[i, j] for i in 1:m, j in 1:n), + ]), + ) + + return Grid(Quantum(TensorNetwork(_tensors), sitemap), Open()) +end + +function LinearAlgebra.transpose!(qtn::Grid) + old = Quantum(qtn).sites + new = Dict(Site(reverse(id(site)); dual=isdual(site)) => ind for (site, ind) in old) + + empty!(old) + merge!(old, new) + + return qtn +end + +Base.transpose(qtn::Grid) = LinearAlgebra.transpose!(copy(qtn)) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl new file mode 100644 index 00000000..13955df9 --- /dev/null +++ b/src/Ansatz/Product.jl @@ -0,0 +1,83 @@ +using LinearAlgebra + +struct Product <: Ansatz + super::Quantum +end + +Base.copy(x::Product) = Product(copy(Quantum(x))) + +Base.similar(x::Product) = Product(similar(Quantum(x))) +Base.zero(x::Product) = Product(zero(Quantum(x))) + +function Product(tn::TensorNetwork, sites) + @assert isempty(inds(tn; set=:inner)) "Product ansatz must not have inner indices" + return Product(Quantum(tn, sites)) +end + +Product(arrays::Vector{<:AbstractVector}) = Product(State(), Open(), arrays) +Product(arrays::Vector{<:AbstractMatrix}) = Product(Operator(), Open(), arrays) + +function Product(::State, ::Open, arrays) + symbols = [nextindex() for _ in 1:length(arrays)] + _tensors = map(enumerate(arrays)) do (i, array) + Tensor(array, [symbols[i]]) + end + + sitemap = Dict(Site(i) => symbols[i] for i in 1:length(arrays)) + + return Product(TensorNetwork(_tensors), sitemap) +end + +function Product(::Operator, ::Open, arrays) + n = length(arrays) + symbols = [nextindex() for _ in 1:(2 * length(arrays))] + _tensors = map(enumerate(arrays)) do (i, array) + Tensor(array, [symbols[i + n], symbols[i]]) + end + + sitemap = merge!(Dict(Site(i; dual=true) => symbols[i] for i in 1:n), Dict(Site(i) => symbols[i + n] for i in 1:n)) + + return Product(TensorNetwork(_tensors), sitemap) +end + +function Base.zeros(::Type{Product}, n::Integer; p::Int=2, eltype=Bool) + return Product(State(), Open(), fill(append!([one(eltype)], collect(Iterators.repeated(zero(eltype), p - 1))), n)) +end + +function Base.ones(::Type{Product}, n::Integer; p::Int=2, eltype=Bool) + return Product( + State(), Open(), fill(append!([zero(eltype), one(eltype)], collect(Iterators.repeated(zero(eltype), p - 2))), n) + ) +end + +LinearAlgebra.norm(tn::Product, p::Real=2) = LinearAlgebra.norm(socket(tn), tn, p) +function LinearAlgebra.norm(::Union{State,Operator}, tn::Product, p::Real) + return mapreduce(*, tensors(tn)) do tensor + norm(tensor, p) + end^(1//p) +end + +LinearAlgebra.opnorm(tn::Product, p::Real=2) = LinearAlgebra.opnorm(socket(tn), tn, p) +function LinearAlgebra.opnorm(::Operator, tn::Product, p::Real) + return mapreduce(*, tensors(tn)) do tensor + opnorm(parent(tensor), p) + end^(1//p) +end + +LinearAlgebra.normalize!(tn::Product, p::Real=2) = LinearAlgebra.normalize!(socket(tn), tn, p) +function LinearAlgebra.normalize!(::Union{State,Operator}, tn::Product, p::Real) + for tensor in tensors(tn) + normalize!(tensor, p) + end + return tn +end + +overlap(a::Product, b::Product) = overlap(socket(a), a, socket(b), b) + +function overlap(::State, a::Product, ::State, b::Product) + @assert issetequal(sites(a), sites(b)) "Ansatzes must have the same sites" + + mapreduce(*, zip(tensors(a), tensors(b))) do (ta, tb) + dot(parent(ta), conj(parent(tb))) + end +end diff --git a/src/Helpers.jl b/src/Helpers.jl index 3c61e314..0f91ead4 100644 --- a/src/Helpers.jl +++ b/src/Helpers.jl @@ -64,3 +64,9 @@ function nonunique(x) nonuniqueindexes = setdiff(1:length(x), uniqueindexes) return unique(x[nonuniqueindexes]) end + +const __indexcounter::Threads.Atomic{Int} = Threads.Atomic{Int}(1) + +currindex() = letter(__indexcounter[]) +nextindex() = (__indexcounter.value >= 135000) ? resetindex() : letter(Threads.atomic_add!(__indexcounter, 1)) +resetindex() = letter(Threads.atomic_xchg!(__indexcounter, 1)) diff --git a/src/Quantum.jl b/src/Quantum.jl new file mode 100644 index 00000000..d7b88dee --- /dev/null +++ b/src/Quantum.jl @@ -0,0 +1,308 @@ +""" + Quantum + +Tensor Network with a notion of "causality". This leads to the notion of sites and directionality (input/output). + +# Notes + + - Indices are referenced by `Site`s. +""" +struct Quantum + tn::TensorNetwork + + # WARN keep them synchronized + sites::Dict{Site,Symbol} + # sitetensors::Dict{Site,Tensor} + + function Quantum(tn::TensorNetwork, sites) + for (_, index) in sites + if !haskey(tn.indexmap, index) + error("Index $index not found in TensorNetwork") + elseif index ∉ inds(tn; set=:open) + error("Index $index must be open") + end + end + + # sitetensors = map(sites) do (site, index) + # site => tn[index] + # end |> Dict{Site,Tensor} + + return new(tn, sites) + end +end + +Quantum(qtn::Quantum) = qtn + +""" + TensorNetwork(q::Quantum) + +Returns the underlying `TensorNetwork` of a [`Quantum`](@ref) Tensor Network. +""" +Tenet.TensorNetwork(q::Quantum) = q.tn + +Base.copy(q::Quantum) = Quantum(copy(TensorNetwork(q)), copy(q.sites)) + +Base.similar(q::Quantum) = Quantum(similar(TensorNetwork(q)), copy(q.sites)) +Base.zero(q::Quantum) = Quantum(zero(TensorNetwork(q)), copy(q.sites)) + +Base.:(==)(a::Quantum, b::Quantum) = a.tn == b.tn && a.sites == b.sites +Base.isapprox(a::Quantum, b::Quantum; kwargs...) = isapprox(a.tn, b.tn; kwargs...) && a.sites == b.sites + +""" + adjoint(q::Quantum) + +Returns the adjoint of a [`Quantum`](@ref) Tensor Network; i.e. the conjugate Tensor Network with the inputs and outputs swapped. +""" +function Base.adjoint(qtn::Quantum) + sites = Dict{Site,Symbol}( + Iterators.map(qtn.sites) do (site, index) + site' => index + end, + ) + + tn = conj(TensorNetwork(qtn)) + + # rename inner indices + physical_inds = values(sites) + virtual_inds = setdiff(inds(tn), physical_inds) + replace!(tn, map(virtual_inds) do i + i => Symbol(i, "'") + end...) + + return Quantum(tn, sites) +end + +""" + ninputs(q::Quantum) + +Returns the number of input sites of a [`Quantum`](@ref) Tensor Network. +""" +ninputs(q::Quantum) = count(isdual, keys(q.sites)) + +""" + noutputs(q::Quantum) + +Returns the number of output sites of a [`Quantum`](@ref) Tensor Network. +""" +noutputs(q::Quantum) = count(!isdual, keys(q.sites)) + +""" + inputs(q::Quantum) + +Returns the input sites of a [`Quantum`](@ref) Tensor Network. +""" +inputs(q::Quantum) = sort!(collect(filter(isdual, keys(q.sites)))) + +""" + outputs(q::Quantum) + +Returns the output sites of a [`Quantum`](@ref) Tensor Network. +""" +outputs(q::Quantum) = sort!(collect(filter(!isdual, keys(q.sites)))) + +Base.summary(io::IO, q::Quantum) = print(io, "$(length(q.tn.tensormap))-tensors Quantum") +Base.show(io::IO, q::Quantum) = print(io, "Quantum (inputs=$(ninputs(q)), outputs=$(noutputs(q)))") + +""" + sites(q::Quantum) + +Returns the sites of a [`Quantum`](@ref) Tensor Network. +""" +function sites(tn::Quantum; kwargs...) + if isempty(kwargs) + collect(keys(tn.sites)) + elseif keys(kwargs) === (:at,) + findfirst(i -> i === kwargs[:at], tn.sites) + else + throw(MethodError(sites, (Quantum,), kwargs)) + end +end + +""" + nsites(q::Quantum) + +Returns the number of sites of a [`Quantum`](@ref) Tensor Network. +""" +nsites(tn::Quantum) = length(tn.sites) + +""" + lanes(q::Quantum) + +Returns the lanes of a [`Quantum`](@ref) Tensor Network. +""" +lanes(tn::Quantum) = unique( + Iterators.map(Iterators.flatten([inputs(tn), outputs(tn)])) do site + isdual(site) ? site' : site + end, +) + +""" + nlanes(q::Quantum) + +Returns the number of lanes of a [`Quantum`](@ref) Tensor Network. +""" +nlanes(tn::Quantum) = length(lanes(tn)) + +""" + getindex(q::Quantum, site::Site) + +Returns the index associated with a site in a [`Quantum`](@ref) Tensor Network. +""" +Base.getindex(q::Quantum, site::Site) = inds(q; at=site) + +""" + Socket + +Abstract type representing the socket of a [`Quantum`](@ref) Tensor Network. +""" +abstract type Socket end + +""" + Scalar <: Socket + +Socket representing a scalar; i.e. a Tensor Network with no open sites. +""" +struct Scalar <: Socket end + +""" + State <: Socket + +Socket representing a state; i.e. a Tensor Network with only input sites (or only output sites if `dual = true`). +""" +Base.@kwdef struct State <: Socket + dual::Bool = false +end + +""" + Operator <: Socket + +Socket representing an operator; i.e. a Tensor Network with both input and output sites. +""" +struct Operator <: Socket end + +""" + socket(q::Quantum) + +Returns the socket of a [`Quantum`](@ref) Tensor Network; i.e. whether it is a [`Scalar`](@ref), [`State`](@ref) or [`Operator`](@ref). +""" +function socket(q::Quantum) + _sites = sites(q) + if isempty(_sites) + Scalar() + elseif all(!isdual, _sites) + State() + elseif all(isdual, _sites) + State(; dual=true) + else + Operator() + end +end + +# forward `TensorNetwork` methods +for f in [:(Tenet.arrays), :(Base.collect)] + @eval $f(@nospecialize tn::Quantum) = $f(TensorNetwork(tn)) +end + +""" + inds(tn::Quantum, set::Symbol = :all, args...; kwargs...) + +Options: + + - `:at`: index at a site +""" +function Tenet.inds(tn::Quantum; kwargs...) + if keys(kwargs) === (:at,) + inds(tn, Val(:at), kwargs[:at]) + else + inds(TensorNetwork(tn); kwargs...) + end +end + +Tenet.inds(tn::Quantum, ::Val{:at}, site::Site) = tn.sites[site] + +""" + tensors(tn::Quantum, query::Symbol, args...; kwargs...) + +Options: + + - `:at`: tensor at a site +""" +function Tenet.tensors(tn::Quantum; kwargs...) + if keys(kwargs) === (:at,) + tensors(tn, Val(:at), kwargs[:at]) + else + tensors(TensorNetwork(tn); kwargs...) + end +end + +Tenet.tensors(tn::Quantum, ::Val{:at}, site::Site) = only(tensors(tn; intersects=inds(tn; at=site))) + +function reindex!(a::Quantum, ioa, b::Quantum, iob) + ioa ∈ [:inputs, :outputs] || error("Invalid argument: :$ioa") + + sitesb = if iob === :inputs + inputs(b) + elseif iob === :outputs + outputs(b) + else + error("Invalid argument: :$iob") + end + + replacements = map(sitesb) do site + inds(b; at=site) => inds(a; at=ioa != iob ? site' : site) + end + + if issetequal(first.(replacements), last.(replacements)) + return b + end + + replace!(TensorNetwork(b), replacements...) + + for site in sitesb + b.sites[site] = inds(a; at=ioa != iob ? site' : site) + end + + return b +end + +""" + @reindex! a => b + +Reindexes the input/output sites of a [`Quantum`](@ref) Tensor Network `b` to match the input/output sites of another [`Quantum`](@ref) Tensor Network `a`. +""" +macro reindex!(expr) + @assert Meta.isexpr(expr, :call) && expr.args[1] == :(=>) + Base.remove_linenums!(expr) + a, b = expr.args[2:end] + + @assert Meta.isexpr(a, :call) + @assert Meta.isexpr(b, :call) + ioa, ida = a.args + iob, idb = b.args + return :((reindex!(Quantum($(esc(ida))), $(Meta.quot(ioa)), Quantum($(esc(idb))), $(Meta.quot(iob)))); $(esc(idb))) +end + +""" + merge(a::Quantum, b::Quantum...) + +Merges multiple [`Quantum`](@ref) Tensor Networks into a single one by connecting input/output sites. +""" +Base.merge(a::Quantum, others::Quantum...) = foldl(merge, others; init=a) +function Base.merge(a::Quantum, b::Quantum) + @assert issetequal(outputs(a), map(adjoint, inputs(b))) "Outputs of $a must match inputs of $b" + + @reindex! outputs(a) => inputs(b) + tn = merge(TensorNetwork(a), TensorNetwork(b)) + + sites = Dict{Site,Symbol}() + + for site in inputs(a) + sites[site] = inds(a; at=site) + end + + for site in outputs(b) + sites[site] = inds(b; at=site) + end + + return Quantum(tn, sites) +end diff --git a/src/Site.jl b/src/Site.jl new file mode 100644 index 00000000..4e865c93 --- /dev/null +++ b/src/Site.jl @@ -0,0 +1,43 @@ +# TODO Should we store here some information about quantum numbers? +""" + Site(id[, dual = false]) + site"i,j,..." + +Represents a physical index. +""" +struct Site{N} + id::NTuple{N,Int} + dual::Bool + + Site(id::NTuple{N,Int}; dual=false) where {N} = new{N}(id, dual) +end + +Site(id::Int; kwargs...) = Site((id,); kwargs...) +Site(id::Vararg{Int,N}; kwargs...) where {N} = Site(id; kwargs...) + +id(site::Site{1}) = only(site.id) +id(site::Site) = site.id + +Base.CartesianIndex(site::Site) = CartesianIndex(id(site)) + +isdual(site::Site) = site.dual +Base.show(io::IO, site::Site) = print(io, "$(id(site))$(site.dual ? "'" : "")") +Base.adjoint(site::Site) = Site(id(site); dual=!site.dual) +Base.isless(a::Site, b::Site) = id(a) < id(b) + +macro site_str(str) + m = match(r"^(\d+,)*\d+('?)$", str) + if isnothing(m) + error("Invalid site string: $str") + end + + id = tuple(map(eachmatch(r"(\d+)", str)) do match + parse(Int, only(match.captures)) + end...) + + dual = endswith(str, "'") + + return :(Site($id; dual=$dual)) +end + +Base.zero(x::Dict{Site,Symbol}) = x diff --git a/src/Tenet.jl b/src/Tenet.jl index c2cdfdce..95709c95 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -18,6 +18,36 @@ export transform, transform! include("Compiler.jl") +include("Site.jl") +export Site, @site_str, isdual + +include("Quantum.jl") +export Quantum, ninputs, noutputs, inputs, outputs, sites, nsites + +include("Ansatz/Ansatz.jl") +export Ansatz +export socket, Scalar, State, Operator +export boundary, Open, Periodic + +include("Ansatz/Product.jl") +export Product + +include("Ansatz/Dense.jl") +export Dense + +include("Ansatz/Chain.jl") +export Chain +export MPS, pMPS, MPO, pMPO +export leftindex, rightindex, isleftcanonical, isrightcanonical +export canonize_site, canonize_site!, truncate! +export canonize, canonize!, mixed_canonize, mixed_canonize! + +include("Ansatz/Grid.jl") +export Grid +export PEPS, pPEPS, PEPO, pPEPO + +export evolve!, expect, overlap + # reexports from EinExprs export einexpr, inds diff --git a/test/Chain_test.jl b/test/Chain_test.jl new file mode 100644 index 00000000..a3b3d294 --- /dev/null +++ b/test/Chain_test.jl @@ -0,0 +1,390 @@ +@testset "Chain ansatz" begin + @testset "Periodic boundary" begin + @testset "State" begin + qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) + @test socket(qtn) == State() + @test ninputs(qtn) == 0 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3"]) + @test boundary(qtn) == Periodic() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing + + arrays = [rand(2, 1, 4), rand(2, 4, 3), rand(2, 3, 1)] + qtn = Chain(State(), Periodic(), arrays) # Default order (:o, :l, :r) + + @test size(tensors(qtn; at=Site(1))) == (2, 1, 4) + @test size(tensors(qtn; at=Site(2))) == (2, 4, 3) + @test size(tensors(qtn; at=Site(3))) == (2, 3, 1) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + arrays = [permutedims(array, (3, 1, 2)) for array in arrays] # now we have (:r, :o, :l) + qtn = Chain(State(), Periodic(), arrays; order=[:r, :o, :l]) + + @test size(tensors(qtn; at=Site(1))) == (4, 2, 1) + @test size(tensors(qtn; at=Site(2))) == (3, 2, 4) + @test size(tensors(qtn; at=Site(3))) == (1, 2, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + for i in 1:nsites(qtn) + @test size(TensorNetwork(qtn), inds(qtn; at=Site(i))) == 2 + end + end + + @testset "Operator" begin + qtn = Chain(Operator(), Periodic(), [rand(2, 2, 4, 4) for _ in 1:3]) + @test socket(qtn) == Operator() + @test ninputs(qtn) == 3 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) + @test boundary(qtn) == Periodic() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing + + arrays = [rand(2, 4, 1, 3), rand(2, 4, 3, 6), rand(2, 4, 6, 1)] # Default order (:o, :i, :l, :r) + qtn = Chain(Operator(), Periodic(), arrays) + + @test size(tensors(qtn; at=Site(1))) == (2, 4, 1, 3) + @test size(tensors(qtn; at=Site(2))) == (2, 4, 3, 6) + @test size(tensors(qtn; at=Site(3))) == (2, 4, 6, 1) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at=Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at=Site(i; dual=true))) == 4 + end + + arrays = [permutedims(array, (4, 1, 3, 2)) for array in arrays] # now we have (:r, :o, :l, :i) + qtn = Chain(Operator(), Periodic(), arrays; order=[:r, :o, :l, :i]) + + @test size(tensors(qtn; at=Site(1))) == (3, 2, 1, 4) + @test size(tensors(qtn; at=Site(2))) == (6, 2, 3, 4) + @test size(tensors(qtn; at=Site(3))) == (1, 2, 6, 4) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) !== nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at=Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at=Site(i; dual=true))) == 4 + end + end + end + + @testset "Open boundary" begin + @testset "State" begin + qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + @test socket(qtn) == State() + @test ninputs(qtn) == 0 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3"]) + @test boundary(qtn) == Open() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing + + arrays = [rand(2, 1), rand(2, 1, 3), rand(2, 3)] + qtn = Chain(State(), Open(), arrays) # Default order (:o, :l, :r) + + @test size(tensors(qtn; at=Site(1))) == (2, 1) + @test size(tensors(qtn; at=Site(2))) == (2, 1, 3) + @test size(tensors(qtn; at=Site(3))) == (2, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (1, 2))] # now we have (:r, :o, :l) + qtn = Chain(State(), Open(), arrays; order=[:r, :o, :l]) + + @test size(tensors(qtn; at=Site(1))) == (1, 2) + @test size(tensors(qtn; at=Site(2))) == (3, 2, 1) + @test size(tensors(qtn; at=Site(3))) == (2, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing + + for i in 1:nsites(qtn) + @test size(TensorNetwork(qtn), inds(qtn; at=Site(i))) == 2 + end + end + @testset "Operator" begin + qtn = Chain(Operator(), Open(), [rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)]) + @test socket(qtn) == Operator() + @test ninputs(qtn) == 3 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) + @test boundary(qtn) == Open() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing + + arrays = [rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)] # Default order (:o :i, :l, :r) + qtn = Chain(Operator(), Open(), arrays) + + @test size(tensors(qtn; at=Site(1))) == (2, 4, 1) + @test size(tensors(qtn; at=Site(2))) == (2, 4, 1, 3) + @test size(tensors(qtn; at=Site(3))) == (2, 4, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at=Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at=Site(i; dual=true))) == 4 + end + + arrays = [ + permutedims(arrays[1], (3, 1, 2)), + permutedims(arrays[2], (4, 1, 3, 2)), + permutedims(arrays[3], (1, 3, 2)), + ] # now we have (:r, :o, :l, :i) + qtn = Chain(Operator(), Open(), arrays; order=[:r, :o, :l, :i]) + + @test size(tensors(qtn; at=Site(1))) == (1, 2, 4) + @test size(tensors(qtn; at=Site(2))) == (3, 2, 1, 4) + @test size(tensors(qtn; at=Site(3))) == (2, 3, 4) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at=Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at=Site(i; dual=true))) == 4 + end + end + end + + @testset "Site" begin + using Tenet: leftsite, rightsite + qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) + + @test leftsite(qtn, Site(1)) == Site(3) + @test leftsite(qtn, Site(2)) == Site(1) + @test leftsite(qtn, Site(3)) == Site(2) + + @test rightsite(qtn, Site(1)) == Site(2) + @test rightsite(qtn, Site(2)) == Site(3) + @test rightsite(qtn, Site(3)) == Site(1) + + qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + + @test isnothing(leftsite(qtn, Site(1))) + @test isnothing(rightsite(qtn, Site(3))) + + @test leftsite(qtn, Site(2)) == Site(1) + @test leftsite(qtn, Site(3)) == Site(2) + + @test rightsite(qtn, Site(2)) == Site(3) + @test rightsite(qtn, Site(1)) == Site(2) + end + + @testset "truncate" begin + qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + canonize_site!(qtn, Site(2); direction=:right, method=:svd) + + @test_throws Tenet.MissingSchmidtCoefficientsException truncate!(qtn, [Site(1), Site(2)]; maxdim=1) + # @test_throws ArgumentError truncate!(qtn, [Site(2), Site(3)]) + + truncated = Tenet.truncate(qtn, [Site(2), Site(3)]; maxdim=1) + @test size(TensorNetwork(truncated), rightindex(truncated, Site(2))) == 1 + @test size(TensorNetwork(truncated), leftindex(truncated, Site(3))) == 1 + + singular_values = tensors(qtn; between=(Site(2), Site(3))) + truncated = Tenet.truncate(qtn, [Site(2), Site(3)]; threshold=singular_values[2] + 0.1) + @test size(TensorNetwork(truncated), rightindex(truncated, Site(2))) == 1 + @test size(TensorNetwork(truncated), leftindex(truncated, Site(3))) == 1 + end + + @testset "rand" begin + using LinearAlgebra: norm + + @testset "State" begin + n = 8 + χ = 10 + + qtn = rand(Chain, Open, State; n, p=2, χ) + @test socket(qtn) == State() + @test ninputs(qtn) == 0 + @test noutputs(qtn) == n + @test issetequal(sites(qtn), map(Site, 1:n)) + @test boundary(qtn) == Open() + @test isapprox(norm(qtn), 1.0) + @test maximum(last, size(TensorNetwork(qtn))) <= χ + end + + @testset "Operator" begin + n = 8 + χ = 10 + + qtn = rand(Chain, Open, Operator; n, p=2, χ) + @test socket(qtn) == Operator() + @test ninputs(qtn) == n + @test noutputs(qtn) == n + @test issetequal(sites(qtn), vcat(map(Site, 1:n), map(adjoint ∘ Site, 1:n))) + @test boundary(qtn) == Open() + @test isapprox(norm(qtn), 1.0) + @test maximum(last, size(TensorNetwork(qtn))) <= χ + end + end + + @testset "Canonization" begin + using Tenet + + @testset "contract" begin + qtn = rand(Chain, Open, State; n=5, p=2, χ=20) + let canonized = canonize(qtn) + @test_throws ArgumentError contract!(canonized, :between, Site(1), Site(2); direction=:dummy) + end + + canonized = canonize(qtn) + + for i in 1:4 + contract_some = contract(canonized, :between, Site(i), Site(i + 1)) + Bᵢ = tensors(contract_some; at=Site(i)) + + @test isapprox(contract(TensorNetwork(contract_some)), contract(TensorNetwork(qtn))) + @test_throws MethodError tensors(contract_some, :between, Site(i), Site(i + 1)) + + @test isrightcanonical(contract_some, Site(i)) + @test isleftcanonical( + contract(canonized, :between, Site(i), Site(i + 1); direction=:right), Site(i + 1) + ) + + Γᵢ = tensors(canonized; at=Site(i)) + Λᵢ₊₁ = tensors(canonized; between=(Site(i), Site(i + 1))) + @test Bᵢ ≈ contract(Γᵢ, Λᵢ₊₁; dims=()) + end + end + + @testset "canonize_site" begin + qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4)]) + + @test_throws ArgumentError canonize_site!(qtn, Site(1); direction=:left) + @test_throws ArgumentError canonize_site!(qtn, Site(3); direction=:right) + + for method in [:qr, :svd] + canonized = canonize_site(qtn, site"1"; direction=:right, method=method) + @test isleftcanonical(canonized, site"1") + @test isapprox( + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)) + ) + + canonized = canonize_site(qtn, site"2"; direction=:right, method=method) + @test isleftcanonical(canonized, site"2") + @test isapprox( + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)) + ) + + canonized = canonize_site(qtn, site"2"; direction=:left, method=method) + @test isrightcanonical(canonized, site"2") + @test isapprox( + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)) + ) + + canonized = canonize_site(qtn, site"3"; direction=:left, method=method) + @test isrightcanonical(canonized, site"3") + @test isapprox( + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)) + ) + end + + # Ensure that svd creates a new tensor + @test length(tensors(canonize_site(qtn, Site(2); direction=:left, method=:svd))) == 4 + end + + @testset "canonize" begin + using Tenet: isleftcanonical, isrightcanonical + + qtn = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + canonized = canonize(qtn) + + @test length(tensors(canonized)) == 9 # 5 tensors + 4 singular values vectors + @test isapprox( + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)) + ) + @test isapprox(norm(qtn), norm(canonized)) + + # Extract the singular values between each adjacent pair of sites in the canonized chain + Λ = [tensors(canonized; between=(Site(i), Site(i + 1))) for i in 1:4] + @test map(λ -> sum(abs2, λ), Λ) ≈ ones(length(Λ)) * norm(canonized)^2 + + for i in 1:5 + canonized = canonize(qtn) + + if i == 1 + @test isleftcanonical(canonized, Site(i)) + elseif i == 5 # in the limits of the chain, we get the norm of the state + contract!(canonized, :between, Site(i - 1), Site(i); direction=:right) + tensor = tensors(canonized; at=Site(i)) + replace!(TensorNetwork(canonized), tensor => tensor / norm(canonized)) + @test isleftcanonical(canonized, Site(i)) + else + contract!(canonized, :between, Site(i - 1), Site(i); direction=:right) + @test isleftcanonical(canonized, Site(i)) + end + end + + for i in 1:5 + canonized = canonize(qtn) + + if i == 1 # in the limits of the chain, we get the norm of the state + contract!(canonized, :between, Site(i), Site(i + 1); direction=:left) + tensor = tensors(canonized; at=Site(i)) + replace!(TensorNetwork(canonized), tensor => tensor / norm(canonized)) + @test isrightcanonical(canonized, Site(i)) + elseif i == 5 + @test isrightcanonical(canonized, Site(i)) + else + contract!(canonized, :between, Site(i), Site(i + 1); direction=:left) + @test isrightcanonical(canonized, Site(i)) + end + end + end + + @testset "mixed_canonize" begin + qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + canonized = mixed_canonize(qtn, Site(3)) + + @test length(tensors(canonized)) == length(tensors(qtn)) + 1 + + @test isleftcanonical(canonized, Site(1)) + @test isleftcanonical(canonized, Site(2)) + @test isrightcanonical(canonized, Site(3)) + @test isrightcanonical(canonized, Site(4)) + @test isrightcanonical(canonized, Site(5)) + + @test isapprox( + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)) + ) + end + end + + @test begin + qtn = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + normalize!(qtn, Site(3)) + isapprox(norm(qtn), 1.0) + end + + @testset "adjoint" begin + qtn = rand(Chain, Open, State; n=5, p=2, χ=10) + adjoint_qtn = adjoint(qtn) + + for i in 1:nsites(qtn) + i < nsites(qtn) && + @test rightindex(adjoint_qtn, Site(i; dual=true)) == Symbol(String(rightindex(qtn, Site(i))) * "'") + i > 1 && @test leftindex(adjoint_qtn, Site(i; dual=true)) == Symbol(String(leftindex(qtn, Site(i))) * "'") + end + + @test isapprox(contract(TensorNetwork(qtn)), contract(TensorNetwork(adjoint_qtn))) + end + + # TODO test `evolve!` methods +end diff --git a/test/Product_test.jl b/test/Product_test.jl new file mode 100644 index 00000000..d255a421 --- /dev/null +++ b/test/Product_test.jl @@ -0,0 +1,29 @@ +@testset "Product ansatz" begin + using LinearAlgebra + + # TODO test `Product` with `Scalar` socket + + qtn = Product([rand(2) for _ in 1:3]) + @test socket(qtn) == State() + @test ninputs(qtn) == 0 + @test noutputs(qtn) == 3 + @test norm(qtn) isa Number + @test begin + normalize!(qtn) + norm(qtn) ≈ 1 + end + + # conversion to `Quantum` + @test Quantum(qtn) isa Quantum + + qtn = Product([rand(2, 2) for _ in 1:3]) + @test socket(qtn) == Operator() + @test ninputs(qtn) == 3 + @test noutputs(qtn) == 3 + @test norm(qtn) isa Number + @test opnorm(qtn) isa Number + @test begin + normalize!(qtn) + norm(qtn) ≈ 1 + end +end diff --git a/test/Project.toml b/test/Project.toml index 94c8cda0..a068fdcd 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,6 +13,7 @@ Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a" OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922" Permutations = "2ae35dd2-176d-5d53-8349-f30d82d94d4f" +Quac = "b9105292-1415-45cf-bff1-d6ccf71e6143" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" diff --git a/test/Quantum_test.jl b/test/Quantum_test.jl new file mode 100644 index 00000000..673869c5 --- /dev/null +++ b/test/Quantum_test.jl @@ -0,0 +1,45 @@ +@testset "Quantum" begin + using Tenet + + _tensors = Tensor[Tensor(zeros(2), [:i])] + tn = TensorNetwork(_tensors) + qtn = Quantum(tn, Dict(site"1" => :i)) + @test ninputs(qtn) == 0 + @test noutputs(qtn) == 1 + @test issetequal(sites(qtn), [site"1"]) + @test socket(qtn) == State(; dual=false) + + # forwarded methods to `TensorNetwork` + @test TensorNetwork(qtn) == tn + @test tensors(qtn) == _tensors + + _tensors = Tensor[Tensor(zeros(2), [:i])] + tn = TensorNetwork(_tensors) + qtn = Quantum(tn, Dict(site"1'" => :i)) + @test ninputs(qtn) == 1 + @test noutputs(qtn) == 0 + @test issetequal(sites(qtn), [site"1'"]) + @test socket(qtn) == State(; dual=true) + + _tensors = Tensor[Tensor(zeros(2, 2), [:i, :j])] + tn = TensorNetwork(_tensors) + qtn = Quantum(tn, Dict(site"1" => :i, site"1'" => :j)) + @test ninputs(qtn) == 1 + @test noutputs(qtn) == 1 + @test issetequal(sites(qtn), [site"1", site"1'"]) + @test socket(qtn) == Operator() + + _tensors = Tensor[Tensor(fill(0))] + tn = TensorNetwork(_tensors) + qtn = Quantum(tn, Dict()) + @test ninputs(qtn) == 0 + @test noutputs(qtn) == 0 + @test isempty(sites(qtn)) + @test socket(qtn) == Scalar() + + # detect errors + _tensors = Tensor[Tensor(zeros(2), [:i]), Tensor(zeros(2), [:i])] + tn = TensorNetwork(_tensors) + @test_throws ErrorException Quantum(tn, Dict(site"1" => :j)) + @test_throws ErrorException Quantum(tn, Dict(site"1" => :i)) +end diff --git a/test/Site_test.jl b/test/Site_test.jl new file mode 100644 index 00000000..c5b8137e --- /dev/null +++ b/test/Site_test.jl @@ -0,0 +1,63 @@ +@testset "Site" begin + using Tenet: id + + s = Site(1) + @test id(s) == 1 + @test CartesianIndex(s) == CartesianIndex(1) + @test isdual(s) == false + + s = Site(1; dual=true) + @test id(s) == 1 + @test CartesianIndex(s) == CartesianIndex(1) + @test isdual(s) == true + + s = Site(1, 2) + @test id(s) == (1, 2) + @test CartesianIndex(s) == CartesianIndex((1, 2)) + @test isdual(s) == false + + s = Site(1, 2; dual=true) + @test id(s) == (1, 2) + @test CartesianIndex(s) == CartesianIndex((1, 2)) + @test isdual(s) == true + + s = site"1" + @test id(s) == 1 + @test CartesianIndex(s) == CartesianIndex(1) + @test isdual(s) == false + + s = site"1'" + @test id(s) == 1 + @test CartesianIndex(s) == CartesianIndex(1) + @test isdual(s) == true + + s = site"1,2" + @test id(s) == (1, 2) + @test CartesianIndex(s) == CartesianIndex((1, 2)) + @test isdual(s) == false + + s = site"1,2'" + @test id(s) == (1, 2) + @test CartesianIndex(s) == CartesianIndex((1, 2)) + @test isdual(s) == true + + s = adjoint(site"1") + @test id(s) == 1 + @test CartesianIndex(s) == CartesianIndex(1) + @test isdual(s) == true + + s = adjoint(site"1'") + @test id(s) == 1 + @test CartesianIndex(s) == CartesianIndex(1) + @test isdual(s) == false + + s = adjoint(site"1,2") + @test id(s) == (1, 2) + @test CartesianIndex(s) == CartesianIndex((1, 2)) + @test isdual(s) == true + + s = adjoint(site"1,2'") + @test id(s) == (1, 2) + @test CartesianIndex(s) == CartesianIndex((1, 2)) + @test isdual(s) == false +end diff --git a/test/integration/ChainRules_test.jl b/test/integration/ChainRules_test.jl index b0166124..79b76386 100644 --- a/test/integration/ChainRules_test.jl +++ b/test/integration/ChainRules_test.jl @@ -183,4 +183,37 @@ end end end + + @testset "Quantum" begin + test_frule(Quantum, TensorNetwork([Tensor(ones(2), [:i])]), Dict{Site,Symbol}(site"1" => :i)) + test_rrule(Quantum, TensorNetwork([Tensor(ones(2), [:i])]), Dict{Site,Symbol}(site"1" => :i)) + end + + @testset "Ansatz" begin + @testset "Product" begin + tn = TensorNetwork([Tensor(ones(2), [:i]), Tensor(ones(2), [:j]), Tensor(ones(2), [:k])]) + qtn = Quantum(tn, Dict([site"1" => :i, site"2" => :j, site"3" => :k])) + + test_frule(Product, qtn) + test_rrule(Product, qtn) + end + + @testset "Chain" begin + tn = Chain(State(), Open(), [ones(2, 2), ones(2, 2, 2), ones(2, 2)]) + # test_frule(Chain, Quantum(tn), Open()) + test_rrule(Chain, Quantum(tn), Open()) + + tn = Chain(State(), Periodic(), [ones(2, 2, 2), ones(2, 2, 2), ones(2, 2, 2)]) + # test_frule(Chain, Quantum(tn), Periodic()) + test_rrule(Chain, Quantum(tn), Periodic()) + + tn = Chain(Operator(), Open(), [ones(2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2)]) + # test_frule(Chain, Quantum(tn), Open()) + test_rrule(Chain, Quantum(tn), Open()) + + tn = Chain(Operator(), Periodic(), [ones(2, 2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2, 2)]) + # test_frule(Chain, Quantum(tn), Periodic()) + test_rrule(Chain, Quantum(tn), Periodic()) + end + end end diff --git a/test/integration/Quac_test.jl b/test/integration/Quac_test.jl new file mode 100644 index 00000000..63aa0e87 --- /dev/null +++ b/test/integration/Quac_test.jl @@ -0,0 +1,23 @@ +@testset "Quac" begin + using Quac + + @testset "QFT" begin + n = 3 + qftcirc = Quac.Algorithms.QFT(n) + qftqtn = Quantum(qftcirc) + + # correct number of inputs and outputs + @test ninputs(qftqtn) == n + @test noutputs(qftqtn) == n + @test socket(qftqtn) == Operator() + + # all open indices are sites + siteinds = getindex.((qftqtn,), sites(qftqtn)) + @test issetequal(inds(TensorNetwork(qftqtn); set=:open), siteinds) + + # all inner indices are not sites + # TODO too strict condition. remove? + notsiteinds = setdiff(inds(TensorNetwork(qftqtn)), siteinds) + @test_skip issetequal(inds(TensorNetwork(qftqtn); set=:inner), notsiteinds) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d16a31fe..02730b2f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,10 @@ using OMEinsum include("Numerics_test.jl") include("TensorNetwork_test.jl") include("Transformations_test.jl") + include("Site_test.jl") + include("Quantum_test.jl") + include("Product_test.jl") + include("Chain_test.jl") end # CI hangs on these tests for some unknown reason on Julia 1.9 @@ -17,6 +21,7 @@ if VERSION >= v"1.10" # include("integration/BlockArray_test.jl") include("integration/Dagger_test.jl") include("integration/Makie_test.jl") + include("integration/Quac_test.jl") end end