diff --git a/taichi/backends/cc/codegen_cc.cpp b/taichi/backends/cc/codegen_cc.cpp index f0fbf2150a43f..565d8e70926dc 100644 --- a/taichi/backends/cc/codegen_cc.cpp +++ b/taichi/backends/cc/codegen_cc.cpp @@ -533,6 +533,9 @@ class CCTransformer : public IRVisitor { void visit(AdStackAllocaStmt *stmt) override { TI_ASSERT(stmt->width() == 1); + TI_ASSERT_INFO( + stmt->max_size > 0, + "Adaptive autodiff stack's size should have been determined."); const auto &var_name = stmt->raw_name(); emit("Ti_u8 {}[{}];", var_name, stmt->size_in_bytes() + sizeof(uint32_t)); diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index d0c023a9cf56e..fa8403c791d4c 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -737,6 +737,9 @@ class KernelCodegenImpl : public IRVisitor { void visit(AdStackAllocaStmt *stmt) override { TI_ASSERT(stmt->width() == 1); + TI_ASSERT_INFO( + stmt->max_size > 0, + "Adaptive autodiff stack's size should have been determined."); const auto &var_name = stmt->raw_name(); emit("byte {}[{}];", var_name, stmt->size_in_bytes()); diff --git a/taichi/codegen/codegen_llvm.cpp b/taichi/codegen/codegen_llvm.cpp index 12cf050d2e3dc..7b6c6e5b6f5ed 100644 --- a/taichi/codegen/codegen_llvm.cpp +++ b/taichi/codegen/codegen_llvm.cpp @@ -1870,6 +1870,8 @@ void CodeGenLLVM::visit(InternalFuncStmt *stmt) { void CodeGenLLVM::visit(AdStackAllocaStmt *stmt) { TI_ASSERT(stmt->width() == 1); + TI_ASSERT_INFO(stmt->max_size > 0, + "Adaptive autodiff stack's size should have been determined."); auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(*llvm_context), stmt->size_in_bytes()); auto alloca = create_entry_block_alloca(type, sizeof(int64)); diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index c7dda4ff4d973..4852003e82158 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -1,6 +1,7 @@ #include "taichi/ir/control_flow_graph.h" #include +#include #include "taichi/ir/analysis.h" #include "taichi/ir/statements.h" @@ -879,4 +880,164 @@ std::unordered_set ControlFlowGraph::gather_loaded_snodes() { return snodes; } +void ControlFlowGraph::determine_ad_stack_size(int default_ad_stack_size) { + /** + * Determine all adaptive AD-stacks' necessary size using the Bellman-Ford + * algorithm. When there is a positive loop (#pushes > #pops in a loop) + * for an AD-stack, we cannot determine the size of the AD-stack, and + * |default_ad_stack_size| is used. The time complexity is + * O(num_statements + num_stacks * num_edges * num_nodes). + */ + const int num_nodes = size(); + + // max_increased_size[i][j] is the maximum number of (pushes - pops) of + // stack |i| among all prefixes of the CFGNode |j|. + std::unordered_map> max_increased_size; + + // increased_size[i][j] is the number of (pushes - pops) of stack |i| in + // the CFGNode |j|. + std::unordered_map> increased_size; + + std::unordered_map node_ids; + std::unordered_set all_stacks; + std::unordered_set indeterminable_stacks; + + for (int i = 0; i < num_nodes; i++) + node_ids[nodes[i].get()] = i; + + for (int i = 0; i < num_nodes; i++) { + for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { + Stmt *stmt = nodes[i]->block->statements[j].get(); + if (auto *stack = stmt->cast()) { + all_stacks.insert(stack); + max_increased_size.insert( + std::make_pair(stack, std::vector(num_nodes, 0))); + increased_size.insert( + std::make_pair(stack, std::vector(num_nodes, 0))); + } + } + } + + // For each basic block we compute the increase of stack size. This is a + // pre-processing step for the next maximum stack size determining algorithm. + for (int i = 0; i < num_nodes; i++) { + for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { + Stmt *stmt = nodes[i]->block->statements[j].get(); + if (auto *stack_push = stmt->cast()) { + auto *stack = stack_push->stack->as(); + if (stack->max_size == 0 /*adaptive*/) { + increased_size[stack][i]++; + if (increased_size[stack][i] > max_increased_size[stack][i]) { + max_increased_size[stack][i] = increased_size[stack][i]; + } + } + } else if (auto *stack_pop = stmt->cast()) { + auto *stack = stack_pop->stack->as(); + if (stack->max_size == 0 /*adaptive*/) { + increased_size[stack][i]--; + } + } + } + } + + // The maximum stack size determining algorithm -- run the Bellman-Ford + // algorithm on each AD-stack separately. + for (auto *stack : all_stacks) { + // The maximum size of |stack| among all control flows starting at the + // beginning of the IR. + int max_size = 0; + + // max_size_at_node_begin[j] is the maximum size of |stack| among + // all control flows starting at the beginning of the IR and ending at the + // beginning of the CFGNode |j|. Initialize this array to -1 to make sure + // that the first iteration of the Bellman-Ford algorithm fully updates + // this array. + std::vector max_size_at_node_begin(num_nodes, -1); + + // The queue for the Bellman-Ford algorithm. + std::queue to_visit; + + // An optimization for the Bellman-Ford algorithm. + std::vector in_queue(num_nodes); + + // An array for detecting positive loop in the Bellman-Ford algorithm. + std::vector times_pushed_in_queue(num_nodes, 0); + + max_size_at_node_begin[start_node] = 0; + to_visit.push(start_node); + in_queue[start_node] = true; + times_pushed_in_queue[start_node]++; + + bool has_positive_loop = false; + + // The Bellman-Ford algorithm. + while (!to_visit.empty()) { + int node_id = to_visit.front(); + to_visit.pop(); + in_queue[node_id] = false; + CFGNode *now = nodes[node_id].get(); + + // Inside this CFGNode -- update the answer |max_size| + const auto max_size_inside_this_node = max_increased_size[stack][node_id]; + const auto current_max_size = + max_size_at_node_begin[node_id] + max_size_inside_this_node; + if (current_max_size > max_size) { + max_size = current_max_size; + } + // At the end of this CFGNode -- update the state + // |max_size_at_node_begin| of other CFGNodes + const auto increase_in_this_node = increased_size[stack][node_id]; + const auto current_size = + max_size_at_node_begin[node_id] + increase_in_this_node; + for (auto *next_node : now->next) { + int next_node_id = node_ids[next_node]; + if (current_size > max_size_at_node_begin[next_node_id]) { + max_size_at_node_begin[next_node_id] = current_size; + if (!in_queue[next_node_id]) { + if (times_pushed_in_queue[next_node_id] <= num_nodes) { + to_visit.push(next_node_id); + in_queue[next_node_id] = true; + times_pushed_in_queue[next_node_id]++; + } else { + // A positive loop is found because a node is going to be pushed + // into the queue the (num_nodes + 1)-th time. + has_positive_loop = true; + break; + } + } + } + } + if (has_positive_loop) { + break; + } + } + + if (has_positive_loop) { + stack->max_size = default_ad_stack_size; + indeterminable_stacks.insert(stack); + } else { + // Since we use |max_size| == 0 for adaptive sizes, we do not want stacks + // with maximum capacity indeed equal to 0. + TI_WARN_IF(max_size == 0, + "Unused autodiff stack {} should have been eliminated.", + stack->name()); + stack->max_size = max_size; + } + } + + // Print a debug message if we have indeterminable AD-stacks' sizes. + if (!indeterminable_stacks.empty()) { + std::vector indeterminable_stacks_name; + indeterminable_stacks_name.reserve(indeterminable_stacks.size()); + for (auto &stack : indeterminable_stacks) { + indeterminable_stacks_name.push_back(stack->name()); + } + TI_DEBUG( + "Unable to determine the necessary size for autodiff stacks [{}]. " + "Use " + "configured size (CompileConfig::default_ad_stack_size) {} instead.", + fmt::join(indeterminable_stacks_name, ", "), default_ad_stack_size); + } +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index 8bf6cd3773aea..09f5b3f9ecbb7 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -158,6 +158,13 @@ class ControlFlowGraph { * task. */ std::unordered_set gather_loaded_snodes(); + + /** + * Determine all adaptive AD-stacks' necessary size. + * @param default_ad_stack_size The default AD-stack's size when we are + * unable to determine some AD-stack's size. + */ + void determine_ad_stack_size(int default_ad_stack_size); }; TLANG_NAMESPACE_END diff --git a/taichi/ir/ir_builder.cpp b/taichi/ir/ir_builder.cpp index f3a06fd9ff180..b0332b9477bfc 100644 --- a/taichi/ir/ir_builder.cpp +++ b/taichi/ir/ir_builder.cpp @@ -418,4 +418,31 @@ ExternalPtrStmt *IRBuilder::create_external_ptr( return insert(Stmt::make_typed(ptr, indices)); } +AdStackAllocaStmt *IRBuilder::create_ad_stack(const DataType &dt, + std::size_t max_size) { + return insert(Stmt::make_typed(dt, max_size)); +} + +void IRBuilder::ad_stack_push(AdStackAllocaStmt *stack, Stmt *val) { + insert(Stmt::make_typed(stack, val)); +} + +void IRBuilder::ad_stack_pop(AdStackAllocaStmt *stack) { + insert(Stmt::make_typed(stack)); +} + +AdStackLoadTopStmt *IRBuilder::ad_stack_load_top(AdStackAllocaStmt *stack) { + return insert(Stmt::make_typed(stack)); +} + +AdStackLoadTopAdjStmt *IRBuilder::ad_stack_load_top_adjoint( + AdStackAllocaStmt *stack) { + return insert(Stmt::make_typed(stack)); +} + +void IRBuilder::ad_stack_accumulate_adjoint(AdStackAllocaStmt *stack, + Stmt *val) { + insert(Stmt::make_typed(stack, val)); +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index 602b0ac382371..ca90b358a3760 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -248,6 +248,14 @@ class IRBuilder { } } + // Autodiff stack operations. + AdStackAllocaStmt *create_ad_stack(const DataType &dt, std::size_t max_size); + void ad_stack_push(AdStackAllocaStmt *stack, Stmt *val); + void ad_stack_pop(AdStackAllocaStmt *stack); + AdStackLoadTopStmt *ad_stack_load_top(AdStackAllocaStmt *stack); + AdStackLoadTopAdjStmt *ad_stack_load_top_adjoint(AdStackAllocaStmt *stack); + void ad_stack_accumulate_adjoint(AdStackAllocaStmt *stack, Stmt *val); + private: std::unique_ptr root_{nullptr}; InsertPoint insert_point_; diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index f54b11ab1c293..5de2b0c4bf32e 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1222,7 +1222,7 @@ class InternalFuncStmt : public Stmt { class AdStackAllocaStmt : public Stmt { public: DataType dt; - std::size_t max_size; // TODO: 0 = adaptive + std::size_t max_size{0}; // 0 = adaptive AdStackAllocaStmt(const DataType &dt, std::size_t max_size) : dt(dt), max_size(max_size) { diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 6b55a79dc4c67..5b6059916f23a 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -71,6 +71,13 @@ bool lower_access(IRNode *root, void auto_diff(IRNode *root, const CompileConfig &config, bool use_stack = false); +/** + * Determine all adaptive AD-stacks' size. This pass is idempotent, i.e., + * there are no side effects if called more than once or called when not needed. + * @return Whether the IR is modified, i.e., whether there exists adaptive + * AD-stacks before this pass. + */ +bool determine_ad_stack_size(IRNode *root, const CompileConfig &config); bool constant_fold(IRNode *root, const CompileConfig &config, const ConstantFoldPass::Args &args); @@ -124,6 +131,7 @@ void offload_to_executable(IRNode *ir, const CompileConfig &config, Kernel *kernel, bool verbose, + bool determine_ad_stack_size, bool lower_global_access, bool make_thread_local, bool make_block_local); diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 2080149e447a4..7dcb4b37ad616 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -154,6 +154,7 @@ void ExecutionQueue::enqueue(const TaskLaunchRecord &ker) { auto ir = stmt; offload_to_executable( ir, config, kernel, /*verbose=*/false, + /*determine_ad_stack_size=*/true, /*lower_global_access=*/true, /*make_thread_local=*/true, /*make_block_local=*/ diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index 0a014dd2dd437..a06805f489f8b 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -45,8 +45,6 @@ CompileConfig::CompileConfig() { cpu_max_num_threads = std::thread::hardware_concurrency(); random_seed = 0; - ad_stack_size = 16; - // LLVM backend options: print_struct_llvm_ir = false; print_kernel_llvm_ir = false; diff --git a/taichi/program/compile_config.h b/taichi/program/compile_config.h index ba2a7784af40a..e42d4e26563b1 100644 --- a/taichi/program/compile_config.h +++ b/taichi/program/compile_config.h @@ -42,7 +42,10 @@ struct CompileConfig { int default_cpu_block_dim; int default_gpu_block_dim; int gpu_max_reg; - int ad_stack_size; + int ad_stack_size{0}; // 0 = adaptive + // The default size when the Taichi compiler is unable to automatically + // determine the autodiff stack size. + int default_ad_stack_size{32}; int saturating_grid_dim; int max_block_dim; diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 67bea20d2797b..367b7f2cb322c 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -145,6 +145,7 @@ void offload_to_executable(IRNode *ir, const CompileConfig &config, Kernel *kernel, bool verbose, + bool determine_ad_stack_size, bool lower_global_access, bool make_thread_local, bool make_block_local) { @@ -224,6 +225,11 @@ void offload_to_executable(IRNode *ir, irpass::full_simplify(ir, config, {lower_global_access, kernel->program}); print("Simplified IV"); + if (determine_ad_stack_size) { + irpass::determine_ad_stack_size(ir, config); + print("Autodiff stack size determined"); + } + if (is_extension_supported(config.arch, Extension::quant)) { irpass::optimize_bit_struct_stores(ir, config, amgr.get()); print("Bit struct stores optimized"); @@ -250,8 +256,10 @@ void compile_to_executable(IRNode *ir, compile_to_offloads(ir, config, kernel, verbose, vectorize, grad, ad_use_stack, start_from_ast); - offload_to_executable(ir, config, kernel, verbose, lower_global_access, - make_thread_local, make_block_local); + offload_to_executable(ir, config, kernel, verbose, + /*determine_ad_stack_size=*/grad && ad_use_stack, + lower_global_access, make_thread_local, + make_block_local); } void compile_inline_function(IRNode *ir, diff --git a/taichi/transforms/determine_ad_stack_size.cpp b/taichi/transforms/determine_ad_stack_size.cpp new file mode 100644 index 0000000000000..8083303e0be8b --- /dev/null +++ b/taichi/transforms/determine_ad_stack_size.cpp @@ -0,0 +1,33 @@ +#include "taichi/ir/analysis.h" +#include "taichi/ir/control_flow_graph.h" +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" + +#include +#include + +namespace taichi { +namespace lang { + +namespace irpass { + +bool determine_ad_stack_size(IRNode *root, const CompileConfig &config) { + if (irpass::analysis::gather_statements(root, [&](Stmt *s) { + if (auto ad_stack = s->cast()) { + return ad_stack->max_size == 0; // adaptive + } + return false; + }).empty()) { + return false; // no AD-stacks with adaptive size + } + auto cfg = analysis::build_cfg(root); + cfg->simplify_graph(); + cfg->determine_ad_stack_size(config.default_ad_stack_size); + return true; +} + +} // namespace irpass + +} // namespace lang +} // namespace taichi diff --git a/tests/cpp/transforms/determine_ad_stack_size_test.cpp b/tests/cpp/transforms/determine_ad_stack_size_test.cpp new file mode 100644 index 0000000000000..69d14a06405fa --- /dev/null +++ b/tests/cpp/transforms/determine_ad_stack_size_test.cpp @@ -0,0 +1,187 @@ +#include "gtest/gtest.h" + +#include "taichi/ir/analysis.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/ir_builder.h" +#include "taichi/ir/transforms.h" + +namespace taichi { +namespace lang { + +class DetermineAdStackSizeTest + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + prog_ = std::make_unique(); + prog_->materialize_runtime(); + } + + std::unique_ptr prog_; +}; + +TEST_F(DetermineAdStackSizeTest, Basic) { + IRBuilder builder; + auto *stack = + builder.create_ad_stack(get_data_type(), 0 /*adaptive size*/); + builder.ad_stack_push(stack, builder.get_int32(1)); + builder.ad_stack_push(stack, builder.get_int32(2)); + builder.ad_stack_push(stack, builder.get_int32(3)); + builder.ad_stack_pop(stack); + builder.ad_stack_pop(stack); + builder.ad_stack_push(stack, builder.get_int32(4)); + builder.ad_stack_push(stack, builder.get_int32(5)); + builder.ad_stack_push(stack, builder.get_int32(6)); + // stack contains [1, 4, 5, 6] now + builder.ad_stack_pop(stack); + builder.ad_stack_pop(stack); + builder.ad_stack_push(stack, builder.get_int32(7)); + + auto *stack2 = + builder.create_ad_stack(get_data_type(), 0 /*adaptive size*/); + builder.ad_stack_push(stack2, builder.get_int32(8)); + + auto ir = builder.extract_ir(); + ASSERT_TRUE(ir->is()); + auto *ir_block = ir->as(); + irpass::type_check(ir_block, CompileConfig()); + + EXPECT_EQ(stack->max_size, 0); + EXPECT_EQ(stack2->max_size, 0); + irpass::determine_ad_stack_size(ir_block, CompileConfig()); + EXPECT_EQ(stack->max_size, 4); + EXPECT_EQ(stack2->max_size, 1); +} + +TEST_F(DetermineAdStackSizeTest, Loop) { + IRBuilder builder; + auto *stack = + builder.create_ad_stack(get_data_type(), 0 /*adaptive size*/); + auto *loop = builder.create_range_for(/*begin=*/builder.get_int32(0), + /*end=*/builder.get_int32(10)); + { + auto _ = builder.get_loop_guard(loop); + builder.ad_stack_push(stack, builder.get_int32(1)); + builder.ad_stack_pop(stack); + } + + auto ir = builder.extract_ir(); + ASSERT_TRUE(ir->is()); + auto *ir_block = ir->as(); + irpass::type_check(ir_block, CompileConfig()); + + EXPECT_EQ(stack->max_size, 0); + irpass::determine_ad_stack_size(ir_block, CompileConfig()); + EXPECT_EQ(stack->max_size, 1); +} + +TEST_F(DetermineAdStackSizeTest, LoopInfeasible) { + IRBuilder builder; + auto *stack = + builder.create_ad_stack(get_data_type(), 0 /*adaptive size*/); + auto *loop = builder.create_range_for(/*begin=*/builder.get_int32(0), + /*end=*/builder.get_int32(100)); + { + auto _ = builder.get_loop_guard(loop); + builder.ad_stack_push(stack, builder.get_int32(1)); + } + + auto ir = builder.extract_ir(); + ASSERT_TRUE(ir->is()); + auto *ir_block = ir->as(); + irpass::type_check(ir_block, CompileConfig()); + + CompileConfig config; + constexpr int kDefaultAdStackSize = 32; + config.default_ad_stack_size = kDefaultAdStackSize; + EXPECT_EQ(stack->max_size, 0); + // Should have a debug message here (unable to determine the necessary size + // for autodiff stacks). + irpass::determine_ad_stack_size(ir_block, config); + EXPECT_EQ(stack->max_size, kDefaultAdStackSize); +} + +TEST_P(DetermineAdStackSizeTest, If) { + constexpr int kCommonPushes = 1; + const int kTrueBranchPushes = std::get<0>(GetParam()); + const int kFalseBranchPushes = std::get<1>(GetParam()); + bool has_true_branch = (kTrueBranchPushes > 0); + bool has_false_branch = (kFalseBranchPushes > 0); + + IRBuilder builder; + auto *arg = builder.create_arg_load(0, get_data_type(), false); + auto *stack = + builder.create_ad_stack(get_data_type(), 0 /*adaptive size*/); + auto *if_stmt = builder.create_if(arg); + auto *one = builder.get_int32(1); + for (int i = 1; i <= kCommonPushes; i++) { + builder.ad_stack_push(stack, one); // Make sure the stack is not unused + } + if (has_true_branch) { + auto _ = builder.get_if_guard(if_stmt, true); + for (int i = 1; i <= kTrueBranchPushes; i++) { + builder.ad_stack_push(stack, one); + } + } + if (has_false_branch) { + auto _ = builder.get_if_guard(if_stmt, false); + for (int i = 1; i <= kFalseBranchPushes; i++) { + builder.ad_stack_push(stack, one); + } + } + + auto ir = builder.extract_ir(); + ASSERT_TRUE(ir->is()); + auto *ir_block = ir->as(); + irpass::type_check(ir_block, CompileConfig()); + EXPECT_EQ(irpass::analysis::count_statements(ir_block), + 4 /*arg_load, stack, if, one*/ + kCommonPushes + + has_true_branch * kTrueBranchPushes + + has_false_branch * kFalseBranchPushes); + + EXPECT_EQ(stack->max_size, 0); + irpass::determine_ad_stack_size(ir_block, CompileConfig()); + EXPECT_EQ(stack->max_size, + kCommonPushes + std::max(has_true_branch * kTrueBranchPushes, + has_false_branch * kFalseBranchPushes)); +} + +INSTANTIATE_TEST_SUITE_P( + Parameterized, + DetermineAdStackSizeTest, + testing::Combine(testing::Values(0, 3), testing::Values(0, 4)), + [](const testing::TestParamInfo + &info) { + return fmt::format("True{}_False{}", std::get<0>(info.param), + std::get<1>(info.param)); + }); + +TEST_F(DetermineAdStackSizeTest, EmptyNodes) { + IRBuilder builder; + auto *arg = builder.create_arg_load(0, get_data_type(), false); + auto *stack = + builder.create_ad_stack(get_data_type(), 0 /*adaptive size*/); + auto *one = builder.get_int32(1); + builder.ad_stack_push(stack, one); // stack contains [1] now + auto *if_stmt = builder.create_if(arg); + { + auto _ = builder.get_if_guard(if_stmt, true); + builder.get_int32(2); // avoid CFGNode being deleted + } + { + auto _ = builder.get_if_guard(if_stmt, false); + builder.get_int32(3); // avoid CFGNode being deleted + } + builder.ad_stack_push(stack, one); // stack contains [1, 1] now + + auto ir = builder.extract_ir(); + ASSERT_TRUE(ir->is()); + auto *ir_block = ir->as(); + irpass::type_check(ir_block, CompileConfig()); + + EXPECT_EQ(stack->max_size, 0); + irpass::determine_ad_stack_size(ir_block, CompileConfig()); + EXPECT_EQ(stack->max_size, 2); +} + +} // namespace lang +} // namespace taichi