Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Matrix operations for new local matrix representation #5861

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4ce08c8
cherrypick Matrix repr support
AD1024 Aug 15, 2022
549a359
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2022
07a4dc1
matrix assign
AD1024 Aug 15, 2022
d2264f4
Merge branch 'matrix-repr' of github.com:AD1024/taichi into matrix-repr
AD1024 Aug 15, 2022
efca3f0
move checks to caller side
AD1024 Aug 15, 2022
82c9413
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2022
38ec750
use ==
AD1024 Aug 15, 2022
13159fd
merge and format
AD1024 Aug 15, 2022
9c91103
refine impl
AD1024 Aug 17, 2022
28f3e0a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2022
bf719a3
no long in use
AD1024 Aug 17, 2022
72e8f26
Merge branch 'matrix-repr' of github.com:AD1024/taichi into matrix-repr
AD1024 Aug 17, 2022
65199ea
add some comments
AD1024 Aug 17, 2022
dcb52ec
elementwise op
AD1024 Aug 16, 2022
0402a8a
add indexing p1
AD1024 Aug 15, 2022
3980ab9
pick fixes
AD1024 Aug 15, 2022
ca68b2c
fin indexing
AD1024 Aug 15, 2022
a29da69
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2022
57e42ac
implement unary op for tensor types
AD1024 Aug 17, 2022
f3ff3d2
format code
AD1024 Aug 17, 2022
39ed915
move expr unification to type check
AD1024 Aug 18, 2022
94f42cf
Merge branch 'matrix-indexing' into matrix-basic-impl
AD1024 Aug 18, 2022
c8c95fe
add some basic ops
AD1024 Aug 18, 2022
5b0b7c4
save
AD1024 Aug 19, 2022
b92daa7
reuse and modify some code to use new implementation; still a workaround
AD1024 Aug 20, 2022
4586fa9
fix typo bugs
AD1024 Aug 23, 2022
92ffb26
fix print on cuda
AD1024 Aug 23, 2022
ff2b24e
fix cfg pass
AD1024 Aug 23, 2022
f727118
oopsie
AD1024 Aug 23, 2022
dc5b593
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from taichi.lang.ast.symbol_resolver import ASTResolver
from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError
from taichi.lang.field import Field
from taichi.lang.matrix import (Matrix, MatrixType, _PyScopeMatrixImpl,
from taichi.lang.matrix import (Matrix, MatrixType, Vector, _PyScopeMatrixImpl,
_TiScopeMatrixImpl)
from taichi.lang.snode import append
from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type
Expand Down Expand Up @@ -488,6 +488,12 @@ def build_Call(ctx, node):
node.ptr = impl.ti_format(*args, **keywords)
return node.ptr

if (isinstance(node.func, ast.Attribute) and
(func == Matrix or func == Vector)
) and impl.current_cfg().real_matrix and in_taichi_scope():
node.ptr = matrix.make_matrix(*args, **keywords)
return node.ptr

if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords):
return node.ptr

Expand Down
9 changes: 9 additions & 0 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def __init__(self, *args, tb=None, dtype=None):
self.ptr.set_tb(self.tb)
self.ptr.type_check(impl.get_runtime().prog.config)

def __getitem__(self, *indices):
if not isinstance(indices, (list, tuple)):
indices = (indices, )

indices = make_expr_group(*indices)
return Expr(
impl.get_runtime().prog.current_ast_builder().expr_indexed_matrix(
self.ptr, indices))

def __hash__(self):
return self.ptr.get_raw_address()

Expand Down
8 changes: 8 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def expr_init_local_tensor(shape, element_type, elements):
shape, element_type, elements)


@taichi_scope
def expr_init_matrix(shape, element_type, elements):
return get_runtime().prog.current_ast_builder().expr_alloca_matrix(
shape, element_type, elements)


@taichi_scope
def expr_init_shared_array(shape, element_type):
return get_runtime().prog.current_ast_builder().expr_alloca_shared_array(
Expand All @@ -48,6 +54,8 @@ def expr_init(rhs):
if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")):
return Matrix(*rhs.to_list(), ndim=rhs.ndim)
if isinstance(rhs, Matrix):
if current_cfg().real_matrix:
return rhs
return Matrix(rhs.to_list(), ndim=rhs.ndim)
if isinstance(rhs, SharedArray):
return rhs
Expand Down
21 changes: 21 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,27 @@ def prop_setter(instance, value):
return cls


def make_matrix(arr, dt=None):
if len(arr) == 0:
return impl.expr_init(
impl.expr_init_local_tensor([0], ti_python_core.DataType_unknown,
impl.make_expr_group([])))
is_matrix = isinstance(arr[0], Iterable)
if dt is None:
dt = _make_entries_initializer(is_matrix).infer_dt(arr)
if not is_matrix:
return impl.expr_init(
impl.expr_init_local_tensor([len(arr)], dt,
impl.make_expr_group(
[expr.Expr(elt) for elt in arr])))
return impl.expr_init(
impl.expr_init_local_tensor([len(arr), len(arr[0])], dt,
impl.make_expr_group([
expr.Expr(elt) for row in arr
for elt in row
])))


class _MatrixBaseImpl:
def __init__(self, m, n, entries):
self.m = m
Expand Down
Loading