Skip to content

Commit

Permalink
[Lang] Remove disable_local_tensor in most cases (#3524)
Browse files Browse the repository at this point in the history
* [Lang] Remove disable_local_tensor in most cases

* Auto Format

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
strongoier and taichi-gardener authored Nov 16, 2021
1 parent 21f581e commit 536fb20
Showing 1 changed file with 30 additions and 48 deletions.
78 changes: 30 additions & 48 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,13 +479,11 @@ def inverse(self):
"""
assert self.n == self.m, 'Only square matrices are invertible'
if self.n == 1:
return Matrix([1 / self(0, 0)], disable_local_tensor=True)
return Matrix([1 / self(0, 0)])
if self.n == 2:
inv_det = impl.expr_init(1.0 / self.determinant())
# Discussion: https://github.com/taichi-dev/taichi/pull/943#issuecomment-626344323
return inv_det * Matrix([[self(1, 1), -self(0, 1)],
[-self(1, 0), self(0, 0)]],
disable_local_tensor=True).variable()
inv_determinant = impl.expr_init(1.0 / self.determinant())
return inv_determinant * Matrix([[self(
1, 1), -self(0, 1)], [-self(1, 0), self(0, 0)]])
if self.n == 3:
n = 3
inv_determinant = impl.expr_init(1.0 / self.determinant())
Expand All @@ -496,10 +494,10 @@ def E(x, y):

for i in range(n):
for j in range(n):
entries[j][i] = impl.expr_init(
inv_determinant * (E(i + 1, j + 1) * E(i + 2, j + 2) -
E(i + 2, j + 1) * E(i + 1, j + 2)))
return Matrix(entries, disable_local_tensor=True)
entries[j][i] = inv_determinant * (
E(i + 1, j + 1) * E(i + 2, j + 2) -
E(i + 2, j + 1) * E(i + 1, j + 2))
return Matrix(entries)
if self.n == 4:
n = 4
inv_determinant = impl.expr_init(1.0 / self.determinant())
Expand All @@ -510,18 +508,15 @@ def E(x, y):

for i in range(n):
for j in range(n):
entries[j][i] = impl.expr_init(
inv_determinant * (-1)**(i + j) *
((E(i + 1, j + 1) *
(E(i + 2, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 2, j + 3)) -
E(i + 2, j + 1) *
(E(i + 1, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 1, j + 3)) +
E(i + 3, j + 1) *
(E(i + 1, j + 2) * E(i + 2, j + 3) -
E(i + 2, j + 2) * E(i + 1, j + 3)))))
return Matrix(entries, disable_local_tensor=True)
entries[j][i] = inv_determinant * (-1)**(i + j) * ((
E(i + 1, j + 1) *
(E(i + 2, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 2, j + 3)) - E(i + 2, j + 1) *
(E(i + 1, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 1, j + 3)) + E(i + 3, j + 1) *
(E(i + 1, j + 2) * E(i + 2, j + 3) -
E(i + 2, j + 2) * E(i + 1, j + 3))))
return Matrix(entries)
raise Exception(
"Inversions of matrices with sizes >= 5 are not supported")

Expand Down Expand Up @@ -567,10 +562,8 @@ def transpose(self):
Get the transpose of a matrix.
"""
ret = Matrix([[self[i, j] for i in range(self.n)]
for j in range(self.m)],
disable_local_tensor=True)
return ret
return Matrix([[self[i, j] for i in range(self.n)]
for j in range(self.m)])

@taichi_scope
def determinant(a):
Expand Down Expand Up @@ -790,10 +783,8 @@ def zero(dt, n, m=None):
"""
if m is None:
return Vector([ti.cast(0, dt) for _ in range(n)],
disable_local_tensor=True)
return Matrix([[ti.cast(0, dt) for _ in range(m)] for _ in range(n)],
disable_local_tensor=True)
return Vector([ti.cast(0, dt) for _ in range(n)])
return Matrix([[ti.cast(0, dt) for _ in range(m)] for _ in range(n)])

@staticmethod
@taichi_scope
Expand All @@ -810,10 +801,8 @@ def one(dt, n, m=None):
"""
if m is None:
return Vector([ti.cast(1, dt) for _ in range(n)],
disable_local_tensor=True)
return Matrix([[ti.cast(1, dt) for _ in range(m)] for _ in range(n)],
disable_local_tensor=True)
return Vector([ti.cast(1, dt) for _ in range(n)])
return Matrix([[ti.cast(1, dt) for _ in range(m)] for _ in range(n)])

@staticmethod
@taichi_scope
Expand All @@ -832,8 +821,7 @@ def unit(n, i, dt=None):
if dt is None:
dt = int
assert 0 <= i < n
return Matrix([ti.cast(int(j == i), dt) for j in range(n)],
disable_local_tensor=True)
return Vector([ti.cast(int(j == i), dt) for j in range(n)])

@staticmethod
@taichi_scope
Expand All @@ -849,8 +837,7 @@ def identity(dt, n):
"""
return Matrix([[ti.cast(int(i == j), dt) for j in range(n)]
for i in range(n)],
disable_local_tensor=True)
for i in range(n)])

@staticmethod
def rotation2d(alpha):
Expand Down Expand Up @@ -1107,18 +1094,15 @@ def dot(self, other):

@kern_mod.pyfunc
def _cross3d(self, other):
ret = Matrix([
return Matrix([
self[1] * other[2] - self[2] * other[1],
self[2] * other[0] - self[0] * other[2],
self[0] * other[1] - self[1] * other[0],
],
disable_local_tensor=True)
return ret
])

@kern_mod.pyfunc
def _cross2d(self, other):
ret = self[0] * other[1] - self[1] * other[0]
return ret
return self[0] * other[1] - self[1] * other[0]

def cross(self, other):
"""Perform the cross product with the input Vector (1-D Matrix).
Expand Down Expand Up @@ -1156,10 +1140,8 @@ def outer_product(self, other):
impl.static(
impl.static_assert(other.m == 1,
"rhs for outer_product is not a vector"))
ret = Matrix([[self[i] * other[j] for j in range(other.n)]
for i in range(self.n)],
disable_local_tensor=True)
return ret
return Matrix([[self[i] * other[j] for j in range(other.n)]
for i in range(self.n)])


def Vector(n, dt=None, **kwargs):
Expand Down

0 comments on commit 536fb20

Please sign in to comment.