Skip to content

Commit

Permalink
Fix replace! when old an new indices intersect
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Aug 2, 2024
1 parent 59efdd9 commit 5995710
Showing 1 changed file with 32 additions and 6 deletions.
38 changes: 32 additions & 6 deletions src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,39 @@ function Base.replace!(tn::TensorNetwork, pair::Pair{<:Tensor,<:Tensor})
end

function Base.replace!(tn::TensorNetwork, old_new::Pair{Symbol,Symbol}...)
first.(old_new) keys(tn.indexmap) ||
throw(ArgumentError("set of old indices must be a subset of current indices"))
isdisjoint(last.(old_new), keys(tn.indexmap)) ||
throw(ArgumentError("set of new indices must be disjoint to current indices"))
for pair in old_new
replace!(tn, pair)
from, to = first.(old_new), last.(old_new)
allinds = inds(tn)

# condition: from ⊆ allinds
from allinds || throw(ArgumentError("set of old indices must be a subset of current indices"))

# condition: from \ to ∩ allinds = ∅
isdisjoint(setdiff(to, from), allinds) || throw(
ArgumentError(
"new indices must be either a element of the old indices or not an element of the TensorNetwork's indices",
),
)

from′ = setdiff(from, to)
to′ = setdiff(to, from)

# no overlap so easy replacement
for (f, t) in zip(from′, to′)
replace!(tn, f => t)
end

# overlap between old and new indices => need a temporary name `replace!`
overlap = from to
if !isempty(overlap)
tmp = Dict([i => gensym(i) for i in overlap])

# replace old indices with temporary names
replace!(tn, pairs(tmp)...)

# replace temporary names with new indices
replace!(tn, [tmp[i] => i for i in Iterators.filter((overlap), to)]...)
end

return tn
end

Expand Down

0 comments on commit 5995710

Please sign in to comment.