From 60ea7b76e2f8d272094559800054a58ffb3fbb4d Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 11 Jul 2024 17:31:08 -0400 Subject: [PATCH 1/2] api: fix tensor mul --- devito/types/basic.py | 2 +- devito/types/tensor.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/devito/types/basic.py b/devito/types/basic.py index a18c17d945..e2859ea07e 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -773,7 +773,7 @@ def _eval_matrix_mul(self, other): new_mat[i] = sum(vec) # Get new class and return product - newcls = self.classof_prod(other, new_mat) + newcls = self.classof_prod(other, other.cols) return newcls._new(self.rows, other.cols, new_mat, copy=False) diff --git a/devito/types/tensor.py b/devito/types/tensor.py index ae69ba8899..843bbd51d2 100644 --- a/devito/types/tensor.py +++ b/devito/types/tensor.py @@ -257,11 +257,8 @@ def new_from_mat(self, mat): func = tens_func(self) return func._new(self.rows, self.cols, mat) - def classof_prod(self, other, mat): - try: - is_mat = len(mat[0]) > 1 - except TypeError: - is_mat = False + def classof_prod(self, other, cols): + is_mat = cols > 1 is_time = (getattr(self, '_is_TimeDependent', False) or getattr(other, '_is_TimeDependent', False)) return mat_time_dict[(is_time, is_mat)] From de4837d77a257b5ff5b69b5026ad89c297fbfbe0 Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 12 Jul 2024 10:57:04 -0400 Subject: [PATCH 2/2] mpi: fix no-grid function halo exchange --- devito/types/dense.py | 6 ++++-- tests/test_tensors.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/devito/types/dense.py b/devito/types/dense.py index a4ea53b699..4ae37269cb 100644 --- a/devito/types/dense.py +++ b/devito/types/dense.py @@ -751,8 +751,10 @@ def _C_get_field(self, region, dim, side=None): def _halo_exchange(self): """Perform the halo exchange with the neighboring processes.""" - if not MPI.Is_initialized() or MPI.COMM_WORLD.size == 1 or \ - not configuration['mpi']: + if not MPI.Is_initialized() or \ + MPI.COMM_WORLD.size == 1 or \ + not configuration['mpi'] or \ + self.grid is None: # Nothing to do return if MPI.COMM_WORLD.size > 1 and self._distributor is None: diff --git a/tests/test_tensors.py b/tests/test_tensors.py index 815efa51ba..30ac7e578f 100644 --- a/tests/test_tensors.py +++ b/tests/test_tensors.py @@ -100,10 +100,10 @@ def test_tensor_matmul(func1, func2, out_type): @pytest.mark.parametrize('func1, func2, out_type', [ - (VectorFunction, TensorFunction, VectorFunction), - (VectorTimeFunction, TensorFunction, VectorTimeFunction), - (VectorFunction, TensorTimeFunction, VectorTimeFunction), - (VectorTimeFunction, TensorTimeFunction, VectorTimeFunction)]) + (VectorFunction, TensorFunction, TensorFunction), + (VectorTimeFunction, TensorFunction, TensorTimeFunction), + (VectorFunction, TensorTimeFunction, TensorTimeFunction), + (VectorTimeFunction, TensorTimeFunction, TensorTimeFunction)]) def test_tensor_matmul_T(func1, func2, out_type): grid = Grid(tuple([5]*3)) f1 = func1(name="f1", grid=grid)