From b52f4dad740e93224daa3f0c439765063b278512 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 11 Nov 2021 20:33:22 +0800 Subject: [PATCH] Enable local tensors for element_wise_unary + element_wise_ternary + matmul --- python/taichi/lang/matrix.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 77696721a4e7e..ca302551b6731 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -180,13 +180,9 @@ def broadcast_copy(self, other): return other def element_wise_ternary(self, foo, other, extra): - ret = self.empty_copy() other = self.broadcast_copy(other) extra = self.broadcast_copy(extra) - for i in range(self.n * self.m): - ret.entries[i] = foo(self.entries[i], other.entries[i], - extra.entries[i]) - return ret + return Matrix([[foo(self(i, j), other(i, j), extra(i, j)) for j in range(self.m)] for i in range(self.n)]) def element_wise_writeback_binary(self, foo, other): ret = self.empty_copy() @@ -210,10 +206,7 @@ def element_wise_writeback_binary(self, foo, other): def element_wise_unary(self, foo): _taichi_skip_traceback = 1 - ret = self.empty_copy() - for i in range(self.n * self.m): - ret.entries[i] = foo(self.entries[i]) - return ret + return Matrix([[foo(self(i, j)) for j in range(self.m)] for i in range(self.n)]) def __matmul__(self, other): """Matrix-matrix or matrix-vector multiply. @@ -229,14 +222,15 @@ def __matmul__(self, other): assert isinstance(other, Matrix), "rhs of `@` is not a matrix / vector" assert self.m == other.n, f"Dimension mismatch between shapes ({self.n}, {self.m}), ({other.n}, {other.m})" del _taichi_skip_traceback - ret = Matrix.empty(self.n, other.m) + entries = [] for i in range(self.n): + entries.append([]) for j in range(other.m): acc = self(i, 0) * other(0, j) for k in range(1, other.n): acc = acc + self(i, k) * other(k, j) - ret.entries[i * other.m + j] = acc - return ret + entries[i].append(acc) + return Matrix(entries) def linearize_entry_id(self, *args): assert 1 <= len(args) <= 2