Skip to content

Commit

Permalink
[Bug] Fix ret_type and cast_type of UnaryOpStmt in Scalarize (#7082)
Browse files Browse the repository at this point in the history
Issue: fix #6978

### Brief Summary

After #7068, The `taichi_ngp` example reports `RuntimeError:
[type_factory.cpp:promoted_type@222] Assertion failure:
a->is<TensorType>() && b->is<TensorType>()`. The root cause is that
`taichi_ngp` uses `atomic_sub` on a value with type different from the
destination, which will generate `neg` + `cast`. However, the `ret_type`
and `cast_type` of `UnaryOpStmt` produced by `Scalarize` aren't set
correctly. This PR fixes the problem.
  • Loading branch information
strongoier authored Jan 9, 2023
1 parent 734b483 commit b523daf
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions taichi/transforms/scalarize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,12 @@ class Scalarize : public BasicStmtVisitor {

std::vector<Stmt *> matrix_init_values;
int num_elements = operand_tensor_type->get_num_elements();
auto primitive_type = operand_tensor_type->get_element_type();
auto primitive_type = stmt->ret_type.get_element_type();
for (size_t i = 0; i < num_elements; i++) {
auto unary_stmt = std::make_unique<UnaryOpStmt>(
stmt->op_type, operand_matrix_init_stmt->values[i]);
if (stmt->is_cast()) {
unary_stmt->cast_type = stmt->cast_type;
unary_stmt->cast_type = stmt->cast_type.get_element_type();
}
unary_stmt->ret_type = primitive_type;
matrix_init_values.push_back(unary_stmt.get());
Expand Down
6 changes: 3 additions & 3 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,8 +1048,8 @@ def test_atomic_op_scalarize():
@ti.func
def func(x: ti.template()):
x[0] = [1., 2., 3.]
tmp = ti.Vector([3., 2., 1.])
z = ti.atomic_add(x[0], tmp)
tmp = ti.Vector([3, 2, 1])
z = ti.atomic_sub(x[0], tmp)
assert z[0] == 1.
assert z[1] == 2.
assert z[2] == 3.
Expand All @@ -1062,7 +1062,7 @@ def func(x: ti.template()):
assert g[2] == 1.

def verify(x):
assert (x[0] == [4., 4., 4.]).all()
assert (x[0] == [-2., 0., 2.]).all()
assert (x[1] == [3., 3., 3.]).all()

field = ti.Vector.field(n=3, dtype=ti.f32, shape=10)
Expand Down

0 comments on commit b523daf

Please sign in to comment.