Skip to content

Commit

Permalink
Add unit test for TensorNetwork module (#152)
Browse files Browse the repository at this point in the history
* Fix variable name in neighbors function

* Add unit tests

* Rename tensors variable and function call in neighbors

* Format code

* Add unit tests

* Fix unit tests

* Fix unit test

* Add unit test for complex SVD

* Add unit tests

* Fix LU unit test

* Add permutation tensor to LU decomposition

* Add LU unit test

* Format code

* Improve Base.similar test sentence

* Improve decompositions unit tests

* Format code

* Add and fix contract unit tests

* Add replace tensor by TN unit test

* Fix julia format

---------

Co-authored-by: Jofre <jofrevalles99@gmail.com>
  • Loading branch information
Todorbsc and jofrevalles authored Jun 26, 2024
1 parent 25c1c9a commit b1ea8bb
Show file tree
Hide file tree
Showing 3 changed files with 315 additions and 42 deletions.
10 changes: 5 additions & 5 deletions src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,15 +211,15 @@ end
function neighbors(tn::TensorNetwork, tensor::Tensor; open::Bool=true)
@assert tensor tn "Tensor not found in TensorNetwork"
tensors = mapreduce(, inds(tensor)) do index
tensors(tn; intersects=index)
Tenet.tensors(tn; intersects=index)
end
open && filter!(x -> x !== tensor, tensors)
return tensors
end

function neighbors(tn::TensorNetwork, i::Symbol; open::Bool=true)
@assert i tn "Index $i not found in TensorNetwork"
tensors = mapreduce(inds, , tensors(tn; intersects=i))
tensors = mapreduce(inds, , Tenet.tensors(tn; intersects=i))
# open && filter!(x -> x !== i, tensors)
return tensors
end
Expand Down Expand Up @@ -392,7 +392,7 @@ end

function Base.replace!(tn::TensorNetwork, old_new::Pair{<:Tensor,<:TensorNetwork})
old, new = old_new
issetequal(inds(new; set=:open), inds(old)) || throw(ArgumentError("indices don't match match"))
issetequal(inds(new; set=:open), inds(old)) || throw(ArgumentError("indices don't match"))

# rename internal indices so there is no accidental hyperedge
replace!(new, [index => Symbol(uuid4()) for index in filter((inds(tn)), inds(new; set=:inner))]...)
Expand Down Expand Up @@ -622,7 +622,7 @@ end

function LinearAlgebra.lu!(tn::TensorNetwork; left_inds=Symbol[], right_inds=Symbol[], kwargs...)
tensor = tn[left_inds right_inds...]
L, U = lu(tensor; left_inds, right_inds, kwargs...)
replace!(tn, tensor => TensorNetwork([L, U]))
L, U, P = lu(tensor; left_inds, right_inds, kwargs...)
replace!(tn, tensor => TensorNetwork([P, L, U]))
return tn
end
2 changes: 1 addition & 1 deletion test/Numerics_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@
@test_throws ArgumentError lu(tensor, right_inds=(:i, :j, :k, :l))

# throw if chosen virtual index already present
@test_throws ArgumentError qr(tensor, left_inds=(:i,), virtualind=:j)
@test_throws ArgumentError lu(tensor, left_inds=(:i, :j), virtualind=(:j, :k))

L, U, P = lu(tensor; left_inds=[:i, :j], virtualind=vidx)
@test inds(L) == [:x, :y]
Expand Down
Loading

0 comments on commit b1ea8bb

Please sign in to comment.