Skip to content

Commit

Permalink
Enable local tensors for element_wise_unary + element_wise_ternary + …
Browse files Browse the repository at this point in the history
…matmul
  • Loading branch information
strongoier committed Nov 11, 2021
1 parent 326c392 commit b52f4da
Showing 1 changed file with 6 additions and 12 deletions.
18 changes: 6 additions & 12 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,9 @@ def broadcast_copy(self, other):
return other

def element_wise_ternary(self, foo, other, extra):
ret = self.empty_copy()
other = self.broadcast_copy(other)
extra = self.broadcast_copy(extra)
for i in range(self.n * self.m):
ret.entries[i] = foo(self.entries[i], other.entries[i],
extra.entries[i])
return ret
return Matrix([[foo(self(i, j), other(i, j), extra(i, j)) for j in range(self.m)] for i in range(self.n)])

def element_wise_writeback_binary(self, foo, other):
ret = self.empty_copy()
Expand All @@ -210,10 +206,7 @@ def element_wise_writeback_binary(self, foo, other):

def element_wise_unary(self, foo):
_taichi_skip_traceback = 1
ret = self.empty_copy()
for i in range(self.n * self.m):
ret.entries[i] = foo(self.entries[i])
return ret
return Matrix([[foo(self(i, j)) for j in range(self.m)] for i in range(self.n)])

def __matmul__(self, other):
"""Matrix-matrix or matrix-vector multiply.
Expand All @@ -229,14 +222,15 @@ def __matmul__(self, other):
assert isinstance(other, Matrix), "rhs of `@` is not a matrix / vector"
assert self.m == other.n, f"Dimension mismatch between shapes ({self.n}, {self.m}), ({other.n}, {other.m})"
del _taichi_skip_traceback
ret = Matrix.empty(self.n, other.m)
entries = []
for i in range(self.n):
entries.append([])
for j in range(other.m):
acc = self(i, 0) * other(0, j)
for k in range(1, other.n):
acc = acc + self(i, k) * other(k, j)
ret.entries[i * other.m + j] = acc
return ret
entries[i].append(acc)
return Matrix(entries)

def linearize_entry_id(self, *args):
assert 1 <= len(args) <= 2
Expand Down

0 comments on commit b52f4da

Please sign in to comment.