Skip to content
This repository has been archived by the owner on Jan 30, 2023. It is now read-only.

Commit

Permalink
Fix bug in Components.contract() (Trac #32355)
Browse files Browse the repository at this point in the history
  • Loading branch information
egourgoulhon committed Aug 12, 2021
1 parent e0cbf2c commit 26ba8f7
Showing 1 changed file with 52 additions and 19 deletions.
71 changes: 52 additions & 19 deletions src/sage/tensor/modules/comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2123,6 +2123,34 @@ def contract(self, *args):
sage: a.contract(0, a, 0) == b.trace(0,1)
True
TESTS:
Check that :trac:`32355` is fixed::
sage: from sage.tensor.modules.comp import CompFullyAntiSym
sage: a = CompFullyAntiSym(QQ, V.basis(), 2)
sage: a[0,1] = 1
sage: b = CompFullyAntiSym(QQ, V.basis(), 2)
sage: b[0,1], b[0,2] = 2, 3
sage: a.contract(0, 1, b, 0, 1)
4
sage: a.contract(0, 1, b, 1, 0)
-4
sage: a.contract(1, 0, b, 0, 1)
-4
sage: a.contract(1, 0, b, 1, 0)
4
sage: Parallelism().set('tensor', nproc=2) # same tests with parallelization
sage: a.contract(0, 1, b, 0, 1)
4
sage: a.contract(0, 1, b, 1, 0)
-4
sage: a.contract(1, 0, b, 0, 1)
-4
sage: a.contract(1, 0, b, 1, 0)
4
sage: Parallelism().set('tensor', nproc=1) # switch off parallelization
"""
#
# Treatment of the input
Expand Down Expand Up @@ -2164,31 +2192,34 @@ def contract(self, *args):
# ncontr indices and call the method index_generator() on it:
comp_for_contr = Components(self._ring, self._frame, ncontr,
start_index=self._sindex)
res = 0

# Pairs of indices tuples for the contraction:
ind_pairs = []
for ind_s in comp_for_contr.index_generator():
ind_o = [None for i in range(ncontr)]
for pos_s, pos_o in contractions:
ind_o[pos_o] = ind_s[pos_s]
ind_pairs.append((ind_s, ind_o))

if Parallelism().get('tensor') != 1:
# parallel contraction to scalar
# parallel computation

# parallel multiplication
@parallel(p_iter='multiprocessing',ncpus=Parallelism().get('tensor'))
def compprod(a,b):
@parallel(p_iter='multiprocessing', ncpus=Parallelism().get('tensor'))
def compprod(a, b):
return a*b

# parallel list of inputs
partial = list(compprod([(other[[ind]],self[[ind]]) for ind in
comp_for_contr.index_generator()
]))
res = sum(map(itemgetter(1),partial))
partial = list(compprod([(self[[ind_s]], other[[ind_o]])
for ind_s, ind_o in ind_pairs]))
res = sum(map(itemgetter(1), partial))
else:
# sequential
# sequential computation
res = 0
for ind in comp_for_contr.index_generator():
res += self[[ind]] * other[[ind]]
for ind_s, ind_o in ind_pairs:
res += self[[ind_s]] * other[[ind_o]]

return res


#
# Positions of self and other indices in the result
# (None = the position is involved in a contraction and therefore
Expand Down Expand Up @@ -2300,16 +2331,18 @@ def compprod(a,b):
nproc = Parallelism().get('tensor')
lol = lambda lst, sz: [lst[i:i+sz] for i in range(0, len(lst), sz)]
ind_list = [ind for ind in res.non_redundant_index_generator()]
ind_step = max(1,int(len(ind_list)/nproc/2))
local_list = lol(ind_list,ind_step)
ind_step = max(1, int(len(ind_list)/nproc/2))
local_list = lol(ind_list, ind_step)

listParalInput = []
for ind_part in local_list:
listParalInput.append((self,other,ind_part,rev_s,rev_o,shift_o,contractions,comp_for_contr))
listParalInput.append((self, other, ind_part, rev_s, rev_o,
shift_o, contractions, comp_for_contr))

# definition of the parallel function
@parallel(p_iter='multiprocessing',ncpus=nproc)
def make_Contraction(this,other,local_list,rev_s,rev_o,shift_o,contractions,comp_for_contr):
@parallel(p_iter='multiprocessing', ncpus=nproc)
def make_Contraction(this, other, local_list, rev_s, rev_o,
shift_o,contractions, comp_for_contr):
local_res = []
for ind in local_list:
ind_s = [None for i in range(this._nid)] # initialization
Expand All @@ -2334,7 +2367,7 @@ def make_Contraction(this,other,local_list,rev_s,rev_o,shift_o,contractions,comp
for jj in val:
res[[jj[0]]] = jj[1]
else:
# sequential
# sequential computation
for ind in res.non_redundant_index_generator():
ind_s = [None for i in range(self._nid)] # initialization
ind_o = [None for i in range(other._nid)] # initialization
Expand Down

0 comments on commit 26ba8f7

Please sign in to comment.