Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Remove disable_local_tensor in most cases #3524

Merged
merged 2 commits into from
Nov 16, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 30 additions & 48 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,13 +479,11 @@ def inverse(self):
"""
assert self.n == self.m, 'Only square matrices are invertible'
if self.n == 1:
return Matrix([1 / self(0, 0)], disable_local_tensor=True)
return Matrix([1 / self(0, 0)])
if self.n == 2:
inv_det = impl.expr_init(1.0 / self.determinant())
# Discussion: https://github.com/taichi-dev/taichi/pull/943#issuecomment-626344323
return inv_det * Matrix([[self(1, 1), -self(0, 1)],
[-self(1, 0), self(0, 0)]],
disable_local_tensor=True).variable()
inv_determinant = impl.expr_init(1.0 / self.determinant())
return inv_determinant * Matrix([[self(
1, 1), -self(0, 1)], [-self(1, 0), self(0, 0)]])
if self.n == 3:
n = 3
inv_determinant = impl.expr_init(1.0 / self.determinant())
Expand All @@ -496,10 +494,10 @@ def E(x, y):

for i in range(n):
for j in range(n):
entries[j][i] = impl.expr_init(
inv_determinant * (E(i + 1, j + 1) * E(i + 2, j + 2) -
E(i + 2, j + 1) * E(i + 1, j + 2)))
return Matrix(entries, disable_local_tensor=True)
entries[j][i] = inv_determinant * (
E(i + 1, j + 1) * E(i + 2, j + 2) -
E(i + 2, j + 1) * E(i + 1, j + 2))
return Matrix(entries)
if self.n == 4:
n = 4
inv_determinant = impl.expr_init(1.0 / self.determinant())
Expand All @@ -510,18 +508,15 @@ def E(x, y):

for i in range(n):
for j in range(n):
entries[j][i] = impl.expr_init(
inv_determinant * (-1)**(i + j) *
((E(i + 1, j + 1) *
(E(i + 2, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 2, j + 3)) -
E(i + 2, j + 1) *
(E(i + 1, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 1, j + 3)) +
E(i + 3, j + 1) *
(E(i + 1, j + 2) * E(i + 2, j + 3) -
E(i + 2, j + 2) * E(i + 1, j + 3)))))
return Matrix(entries, disable_local_tensor=True)
entries[j][i] = inv_determinant * (-1)**(i + j) * ((
E(i + 1, j + 1) *
(E(i + 2, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 2, j + 3)) - E(i + 2, j + 1) *
(E(i + 1, j + 2) * E(i + 3, j + 3) -
E(i + 3, j + 2) * E(i + 1, j + 3)) + E(i + 3, j + 1) *
(E(i + 1, j + 2) * E(i + 2, j + 3) -
E(i + 2, j + 2) * E(i + 1, j + 3))))
return Matrix(entries)
raise Exception(
"Inversions of matrices with sizes >= 5 are not supported")

Expand Down Expand Up @@ -567,10 +562,8 @@ def transpose(self):
Get the transpose of a matrix.

"""
ret = Matrix([[self[i, j] for i in range(self.n)]
for j in range(self.m)],
disable_local_tensor=True)
return ret
return Matrix([[self[i, j] for i in range(self.n)]
for j in range(self.m)])

@taichi_scope
def determinant(a):
Expand Down Expand Up @@ -790,10 +783,8 @@ def zero(dt, n, m=None):

"""
if m is None:
return Vector([ti.cast(0, dt) for _ in range(n)],
disable_local_tensor=True)
return Matrix([[ti.cast(0, dt) for _ in range(m)] for _ in range(n)],
disable_local_tensor=True)
return Vector([ti.cast(0, dt) for _ in range(n)])
return Matrix([[ti.cast(0, dt) for _ in range(m)] for _ in range(n)])

@staticmethod
@taichi_scope
Expand All @@ -810,10 +801,8 @@ def one(dt, n, m=None):

"""
if m is None:
return Vector([ti.cast(1, dt) for _ in range(n)],
disable_local_tensor=True)
return Matrix([[ti.cast(1, dt) for _ in range(m)] for _ in range(n)],
disable_local_tensor=True)
return Vector([ti.cast(1, dt) for _ in range(n)])
return Matrix([[ti.cast(1, dt) for _ in range(m)] for _ in range(n)])

@staticmethod
@taichi_scope
Expand All @@ -832,8 +821,7 @@ def unit(n, i, dt=None):
if dt is None:
dt = int
assert 0 <= i < n
return Matrix([ti.cast(int(j == i), dt) for j in range(n)],
disable_local_tensor=True)
return Vector([ti.cast(int(j == i), dt) for j in range(n)])

@staticmethod
@taichi_scope
Expand All @@ -849,8 +837,7 @@ def identity(dt, n):

"""
return Matrix([[ti.cast(int(i == j), dt) for j in range(n)]
for i in range(n)],
disable_local_tensor=True)
for i in range(n)])

@staticmethod
def rotation2d(alpha):
Expand Down Expand Up @@ -1107,18 +1094,15 @@ def dot(self, other):

@kern_mod.pyfunc
def _cross3d(self, other):
ret = Matrix([
return Matrix([
self[1] * other[2] - self[2] * other[1],
self[2] * other[0] - self[0] * other[2],
self[0] * other[1] - self[1] * other[0],
],
disable_local_tensor=True)
return ret
])

@kern_mod.pyfunc
def _cross2d(self, other):
ret = self[0] * other[1] - self[1] * other[0]
return ret
return self[0] * other[1] - self[1] * other[0]

def cross(self, other):
"""Perform the cross product with the input Vector (1-D Matrix).
Expand Down Expand Up @@ -1156,10 +1140,8 @@ def outer_product(self, other):
impl.static(
impl.static_assert(other.m == 1,
"rhs for outer_product is not a vector"))
ret = Matrix([[self[i] * other[j] for j in range(other.n)]
for i in range(self.n)],
disable_local_tensor=True)
return ret
return Matrix([[self[i] * other[j] for j in range(other.n)]
for i in range(self.n)])


def Vector(n, dt=None, **kwargs):
Expand Down