Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Enable definition of local matrices/vectors #5782

Merged
merged 46 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from 45 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
4ce08c8
cherrypick Matrix repr support
AD1024 Aug 15, 2022
549a359
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2022
07a4dc1
matrix assign
AD1024 Aug 15, 2022
d2264f4
Merge branch 'matrix-repr' of github.com:AD1024/taichi into matrix-repr
AD1024 Aug 15, 2022
efca3f0
move checks to caller side
AD1024 Aug 15, 2022
82c9413
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 15, 2022
38ec750
use ==
AD1024 Aug 15, 2022
13159fd
merge and format
AD1024 Aug 15, 2022
9c91103
refine impl
AD1024 Aug 17, 2022
28f3e0a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2022
bf719a3
no long in use
AD1024 Aug 17, 2022
72e8f26
Merge branch 'matrix-repr' of github.com:AD1024/taichi into matrix-repr
AD1024 Aug 17, 2022
65199ea
add some comments
AD1024 Aug 17, 2022
cbf1ea8
get rid of always-true condition
AD1024 Aug 23, 2022
4c8d6b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 23, 2022
1a9df8c
save
AD1024 Aug 23, 2022
834699e
some fixes for print and matrix expr
AD1024 Aug 23, 2022
23d7bf7
fix codegen alloca size
AD1024 Aug 23, 2022
6a8a8cb
unsupport empty matrix
AD1024 Aug 24, 2022
08926ef
only check and cast elements
AD1024 Aug 24, 2022
1ae02aa
fmt Vectors to one line
AD1024 Aug 24, 2022
17412e8
lift duplicate part
AD1024 Aug 24, 2022
5421fe1
clean-up
AD1024 Aug 24, 2022
b9fd3a9
clean-up cse code
AD1024 Aug 24, 2022
43a456a
breaks ci; keep as original impl
AD1024 Aug 24, 2022
78ad14a
handle alloca
AD1024 Aug 24, 2022
40825f4
move checks to front
AD1024 Aug 24, 2022
88d01b6
Merge branch 'master' into matrix-repr
AD1024 Aug 24, 2022
f395d2a
reuse code
AD1024 Aug 24, 2022
2a6a8e6
Revert "clean-up cse code"
AD1024 Aug 24, 2022
7f8ca37
clean up together
AD1024 Aug 24, 2022
988abb3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2022
93b3a03
also checks for tlctx
AD1024 Aug 24, 2022
13d2efd
Merge branch 'matrix-repr' of github.com:AD1024/taichi into matrix-repr
AD1024 Aug 24, 2022
b2e101a
format
AD1024 Aug 24, 2022
6fcf070
fix codegen: allocate pointer to vector
AD1024 Aug 24, 2022
f43d2a8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2022
c613318
check real matrix when allocating memory
AD1024 Aug 24, 2022
41b6641
Merge branch 'matrix-repr' of github.com:AD1024/taichi into matrix-repr
AD1024 Aug 24, 2022
9fc758d
format and fix tc for variable holding matrix expression
AD1024 Aug 24, 2022
f635b7c
refactor: change to `make_local_matrix` which returns only an Expr; p…
AD1024 Aug 25, 2022
f82dc25
get rid of duplicated check
AD1024 Aug 25, 2022
bd68c23
save changes
AD1024 Aug 25, 2022
5d00c98
format
AD1024 Aug 25, 2022
3451397
also rename cxx part
AD1024 Aug 25, 2022
5ed0e93
Apply suggestions from code review
strongoier Aug 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from taichi.lang.ast.symbol_resolver import ASTResolver
from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError
from taichi.lang.field import Field
from taichi.lang.matrix import (Matrix, MatrixType, _PyScopeMatrixImpl,
from taichi.lang.matrix import (Matrix, MatrixType, Vector, _PyScopeMatrixImpl,
_TiScopeMatrixImpl)
from taichi.lang.snode import append
from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type
Expand Down Expand Up @@ -489,6 +489,12 @@ def build_Call(ctx, node):
node.ptr = impl.ti_format(*args, **keywords)
return node.ptr

if (isinstance(node.func, ast.Attribute) and
(func == Matrix
or func == Vector)) and impl.current_cfg().real_matrix:
node.ptr = matrix.make_matrix(*args, **keywords)
return node.ptr

if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords):
return node.ptr

Expand Down
16 changes: 15 additions & 1 deletion python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from taichi.lang.field import Field, ScalarField
from taichi.lang.kernel_arguments import SparseMatrixProxy
from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType,
_IntermediateMatrix, _MatrixFieldElement)
_IntermediateMatrix, _MatrixFieldElement,
make_matrix)
from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance,
MeshRelationAccessProxy,
MeshReorderedMatrixFieldProxy,
Expand All @@ -36,6 +37,12 @@ def expr_init_local_tensor(shape, element_type, elements):
get_runtime().get_current_src_info())


