diff --git a/src/pass/data_parallel.cc b/src/pass/data_parallel.cc index 5176ad63..7c807a05 100644 --- a/src/pass/data_parallel.cc +++ b/src/pass/data_parallel.cc @@ -266,14 +266,15 @@ struct DataParallel { // we should add a allreduce op after it. static Op op_allreduce = Op::Get("mnm.op._allreduce"); auto input_var = mnm::ir::MakeVar("allreduce_in", {}); + auto rank_list = MakeConstant(TupleValue::make(Array({}))); #if defined MNM_USE_NCCL && NCCL_VERSION_CODE >= 21000 bp_ell->vars[p2 - 1] = input_var; bp_ell->exprs[p2 - 1] = Tuple({bp_ell->vars[i]}); // Here we name the var as 'g'(global gradient), to help us identify it easier. bp_ell->vars[p2] = mnm::ir::MakeVar("g", {}); - bp_ell->exprs[p2] = - Call(op_allreduce, {bp_ell->vars[p2 - 1], MakeConstant(StringValue::make("avg"))}); + bp_ell->exprs[p2] = Call(op_allreduce, {bp_ell->vars[p2 - 1], + MakeConstant(StringValue::make("avg")), rank_list}); var_var_map.insert({bp_ell->vars[i], bp_ell->vars[p2]}); p2 -= 2; #else @@ -282,7 +283,8 @@ struct DataParallel { bp_ell->exprs[p2 - 2] = Tuple({bp_ell->vars[i]}); bp_ell->vars[p2 - 1] = mnm::ir::MakeVar("g_sum", {}); bp_ell->exprs[p2 - 1] = - Call(op_allreduce, {bp_ell->vars[p2 - 2], MakeConstant(StringValue::make("sum"))}); + Call(op_allreduce, + {bp_ell->vars[p2 - 2], MakeConstant(StringValue::make("sum")), rank_list}); bp_ell->vars[p2] = mnm::ir::MakeVar("g", {}); auto tt = bp_ell->vars[i]->checked_type().as(); if (tt->dtype.code() == kDLFloat) { diff --git a/tests/python/pass/test_pass_data_parallel_schedule.py b/tests/python/pass/test_pass_data_parallel_schedule.py index 5c156a86..2195e75e 100644 --- a/tests/python/pass/test_pass_data_parallel_schedule.py +++ b/tests/python/pass/test_pass_data_parallel_schedule.py @@ -66,8 +66,8 @@ def fifo_expected(self): let %v2 = mnm.op.atan(%v); let %v3 = (%v1,); let %v4 = (%v2,); - let %v5 = mnm.op._allreduce(%v3, str"sum"); - let %v6 = mnm.op._allreduce(%v4, str"sum"); + let %v5 = mnm.op._allreduce(%v3, str"sum", TupleValue([])); + let %v6 = mnm.op._allreduce(%v4, str"sum", TupleValue([])); let %v7 = mnm.op.multiply(%v5, %c); let %v8 = mnm.op.multiply(%v6, %c); let %v9 = (%v7, %v8); @@ -87,9 +87,9 @@ def fifo_expected(self): x_2b = builder.make_tuple((x_1b,)) x_2a1 = builder.const("sum") - x_3a = builder.call("_allreduce", [x_2a, x_2a1]) + x_3a = builder.call("_allreduce", [x_2a, x_2a1, builder.const([])]) x_2a2 = builder.const("sum") - x_3b = builder.call("_allreduce", [x_2b, x_2a2]) + x_3b = builder.call("_allreduce", [x_2b, x_2a2, builder.const([])]) x_4a = builder.call("multiply", [x_3a, c]) x_4b = builder.call("multiply", [x_3b, c]) @@ -135,10 +135,10 @@ def fifo_expected(self): let %v3 = mnm.op.atan(%v1); let %v4 = (%v2,); let %v5 = mnm.op.atan(%v3); - let %v6 = mnm.op._allreduce(%v4, str"sum"); + let %v6 = mnm.op._allreduce(%v4, str"sum", TupleValue([])); let %v7 = mnm.op.atan(%v5); let %v8 = (%v7,); - let %v9 = mnm.op._allreduce(%v8, str"sum"); + let %v9 = mnm.op._allreduce(%v8, str"sum", TupleValue([])); let %v10 = mnm.op.multiply(%v6, %c); let %v11 = mnm.op.multiply(%v9, %c); let %v12 = (%v11, %v10); @@ -162,13 +162,13 @@ def version1(): x_3a = builder.call("atan", [x_2a]) x_2b1 = builder.const("sum") - x_3b = builder.call("_allreduce", [x_2b, x_2b1]) + x_3b = builder.call("_allreduce", [x_2b, x_2b1, builder.const([])]) x_4a = builder.call("atan", [x_3a]) # branch b is delayed since mul depdends on allreduce x_5a = builder.make_tuple((x_4a,)) x_5a1 = builder.const("sum") - x_6a = builder.call("_allreduce", [x_5a, x_5a1]) + x_6a = builder.call("_allreduce", [x_5a, x_5a1, builder.const([])]) # now launch update ops, branch b first x_4b = builder.call("multiply", [x_3b, c]) @@ -194,14 +194,14 @@ def version2(): x_2a = builder.call("atan", [x_1a]) x_2b1 = builder.const("sum") - x_3b = builder.call("_allreduce", [x_2b, x_2b1]) + x_3b = builder.call("_allreduce", [x_2b, x_2b1, builder.const([])]) x_3a = builder.call("atan", [x_2a]) # branch b is delayed since mul depdends on allreduce x_4a = builder.call("atan", [x_3a]) x_5a = builder.make_tuple((x_4a,)) x_5a1 = builder.const("sum") - x_6a = builder.call("_allreduce", [x_5a, x_5a1]) + x_6a = builder.call("_allreduce", [x_5a, x_5a1, builder.const([])]) # now launch update ops, branch b first x_4b = builder.call("multiply", [x_3b, c]) @@ -241,7 +241,7 @@ def fifo_expected(self): let %v = mnm.op.atan(%x); let %v1 = (%v,); let %v2 = mnm.op.atan(%v); - let %v3 = mnm.op._allreduce(%v1, str"sum"); + let %v3 = mnm.op._allreduce(%v1, str"sum", TupleValue([])); let %v4 = mnm.op.atan(%v2); let %v5 = mnm.op.atan(%v3); let %v6 = mnm.op.multiply(%v5, %v4); @@ -259,7 +259,7 @@ def version1(): a1_b = builder.call("atan", [a0]) a1_aii = builder.const("sum") - a1_a = builder.call("_allreduce", [a1_ai, a1_aii]) + a1_a = builder.call("_allreduce", [a1_ai, a1_aii, builder.const([])]) a2_b = builder.call("atan", [a1_b]) a2_a = builder.call("atan", [a1_a]) @@ -278,7 +278,7 @@ def version2(): a2_b = builder.call("atan", [a1_b]) a1_aii = builder.const("sum") - a1_a = builder.call("_allreduce", [a1_ai, a1_aii]) + a1_a = builder.call("_allreduce", [a1_ai, a1_aii, builder.const([])]) a2_a = builder.call("atan", [a1_a]) @@ -316,7 +316,7 @@ def fifo_expected(self): let %x_0 = mnm.op.atan(%x); let %x_1 = (%x_0,); let %x_2 = mnm.op.atan(%x_0); - let %x_3 = mnm.op._allreduce(%x_1, str"sum"); + let %x_3 = mnm.op._allreduce(%x_1, str"sum", TupleValue([])); let %x_4 = mnm.op.atan(%x_2); let %x_5 = mnm.op.atan(%x_4); let %x_6 = mnm.op.atan(%x_5); @@ -338,7 +338,7 @@ def version1(): a1_b = builder.call("atan", [a0]) a1_aii = builder.const("sum") - a1_a = builder.call("_allreduce", [a1_ai, a1_aii]) + a1_a = builder.call("_allreduce", [a1_ai, a1_aii, builder.const([])]) a2_b = builder.call("atan", [a1_b]) a3_b = builder.call("atan", [a2_b]) @@ -362,7 +362,7 @@ def version2(): a2_b = builder.call("atan", [a1_b]) a1_aii = builder.const("sum") - a1_a = builder.call("_allreduce", [a1_ai, a1_aii]) + a1_a = builder.call("_allreduce", [a1_ai, a1_aii, builder.const([])]) a3_b = builder.call("atan", [a2_b]) a4_b = builder.call("atan", [a3_b]) @@ -502,14 +502,14 @@ def fifo_expected(self): let %v = mnm.op.multiply(%x, %c); let %v1 = (%v,); let %v2 = mnm.op.atan(%v); - let %v3 = mnm.op._allreduce(%v1, str"sum"); + let %v3 = mnm.op._allreduce(%v1, str"sum", TupleValue([])); let %v4 = mnm.op.atan(%v2); let %v5 = mnm.op.atan(%v4); let %v6 = mnm.op.relu(%v3); let %v7 = mnm.op.multiply(%v6, %v5); let %v8 = (%v7,); let %v9 = mnm.op.atan(%v7); - let %v10 = mnm.op._allreduce(%v8, str"sum"); + let %v10 = mnm.op._allreduce(%v8, str"sum", TupleValue([])); let %v11 = mnm.op.atan(%v9); let %v12 = mnm.op.atan(%v11); let %v13 = mnm.op.relu(%v10); @@ -525,7 +525,7 @@ def version1(builder, input0, input1): a1_b = builder.call("atan", [a0]) a1_aii = builder.const("sum") - a1_a = builder.call("_allreduce", [a1_ai, a1_aii]) + a1_a = builder.call("_allreduce", [a1_ai, a1_aii, builder.const([])]) a2_b = builder.call("atan", [a1_b]) # a2_a is delayed after a3_b @@ -542,7 +542,7 @@ def version2(builder, input0, input1): a2_b = builder.call("atan", [a1_b]) a1_aii = builder.const("sum") - a1_a = builder.call("_allreduce", [a1_ai, a1_aii]) + a1_a = builder.call("_allreduce", [a1_ai, a1_aii, builder.const([])]) # a2_a is delayed after a3_b a3_b = builder.call("atan", [a2_b]) diff --git a/tests/python/pass/test_pass_enforce_sync.py b/tests/python/pass/test_pass_enforce_sync.py index f8aed745..a5ad72c4 100644 --- a/tests/python/pass/test_pass_enforce_sync.py +++ b/tests/python/pass/test_pass_enforce_sync.py @@ -47,7 +47,7 @@ def expected(): let %add_event_comp = mnm.op.add_event(int64(1), int64(1)); let %set_stream_comp1 = mnm.op.set_stream(int64(0), int64(1)); let %wait_for_comp = mnm.op.wait_event(int64(1), int64(4)); - let %a3 = mnm.op._allreduce(%a2, str"sum"); + let %a3 = mnm.op._allreduce(%a2, str"sum", TupleValue([])); let %add_event_comm = mnm.op.add_event(int64(2), int64(4)); let %set_stream_comp2 = mnm.op.set_stream(int64(0), int64(1)); let %wait_for_comm = mnm.op.wait_event(int64(2), int64(1)); @@ -63,7 +63,8 @@ def expected(): builder.add_event(1, comp_stream) builder.set_stream(0, comm_stream) builder.wait_event(1, comm_stream) - x_2 = builder.call("_allreduce", [x_2, mnm.ir.const("sum")]) + r = mnm.ir.const([]) + x_2 = builder.call("_allreduce", [x_2, mnm.ir.const("sum"), mnm.ir.const([])]) builder.add_event(2, comm_stream) builder.set_stream(0, comp_stream) builder.wait_event(2, comp_stream) @@ -117,7 +118,7 @@ def expected(): let %add_event_comp = mnm.op.add_event(int64(1), int64(1)); let %set_stream_comm = mnm.op.set_stream(int64(0), int64(4)); let %wait_for_comp = mnm.op.wait_event(int64(1), int64(4)); - let %a4 = mnm.op._allreduce(%a3, str"sum"); + let %a4 = mnm.op._allreduce(%a3, str"sum", TupleValue([])); let %add_event_comm = mnm.op.add_event(int64(3), int64(4)); let %set_stream_comp1 = mnm.op.set_stream(int64(0), int64(1)); let %wait_for_comm = mnm.op.wait_event(int64(3), int64(1)); @@ -127,7 +128,7 @@ def expected(): let %add_event_comp1 = mnm.op.add_event(int64(2), int64(1)); let %set_stream_comm1 = mnm.op.set_stream(int64(0), int64(4)); let %wait_for_comp1 = mnm.op.wait_event(int64(2), int64(4)); - let %a8 = mnm.op._allreduce(%a7, str"sum"); + let %a8 = mnm.op._allreduce(%a7, str"sum", TupleValue([])); let %add_event_comm1 = mnm.op.add_event(int64(4), int64(4)); let %set_stream_comp2 = mnm.op.set_stream(int64(0), int64(1)); let %wait_for_comm1 = mnm.op.wait_event(int64(4), int64(1)); @@ -148,7 +149,7 @@ def expected(): builder.add_event(1, comp_stream) builder.set_stream(0, comm_stream) builder.wait_event(1, comm_stream) - x_3 = builder.call("_allreduce", [x_3i, mnm.ir.const("sum")]) + x_3 = builder.call("_allreduce", [x_3i, mnm.ir.const("sum"), mnm.ir.const([])]) builder.add_event(2, comm_stream) builder.set_stream(0, comp_stream) builder.wait_event(2, comp_stream) @@ -159,7 +160,7 @@ def expected(): builder.add_event(3, comp_stream) builder.set_stream(0, comm_stream) builder.wait_event(3, comm_stream) - x_6 = builder.call("_allreduce", [x_6i, mnm.ir.const("sum")]) + x_6 = builder.call("_allreduce", [x_6i, mnm.ir.const("sum"), mnm.ir.const([])]) builder.add_event(4, comm_stream) builder.set_stream(0, comp_stream) builder.wait_event(4, comp_stream) @@ -380,7 +381,7 @@ def expected(): builder.add_event(1, comp_stream) builder.set_stream(0, comm_stream) builder.wait_event(1, comm_stream) - x_3 = builder.call("_allreduce", [x_2, mnm.ir.const("sum")]) + x_3 = builder.call("_allreduce", [x_2, mnm.ir.const("sum"), mnm.ir.const([])]) builder.add_event(2, comm_stream) builder.set_stream(0, comp_stream) builder.wait_event(2, comp_stream) @@ -432,7 +433,7 @@ def expected(): let %add_event_comp = mnm.op.add_event(int64(1), int64(1)); let %set_stream_comm = mnm.op.set_stream(int64(0), int64(4)); let %wait_for_comp = mnm.op.wait_event(int64(1), int64(4)); - let %a3 = mnm.op._allreduce(%a2, str"sum"); + let %a3 = mnm.op._allreduce(%a2, str"sum", TupleValue([])); let %add_event_comm = mnm.op.add_event(int64(2), int64(4)); let %set_stream_comp1 = mnm.op.set_stream(int64(0), int64(1)); let %wait_for_comm = mnm.op.wait_event(int64(2), int64(1)); @@ -451,7 +452,7 @@ def expected(): builder.add_event(1, comp_stream) builder.set_stream(0, comm_stream) builder.wait_event(1, comm_stream) - x_3 = builder.call("_allreduce", [x_2, mnm.ir.const("sum")]) + x_3 = builder.call("_allreduce", [x_2, mnm.ir.const("sum"), mnm.ir.const([])]) builder.add_event(4, comm_stream) builder.set_stream(0, comp_stream) builder.wait_event(4, comp_stream)