Skip to content

Commit

Permalink
[lang] Stop broadcasting scalar cond in select statements (taichi-dev…
Browse files Browse the repository at this point in the history
…#7344)

Issue: taichi-dev#7240 

### Brief Summary
1. Stop broadcasting `cond` when it is a scalar in a `select` statement.
2. Since `config.real_matrix_scalarize` is currently set to `True`, we
need to scalarize the `select`. I added another code path to do this,
which seems ugly.
3. We can see the expected behavior when `config.real_matrix_scalarize`
is unset.
4. Add python and c++ unit test for select stmts.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 0d3c3ff commit f2a7d38
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 9 deletions.
16 changes: 9 additions & 7 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,7 @@ static std::tuple<Expr, Expr, Expr> unify_ternaryop_operands(const Expr &e1,
return std::tuple(e1, e2, e3);
}

return std::tuple(to_broadcast_tensor(e1, target_dtype),
to_broadcast_tensor(e2, target_dtype),
return std::tuple(e1, to_broadcast_tensor(e2, target_dtype),
to_broadcast_tensor(e3, target_dtype));
}

Expand Down Expand Up @@ -450,19 +449,22 @@ void TernaryOpExpression::type_check(const CompileConfig *config) {
op2->ret_type->to_string(), op3->ret_type->to_string()));
};

