Skip to content

Commit

Permalink
Auto Format
Browse files Browse the repository at this point in the history
  • Loading branch information
taichi-gardener committed Nov 11, 2021
1 parent b52f4da commit f58a12f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
13 changes: 9 additions & 4 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,20 +169,24 @@ def __init__(self,
def element_wise_binary(self, foo, other):
_taichi_skip_traceback = 1
other = self.broadcast_copy(other)
return Matrix([[foo(self(i, j), other(i, j)) for j in range(self.m)] for i in range(self.n)])
return Matrix([[foo(self(i, j), other(i, j)) for j in range(self.m)]
for i in range(self.n)])

def broadcast_copy(self, other):
if isinstance(other, (list, tuple)):
other = Matrix(other)
if not isinstance(other, Matrix):
other = Matrix([[other for _ in range(self.m)] for _ in range(self.n)])
other = Matrix([[other for _ in range(self.m)]
for _ in range(self.n)])
assert self.m == other.m and self.n == other.n, f"Dimension mismatch between shapes ({self.n}, {self.m}), ({other.n}, {other.m})"
return other

def element_wise_ternary(self, foo, other, extra):
other = self.broadcast_copy(other)
extra = self.broadcast_copy(extra)
return Matrix([[foo(self(i, j), other(i, j), extra(i, j)) for j in range(self.m)] for i in range(self.n)])
return Matrix([[
foo(self(i, j), other(i, j), extra(i, j)) for j in range(self.m)
] for i in range(self.n)])

def element_wise_writeback_binary(self, foo, other):
ret = self.empty_copy()
Expand All @@ -206,7 +210,8 @@ def element_wise_writeback_binary(self, foo, other):

def element_wise_unary(self, foo):
_taichi_skip_traceback = 1
return Matrix([[foo(self(i, j)) for j in range(self.m)] for i in range(self.n)])
return Matrix([[foo(self(i, j)) for j in range(self.m)]
for i in range(self.n)])

def __matmul__(self, other):
"""Matrix-matrix or matrix-vector multiply.
Expand Down
5 changes: 4 additions & 1 deletion taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,10 @@ class SNodeOpExpression : public Expression {
SNodeOpType op_type,
const ExprGroup &indices,
const Expr &value)
: snode(snode), op_type(op_type), indices(indices.loaded()), value(value) {
: snode(snode),
op_type(op_type),
indices(indices.loaded()),
value(value) {
}

void type_check() override;
Expand Down

0 comments on commit f58a12f

Please sign in to comment.