From f2a7d386d16ebda2311a41dc98b9d0dc8395b113 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AD=94=E6=B3=95=E5=B0=91=E5=A5=B3=E8=B5=B5=E5=BF=97?= =?UTF-8?q?=E8=BE=89?= Date: Wed, 15 Feb 2023 17:16:25 +0800 Subject: [PATCH] [lang] Stop broadcasting scalar cond in select statements (#7344) Issue: #7240 ### Brief Summary 1. Stop broadcasting `cond` when it is a scalar in a `select` statement. 2. Since `config.real_matrix_scalarize` is currently set to `True`, we need to scalarize the `select`. I added another code path to do this, which seems ugly. 3. We can see the expected behavior when `config.real_matrix_scalarize` is unset. 4. Add python and c++ unit test for select stmts. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- taichi/ir/frontend_ir.cpp | 16 ++++--- taichi/transforms/scalarize.cpp | 43 ++++++++++++++++++- tests/cpp/ir/frontend_type_inference_test.cpp | 35 +++++++++++++++ tests/python/test_matrix.py | 17 ++++++++ 4 files changed, 102 insertions(+), 9 deletions(-) diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 2dfb9204a3a38..9e811e5f2eb21 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -421,8 +421,7 @@ static std::tuple unify_ternaryop_operands(const Expr &e1, return std::tuple(e1, e2, e3); } - return std::tuple(to_broadcast_tensor(e1, target_dtype), - to_broadcast_tensor(e2, target_dtype), + return std::tuple(e1, to_broadcast_tensor(e2, target_dtype), to_broadcast_tensor(e3, target_dtype)); } @@ -450,19 +449,22 @@ void TernaryOpExpression::type_check(const CompileConfig *config) { op2->ret_type->to_string(), op3->ret_type->to_string())); }; - if (op1_type->is() && op2_type->is() && - op3_type->is()) { + if (op2_type->is() && op3_type->is()) { // valid is_tensor = true; - if (op1_type->cast()->get_shape() != - op2_type->cast()->get_shape()) { + if (op1_type->is() && + op1_type->cast()->get_shape() != + op2_type->cast()->get_shape()) { is_valid = false; } if (op2_type->cast()->get_shape() != op3_type->cast()->get_shape()) { is_valid = false; } - op1_type = op1_type->cast()->get_element_type(); + + if (op1_type->is()) { + op1_type = op1_type->cast()->get_element_type(); + } op2_type = op2_type->cast()->get_element_type(); op3_type = op3_type->cast()->get_element_type(); diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 5306c3e51ca04..42305cf394f25 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -443,8 +443,7 @@ class Scalarize : public BasicStmtVisitor { auto cond_dtype = stmt->op1->ret_type; auto op2_dtype = stmt->op2->ret_type; auto op3_dtype = stmt->op3->ret_type; - if (cond_dtype->is() || op2_dtype->is() || - op3_dtype->is()) { + if (cond_dtype->is()) { // Make sure broadcasting has been correctly applied by // TernaryOpExpression::type_check(). TI_ASSERT(cond_dtype->is() && op2_dtype->is() && @@ -490,6 +489,46 @@ class Scalarize : public BasicStmtVisitor { immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get()); delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); + delayed_modifier_.erase(stmt); + } else if (cond_dtype->is() && + (op2_dtype->is() || op3_dtype->is())) { + TI_ASSERT(cond_dtype->is() && + op2_dtype->is() && op3_dtype->is()); + TI_ASSERT(op2_dtype.get_shape() == op3_dtype.get_shape()); + // Scalarization for LoadStmt should have already replaced all operands + // to MatrixInitStmt. + TI_ASSERT(stmt->op2->is()); + TI_ASSERT(stmt->op3->is()); + + Stmt *cond_val = stmt->op1; + + auto op2_matrix_init_stmt = stmt->op2->cast(); + std::vector op2_vals = op2_matrix_init_stmt->values; + + auto op3_matrix_init_stmt = stmt->op3->cast(); + std::vector op3_vals = op3_matrix_init_stmt->values; + + TI_ASSERT(op2_vals.size() == op3_vals.size()); + + size_t num_elements = op2_vals.size(); + auto primitive_type = stmt->ret_type.get_element_type(); + std::vector matrix_init_values; + for (size_t i = 0; i < num_elements; i++) { + auto ternary_stmt = std::make_unique( + stmt->op_type, cond_val, op2_vals[i], op3_vals[i]); + matrix_init_values.push_back(ternary_stmt.get()); + ternary_stmt->ret_type = primitive_type; + + delayed_modifier_.insert_before(stmt, std::move(ternary_stmt)); + } + + auto matrix_init_stmt = + std::make_unique(matrix_init_values); + matrix_init_stmt->ret_type = stmt->ret_type; + + immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get()); + delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt)); + delayed_modifier_.erase(stmt); } } diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index fe5a7bb284774..b81cc2d29e955 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -84,6 +84,41 @@ TEST(FrontendTypeInference, TernaryOp) { EXPECT_EQ(ternary_f32->ret_type, PrimitiveType::f32); } +TEST(FrontendTypeInference, TernaryOp_NoBroadcast) { + auto cond = value(42); + cond->type_check(nullptr); + EXPECT_EQ(cond->ret_type, PrimitiveType::i32); + + auto const_3 = Expr::make(3); + auto const_5 = Expr::make(5); + std::vector shape = {3, 1}; + + std::vector op2_element = {const_3, const_3, const_3}; + std::vector op3_element = {const_5, const_5, const_5}; + + auto op2 = + Expr::make(op2_element, shape, PrimitiveType::i32); + op2->type_check(nullptr); + auto op3 = + Expr::make(op2_element, shape, PrimitiveType::i32); + op3->type_check(nullptr); + + auto ternary = + Expr::make(TernaryOpType::select, cond, op2, op3); + ternary->type_check(nullptr); + + auto ternary_expr = ternary.cast(); + auto cond_ret_type = ternary_expr->op1->ret_type; + auto op2_ret_type = ternary_expr->op2->ret_type; + auto op3_ret_type = ternary_expr->op3->ret_type; + + EXPECT_TRUE(op2_ret_type->is() && + op2_ret_type->cast()->get_shape() == shape); + EXPECT_TRUE(op3_ret_type->is() && + op3_ret_type->cast()->get_shape() == shape); + EXPECT_EQ(cond_ret_type, PrimitiveType::i32); +} + TEST(FrontendTypeInference, GlobalPtr_Field) { auto prog = std::make_unique(Arch::x64); auto func = []() {}; diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 1cc9da49bdf9b..24e752bc52888 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -1030,6 +1030,23 @@ def test(): test() +@test_utils.test(debug=True) +def test_ternary_op_cond_is_scalar(): + @ti.kernel + def test(): + x = ti.Vector([3, 3, 3]) + y = ti.Vector([5, 5, 5]) + + for i in range(10): + z = ti.select(i % 2, x, y) + if i % 2 == 1: + assert z[0] == x[0] and z[1] == x[1] and z[2] == x[2] + else: + assert z[0] == y[0] and z[1] == y[1] and z[2] == y[2] + + test() + + @test_utils.test(debug=True) def test_fill_op(): @ti.kernel