@taichi_scope
def make_matrix_expr(shape, element_type, elements):
return get_runtime().prog.current_ast_builder().make_local_matrix(
strongoier marked this conversation as resolved.
Show resolved Hide resolved
shape, element_type, elements)


@taichi_scope
def expr_init_shared_array(shape, element_type):
return get_runtime().prog.current_ast_builder().expr_alloca_shared_array(
Expand All @@ -49,6 +56,13 @@ def expr_init(rhs):
if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")):
return Matrix(*rhs.to_list(), ndim=rhs.ndim)
if isinstance(rhs, Matrix):
if current_cfg().real_matrix:
if rhs.ndim == 1:
entries = [rhs(i) for i in range(rhs.n)]
else:
entries = [[rhs(i, j) for j in range(rhs.m)]
for i in range(rhs.n)]
return make_matrix(entries)
return Matrix(rhs.to_list(), ndim=rhs.ndim)
if isinstance(rhs, SharedArray):
return rhs
Expand Down
15 changes: 15 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,21 @@ def prop_setter(instance, value):
return cls


def make_matrix(arr, dt=None):
assert len(arr) > 0, "Cannot create empty matrix"
is_matrix = isinstance(arr[0], Iterable)
if dt is None:
dt = _make_entries_initializer(is_matrix).infer_dt(arr)
if not is_matrix:
return impl.Expr(
impl.make_matrix_expr([len(arr)], dt,
[expr.Expr(elt).ptr for elt in arr]))
return impl.Expr(
impl.make_matrix_expr(
[len(arr), len(arr[0])], dt,
[expr.Expr(elt).ptr for row in arr for elt in row]))


class _MatrixBaseImpl:
def __init__(self, m, n, entries):
self.m = m
Expand Down
8 changes: 8 additions & 0 deletions taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
emit(expr->indices.exprs);
}

void visit(MatrixExpression *expr) override {
emit(ExprOpCode::MatrixExpression);
emit(expr->dt);
for (auto elt : expr->elements) {
emit(elt);
}
}

