Skip to content

Commit

Permalink
fix gpu test
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonny-Gu committed Feb 9, 2022
1 parent 6b1bc9b commit b676af6
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 32 deletions.
8 changes: 5 additions & 3 deletions src/pass/data_parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,15 @@ struct DataParallel {
// we should add a allreduce op after it.
static Op op_allreduce = Op::Get("mnm.op._allreduce");
auto input_var = mnm::ir::MakeVar("allreduce_in", {});
auto rank_list = MakeConstant(TupleValue::make(Array<Value>({})));

#if defined MNM_USE_NCCL && NCCL_VERSION_CODE >= 21000
bp_ell->vars[p2 - 1] = input_var;
bp_ell->exprs[p2 - 1] = Tuple({bp_ell->vars[i]});
// Here we name the var as 'g'(global gradient), to help us identify it easier.
bp_ell->vars[p2] = mnm::ir::MakeVar("g", {});
bp_ell->exprs[p2] =
Call(op_allreduce, {bp_ell->vars[p2 - 1], MakeConstant(StringValue::make("avg"))});
bp_ell->exprs[p2] = Call(op_allreduce, {bp_ell->vars[p2 - 1],
MakeConstant(StringValue::make("avg")), rank_list});
var_var_map.insert({bp_ell->vars[i], bp_ell->vars[p2]});
p2 -= 2;
#else
Expand All @@ -282,7 +283,8 @@ struct DataParallel {
bp_ell->exprs[p2 - 2] = Tuple({bp_ell->vars[i]});
bp_ell->vars[p2 - 1] = mnm::ir::MakeVar("g_sum", {});
bp_ell->exprs[p2 - 1] =
Call(op_allreduce, {bp_ell->vars[p2 - 2], MakeConstant(StringValue::make("sum"))});
Call(op_allreduce,
{bp_ell->vars[p2 - 2], MakeConstant(StringValue::make("sum")), rank_list});
bp_ell->vars[p2] = mnm::ir::MakeVar("g", {});
auto tt = bp_ell->vars[i]->checked_type().as<TensorTypeNode>();
if (tt->dtype.code() == kDLFloat) {
Expand Down
40 changes: 20 additions & 20 deletions tests/python/pass/test_pass_data_parallel_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def fifo_expected(self):
let %v2 = mnm.op.atan(%v);
let %v3 = (%v1,);
let %v4 = (%v2,);
let %v5 = mnm.op._allreduce(%v3, str"sum");
let %v6 = mnm.op._allreduce(%v4, str"sum");
let %v5 = mnm.op._allreduce(%v3, str"sum", TupleValue([]));
let %v6 = mnm.op._allreduce(%v4, str"sum", TupleValue([]));
let %v7 = mnm.op.multiply(%v5, %c);
let %v8 = mnm.op.multiply(%v6, %c);
let %v9 = (%v7, %v8);
Expand All @@ -87,9 +87,9 @@ def fifo_expected(self):
x_2b = builder.make_tuple((x_1b,))

x_2a1 = builder.const("sum")
x_3a = builder.call("_allreduce", [x_2a, x_2a1])
x_3a = builder.call("_allreduce", [x_2a, x_2a1, builder.const([])])
x_2a2 = builder.const("sum")
x_3b = builder.call("_allreduce", [x_2b, x_2a2])
x_3b = builder.call("_allreduce", [x_2b, x_2a2, builder.const([])])

x_4a = builder.call("multiply", [x_3a, c])
x_4b = builder.call("multiply", [x_3b, c])
Expand Down Expand Up @@ -135,10 +135,10 @@ def fifo_expected(self):
let %v3 = mnm.op.atan(%v1);
let %v4 = (%v2,);
let %v5 = mnm.op.atan(%v3);
let %v6 = mnm.op._allreduce(%v4, str"sum");
let %v6 = mnm.op._allreduce(%v4, str"sum", TupleValue([]));
let %v7 = mnm.op.atan(%v5);
let %v8 = (%v7,);
let %v9 = mnm.op._allreduce(%v8, str"sum");
let %v9 = mnm.op._allreduce(%v8, str"sum", TupleValue([]));
let %v10 = mnm.op.multiply(%v6, %c);
let %v11 = mnm.op.multiply(%v9, %c);
let %v12 = (%v11, %v10);
Expand All @@ -162,13 +162,13 @@ def version1():

x_3a = builder.call("atan", [x_2a])
x_2b1 = builder.const("sum")
x_3b = builder.call("_allreduce", [x_2b, x_2b1])
x_3b = builder.call("_allreduce", [x_2b, x_2b1, builder.const([])])

x_4a = builder.call("atan", [x_3a])
# branch b is delayed since mul depdends on allreduce
x_5a = builder.make_tuple((x_4a,))
x_5a1 = builder.const("sum")
x_6a = builder.call("_allreduce", [x_5a, x_5a1])
x_6a = builder.call("_allreduce", [x_5a, x_5a1, builder.const([])])

# now launch update ops, branch b first
x_4b = builder.call("multiply", [x_3b, c])
Expand All @@ -194,14 +194,14 @@ def version2():
x_2a = builder.call("atan", [x_1a])

x_2b1 = builder.const("sum")
x_3b = builder.call("_allreduce", [x_2b, x_2b1])
x_3b = builder.call("_allreduce", [x_2b, x_2b1, builder.const([])])
x_3a = builder.call("atan", [x_2a])

# branch b is delayed since mul depdends on allreduce
x_4a = builder.call("atan", [x_3a])
x_5a = builder.make_tuple((x_4a,))
x_5a1 = builder.const("sum")
x_6a = builder.call("_allreduce", [x_5a, x_5a1])
x_6a = builder.call("_allreduce", [x_5a, x_5a1, builder.const([])])

# now launch update ops, branch b first
x_4b = builder.call("multiply", [x_3b, c])
Expand Down Expand Up @@ -241,7 +241,7 @@ def fifo_expected(self):
let %v = mnm.op.atan(%x);
let %v1 = (%v,);
let %v2 = mnm.op.atan(%v);
let %v3 = mnm.op._allreduce(%v1, str"sum");
let %v3 = mnm.op._allreduce(%v1, str"sum", TupleValue([]));
let %v4 = mnm.op.atan(%v2);
let %v5 = mnm.op.atan(%v3);
let %v6 = mnm.op.multiply(%v5, %v4);
Expand All @@ -259,7 +259,7 @@ def version1():
a1_b = builder.call("atan", [a0])

a1_aii = builder.const("sum")
a1_a = builder.call("_allreduce", [a1_ai, a1_aii])
a1_a = builder.call("_allreduce", [a1_ai, a1_aii, builder.const([])])
a2_b = builder.call("atan", [a1_b])

a2_a = builder.call("atan", [a1_a])
Expand All @@ -278,7 +278,7 @@ def version2():

a2_b = builder.call("atan", [a1_b])
a1_aii = builder.const("sum")
a1_a = builder.call("_allreduce", [a1_ai, a1_aii])
a1_a = builder.call("_allreduce", [a1_ai, a1_aii, builder.const([])])

a2_a = builder.call("atan", [a1_a])

Expand Down Expand Up @@ -316,7 +316,7 @@ def fifo_expected(self):
let %x_0 = mnm.op.atan(%x);
let %x_1 = (%x_0,);
let %x_2 = mnm.op.atan(%x_0);
let %x_3 = mnm.op._allreduce(%x_1, str"sum");
let %x_3 = mnm.op._allreduce(%x_1, str"sum", TupleValue([]));
let %x_4 = mnm.op.atan(%x_2);
let %x_5 = mnm.op.atan(%x_4);
let %x_6 = mnm.op.atan(%x_5);
Expand All @@ -338,7 +338,7 @@ def version1():
a1_b = builder.call("atan", [a0])

a1_aii = builder.const("sum")
a1_a = builder.call("_allreduce", [a1_ai, a1_aii])
a1_a = builder.call("_allreduce", [a1_ai, a1_aii, builder.const([])])
a2_b = builder.call("atan", [a1_b])

a3_b = builder.call("atan", [a2_b])
Expand All @@ -362,7 +362,7 @@ def version2():

a2_b = builder.call("atan", [a1_b])
a1_aii = builder.const("sum")
a1_a = builder.call("_allreduce", [a1_ai, a1_aii])
a1_a = builder.call("_allreduce", [a1_ai, a1_aii, builder.const([])])

a3_b = builder.call("atan", [a2_b])
a4_b = builder.call("atan", [a3_b])
Expand Down Expand Up @@ -502,14 +502,14 @@ def fifo_expected(self):
let %v = mnm.op.multiply(%x, %c);
let %v1 = (%v,);
let %v2 = mnm.op.atan(%v);
let %v3 = mnm.op._allreduce(%v1, str"sum");
let %v3 = mnm.op._allreduce(%v1, str"sum", TupleValue([]));
let %v4 = mnm.op.atan(%v2);
let %v5 = mnm.op.atan(%v4);
let %v6 = mnm.op.relu(%v3);
let %v7 = mnm.op.multiply(%v6, %v5);
let %v8 = (%v7,);
let %v9 = mnm.op.atan(%v7);
let %v10 = mnm.op._allreduce(%v8, str"sum");
let %v10 = mnm.op._allreduce(%v8, str"sum", TupleValue([]));
let %v11 = mnm.op.atan(%v9);
let %v12 = mnm.op.atan(%v11);
let %v13 = mnm.op.relu(%v10);
Expand All @@ -525,7 +525,7 @@ def version1(builder, input0, input1):
a1_b = builder.call("atan", [a0])

a1_aii = builder.const("sum")
a1_a = builder.call("_allreduce", [a1_ai, a1_aii])
a1_a = builder.call("_allreduce", [a1_ai, a1_aii, builder.const([])])
a2_b = builder.call("atan", [a1_b])

# a2_a is delayed after a3_b
Expand All @@ -542,7 +542,7 @@ def version2(builder, input0, input1):

a2_b = builder.call("atan", [a1_b])
a1_aii = builder.const("sum")
a1_a = builder.call("_allreduce", [a1_ai, a1_aii])
a1_a = builder.call("_allreduce", [a1_ai, a1_aii, builder.const([])])

# a2_a is delayed after a3_b
a3_b = builder.call("atan", [a2_b])
Expand Down
19 changes: 10 additions & 9 deletions tests/python/pass/test_pass_enforce_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def expected():
let %add_event_comp = mnm.op.add_event(int64(1), int64(1));
let %set_stream_comp1 = mnm.op.set_stream(int64(0), int64(1));
let %wait_for_comp = mnm.op.wait_event(int64(1), int64(4));
let %a3 = mnm.op._allreduce(%a2, str"sum");
let %a3 = mnm.op._allreduce(%a2, str"sum", TupleValue([]));
let %add_event_comm = mnm.op.add_event(int64(2), int64(4));
let %set_stream_comp2 = mnm.op.set_stream(int64(0), int64(1));
let %wait_for_comm = mnm.op.wait_event(int64(2), int64(1));
Expand All @@ -63,7 +63,8 @@ def expected():
builder.add_event(1, comp_stream)
builder.set_stream(0, comm_stream)
builder.wait_event(1, comm_stream)
x_2 = builder.call("_allreduce", [x_2, mnm.ir.const("sum")])
r = mnm.ir.const([])
x_2 = builder.call("_allreduce", [x_2, mnm.ir.const("sum"), mnm.ir.const([])])
builder.add_event(2, comm_stream)
builder.set_stream(0, comp_stream)
builder.wait_event(2, comp_stream)
Expand Down Expand Up @@ -117,7 +118,7 @@ def expected():
let %add_event_comp = mnm.op.add_event(int64(1), int64(1));
let %set_stream_comm = mnm.op.set_stream(int64(0), int64(4));
let %wait_for_comp = mnm.op.wait_event(int64(1), int64(4));
let %a4 = mnm.op._allreduce(%a3, str"sum");
let %a4 = mnm.op._allreduce(%a3, str"sum", TupleValue([]));
let %add_event_comm = mnm.op.add_event(int64(3), int64(4));
let %set_stream_comp1 = mnm.op.set_stream(int64(0), int64(1));
let %wait_for_comm = mnm.op.wait_event(int64(3), int64(1));
Expand All @@ -127,7 +128,7 @@ def expected():
let %add_event_comp1 = mnm.op.add_event(int64(2), int64(1));
let %set_stream_comm1 = mnm.op.set_stream(int64(0), int64(4));
let %wait_for_comp1 = mnm.op.wait_event(int64(2), int64(4));
let %a8 = mnm.op._allreduce(%a7, str"sum");
let %a8 = mnm.op._allreduce(%a7, str"sum", TupleValue([]));
let %add_event_comm1 = mnm.op.add_event(int64(4), int64(4));
let %set_stream_comp2 = mnm.op.set_stream(int64(0), int64(1));
let %wait_for_comm1 = mnm.op.wait_event(int64(4), int64(1));
Expand All @@ -148,7 +149,7 @@ def expected():
builder.add_event(1, comp_stream)
builder.set_stream(0, comm_stream)
builder.wait_event(1, comm_stream)
x_3 = builder.call("_allreduce", [x_3i, mnm.ir.const("sum")])
x_3 = builder.call("_allreduce", [x_3i, mnm.ir.const("sum"), mnm.ir.const([])])
builder.add_event(2, comm_stream)
builder.set_stream(0, comp_stream)
builder.wait_event(2, comp_stream)
Expand All @@ -159,7 +160,7 @@ def expected():
builder.add_event(3, comp_stream)
builder.set_stream(0, comm_stream)
builder.wait_event(3, comm_stream)
x_6 = builder.call("_allreduce", [x_6i, mnm.ir.const("sum")])
x_6 = builder.call("_allreduce", [x_6i, mnm.ir.const("sum"), mnm.ir.const([])])
builder.add_event(4, comm_stream)
builder.set_stream(0, comp_stream)
builder.wait_event(4, comp_stream)
Expand Down Expand Up @@ -380,7 +381,7 @@ def expected():
builder.add_event(1, comp_stream)
builder.set_stream(0, comm_stream)
builder.wait_event(1, comm_stream)
x_3 = builder.call("_allreduce", [x_2, mnm.ir.const("sum")])
x_3 = builder.call("_allreduce", [x_2, mnm.ir.const("sum"), mnm.ir.const([])])
builder.add_event(2, comm_stream)
builder.set_stream(0, comp_stream)
builder.wait_event(2, comp_stream)
Expand Down Expand Up @@ -432,7 +433,7 @@ def expected():
let %add_event_comp = mnm.op.add_event(int64(1), int64(1));
let %set_stream_comm = mnm.op.set_stream(int64(0), int64(4));
let %wait_for_comp = mnm.op.wait_event(int64(1), int64(4));
let %a3 = mnm.op._allreduce(%a2, str"sum");
let %a3 = mnm.op._allreduce(%a2, str"sum", TupleValue([]));
let %add_event_comm = mnm.op.add_event(int64(2), int64(4));
let %set_stream_comp1 = mnm.op.set_stream(int64(0), int64(1));
let %wait_for_comm = mnm.op.wait_event(int64(2), int64(1));
Expand All @@ -451,7 +452,7 @@ def expected():
builder.add_event(1, comp_stream)
builder.set_stream(0, comm_stream)
builder.wait_event(1, comm_stream)
x_3 = builder.call("_allreduce", [x_2, mnm.ir.const("sum")])
x_3 = builder.call("_allreduce", [x_2, mnm.ir.const("sum"), mnm.ir.const([])])
builder.add_event(4, comm_stream)
builder.set_stream(0, comp_stream)
builder.wait_event(4, comp_stream)
Expand Down

0 comments on commit b676af6

Please sign in to comment.