Skip to content

Commit

Permalink
[Lang] MatrixType refactor: Support dot/cross/outer_product (taichi-d…
Browse files Browse the repository at this point in the history
…ev#6545)

Issue: taichi-dev#5819

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent c830166 commit 17918c5
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 73 deletions.
39 changes: 2 additions & 37 deletions python/taichi/_funcs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math

from taichi.lang import impl, matrix, ops
from taichi.lang import impl, ops
from taichi.lang.impl import expr_init, get_runtime, grouped, static
from taichi.lang.kernel_impl import func, pyfunc
from taichi.lang.kernel_impl import func
from taichi.lang.matrix import Matrix, Vector
from taichi.types import f32, f64
from taichi.types.annotations import template
Expand Down Expand Up @@ -49,41 +49,6 @@ def randn(dt=None):
return _randn(dt)


@pyfunc
def _matrix_cross3d(self, other):
return matrix.Matrix([
self[1] * other[2] - self[2] * other[1],
self[2] * other[0] - self[0] * other[2],
self[0] * other[1] - self[1] * other[0],
])


@pyfunc
def _matrix_cross2d(self, other):
return self[0] * other[1] - self[1] * other[0]


@pyfunc
def _vector_outer_product(self, other):
"""Perform the outer product with the input Vector.
Args:
other (:class:`~taichi.lang.matrix.Vector`): The input Vector to perform the outer product.
Returns:
:class:`~taichi.lang.matrix.Matrix`: The outer product result (Matrix) of the two Vectors.
"""
impl.static(
impl.static_assert(self.m == 1 and isinstance(self, Vector),
"lhs for outer_product is not a vector"))
impl.static(
impl.static_assert(other.m == 1 and isinstance(other, Vector),
"rhs for outer_product is not a vector"))
return matrix.Matrix([[self[i] * other[j] for j in range(other.n)]
for i in range(self.n)])


@func
def polar_decompose2d(A, dt):
"""Perform polar decomposition (A=UP) for 2x2 matrix.
Expand Down
30 changes: 6 additions & 24 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,19 +1376,8 @@ def dot(self, other):
>>> v1.dot(v2)
26
"""
impl.static(
impl.static_assert(self.m == 1, "lhs for dot is not a vector"))
impl.static(
impl.static_assert(other.m == 1, "rhs for dot is not a vector"))
return (self * other).sum()

def _cross3d(self, other):
from taichi._funcs import _matrix_cross3d # pylint: disable=C0415
return _matrix_cross3d(self, other)

def _cross2d(self, other):
from taichi._funcs import _matrix_cross2d # pylint: disable=C0415
return _matrix_cross2d(self, other)
from taichi.lang import matrix_ops # pylint: disable=C0415
return matrix_ops.dot(self, other)

def cross(self, other):
"""Performs the cross product with the input vector (1-D Matrix).
Expand All @@ -1407,14 +1396,8 @@ def cross(self, other):
Returns:
:class:`~taichi.Matrix`: The cross product of the two Vectors.
"""
if self.n == 3 and self.m == 1 and other.n == 3 and other.m == 1:
return self._cross3d(other)

if self.n == 2 and self.m == 1 and other.n == 2 and other.m == 1:
return self._cross2d(other)

raise ValueError(
"Cross product is only supported between pairs of 2D/3D vectors")
from taichi.lang import matrix_ops # pylint: disable=C0415
return matrix_ops.cross(self, other)

def outer_product(self, other):
"""Performs the outer product with the input Vector (1-D Matrix).
Expand All @@ -1429,9 +1412,8 @@ def outer_product(self, other):
Returns:
:class:`~taichi.Matrix`: The outer product of the two Vectors.
"""
from taichi._funcs import \
_vector_outer_product # pylint: disable=C0415
return _vector_outer_product(self, other)
from taichi.lang import matrix_ops # pylint: disable=C0415
return matrix_ops.outer_product(self, other)


class Vector(Matrix):
Expand Down
39 changes: 37 additions & 2 deletions python/taichi/lang/matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _reduce(mat, fun: template()):
arg_at(
0,
foreach(
Or(assert_vector,
Or(assert_vector(),
assert_list,
msg="Cols/rows must be a list of lists, or a list of vectors")),
same_shapes))
Expand Down Expand Up @@ -176,7 +176,7 @@ def norm_inv(mat, eps=0.0):
return ops_mod.rsqrt(norm_sqr(mat) + eps)


@preconditions(arg_at(0, assert_vector))
@preconditions(arg_at(0, assert_vector()))
@pyfunc
def normalized(vec, eps=0.0):
invlen = 1 / (norm(vec) + eps)
Expand Down Expand Up @@ -268,3 +268,38 @@ def matmul(x, y):
if static(len(shape_x) == 1 and len(shape_y) == 2):
return _matmul_helper(transpose(y), x)
return _matmul_helper(x, y)


@preconditions(arg_at(0, assert_vector("lhs for dot is not a vector")),
arg_at(1, assert_vector("rhs for dot is not a vector")))
@pyfunc
def dot(vec_x, vec_y):
return sum(vec_x * vec_y)


@preconditions(arg_at(0, assert_vector("lhs for cross is not a vector")),
arg_at(1, assert_vector("rhs for cross is not a vector")),
same_shapes, arg_at(0, dim_lt(0, 4)))
@pyfunc
def cross(vec_x, vec_y):
shape = static(vec_x.get_shape())
if static(shape[0] == 2):
return vec_x[0] * vec_y[1] - vec_x[1] * vec_y[0]
if static(shape[0] == 3):
return Vector([
vec_x[1] * vec_y[2] - vec_x[2] * vec_y[1],
vec_x[2] * vec_y[0] - vec_x[0] * vec_y[2],
vec_x[0] * vec_y[1] - vec_x[1] * vec_y[0]
])
return None


@preconditions(
arg_at(0, assert_vector("lhs for outer_product is not a vector")),
arg_at(1, assert_vector("rhs for outer_product is not a vector")))
@pyfunc
def outer_product(vec_x, vec_y):
shape_x = static(vec_x.get_shape())
shape_y = static(vec_y.get_shape())
return Matrix([[vec_x[i] * vec_y[j] for j in static(range(shape_y[0]))]
for i in static(range(shape_x[0]))])
15 changes: 9 additions & 6 deletions python/taichi/lang/matrix_ops_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,14 @@ def assert_tensor(m, msg='not tensor type: {}'):

# TODO(zhanlue): rearrange to more generic checker functions
# for example: "assert_is_instance(args, indices=[], instances=[], logic='or')"
def assert_vector(v, msg='not a vector: {}'):
if (isinstance(v, Expr) or isinstance(v, Matrix)) and len(
v.get_shape()) == 1:
return True, None
raise TaichiCompilationError(msg.format(type(v)))
def assert_vector(msg='expected a vector, got {}'):
def check(v):
if (isinstance(v, Expr) or isinstance(v, Matrix)) and len(
v.get_shape()) == 1:
return True, None
return False, msg.format(type(v))

return check


def assert_list(x, msg='not a list: {}'):
Expand All @@ -90,7 +93,7 @@ def assert_list(x, msg='not a list: {}'):
raise TaichiCompilationError(msg.format(type(x)))


def same_shapes(xs):
def same_shapes(*xs):
shapes = [x.get_shape() for x in xs]
if len(set(shapes)) != 1:
return False, f'required shapes to be the same, got shapes {shapes}'
Expand Down
35 changes: 31 additions & 4 deletions tests/python/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ def init():
assert b[None][j] == j


@test_utils.test()
def test_basic_utils():
def _test_basic_utils():
a = ti.Vector.field(3, dtype=ti.f32)
b = ti.Vector.field(2, dtype=ti.f32)
abT = ti.Matrix.field(3, 2, dtype=ti.f32)
Expand Down Expand Up @@ -70,7 +69,16 @@ def init():


@test_utils.test()
def test_cross():
def test_basic_utils():
_test_basic_utils()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_basic_utils_real_matrix_scalarize():
_test_basic_utils()


def _test_cross():
a = ti.Vector.field(3, dtype=ti.f32)
b = ti.Vector.field(3, dtype=ti.f32)
c = ti.Vector.field(3, dtype=ti.f32)
Expand Down Expand Up @@ -99,7 +107,16 @@ def init():


@test_utils.test()
def test_dot():
def test_cross():
_test_cross()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_cross_real_matrix_scalarize():
_test_cross()


def _test_dot():
a = ti.Vector.field(3, dtype=ti.f32)
b = ti.Vector.field(3, dtype=ti.f32)
c = ti.field(dtype=ti.f32)
Expand All @@ -125,6 +142,16 @@ def init():
assert c2[None] == 14.0


@test_utils.test()
def test_dot():
_test_dot()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_dot_real_matrix_scalarize():
_test_dot()


@test_utils.test()
def test_transpose():
dim = 3
Expand Down

0 comments on commit 17918c5

Please sign in to comment.