void visit(StrideExpression *expr) override {
emit(ExprOpCode::StrideExpression);
emit(expr->var);
Expand Down
1 change: 1 addition & 0 deletions taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ static std::vector<std::uint8_t> get_offline_cache_key_of_compile_config(
serializer(config->demote_no_access_mesh_fors);
serializer(config->experimental_auto_mesh_local);
serializer(config->auto_mesh_local_default_occupacy);
serializer(config->real_matrix);
serializer.finalize();

return serializer.data;
Expand Down
57 changes: 47 additions & 10 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,11 @@ void TaskCodeGenLLVM::visit(Block *stmt_list) {
void TaskCodeGenLLVM::visit(AllocaStmt *stmt) {
if (stmt->ret_type->is<TensorType>()) {
auto tensor_type = stmt->ret_type->cast<TensorType>();
auto type = tlctx->get_data_type(tensor_type->get_element_type());
auto array_size = tlctx->get_constant(tensor_type->get_num_elements());
// Return type is [array_size x type]*.
auto type = kernel->program->config.real_matrix
? tlctx->get_data_type(tensor_type)
: tlctx->get_data_type(tensor_type->get_element_type());
// Return type is vector<tensor_type>* if use real matrix.
// otherwise the return type is [type * array_size]*
if (stmt->is_shared) {
size_t data_element_size = tlctx->get_type_size(
tlctx->get_data_type(tensor_type->get_element_type()));
Expand All @@ -148,7 +150,12 @@ void TaskCodeGenLLVM::visit(AllocaStmt *stmt) {
tlctx->get_data_type(tensor_type->get_element_type()), 0);
llvm_val[stmt] = builder->CreatePointerCast(ptr, ptr_type);
} else {
llvm_val[stmt] = create_entry_block_alloca(type, 0, array_size);
if (kernel->program->config.real_matrix)
llvm_val[stmt] =
create_entry_block_alloca(type, stmt->ret_type.is_pointer());
else
llvm_val[stmt] = create_entry_block_alloca(
type, 0, tlctx->get_constant(tensor_type->get_num_elements()));
}
} else {
TI_ASSERT(stmt->width() == 1);
Expand Down Expand Up @@ -688,6 +695,13 @@ llvm::Type *TaskCodeGenLLVM::llvm_type(DataType dt) {
return llvm::Type::getDoubleTy(*llvm_context);
} else if (dt->is_primitive(PrimitiveTypeID::f16)) {
return llvm::Type::getHalfTy(*llvm_context);
} else if (dt->is<TensorType>()) {
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
TI_ASSERT_INFO(kernel->program->config.real_matrix,
"Real matrix not enabled but got TensorType");
auto tensor_type = dt->cast<TensorType>();
auto element_type = llvm_type(tensor_type->get_element_type());
return llvm::VectorType::get(element_type, tensor_type->get_num_elements(),
/*scalable=*/false);
} else {
TI_NOT_IMPLEMENTED;
}
Expand Down Expand Up @@ -796,16 +810,29 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) {
TI_ASSERT(stmt->width() == 1);
std::vector<llvm::Value *> args;
std::string formats;
auto value_for_printf = [this](llvm::Value *to_print, DataType dtype) {
if (dtype->is_primitive(PrimitiveTypeID::f32) ||
dtype->is_primitive(PrimitiveTypeID::f16))
return this->builder->CreateFPExt(
to_print, this->tlctx->get_data_type(PrimitiveType::f64));
return to_print;
};
for (auto const &content : stmt->contents) {
if (std::holds_alternative<Stmt *>(content)) {
auto arg_stmt = std::get<Stmt *>(content);
auto value = llvm_val[arg_stmt];
if (arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f32) ||
arg_stmt->ret_type->is_primitive(PrimitiveTypeID::f16))
value = builder->CreateFPExt(value,
tlctx->get_data_type(PrimitiveType::f64));
args.push_back(value);
formats += data_type_format(arg_stmt->ret_type);
if (arg_stmt->ret_type->is<TensorType>()) {
auto dtype = arg_stmt->ret_type->cast<TensorType>();
auto elem_type = dtype->get_element_type();
for (int i = 0; i < dtype->get_num_elements(); ++i) {
auto elem_value = builder->CreateExtractElement(value, i);
args.push_back(value_for_printf(elem_value, elem_type));
}
formats += data_type_format(arg_stmt->ret_type);
} else {
args.push_back(value_for_printf(value, arg_stmt->ret_type));
formats += data_type_format(arg_stmt->ret_type);
}
} else {
auto arg_str = std::get<std::string>(content);
auto value = builder->CreateGlobalStringPtr(arg_str, "content_string");
Expand Down Expand Up @@ -2515,6 +2542,16 @@ void TaskCodeGenLLVM::visit(MeshPatchIndexStmt *stmt) {
llvm_val[stmt] = get_arg(2);
}

void TaskCodeGenLLVM::visit(MatrixInitStmt *stmt) {
auto type = tlctx->get_data_type(stmt->ret_type->as<TensorType>());
llvm::Value *vec = llvm::UndefValue::get(type);
for (int i = 0; i < stmt->values.size(); ++i) {
auto *elem = llvm_val[stmt->values[i]];
vec = builder->CreateInsertElement(vec, elem, i);
}
llvm_val[stmt] = vec;
}

void TaskCodeGenLLVM::eliminate_unused_functions() {
TaichiLLVMContext::eliminate_unused_functions(
module.get(), [&](std::string func_name) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

void visit(ReferenceStmt *stmt) override;

void visit(MatrixInitStmt *stmt) override;

llvm::Value *create_xlogue(std::unique_ptr<Block> &block);

llvm::Value *create_mesh_xlogue(std::unique_ptr<Block> &block);
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/expressions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ PER_EXPRESSION(InternalFuncCallExpression)
PER_EXPRESSION(ExternalTensorExpression)
PER_EXPRESSION(GlobalVariableExpression)
PER_EXPRESSION(IndexExpression)
PER_EXPRESSION(MatrixExpression)
PER_EXPRESSION(StrideExpression)
PER_EXPRESSION(RangeAssumptionExpression)
PER_EXPRESSION(LoopUniqueExpression)
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ PER_STATEMENT(LoopUniqueStmt)
PER_STATEMENT(AssertStmt)
PER_STATEMENT(ExternalFuncCallStmt)
PER_STATEMENT(ExternalTensorShapeAlongAxisStmt)
PER_STATEMENT(MatrixInitStmt)

// Locals with reverse-mode autodiff
PER_STATEMENT(AdStackAllocaStmt)
Expand Down
7 changes: 7 additions & 0 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
}
}

void visit(MatrixExpression *expr) override {
emit('[');
emit_vector(expr->elements);
emit(']');
emit(fmt::format(" (dt={})", expr->dt->to_string()));
}

void visit(IndexExpression *expr) override {
expr->var->accept(this);
emit('[');
Expand Down
26 changes: 26 additions & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,26 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx,
return ctx->push_back<PtrOffsetStmt>(var->stmt, offset_stmt);
}

void MatrixExpression::type_check(CompileConfig *config) {
// TODO: typecheck matrix
for (auto &arg : elements) {
TI_ASSERT_TYPE_CHECKED(arg);
}
ret_type = dt;
}

void MatrixExpression::flatten(FlattenContext *ctx) {
// TODO: implement flatten
TI_ASSERT(this->dt->is<TensorType>());
std::vector<Stmt *> values;
for (auto &elt : elements) {
flatten_rvalue(elt, ctx);
values.push_back(elt->stmt);
}
stmt = ctx->push_back<MatrixInitStmt>(values);
stmt->ret_type = this->dt;
}

bool IndexExpression::is_field() const {
return var.is<GlobalVariableExpression>();
}
Expand Down Expand Up @@ -970,6 +990,12 @@ Expr ASTBuilder::expr_alloca() {
return var;
}

Expr ASTBuilder::make_matrix_expr(const std::vector<int> &shape,
const DataType &dt,
const std::vector<Expr> &elements) {
return Expr(std::make_shared<MatrixExpression>(elements, shape, dt));
}

Expr ASTBuilder::expr_alloca_local_tensor(const std::vector<int> &shape,
const DataType &element_type,
const ExprGroup &elements,
Expand Down
26 changes: 26 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,29 @@ class GlobalVariableExpression : public Expression {
TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

/**
* Creating a local matrix;
* lowered from ti.Matrix with real_matrix=True
*/
class MatrixExpression : public Expression {
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
public:
std::vector<Expr> elements;
DataType dt;

MatrixExpression(const std::vector<Expr> &elements,
std::vector<int> shape,
DataType element_type)
: elements(elements) {
this->dt = DataType(TypeFactory::create_tensor_type(shape, element_type));
}

void type_check(CompileConfig *config) override;

void flatten(FlattenContext *ctx) override;

TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class IndexExpression : public Expression {
public:
// `var` is one of GlobalVariableExpression, ExternalTensorExpression,
Expand Down Expand Up @@ -865,6 +888,9 @@ class ASTBuilder {
const std::function<void(Expr)> &func);

Expr make_id_expr(const std::string &name);
Expr make_matrix_expr(const std::vector<int> &shape,
const DataType &dt,
const std::vector<Expr> &elements);
Expr insert_thread_idx_expr();
Expr insert_patch_idx_expr();
void create_kernel_exprgroup_return(const ExprGroup &group);
Expand Down
15 changes: 15 additions & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -1807,5 +1807,20 @@ class MeshPatchIndexStmt : public Stmt {
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* Initialization of a local matrix
*/
class MatrixInitStmt : public Stmt {
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
public:
std::vector<Stmt *> values;

MatrixInitStmt(const std::vector<Stmt *> &values) : values(values) {
TI_STMT_REG_FIELDS;
}

TI_STMT_DEF_FIELDS(ret_type, values);
TI_DEFINE_ACCEPT_AND_CLONE
};

} // namespace lang
} // namespace taichi
32 changes: 32 additions & 0 deletions taichi/ir/type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,36 @@ int data_type_size(DataType t) {
}
}

std::string tensor_type_format_helper(const std::vector<int> &shape,
std::string format_str,
int dim) {
std::string fmt = "[";
for (int i = 0; i < shape[dim]; ++i) {
if (dim != shape.size() - 1) {
fmt += tensor_type_format_helper(shape, format_str, dim + 1);
} else {
fmt += format_str;
}
if (i != shape[dim] - 1) {
fmt += ", ";
if (dim == 0 && dim != shape.size() - 1) {
fmt += "\n";
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
}
}
}
fmt += "]";
return fmt;
}

std::string tensor_type_format(DataType t) {
TI_ASSERT(t->is<TensorType>());
auto tensor_type = t->as<TensorType>();
auto shape = tensor_type->get_shape();
auto element_type = tensor_type->get_element_type();
auto element_type_format = data_type_format(element_type);
return tensor_type_format_helper(shape, element_type_format, 0);
}

std::string data_type_format(DataType dt) {
if (dt->is_primitive(PrimitiveTypeID::i16)) {
return "%hd";
Expand All @@ -79,6 +109,8 @@ std::string data_type_format(DataType dt) {
// TaskCodeGenLLVM::visit(PrintStmt *stmt) and
// TaskCodeGenCUDA::visit(PrintStmt *stmt) for more details.
return "%f";
} else if (dt->is<TensorType>()) {
return tensor_type_format(dt);
} else {
TI_NOT_IMPLEMENTED
}
Expand Down
1 change: 1 addition & 0 deletions taichi/program/compile_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ CompileConfig::CompileConfig() {
detect_read_only = true;
ndarray_use_cached_allocator = true;
use_mesh = false;
real_matrix = false;

saturating_grid_dim = 0;
max_block_dim = 0;
Expand Down
Loading