Skip to content

Commit

Permalink
api: fix tensor mul
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jul 11, 2024
1 parent 0124871 commit 60ea7b7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
2 changes: 1 addition & 1 deletion devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def _eval_matrix_mul(self, other):
new_mat[i] = sum(vec)

# Get new class and return product
newcls = self.classof_prod(other, new_mat)
newcls = self.classof_prod(other, other.cols)
return newcls._new(self.rows, other.cols, new_mat, copy=False)


Expand Down
7 changes: 2 additions & 5 deletions devito/types/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,8 @@ def new_from_mat(self, mat):
func = tens_func(self)
return func._new(self.rows, self.cols, mat)

def classof_prod(self, other, mat):
try:
is_mat = len(mat[0]) > 1
except TypeError:
is_mat = False
def classof_prod(self, other, cols):
is_mat = cols > 1
is_time = (getattr(self, '_is_TimeDependent', False) or
getattr(other, '_is_TimeDependent', False))
return mat_time_dict[(is_time, is_mat)]
Expand Down

0 comments on commit 60ea7b7

Please sign in to comment.