Skip to content

Commit

Permalink
[refactor] Simplify ndarray arg declaration
Browse files Browse the repository at this point in the history
ghstack-source-id: 1f69c90e1e4fe433fd2f73464f370a4bde79d630
Pull Request resolved: #8076
  • Loading branch information
Ailing Zhang authored and Taichi Gardener committed May 26, 2023
1 parent 99d52c4 commit d56f7b1
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 11 deletions.
3 changes: 1 addition & 2 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,9 +625,8 @@ def transform_as_kernel():
kernel_arguments.decl_ndarray_arg(
to_taichi_type(ctx.arg_features[i][0]),
ctx.arg_features[i][1],
ctx.arg_features[i][2],
ctx.func.arguments[i].name,
ctx.arg_features[i][3],
ctx.arg_features[i][2],
),
)
elif isinstance(ctx.func.arguments[i].annotation, texture_type.TextureType):
Expand Down
8 changes: 1 addition & 7 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,7 @@ def decl_sparse_matrix(dtype, name):
return SparseMatrixProxy(_ti_core.make_arg_load_expr(arg_id, ptr_type, False), value_type)


def decl_ndarray_arg(dtype, ndim, element_shape, name, needs_grad):
# TODO: use element_type from runtime ndarray once we support tensortype hashing
dtype = cook_dtype(dtype)
element_dim = len(element_shape)
element_type = (
_ti_core.get_type_factory_instance().get_tensor_type(element_shape, dtype) if element_dim != 0 else dtype
)
def decl_ndarray_arg(element_type, ndim, name, needs_grad):
arg_id = impl.get_runtime().compiling_callable.insert_ndarray_param(element_type, ndim, name, needs_grad)
return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad))

Expand Down
10 changes: 8 additions & 2 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def extract_arg(arg, anno):
if isinstance(arg, taichi.lang._ndarray.Ndarray):
anno.check_matched(arg.get_type())
needs_grad = (arg.grad is not None) if anno.needs_grad is None else anno.needs_grad
return arg.dtype, len(arg.shape), arg.element_shape, needs_grad
return arg.element_type, len(arg.shape), needs_grad
# external arrays
shape = getattr(arg, "shape", None)
if shape is None:
Expand Down Expand Up @@ -427,7 +427,13 @@ def extract_arg(arg, anno):
f"but the argument has {len(shape)} dimensions"
)
needs_grad = getattr(arg, "requires_grad", False) if anno.needs_grad is None else anno.needs_grad
return to_taichi_type(arg.dtype), len(shape) - len(element_shape), element_shape, needs_grad
dtype = to_taichi_type(arg.dtype)
element_type = (
_ti_core.get_type_factory_instance().get_tensor_type(element_shape, dtype)
if len(element_shape) != 0
else arg.dtype
)
return element_type, len(shape) - len(element_shape), needs_grad
if isinstance(anno, sparse_matrix_builder):
return arg.dtype
# Use '#' as a placeholder because other kinds of arguments are not involved in template instantiation
Expand Down
7 changes: 7 additions & 0 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ DataType PrimitiveType::get(PrimitiveTypeID t) {
std::size_t DataType::hash() const {
if (auto primitive = ptr_->cast<PrimitiveType>()) {
return (std::size_t)primitive->type;
} else if (auto tensor_type = ptr_->cast<TensorType>()) {
std::size_t ret = 0;
auto tensor_shape = tensor_type->get_shape();
for (int i = 0; i < tensor_shape.size(); i++) {
ret += (i + 1) * 107 + tensor_shape[i];
}
return ret + DataType(tensor_type->get_element_type()).hash();
} else if (auto pointer = ptr_->cast<PointerType>()) {
return 10007 + DataType(pointer->get_pointee_type()).hash();
} else {
Expand Down

0 comments on commit d56f7b1

Please sign in to comment.