Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Remove variable() of Matrix/Struct and empty() of Matrix/StructType #3531

Merged
merged 5 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def expr_init(rhs):
if rhs is None:
return Expr(_ti_core.expr_alloca())
if is_taichi_class(rhs):
if rhs.local_tensor_proxy is not None:
return rhs
return rhs.variable()
return rhs
if isinstance(rhs, list):
return [expr_init(e) for e in rhs]
if isinstance(rhs, tuple):
Expand Down Expand Up @@ -160,7 +158,7 @@ def subscript(value, *_indices):
for e in value.get_field_members()
])
if isinstance(value, StructField):
return ti.Struct(
return ti.lang.struct.IntermediateStruct(
{k: subscript(v, *_indices)
for k, v in value.items})
return Expr(_ti_core.subscript(_var, indices_expr_group))
Expand Down
38 changes: 13 additions & 25 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,10 @@ def subscript(self, *indices):
return ti.local_subscript_with_offset(self.local_tensor_proxy,
(i, j), (self.n, self.m))
# ptr.is_global_ptr() will check whether it's an element in the field (which is different from ptr.is_global_var()).
if isinstance(self.entries[0],
ti.Expr) and self.entries[0].ptr.is_global_ptr(
) and ti.current_cfg().dynamic_index:
if ti.current_cfg().dynamic_index and isinstance(
self.entries[0], expr.Expr) and not ti_core.is_custom_type(
self.entries[0].ptr.get_ret_type(
)) and self.entries[0].ptr.is_global_ptr():
# TODO: Add API to query whether AOS or SOA
return ti.global_subscript_with_offset(self.entries[0], (i, j),
(self.n, self.m), True)
Expand Down Expand Up @@ -427,12 +428,6 @@ def copy(self):
ret.entries = copy.copy(self.entries)
return ret

@taichi_scope
def variable(self):
ret = self.copy()
ret.entries = [impl.expr_init(e) for e in ret.entries]
return ret

@taichi_scope
def cast(self, dtype):
"""Cast the matrix element data type.
Expand Down Expand Up @@ -1345,29 +1340,22 @@ def __call__(self, *args):
mat = self.cast(Matrix(entries, dt=self.dtype))
return mat

def cast(self, mat, in_place=False):
if not in_place:
mat = mat.copy()
def cast(self, mat):
# sanity check shape
if self.m != mat.m or self.n != mat.n:
raise TaichiSyntaxError(
f"Incompatible arguments for the custom vector/matrix type: ({self.n}, {self.m}), ({mat.n}, {mat.m})"
)
if in_python_scope():
mat.entries = [
int(x) if self.dtype in ti.integer_types else x
for x in mat.entries
]
else:
# only performs casting in Taichi scope
mat.entries = [cast(x, self.dtype) for x in mat.entries]
return mat
return Matrix([[
int(mat(i, j)) if self.dtype in ti.integer_types else float(
mat(i, j)) for j in range(self.m)
] for i in range(self.n)])
return Matrix([[cast(mat(i, j), self.dtype) for j in range(self.m)]
for i in range(self.n)])

def empty(self):
"""
Create an empty instance of the given compound type.
"""
return Matrix.empty(self.n, self.m)
def scalar_filled(self, value):
return Matrix([[value for _ in range(self.m)] for _ in range(self.n)])

def field(self, **kwargs):
return Matrix.field(self.n, self.m, dtype=self.dtype, **kwargs)
Expand Down
151 changes: 59 additions & 92 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ def __init__(self, *args, **kwargs):
v = Matrix(v)
if isinstance(v, dict):
v = Struct(v)
self.entries[k] = v
self.entries[k] = v if in_python_scope() else impl.expr_init(v)
self.register_members()
self.local_tensor_proxy = None
self.any_array_access = None

@property
def keys(self):
Expand Down Expand Up @@ -114,95 +112,67 @@ def setter(self, value):

def element_wise_unary(self, foo):
_taichi_skip_traceback = 1
ret = self.empty_copy()
entries = {}
for k, v in self.items:
if isinstance(v, expr.Expr):
ret.entries[k] = foo(v)
if is_taichi_class(v):
entries[k] = v.element_wise_unary(foo)
else:
ret.entries[k] = v.element_wise_unary(foo)
return ret
entries[k] = foo(v)
return Struct(entries)

def element_wise_binary(self, foo, other):
_taichi_skip_traceback = 1
ret = self.empty_copy()
if isinstance(other, (dict)):
other = Struct(other)
if isinstance(other, Struct):
if self.entries.keys() != other.entries.keys():
raise TypeError(
f"Member mismatch between structs {self.keys}, {other.keys}"
)
for k, v in self.items:
if isinstance(v, expr.Expr):
ret.entries[k] = foo(v, other.entries[k])
else:
ret.entries[k] = v.element_wise_binary(
foo, other.entries[k])
else: # assumed to be scalar
for k, v in self.items:
if isinstance(v, expr.Expr):
ret.entries[k] = foo(v, other)
else:
ret.entries[k] = v.element_wise_binary(foo, other)
return ret
other = self.broadcast_copy(other)
entries = {}
for k, v in self.items:
if is_taichi_class(v):
entries[k] = v.element_wise_binary(foo, other.entries[k])
else:
entries[k] = foo(v, other.entries[k])
return Struct(entries)

def broadcast_copy(self, other):
if isinstance(other, dict):
other = Struct(other)
if not isinstance(other, Struct):
ret = self.empty_copy()
for k, v in ret.items:
if isinstance(v, (Matrix, Struct)):
ret.entries[k] = v.broadcast_copy(other)
entries = {}
for k, v in self.items:
if is_taichi_class(v):
entries[k] = v.broadcast_copy(other)
else:
ret.entries[k] = other
other = ret
entries[k] = other
other = Struct(entries)
if self.entries.keys() != other.entries.keys():
raise TypeError(
f"Member mismatch between structs {self.keys}, {other.keys}")
return other

def element_wise_writeback_binary(self, foo, other):
ret = self.empty_copy()
if isinstance(other, (dict)):
other = Struct(other)
if is_taichi_class(other):
other = other.variable()
if foo.__name__ == 'assign' and not isinstance(other, Struct):
if foo.__name__ == 'assign' and not isinstance(other, (dict, Struct)):
raise TaichiSyntaxError(
'cannot assign scalar expr to '
f'taichi class {type(self)}, maybe you want to use `a.fill(b)` instead?'
)
if isinstance(other, Struct):
if self.entries.keys() != other.entries.keys():
raise TypeError(
f"Member mismatch between structs {self.keys}, {other.keys}"
)
for k, v in self.items:
if isinstance(v, expr.Expr):
ret.entries[k] = foo(v, other.entries[k])
else:
ret.entries[k] = v.element_wise_binary(
foo, other.entries[k])
else: # assumed to be scalar
for k, v in self.items:
if isinstance(v, expr.Expr):
ret.entries[k] = foo(v, other)
else:
ret.entries[k] = v.element_wise_binary(foo, other)
return ret
other = self.broadcast_copy(other)
entries = {}
for k, v in self.items:
if is_taichi_class(v):
entries[k] = v.element_wise_binary(foo, other.entries[k])
else:
entries[k] = foo(v, other.entries[k])
return self if foo.__name__ == 'assign' else Struct(entries)

def element_wise_ternary(self, foo, other, extra):
ret = self.empty_copy()
other = self.broadcast_copy(other)
extra = self.broadcast_copy(extra)
entries = {}
for k, v in self.items:
if isinstance(v, expr.Expr):
ret.entries[k] = foo(v, other.entries[k], extra.entries[k])
if is_taichi_class(v):
entries[k] = v.element_wise_ternary(foo, other.entries[k],
extra.entries[k])
else:
ret.entries[k] = v.element_wise_ternary(
foo, other.entries[k], extra.entries[k])
return ret
entries[k] = foo(v, other.entries[k], extra.entries[k])
return Struct(entries)

@taichi_scope
def fill(self, val):
Expand Down Expand Up @@ -231,17 +201,6 @@ def copy(self):
ret.entries = copy.copy(self.entries)
return ret

@taichi_scope
def variable(self):
ret = self.copy()
ret.entries = {
k: impl.expr_init(v) if isinstance(v,
(numbers.Number,
expr.Expr)) else v.variable()
for k, v in ret.items
}
return ret

def __len__(self):
"""Get the number of entries in a custom struct"""
return len(self.entries)
Expand Down Expand Up @@ -343,6 +302,18 @@ def field(cls,
return StructField(field_dict, name=name)


class IntermediateStruct(Struct):
"""The Struct type class for compiler internal use only.

Args:
entries (Dict[str, Union[Expr, Matrix, Struct]]): keys and values for struct members.
"""
def __init__(self, entries):
assert isinstance(entries, dict)
self.entries = entries
self.register_members()


class StructField(Field):
"""Taichi struct field with SNode implementation.
Instead of directly contraining Expr entries, the StructField object
Expand Down Expand Up @@ -542,41 +513,37 @@ def __call__(self, *args, **kwargs):
if isinstance(args[0], (numbers.Number, expr.Expr)):
entries = self.scalar_filled(args[0])
else:
# fill a single vector or matrix
# initialize struct members by dictionary
entries = Struct(args[0])
struct = self.cast(entries)
return struct

def cast(self, struct, in_place=False):
if not in_place:
struct = struct.copy()
def cast(self, struct):
# sanity check members
if self.members.keys() != struct.entries.keys():
raise TaichiSyntaxError(
"Incompatible arguments for custom struct members!")
entries = {}
for k, dtype in self.members.items():
if isinstance(dtype, CompoundType):
struct.entries[k] = dtype.cast(struct.entries[k])
entries[k] = dtype.cast(struct.entries[k])
else:
if in_python_scope():
v = struct.entries[k]
struct.entries[k] = int(
entries[k] = int(
v) if dtype in ti.integer_types else float(v)
else:
struct.entries[k] = cast(struct.entries[k], dtype)
return struct
entries[k] = cast(struct.entries[k], dtype)
return Struct(entries)

def empty(self):
"""
Create an empty instance of the given compound type.
Nested structs and matrices need to be recursively handled.
"""
struct = Struct.empty(self.members.keys())
def scalar_filled(self, value):
strongoier marked this conversation as resolved.
Show resolved Hide resolved
entries = {}
for k, dtype in self.members.items():
if isinstance(dtype, CompoundType):
struct.entries[k] = dtype.empty()
return struct
entries[k] = dtype.scalar_filled(value)
else:
entries[k] = value
return Struct(entries)

def field(self, **kwargs):
return Struct.field(self.members, **kwargs)
23 changes: 6 additions & 17 deletions python/taichi/lang/types.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,16 @@
import taichi.lang.matrix
import taichi


class CompoundType:
def empty(self):
"""
Create an empty instance of the given compound type.
"""
raise NotImplementedError
pass

def scalar_filled(self, value):
instance = self.empty()
return instance.broadcast_copy(value)

def field(self, **kwargs):
raise NotImplementedError
def matrix(n, m, dtype):
return taichi.lang.matrix.MatrixType(n, m, dtype)


def matrix(m, n, dtype=None):
return taichi.lang.matrix.MatrixType(m, n, dtype=dtype)


def vector(m, dtype=None):
return taichi.lang.matrix.MatrixType(m, 1, dtype=dtype)
def vector(n, dtype):
return taichi.lang.matrix.MatrixType(n, 1, dtype)


def struct(**kwargs):
Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ inline PrimitiveTypeID get_primitive_data_type() {
}
}

inline bool is_custom_type(DataType dt) {
return dt->is<CustomIntType>() || dt->is<CustomFloatType>();
}

inline bool is_real(DataType dt) {
return dt->is_primitive(PrimitiveTypeID::f16) ||
dt->is_primitive(PrimitiveTypeID::f32) ||
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,7 @@ void export_lang(py::module &m) {
#include "taichi/inc/data_type.inc.h"
#undef PER_TYPE

m.def("is_custom_type", is_custom_type);
m.def("is_integral", is_integral);
m.def("is_signed", is_signed);
m.def("is_real", is_real);
Expand Down