Skip to content

Commit

Permalink
[lang] Migrate TensorType expansion for TextureOpExpression from Pyth…
Browse files Browse the repository at this point in the history
…on code to Frontend IR (#6968)

Issue: #5819

### Brief Summary
  • Loading branch information
jim19930609 authored Jan 9, 2023
1 parent ecc9664 commit 734b483
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 18 deletions.
24 changes: 13 additions & 11 deletions python/taichi/lang/_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
def _get_entries(mat):
if isinstance(mat, Matrix):
return mat.entries
assert isinstance(mat, Expr) and mat.is_tensor()
return impl.get_runtime().prog.current_ast_builder().expand_exprs(
[mat.ptr])
return [mat]


class TextureSampler:
Expand All @@ -24,9 +22,10 @@ def __init__(self, ptr_expr, num_dims) -> None:

@taichi_scope
def sample_lod(self, uv, lod):
ast_builder = impl.get_runtime().prog.current_ast_builder()
args_group = make_expr_group(*_get_entries(uv), lod)
v = _ti_core.make_texture_op_expr(_ti_core.TextureOpType.kSampleLod,
self.ptr_expr, args_group)
v = ast_builder.make_texture_op_expr(_ti_core.TextureOpType.kSampleLod,
self.ptr_expr, args_group)
r = impl.call_internal("composite_extract_0",
v,
with_runtime_context=False)
Expand All @@ -43,9 +42,10 @@ def sample_lod(self, uv, lod):

@taichi_scope
def fetch(self, index, lod):
ast_builder = impl.get_runtime().prog.current_ast_builder()
args_group = make_expr_group(*_get_entries(index), lod)
v = _ti_core.make_texture_op_expr(_ti_core.TextureOpType.kFetchTexel,
self.ptr_expr, args_group)
v = ast_builder.make_texture_op_expr(
_ti_core.TextureOpType.kFetchTexel, self.ptr_expr, args_group)
r = impl.call_internal("composite_extract_0",
v,
with_runtime_context=False)
Expand All @@ -69,9 +69,10 @@ def __init__(self, ptr_expr, num_dims) -> None:

@taichi_scope
def load(self, index):
ast_builder = impl.get_runtime().prog.current_ast_builder()
args_group = make_expr_group(*_get_entries(index))
v = _ti_core.make_texture_op_expr(_ti_core.TextureOpType.kLoad,
self.ptr_expr, args_group)
v = ast_builder.make_texture_op_expr(_ti_core.TextureOpType.kLoad,
self.ptr_expr, args_group)
r = impl.call_internal("composite_extract_0",
v,
with_runtime_context=False)
Expand All @@ -88,11 +89,12 @@ def load(self, index):

@taichi_scope
def store(self, index, value):
ast_builder = impl.get_runtime().prog.current_ast_builder()
args_group = make_expr_group(*_get_entries(index),
*_get_entries(value))
impl.expr_init(
_ti_core.make_texture_op_expr(_ti_core.TextureOpType.kStore,
self.ptr_expr, args_group))
ast_builder.make_texture_op_expr(_ti_core.TextureOpType.kStore,
self.ptr_expr, args_group))

@property
@taichi_scope
Expand Down
14 changes: 14 additions & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,12 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

TextureOpExpression::TextureOpExpression(TextureOpType op,
Expr texture_ptr,
const ExprGroup &args)
: op(op), texture_ptr(texture_ptr), args(args) {
}

void TextureOpExpression::type_check(CompileConfig *config) {
TI_ASSERT(texture_ptr.is<TexturePtrExpression>());
auto ptr = texture_ptr.cast<TexturePtrExpression>();
Expand Down Expand Up @@ -1651,6 +1657,14 @@ void ASTBuilder::pop_scope() {
loop_state_stack_.pop_back();
}

Expr ASTBuilder::make_texture_op_expr(const TextureOpType &op,
const Expr &texture_ptr,
const ExprGroup &args) {
ExprGroup expanded_args;
expanded_args.exprs = this->expand_exprs(args.exprs);
return Expr::make<TextureOpExpression>(op, texture_ptr, expanded_args);
}

Stmt *flatten_lvalue(Expr expr, Expression::FlattenContext *ctx) {
expr->flatten(ctx);
return expr->get_flattened_stmt();
Expand Down
8 changes: 4 additions & 4 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -726,9 +726,7 @@ class TextureOpExpression : public Expression {

explicit TextureOpExpression(TextureOpType op,
Expr texture_ptr,
const ExprGroup &args)
: op(op), texture_ptr(texture_ptr), args(args) {
}
const ExprGroup &args);

void type_check(CompileConfig *config) override;

Expand Down Expand Up @@ -985,7 +983,9 @@ class ASTBuilder {
void insert_expr_stmt(const Expr &val);
void insert_snode_activate(SNode *snode, const ExprGroup &expr_group);
void insert_snode_deactivate(SNode *snode, const ExprGroup &expr_group);

Expr make_texture_op_expr(const TextureOpType &op,
const Expr &texture_ptr,
const ExprGroup &args);
/*
* This function allocates the space for a new item (a struct or a scalar)
* in the Dynamic SNode, and assigns values to the elements inside it.
Expand Down
4 changes: 1 addition & 3 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ void export_lang(py::module &m) {
.def("insert_expr_stmt", &ASTBuilder::insert_expr_stmt)
.def("insert_thread_idx_expr", &ASTBuilder::insert_thread_idx_expr)
.def("insert_patch_idx_expr", &ASTBuilder::insert_patch_idx_expr)
.def("make_texture_op_expr", &ASTBuilder::make_texture_op_expr)
.def("expand_exprs", &ASTBuilder::expand_exprs)
.def("mesh_index_conversion", &ASTBuilder::mesh_index_conversion)
.def("expr_subscript", &ASTBuilder::expr_subscript)
Expand Down Expand Up @@ -955,9 +956,6 @@ void export_lang(py::module &m) {
texture.value(texture_op_type_name(TextureOpType(t)).c_str(),
TextureOpType(t));
texture.export_values();
m.def("make_texture_op_expr",
Expr::make<TextureOpExpression, const TextureOpType &, const Expr &,
const ExprGroup &>);

auto &&bin = py::enum_<BinaryOpType>(m, "BinaryOpType", py::arithmetic());
for (int t = 0; t <= (int)BinaryOpType::undefined; t++)
Expand Down

0 comments on commit 734b483

Please sign in to comment.