if (op1_type->is<TensorType>() && op2_type->is<TensorType>() &&
op3_type->is<TensorType>()) {
if (op2_type->is<TensorType>() && op3_type->is<TensorType>()) {
// valid
is_tensor = true;
if (op1_type->cast<TensorType>()->get_shape() !=
op2_type->cast<TensorType>()->get_shape()) {
if (op1_type->is<TensorType>() &&
op1_type->cast<TensorType>()->get_shape() !=
op2_type->cast<TensorType>()->get_shape()) {
is_valid = false;
}
if (op2_type->cast<TensorType>()->get_shape() !=
op3_type->cast<TensorType>()->get_shape()) {
is_valid = false;
}
op1_type = op1_type->cast<TensorType>()->get_element_type();

if (op1_type->is<TensorType>()) {
op1_type = op1_type->cast<TensorType>()->get_element_type();
}
op2_type = op2_type->cast<TensorType>()->get_element_type();
op3_type = op3_type->cast<TensorType>()->get_element_type();

Expand Down
43 changes: 41 additions & 2 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,7 @@ class Scalarize : public BasicStmtVisitor {
auto cond_dtype = stmt->op1->ret_type;
auto op2_dtype = stmt->op2->ret_type;
auto op3_dtype = stmt->op3->ret_type;
if (cond_dtype->is<TensorType>() || op2_dtype->is<TensorType>() ||
op3_dtype->is<TensorType>()) {
if (cond_dtype->is<TensorType>()) {
// Make sure broadcasting has been correctly applied by
// TernaryOpExpression::type_check().
TI_ASSERT(cond_dtype->is<TensorType>() && op2_dtype->is<TensorType>() &&
Expand Down Expand Up @@ -490,6 +489,46 @@ class Scalarize : public BasicStmtVisitor {
immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));

delayed_modifier_.erase(stmt);
} else if (cond_dtype->is<PrimitiveType>() &&
(op2_dtype->is<TensorType>() || op3_dtype->is<TensorType>())) {
TI_ASSERT(cond_dtype->is<PrimitiveType>() &&
op2_dtype->is<TensorType>() && op3_dtype->is<TensorType>());
TI_ASSERT(op2_dtype.get_shape() == op3_dtype.get_shape());
// Scalarization for LoadStmt should have already replaced all operands
// to MatrixInitStmt.
TI_ASSERT(stmt->op2->is<MatrixInitStmt>());
TI_ASSERT(stmt->op3->is<MatrixInitStmt>());

Stmt *cond_val = stmt->op1;

auto op2_matrix_init_stmt = stmt->op2->cast<MatrixInitStmt>();
std::vector<Stmt *> op2_vals = op2_matrix_init_stmt->values;

auto op3_matrix_init_stmt = stmt->op3->cast<MatrixInitStmt>();
std::vector<Stmt *> op3_vals = op3_matrix_init_stmt->values;

TI_ASSERT(op2_vals.size() == op3_vals.size());

size_t num_elements = op2_vals.size();
auto primitive_type = stmt->ret_type.get_element_type();
std::vector<Stmt *> matrix_init_values;
for (size_t i = 0; i < num_elements; i++) {
auto ternary_stmt = std::make_unique<TernaryOpStmt>(
stmt->op_type, cond_val, op2_vals[i], op3_vals[i]);
matrix_init_values.push_back(ternary_stmt.get());
ternary_stmt->ret_type = primitive_type;

delayed_modifier_.insert_before(stmt, std::move(ternary_stmt));
}

auto matrix_init_stmt =
std::make_unique<MatrixInitStmt>(matrix_init_values);
matrix_init_stmt->ret_type = stmt->ret_type;

immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));

delayed_modifier_.erase(stmt);
}
}
Expand Down
35 changes: 35 additions & 0 deletions tests/cpp/ir/frontend_type_inference_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,41 @@ TEST(FrontendTypeInference, TernaryOp) {
EXPECT_EQ(ternary_f32->ret_type, PrimitiveType::f32);
}

TEST(FrontendTypeInference, TernaryOp_NoBroadcast) {
auto cond = value<int32>(42);
cond->type_check(nullptr);
EXPECT_EQ(cond->ret_type, PrimitiveType::i32);

auto const_3 = Expr::make<ConstExpression, int32>(3);
auto const_5 = Expr::make<ConstExpression, int32>(5);
std::vector<int> shape = {3, 1};

std::vector<Expr> op2_element = {const_3, const_3, const_3};
std::vector<Expr> op3_element = {const_5, const_5, const_5};

auto op2 =
Expr::make<MatrixExpression>(op2_element, shape, PrimitiveType::i32);
op2->type_check(nullptr);
auto op3 =
Expr::make<MatrixExpression>(op2_element, shape, PrimitiveType::i32);
op3->type_check(nullptr);

auto ternary =
Expr::make<TernaryOpExpression>(TernaryOpType::select, cond, op2, op3);
ternary->type_check(nullptr);

auto ternary_expr = ternary.cast<TernaryOpExpression>();
auto cond_ret_type = ternary_expr->op1->ret_type;
auto op2_ret_type = ternary_expr->op2->ret_type;
auto op3_ret_type = ternary_expr->op3->ret_type;

EXPECT_TRUE(op2_ret_type->is<TensorType>() &&
op2_ret_type->cast<TensorType>()->get_shape() == shape);
EXPECT_TRUE(op3_ret_type->is<TensorType>() &&
op3_ret_type->cast<TensorType>()->get_shape() == shape);
EXPECT_EQ(cond_ret_type, PrimitiveType::i32);
}

TEST(FrontendTypeInference, GlobalPtr_Field) {
auto prog = std::make_unique<Program>(Arch::x64);
auto func = []() {};
Expand Down
17 changes: 17 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,23 @@ def test():
test()


@test_utils.test(debug=True)
def test_ternary_op_cond_is_scalar():
@ti.kernel
def test():
x = ti.Vector([3, 3, 3])
y = ti.Vector([5, 5, 5])

for i in range(10):
z = ti.select(i % 2, x, y)
if i % 2 == 1:
assert z[0] == x[0] and z[1] == x[1] and z[2] == x[2]
else:
assert z[0] == y[0] and z[1] == y[1] and z[2] == y[2]

test()


@test_utils.test(debug=True)
def test_fill_op():
@ti.kernel
Expand Down

0 comments on commit f2a7d38

Please sign in to comment.