Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Stop broadcasting scalar cond in select statements #7344

Merged
merged 8 commits into from
Feb 15, 2023
16 changes: 9 additions & 7 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,7 @@ static std::tuple<Expr, Expr, Expr> unify_ternaryop_operands(const Expr &e1,
return std::tuple(e1, e2, e3);
}

return std::tuple(to_broadcast_tensor(e1, target_dtype),
to_broadcast_tensor(e2, target_dtype),
return std::tuple(e1, to_broadcast_tensor(e2, target_dtype),
to_broadcast_tensor(e3, target_dtype));
}

Expand Down Expand Up @@ -450,19 +449,22 @@ void TernaryOpExpression::type_check(const CompileConfig *config) {
op2->ret_type->to_string(), op3->ret_type->to_string()));
};

if (op1_type->is<TensorType>() && op2_type->is<TensorType>() &&
op3_type->is<TensorType>()) {
if (op2_type->is<TensorType>() && op3_type->is<TensorType>()) {
// valid
is_tensor = true;
if (op1_type->cast<TensorType>()->get_shape() !=
op2_type->cast<TensorType>()->get_shape()) {
if (op1_type->is<TensorType>() &&
op1_type->cast<TensorType>()->get_shape() !=
op2_type->cast<TensorType>()->get_shape()) {
is_valid = false;
}
if (op2_type->cast<TensorType>()->get_shape() !=
op3_type->cast<TensorType>()->get_shape()) {
is_valid = false;
}
op1_type = op1_type->cast<TensorType>()->get_element_type();

if (op1_type->is<TensorType>()) {
op1_type = op1_type->cast<TensorType>()->get_element_type();
}
op2_type = op2_type->cast<TensorType>()->get_element_type();
op3_type = op3_type->cast<TensorType>()->get_element_type();

Expand Down
43 changes: 41 additions & 2 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,7 @@ class Scalarize : public BasicStmtVisitor {
auto cond_dtype = stmt->op1->ret_type;
auto op2_dtype = stmt->op2->ret_type;
auto op3_dtype = stmt->op3->ret_type;
if (cond_dtype->is<TensorType>() || op2_dtype->is<TensorType>() ||
op3_dtype->is<TensorType>()) {
if (cond_dtype->is<TensorType>()) {
// Make sure broadcasting has been correctly applied by
// TernaryOpExpression::type_check().
TI_ASSERT(cond_dtype->is<TensorType>() && op2_dtype->is<TensorType>() &&
Expand Down Expand Up @@ -490,6 +489,46 @@ class Scalarize : public BasicStmtVisitor {
immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));

delayed_modifier_.erase(stmt);
} else if (cond_dtype->is<PrimitiveType>() &&
(op2_dtype->is<TensorType>() || op3_dtype->is<TensorType>())) {
TI_ASSERT(cond_dtype->is<PrimitiveType>() &&
op2_dtype->is<TensorType>() && op3_dtype->is<TensorType>());
TI_ASSERT(op2_dtype.get_shape() == op3_dtype.get_shape());
// Scalarization for LoadStmt should have already replaced all operands
// to MatrixInitStmt.
TI_ASSERT(stmt->op2->is<MatrixInitStmt>());
TI_ASSERT(stmt->op3->is<MatrixInitStmt>());

Stmt *cond_val = stmt->op1;

auto op2_matrix_init_stmt = stmt->op2->cast<MatrixInitStmt>();
std::vector<Stmt *> op2_vals = op2_matrix_init_stmt->values;

auto op3_matrix_init_stmt = stmt->op3->cast<MatrixInitStmt>();
std::vector<Stmt *> op3_vals = op3_matrix_init_stmt->values;

TI_ASSERT(op2_vals.size() == op3_vals.size());

size_t num_elements = op2_vals.size();
auto primitive_type = stmt->ret_type.get_element_type();
std::vector<Stmt *> matrix_init_values;
for (size_t i = 0; i < num_elements; i++) {
auto ternary_stmt = std::make_unique<TernaryOpStmt>(
stmt->op_type, cond_val, op2_vals[i], op3_vals[i]);
matrix_init_values.push_back(ternary_stmt.get());
ternary_stmt->ret_type = primitive_type;

delayed_modifier_.insert_before(stmt, std::move(ternary_stmt));
}

auto matrix_init_stmt =
std::make_unique<MatrixInitStmt>(matrix_init_values);
matrix_init_stmt->ret_type = stmt->ret_type;

immediate_modifier_.replace_usages_with(stmt, matrix_init_stmt.get());
delayed_modifier_.insert_before(stmt, std::move(matrix_init_stmt));

delayed_modifier_.erase(stmt);
}
}
Expand Down
35 changes: 35 additions & 0 deletions tests/cpp/ir/frontend_type_inference_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,41 @@ TEST(FrontendTypeInference, TernaryOp) {
EXPECT_EQ(ternary_f32->ret_type, PrimitiveType::f32);
}

TEST(FrontendTypeInference, TernaryOp_NoBroadcast) {
auto cond = value<int32>(42);
cond->type_check(nullptr);
EXPECT_EQ(cond->ret_type, PrimitiveType::i32);

auto const_3 = Expr::make<ConstExpression, int32>(3);
auto const_5 = Expr::make<ConstExpression, int32>(5);
std::vector<int> shape = {3, 1};

std::vector<Expr> op2_element = {const_3, const_3, const_3};
std::vector<Expr> op3_element = {const_5, const_5, const_5};

auto op2 =
Expr::make<MatrixExpression>(op2_element, shape, PrimitiveType::i32);
op2->type_check(nullptr);
auto op3 =
Expr::make<MatrixExpression>(op2_element, shape, PrimitiveType::i32);
op3->type_check(nullptr);

auto ternary =
Expr::make<TernaryOpExpression>(TernaryOpType::select, cond, op2, op3);
ternary->type_check(nullptr);

auto ternary_expr = ternary.cast<TernaryOpExpression>();
auto cond_ret_type = ternary_expr->op1->ret_type;
auto op2_ret_type = ternary_expr->op2->ret_type;
auto op3_ret_type = ternary_expr->op3->ret_type;

EXPECT_TRUE(op2_ret_type->is<TensorType>() &&
op2_ret_type->cast<TensorType>()->get_shape() == shape);
EXPECT_TRUE(op3_ret_type->is<TensorType>() &&
op3_ret_type->cast<TensorType>()->get_shape() == shape);
EXPECT_EQ(cond_ret_type, PrimitiveType::i32);
}

TEST(FrontendTypeInference, GlobalPtr_Field) {
auto prog = std::make_unique<Program>(Arch::x64);
auto func = []() {};
Expand Down
17 changes: 17 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,23 @@ def test():
test()


@test_utils.test(debug=True)
def test_ternary_op_cond_is_scalar():
@ti.kernel
def test():
x = ti.Vector([3, 3, 3])
y = ti.Vector([5, 5, 5])

for i in range(10):
z = ti.select(i % 2, x, y)
if i % 2 == 1:
assert z[0] == x[0] and z[1] == x[1] and z[2] == x[2]
else:
assert z[0] == y[0] and z[1] == y[1] and z[2] == y[2]

test()


@test_utils.test(debug=True)
def test_fill_op():
@ti.kernel
Expand Down