-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support conversion to/from ITensors.jl and ITensorNetworks.jl (#200)
* Init extension * Support conversion from `ITensor` to `Tensor` * Support conversion from `Vector{ITensor}` to `TensorNetwork` * Fix call to `id` * Support conversion from `ITensorNetwork` to `Quantum`, `TensorNetwork` * Lower compat bound of KrylovKit to v0.7 Required by ITensorNetworks * Support conversion from `Tensor`, `TensorNetwork` to `ITensor`, `Vector{ITensor}` * Rename `Tensor`, `TensorNetwork`, `Quantum` methods to `Base.convert` methods * Support conversion from `AbstractTensorNetwork` to `ITensorNetwork` * Fix some typos and problems * Implement `sites` method for recovering site from `Symbol` * Support conversion from `AbstractQuantum` to `ITensorNetwork` * Comment future idea * Test `ITensors`, `ITensorNetworks` integration * Fix `Vector{ITensor}` to `TensorNetwork` conversion * Fix tests * Defer ITensorNetworks instantiation to test time * Add Pkg package to test dependencies
- Loading branch information
Showing
8 changed files
with
166 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
module TenetITensorNetworksExt | ||
|
||
using Tenet | ||
using ITensorNetworks: ITensorNetworks, ITensorNetwork, ITensor, Index, siteinds, plev, vertices, rename_vertices | ||
const ITensors = ITensorNetworks.ITensors | ||
const DataGraphs = ITensorNetworks.DataGraphs | ||
const TenetITensorsExt = Base.get_extension(Tenet, :TenetITensorsExt) | ||
|
||
Base.convert(::Type{TensorNetwork}, tn::ITensorNetwork) = TensorNetwork([convert(Tensor, tn[v]) for v in vertices(tn)]) | ||
|
||
function Base.convert(::Type{ITensorNetwork}, tn::Tenet.AbstractTensorNetwork; inds=Dict{Symbol,Index}()) | ||
return ITensorNetwork(convert(Vector{ITensor}, tn; inds)) | ||
end | ||
|
||
function Base.convert(::Type{Quantum}, tn::ITensorNetwork) | ||
sitedict = Dict( | ||
map(pairs(DataGraphs.vertex_data(siteinds(tn)))) do (loc, index) | ||
index = only(index) | ||
primelevel = plev(index) | ||
@assert primelevel ∈ (0, 1) | ||
|
||
# NOTE ITensors' Index's tag only has space for 16 characters | ||
tag = ITensors.id(index) | ||
Site(loc; dual=Bool(primelevel)) => TenetITensorsExt.symbolize(index) | ||
end, | ||
) | ||
return Quantum(convert(TensorNetwork, tn), sitedict) | ||
end | ||
|
||
function Base.convert(::Type{ITensorNetwork}, tn::Tenet.AbstractQuantum) | ||
itn = @invoke convert(ITensorNetwork, tn::Tenet.AbstractTensorNetwork) | ||
|
||
return rename_vertices(itn) do v | ||
itensor = itn[v] | ||
indices = map(x -> Symbol(replace(x, "\"" => "")), string.(ITensors.tags.(ITensors.inds(itensor)))) | ||
tensor = only(tensors(tn; contains=indices)) | ||
physical_index = only(inds(tn; set=:physical) ∩ inds(tensor)) | ||
return sites(tn; at=physical_index).id | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
module TenetITensorsExt | ||
|
||
using Tenet | ||
using ITensors: ITensors, ITensor, Index | ||
|
||
function symbolize(index::Index) | ||
tag = string(ITensors.id(index)) | ||
|
||
# NOTE ITensors' Index's tag only has space for 16 characters | ||
return Symbol(length(tag) > 16 ? tag[(end - 16 + 1):end] : tag) | ||
end | ||
|
||
function tagize(index::Symbol) | ||
tag = string(index) | ||
|
||
# NOTE ITensors' Index's tag only has space for 16 characters | ||
return length(tag) > 16 ? tag[(end - 16 + 1):end] : tag | ||
end | ||
|
||
# TODO customize index names | ||
function Base.convert(::Type{Tensor}, tensor::ITensor) | ||
array = ITensors.array(tensor) | ||
is = map(symbolize, ITensors.inds(tensor)) | ||
return Tensor(array, is) | ||
end | ||
|
||
function Base.convert(::Type{ITensor}, tensor::Tensor; inds=Dict{Symbol,Index}()) | ||
indices = map(Tenet.inds(tensor)) do i | ||
haskey(inds, i) ? inds[i] : Index(size(tensor, i), tagize(i)) | ||
end | ||
return ITensor(parent(tensor), indices) | ||
end | ||
|
||
Base.convert(::Type{TensorNetwork}, tn::Vector{ITensor}) = TensorNetwork(map(t -> convert(Tensor, t), tn)) | ||
|
||
function Base.convert(::Type{Vector{ITensor}}, tn::Tenet.AbstractTensorNetwork; inds=Dict{Symbol,Index}()) | ||
indices = merge(inds, Dict( | ||
map(Iterators.filter(!Base.Fix1(haskey, inds), Tenet.inds(tn))) do i | ||
i => Index(size(tn, i), tagize(i)) | ||
end, | ||
)) | ||
return map(tensors(tn)) do tensor | ||
ITensor(parent(tensor), map(i -> indices[i], Tenet.inds(tensor))) | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# breaks in instantiation on Julia 1.9 | ||
using Pkg | ||
Pkg.add("ITensorNetworks") | ||
|
||
@testset "ITensorNetworks" begin | ||
using ITensors: ITensors, ITensor, Index, array | ||
using ITensorNetworks: ITensorNetwork, vertices | ||
|
||
i = Index(2, "i") | ||
j = Index(3, "j") | ||
k = Index(4, "k") | ||
l = Index(5, "l") | ||
m = Index(6, "m") | ||
|
||
a = ITensor(rand(2, 3), i, j) | ||
b = ITensor(rand(3, 4, 5), j, k, l) | ||
c = ITensor(rand(5, 6), l, m) | ||
itn = ITensorNetwork([a, b, c]) | ||
|
||
tn = convert(TensorNetwork, itn) | ||
@test tn isa TensorNetwork | ||
@test issetequal(arrays(tn), array.([a, b, c])) | ||
|
||
itn = convert(ITensorNetwork, tn) | ||
@test itn isa ITensorNetwork | ||
@test issetequal(map(v -> array(itn[v]), vertices(itn)), array.([a, b, c])) | ||
|
||
# TODO test Quantum | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
@testset "ITensors" begin | ||
using ITensors: ITensors, ITensor, Index, array | ||
|
||
i = Index(2, "i") | ||
j = Index(3, "j") | ||
k = Index(4, "k") | ||
|
||
itensor = ITensor(rand(2, 3, 4), i, j, k) | ||
|
||
tensor = convert(Tensor, itensor) | ||
@test tensor isa Tensor | ||
@test size(tensor) == (2, 3, 4) | ||
@test parent(tensor) == array(itensor) | ||
|
||
tensor = Tensor(rand(2, 3, 4), (:i, :j, :k)) | ||
itensor = convert(ITensor, tensor) | ||
@test itensor isa ITensor | ||
@test size(itensor) == (2, 3, 4) | ||
@test array(itensor) == parent(tensor) | ||
@test all( | ||
splat(==), | ||
zip(map(x -> replace(x, "\"" => ""), string.(ITensors.tags.(ITensors.inds(itensor)))), ["i", "j", "k"]), | ||
) | ||
|
||
tn = rand(TensorNetwork, 4, 3) | ||
itensors = convert(Vector{ITensor}, tn) | ||
@test itensors isa Vector{ITensor} | ||
|
||
tnr = convert(TensorNetwork, itensors) | ||
@test tnr isa TensorNetwork | ||
@test issetequal(arrays(tn), arrays(tnr)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters