From 60ea7b76e2f8d272094559800054a58ffb3fbb4d Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 11 Jul 2024 17:31:08 -0400 Subject: [PATCH] api: fix tensor mul --- devito/types/basic.py | 2 +- devito/types/tensor.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/devito/types/basic.py b/devito/types/basic.py index a18c17d945..e2859ea07e 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -773,7 +773,7 @@ def _eval_matrix_mul(self, other): new_mat[i] = sum(vec) # Get new class and return product - newcls = self.classof_prod(other, new_mat) + newcls = self.classof_prod(other, other.cols) return newcls._new(self.rows, other.cols, new_mat, copy=False) diff --git a/devito/types/tensor.py b/devito/types/tensor.py index ae69ba8899..843bbd51d2 100644 --- a/devito/types/tensor.py +++ b/devito/types/tensor.py @@ -257,11 +257,8 @@ def new_from_mat(self, mat): func = tens_func(self) return func._new(self.rows, self.cols, mat) - def classof_prod(self, other, mat): - try: - is_mat = len(mat[0]) > 1 - except TypeError: - is_mat = False + def classof_prod(self, other, cols): + is_mat = cols > 1 is_time = (getattr(self, '_is_TimeDependent', False) or getattr(other, '_is_TimeDependent', False)) return mat_time_dict[(is_time, is_mat)]