diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 20bbb0533c841..356cc33342dfb 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -30,7 +30,7 @@ namespace irpass { void re_id(IRNode *root); void flag_access(IRNode *root); void eliminate_immutable_local_vars(IRNode *root); -void scalarize(IRNode *root, const CompileConfig &config); +void scalarize(IRNode *root); void lower_matrix_ptr(IRNode *root); bool die(IRNode *root); bool simplify(IRNode *root, const CompileConfig &config); diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 76a33b6c22dc8..d2c3c191a387b 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -56,7 +56,7 @@ void compile_to_offloads(IRNode *ir, print("Immutable local vars eliminated"); if (config.real_matrix_scalarize) { - irpass::scalarize(ir, config); + irpass::scalarize(ir); // Remove redundant MatrixInitStmt inserted during scalarization irpass::die(ir); @@ -342,7 +342,7 @@ void compile_function(IRNode *ir, } if (config.real_matrix_scalarize) { - irpass::scalarize(ir, config); + irpass::scalarize(ir); // Remove redundant MatrixInitStmt inserted during scalarization irpass::die(ir); diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 9c925b187c488..02755190ce10e 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -740,7 +740,7 @@ class ExtractLocalPointers : public BasicStmtVisitor { namespace irpass { -void scalarize(IRNode *root, const CompileConfig &config) { +void scalarize(IRNode *root) { TI_AUTO_PROF; Scalarize scalarize_pass(root); auto scalarizable_allocas = GatherScalarizableLocalPointers::run(root); diff --git a/tests/cpp/transforms/scalarize_test.cpp b/tests/cpp/transforms/scalarize_test.cpp index 3ac66de281054..058320a053406 100644 --- a/tests/cpp/transforms/scalarize_test.cpp +++ b/tests/cpp/transforms/scalarize_test.cpp @@ -44,7 +44,7 @@ TEST(Scalarize, ScalarizeGlobalStore) { block->push_back(dest_stmt, matrix_init_stmt); - irpass::scalarize(block.get(), test_prog.prog()->this_thread_config()); + irpass::scalarize(block.get()); irpass::lower_matrix_ptr(block.get()); irpass::die(block.get()); @@ -100,7 +100,7 @@ TEST(Scalarize, ScalarizeGlobalLoad) { // Without this GlobalStoreStmt, nothing survives irpass::die() block->push_back(src_stmt, load_stmt); - irpass::scalarize(block.get(), test_prog.prog()->this_thread_config()); + irpass::scalarize(block.get()); irpass::lower_matrix_ptr(block.get()); irpass::die(block.get()); @@ -160,7 +160,7 @@ TEST(Scalarize, ScalarizeLocalStore) { // LocalStoreStmt survives irpass::die() block->push_back(dest_stmt, matrix_init_stmt); - irpass::scalarize(block.get(), test_prog.prog()->this_thread_config()); + irpass::scalarize(block.get()); irpass::die(block.get()); EXPECT_EQ(block->size(), 2 /*const*/ + 4 /*alloca*/ + 4 /*store*/); @@ -207,7 +207,7 @@ TEST(Scalarize, ScalarizeLocalLoad) { // Without this GlobalStoreStmt, nothing survives irpass::die() block->push_back(src_stmt, load_stmt); - irpass::scalarize(block.get(), test_prog.prog()->this_thread_config()); + irpass::scalarize(block.get()); irpass::die(block.get()); EXPECT_EQ(block->size(), 4 /*alloca*/ + 4 /*load*/ + 4 /*store*/);