Skip to content

Commit

Permalink
Merge branch 'awslabs:main' into sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonny-Gu authored Aug 7, 2022
2 parents dd26fbc + 3136651 commit 61f3dd0
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 67 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 75ec1c to 4ec868
28 changes: 27 additions & 1 deletion python/raf/distributed/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
"""Collective communication operators"""
from .._op import sym
from .communicator import get_communicator
from .._core.ndarray import Symbol
from .._core.module import IRModule
from .._ffi.pass_ import ExtractBinding, InferType


def allreduce(x, computation="sum", rank_list=None):
Expand Down Expand Up @@ -131,7 +134,7 @@ def reduce_scatter(x, computation="sum", rank_list=None):
Parameters
----------
x : List[Tensor]
x : Tensor or List[Tensor]
A list of tensors of equal shape
replica i receives reduction of x[i] over all replicas
computation: string
Expand All @@ -151,6 +154,29 @@ def reduce_scatter(x, computation="sum", rank_list=None):
reduction result of x[rank] over all replicas,
where rank represents rank number of the current process
"""
comm = get_communicator()
if rank_list:
for group in rank_list:
if comm.rank in group:
size = len(group)
break
else:
size = 1
else:
size = comm.size

if isinstance(x, (tuple, list)):
assert len(x) == size, "Invalid size of tensor list"
body = Symbol.make_tuple(x)._Symbol__handle
body = ExtractBinding(body, [])
mod = IRModule.from_expr(body)
mod = InferType()(mod)
ret_list = mod["main"].checked_type.ret_type
single_tensor_type = ret_list.fields[0]
for tensor_type in ret_list.fields:
assert single_tensor_type == tensor_type, "Invalid tensor shape"
x = sym.concatenate(x, axis=0)

return sym._reduce_scatter(x, computation, rank_list=rank_list)


Expand Down
2 changes: 1 addition & 1 deletion scripts/src_codegen/def_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@
),
],
"communication.h::reduce_scatter": [
Arg(name="x", cxx_type="std::vector<value::BaseTensorValue>", cxx_normalizer="TensorTuple"),
Arg(name="x", cxx_type="value::BaseTensorValue"),
Arg(name="computation", cxx_type="std::string", cxx_default='"sum"', py_default='"sum"'),
Arg(name="rank_list", cxx_type="value::Value", cxx_default="nullptr"),
],
Expand Down
19 changes: 8 additions & 11 deletions src/op/declare/collective_comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,17 @@ RAF_OP_DECLARE("raf.op._group_allgather", GroupAllGather)
void ReduceScatter(const CallValues& call) {
const auto* args = call->args.as<ReduceScatterArgs>();
CHECK(args != nullptr);
std::vector<BaseTensorValue> tvs = args->x;
CHECK_GE(tvs.size(), 1U);
const DLTensor* x = tvs[0];
const DLTensor* x = args->x;
std::vector<int64_t> shape(x->shape, x->shape + x->ndim);
if (tvs.size() == 1) {
int size = GetGlobalCommunicator()->size;
CHECK(shape[0] % size == 0);
shape[0] = shape[0] / size;
int size;
if (args->rank_list.defined()) {
size = Communicator::Get("void", args->rank_list)->size;
} else {
for (const auto& tv : tvs) {
const DLTensor* x = tv;
CHECK(shape == std::vector<int64_t>(x->shape, x->shape + x->ndim));
}
size = GetGlobalCommunicator()->size;
}
CHECK(shape[0] % size == 0) << "Input tensor with first dim shape " << shape[0]
<< " cannot be scattered to " << size << "devices evenly";
shape[0] = shape[0] / size;
call->device = x->device;
call->out = TensorValue::Assemble(/*ctx=*/x->device,
/*dtype=*/x->dtype,
Expand Down
29 changes: 6 additions & 23 deletions src/op/dialect/nccl/nccl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,6 @@ RAF_REGISTER_DIALECT_OP(nccl, _group_allgather, 10);
RAF_OP_ENV_MAKER("raf.op.nccl._group_allgather", NCCLGroupAllGather::make);

class NCCLReduceScatter : public NCCLOpEnv {
void* in_buffer;
size_t size_in_bytes;
size_t size;
ncclRedOp_t compute;

Expand Down Expand Up @@ -282,9 +280,8 @@ class NCCLReduceScatter : public NCCLOpEnv {
}

const DLTensor* out = cv->out;
size_in_bytes = BytesCompactTensor(*out);
size_t size_in_bytes = BytesCompactTensor(*out);
size = size_in_bytes / (out->dtype.bits / 8);
RequestWorkspace(&in_buffer, cv->device, size_in_bytes * GetGlobalCommunicator()->size);
}

public:
Expand All @@ -297,33 +294,19 @@ class NCCLReduceScatter : public NCCLOpEnv {

void Execute(const CallValues& cv) {
auto args = cv->args.as<raf::op::schema::ReduceScatterArgs>();
Execute({TupleValue::make(ir::Array<Value>(args->x.begin(), args->x.end()))}, cv->out);
Execute({args->x}, cv->out);
}

void Execute(const std::vector<value::Value>& inputs, value::Value output) {
auto comm_ref = GetRef<Communicator>(reinterpret_cast<CommunicatorObj*>(communicator));
ncclComm_t nccl_comm = Downcast<NCCLCommunicator>(comm_ref)->nccl_comm;
size_t offset = 0;
DLTensor* out = output;
DType dtype;

auto tv = Downcast<value::TupleValue>(inputs[0]);
if (tv->fields.size() == 1) {
DLTensor* x = tv->fields[0];
dtype = x->dtype;
NCCL_CALL(ncclReduceScatter(x->data, out->data, size, dtype, compute, nccl_comm,
(cudaStream_t)stream));
} else {
for (int i = 0; i < tv->fields.size(); ++i) {
DLTensor* x = tv->fields[i];
void* buffer_data_at_offset = reinterpret_cast<uint8_t*>(in_buffer) + size_in_bytes * i;
cudaMemcpyAsync(buffer_data_at_offset, x->data, size_in_bytes, cudaMemcpyDeviceToDevice,
(cudaStream_t)stream);
dtype = x->dtype;
}
NCCL_CALL(ncclReduceScatter(in_buffer, out->data, size, dtype, compute, nccl_comm,
(cudaStream_t)stream));
}
DLTensor* x = inputs[0];
dtype = x->dtype;
NCCL_CALL(ncclReduceScatter(x->data, out->data, size, dtype, compute, nccl_comm,
(cudaStream_t)stream));
}

static OpEnv* make(const CallValues& cv) {
Expand Down
26 changes: 9 additions & 17 deletions src/op/ty/collective_comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,15 @@ Type ReduceScatterInfer(const CallValues& value) {

const auto* args = value->args.as<ReduceScatterArgs>();
CHECK(args != nullptr);
CHECK_GE(args->x.size(), 1U);
const auto& ty = GetType(args->x[0]);
if (args->x.size() == 1) {
int size = GetGlobalCommunicator()->size;
auto tpn = ty.as<TensorTypeNode>();
auto shape = tpn->shape;
auto old_size = shape[0].as<IntImmNode>()->value;
CHECK(old_size % size == 0);
auto new_size = old_size / size;
shape.Set(0, Integer(new_size));
return TensorType(shape, DataType(tpn->dtype));
} else {
for (const auto& x : args->x) {
(*structural_equal)(GetType(x), ty, true, true);
}
return ty;
}
const auto& ty = GetType(args->x);
int size = GetGlobalCommunicator()->size;
auto tpn = ty.as<TensorTypeNode>();
auto shape = tpn->shape;
auto old_size = shape[0].as<IntImmNode>()->value;
CHECK(old_size % size == 0);
auto new_size = old_size / size;
shape.Set(0, Integer(new_size));
return TensorType(shape, DataType(tpn->dtype));
}

RAF_OP_TYPE("raf.op._reduce_scatter", "NCCLReduceScatter", ReduceScatterInfer);
Expand Down
9 changes: 3 additions & 6 deletions src/pass/partition_gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,12 @@ class GradientPartitioner : public ExprMutator {
* // if NCCL version is >= 2.10
* let %1 = op(%0); // A backward op to generate gradient
* let %2 = pad(%1, ...); // %1 is the complete local gradient
* let %3 = Tuple(%2);
* let %4 = reduce_scatter(%3, avg);
* let %3 = reduce_scatter(%2, avg);
* // else NCCL version is < 2.10
* let %1 = op(%0); // A backward op to generate gradient
* let %2 = pad(%1, ...); // %1 is the complete local gradient
* let %3 = Tuple(%2);
* let %4 = reduce_scatter(%3, sum);
* let %5 = divide(%4, ...)
* let %3 = reduce_scatter(%2, sum);
* let %4 = divide(%3, ...)
* The desired IR for ZeRO-2 is if bucket_size_ > 2, which means group reduce_scatter:
* // if NCCL version is >= 2.10
* let %1 = op(%0); // A backward op to generate gradient
Expand Down Expand Up @@ -272,7 +270,6 @@ class GradientPartitioner : public ExprMutator {
grad_var = GenPadCall(scope, grad_var);
if (bucket_size_ < 2) {
// Do not group redcue_scatter
grad_var = scope->Push(Tuple({grad_var}));
auto reduce_scatter_var = scope->Push(Call(reduce_scatter_op, {grad_var, compute}));
if (divide_expr.defined()) {
// update the divide op args
Expand Down
11 changes: 4 additions & 7 deletions tests/python/distributed/test_collective_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,14 @@ def forward(self, x1, x2):

@pytest.mark.skipif(skip_dist_test(min_rank_num=2, require_exact_rank=True), reason=SKIP_REASON)
@pytest.mark.parametrize("computation", ["sum", "prod", "min", "max"])
def test_reduce_scatter(computation):
def test_reduce_scatter_tensor_list(computation):
class TestModel(raf.Model):
def build(self):
pass

@raf.model.trace
def forward(self, x, y):
z = Symbol.make_tuple([x, y])
out = raf.reduce_scatter(z, computation=computation)
out = raf.reduce_scatter([x, y], computation=computation)
return out

if computation == "avg" and raf.build.with_nccl() < 21000:
Expand Down Expand Up @@ -345,8 +344,7 @@ def build(self):

@raf.model.trace
def forward(self, x, y):
z = Symbol.make_tuple([x, y])
out = raf.reduce_scatter(z, computation=computation, rank_list=rank_list)
out = raf.reduce_scatter([x, y], computation=computation, rank_list=rank_list)
return out

if computation == "avg" and raf.build.with_nccl() < 21000:
Expand Down Expand Up @@ -662,8 +660,7 @@ def build(self):

@raf.model.trace
def forward(self, x):
z = Symbol.make_tuple([x])
out = raf.reduce_scatter(z, computation=computation)
out = raf.reduce_scatter(x, computation=computation)
return out

if computation == "avg" and raf.build.with_nccl() < 21000:
Expand Down

0 comments on commit 61f3dd0

Please sign in to comment.