Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonny-Gu committed Mar 18, 2022
1 parent 8a8b783 commit 2133a71
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 73 deletions.
8 changes: 4 additions & 4 deletions include/raf/sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,17 @@ class ShardSpecObj final : public BaseShardSpecObj {
Array<Integer> replicas;
Array<Integer> logic_shape;
Array<Integer> logic_index;
Array<Integer> real_shape;
Array<Integer> real_index;
Array<Integer> phy_shape;
Array<Integer> phy_index;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("immutable", &immutable);
v->Visit("ranks", &ranks);
v->Visit("replicas", &replicas);
v->Visit("logic_shape", &logic_shape);
v->Visit("logic_index", &logic_index);
v->Visit("real_shape", &real_shape);
v->Visit("real_index", &real_index);
v->Visit("phy_shape", &phy_shape);
v->Visit("phy_index", &phy_index);
}

static constexpr const char* _type_key = "raf.sharding.ShardSpec";
Expand Down
2 changes: 1 addition & 1 deletion python/raf/_tvm_op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@

"""Compute definition and schedules for TVM operators"""
from . import loss, sgd, reduce, transform, broadcast, unary, nn, vision
from . import algorithm, init, random, argwhere
from . import algorithm, init, random, argwhere, sharding
from . import utils
1 change: 0 additions & 1 deletion python/raf/distributed/sharding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from raf._ffi.sharding._make import ShardOpCallAttrs
from .shardspec import BaseShardSpec, ReplicatedSpec, TupleShardSpec, ShardSpec
from .utils import (
get_dist_devices,
expand_when,
always_apply,
register_expansion_pattern,
Expand Down
4 changes: 2 additions & 2 deletions python/raf/distributed/sharding/shardspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ def __len__(self):
class ShardSpec(BaseShardSpec):
"""Annotation of Sharding Specifications"""

def __init__(self, ranks, real_shape, replicas, immutable=False):
self.__init_handle_by_constructor__(_make.ShardSpec, immutable, ranks, real_shape, replicas)
def __init__(self, ranks, phy_shape, replicas, immutable=False):
self.__init_handle_by_constructor__(_make.ShardSpec, immutable, ranks, phy_shape, replicas)
13 changes: 1 addition & 12 deletions python/raf/distributed/sharding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,7 @@
7: "kTuple",
8: "kOpaque",
}
# TODO: this pattern map is replicated mulitple times in source code


def get_dist_devices():
"""Return all available devices in the cluster as a list of Device Objects.
Returns
-------
ret: list
List of Device Objects
"""
return dist.get_context().dist_devices
# TODO: this pattern map is replicated multiple times in source code


def always_apply(call: relay.Call):
Expand Down
56 changes: 28 additions & 28 deletions src/impl/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ ReplicatedSpec ReplicatedSpec::make(bool immutable) {
return ReplicatedSpec(n);
}

ShardSpec ShardSpec::make(bool immutable, Array<Integer> ranks, Array<Integer> real_shape,
ShardSpec ShardSpec::make(bool immutable, Array<Integer> ranks, Array<Integer> phy_shape,
Array<Integer> replicas) {
auto ndim = real_shape.size();
auto ndim = phy_shape.size();
CHECK_EQ(ndim, replicas.size());
auto n = make_object<ShardSpecObj>();
auto real_index = std::vector<Integer>(ndim);
auto phy_index = std::vector<Integer>(ndim);
auto logic_index = std::vector<Integer>(ndim);
auto logic_shape = std::vector<Integer>(ndim);

Expand All @@ -51,22 +51,22 @@ ShardSpec ShardSpec::make(bool immutable, Array<Integer> ranks, Array<Integer> r
}

for (int64_t i = ndim - 1; i >= 0; --i) {
logic_shape[i] = real_shape[i]->value / replicas[i]->value;
real_index[i] = rank_idx % real_shape[i]->value;
logic_index[i] = real_index[i]->value / replicas[i]->value;
rank_idx /= real_shape[i]->value;
logic_shape[i] = phy_shape[i]->value / replicas[i]->value;
phy_index[i] = rank_idx % phy_shape[i]->value;
logic_index[i] = phy_index[i]->value / replicas[i]->value;
rank_idx /= phy_shape[i]->value;
}

n->immutable = immutable;
n->ranks = std::move(ranks);
n->replicas = std::move(replicas);
n->real_shape = std::move(real_shape);
n->phy_shape = std::move(phy_shape);
n->logic_shape = Array<Integer>(logic_shape);
if (rank_idx == -1) {
n->real_index = NullValue<Array<Integer>>();
n->phy_index = NullValue<Array<Integer>>();
n->logic_index = NullValue<Array<Integer>>();
} else {
n->real_index = Array<Integer>(real_index);
n->phy_index = Array<Integer>(phy_index);
n->logic_index = Array<Integer>(logic_index);
}

Expand All @@ -93,11 +93,11 @@ void Reshard_R2S(const CallValues& call) {
const DLTensor* x = args->x;
std::vector<int64_t> shape(x->shape, x->shape + x->ndim);
auto spec = Downcast<ShardSpec>(args->spec);
if (spec->real_index.defined()) {
if (spec->logic_index.defined()) {
for (int64_t i = 0; i < x->ndim; ++i) {
auto grid_dim_size = spec->real_shape[i]->value;
CHECK_EQ(x->shape[i] % grid_dim_size, 0) << "Currently automaic padding is unsupported.";
shape[i] /= grid_dim_size;
auto shard_dim_size = spec->logic_shape[i]->value;
CHECK_EQ(x->shape[i] % shard_dim_size, 0) << "Currently automaic padding is unsupported.";
shape[i] /= shard_dim_size;
}
call->out = TensorValue::Assemble(/*dev=*/x->device,
/*dtype=*/x->dtype,
Expand All @@ -120,12 +120,12 @@ Type Reshard_R2S_Infer(const CallValues& call) {
Array<PrimExpr> dshape = data->shape;
size_t ndim = dshape.size();
std::vector<PrimExpr> oshape(ndim);
CHECK(spec->real_index.defined());
CHECK(spec->logic_index.defined());
for (int64_t i = 0; i < ndim; ++i) {
auto grid_dim_size = spec->real_shape[i]->value;
auto shard_dim_size = spec->logic_shape[i]->value;
auto dim_size = Downcast<IntImm>(dshape[i])->value;
CHECK_EQ(dim_size % grid_dim_size, 0) << "Currently automaic padding is unsupported.";
oshape[i] = Integer(dim_size / grid_dim_size);
CHECK_EQ(dim_size % shard_dim_size, 0) << "Currently automaic padding is unsupported.";
oshape[i] = Integer(dim_size / shard_dim_size);
}
return TensorType(oshape, data->dtype);
}
Expand All @@ -148,22 +148,22 @@ using tvm::runtime::ObjectRef;
void PrintAllocTable(const ObjectRef& ref, ReprPrinter* p) {
/*size_t dev_idx = 0;
const auto obj = Downcast<ShardSpec>(ref);
const auto num_dim = obj->real_shape.size();
const auto num_dim = obj->phy_shape.size();
static thread_local size_t *indices = new size_t[num_dim];
std::function<void(int)> _print_alloc_table;
_print_alloc_table = [&](int depth) {
if (depth == num_dim) {
p->stream << (dev_idx != 0 ? " [" : "[");
for (size_t i = 0; i < num_dim; ++i) {
auto num_devices = obj->real_shape[i]->value;
auto num_devices = obj->phy_shape[i]->value;
auto index = std::to_string(indices[i]);
p->stream << (num_devices == 1 ? ":" : index)
<< (i != num_dim - 1 ? ", " : "");
}
auto dev_info = obj->ranks[dev_idx++].c_str();
p->stream << "]@" << dev_info;
} else {
auto subgroup_num = obj->real_shape[depth]->value;
auto subgroup_num = obj->phy_shape[depth]->value;
for (size_t i = 0; i < subgroup_num; ++i) {
indices[depth] = i;
_print_alloc_table(depth + 1);
Expand All @@ -182,12 +182,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShardSpecObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto r = Downcast<ShardSpec>(ref);
auto ndim = r->real_shape.size();
auto ndim = r->logic_shape.size();
p->stream << "ShardSpec(" << (r->immutable ? "Immut " : "") << "[";
for (size_t i = 0; i < ndim; ++i) {
auto grid_dim_size = r->real_shape[i]->value;
auto shard_dim_size = r->logic_shape[i]->value;
auto subgroup_size = r->replicas[i]->value;
p->stream << (grid_dim_size == 1 ? ":" : std::to_string(grid_dim_size))
p->stream << (shard_dim_size == 1 ? ":" : std::to_string(shard_dim_size))
<< (subgroup_size == 1 ? "" : "(x" + std::to_string(subgroup_size) + ")")
<< (i != ndim - 1 ? ", " : "");
}
Expand Down Expand Up @@ -235,10 +235,10 @@ Attrs ReshardSchema2Attrs(const ShardUnaryArgs* args) {
const DLTensor* x = args->x;
std::vector<Integer> begin(x->ndim);
std::vector<Integer> end(x->ndim);
CHECK(spec->real_index.defined());
CHECK(spec->logic_index.defined());
for (int i = 0; i < x->ndim; ++i) {
auto idx = spec->real_index[i]->value;
auto size = spec->real_shape[i]->value;
auto idx = spec->logic_index[i]->value;
auto size = spec->logic_shape[i]->value;
begin[i] = Integer((x->shape[i] / size) * idx);
end[i] = Integer((x->shape[i] / size) * (idx + 1));
}
Expand All @@ -251,7 +251,7 @@ HashKey ReshardHasher(const std::vector<Type>& param_types, const Type& y_type,
const ShardUnaryArgs* args) {
HashKey key = GenericHasher<nullptr_t>(param_types, y_type, nullptr);
auto spec = Downcast<ShardSpec>(args->spec);
for (auto array : {spec->ranks, spec->real_shape, spec->replicas}) {
for (auto array : {spec->ranks, spec->phy_shape, spec->replicas}) {
for (auto i : array) {
key << i->value;
}
Expand Down
48 changes: 23 additions & 25 deletions tests/python/pass/test_pass_sharding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=missing-function-docstring, missing-class-docstring, invalid-name, protected-access
import pytest
import raf
import numpy as np
from raf._core.core_utils import str2dev
from raf._core.executor import interpret
from raf.distributed.sharding import (
Expand All @@ -11,12 +12,9 @@
ShardOpCallAttrs,
)
from raf._ffi.pass_ import SetShardOpCallAttrs, ToGraphNormalForm, ExpandShardOpCall, InferType
from raf._ffi.device import Device
from raf._lib import relay
from raf.distributed.sharding.utils import get_dist_devices
from raf.testing import randn
from raf.hybrid.hybrid import _make_argument, _unwrap
from raf import distributed as dist
from tvm.relay.analysis.analysis import post_order_visit


Expand All @@ -31,33 +29,17 @@ def forward(self, x, y):
return z

model = Model()
# m_x, _ = randn((128, 128))
# m_y, _ = randn((128, 128))
m_x = raf.array([1, 2, 3, 4])
m_y = raf.array([0, 0, 0, 0])
m_x = raf.array(np.arange(16, dtype="float").reshape((4, 4)))
m_y = raf.array(np.zeros(16, dtype="float").reshape((4, 4)))
record = model._internal(m_x, m_y)
mod_before = record.mod
mod_before = InferType()(mod_before)

def get_global_device_list(dev_type="cuda"):
dev_type_id = str2dev(dev_type).device_type
dctx = dist.get_context()
local_id = 6
local_id -= 1
dev_array = (
[Device(dev_type_id, i) for i in range(1, local_id)]
+ [dctx.local_device]
+ [Device(dev_type_id, i) for i in range(local_id, 16)]
)
return dev_array

# devices = get_global_device_list()
devices = get_dist_devices()
attrs = ShardOpCallAttrs(
TupleShardSpec([ReplicatedSpec(), ReplicatedSpec()]), ShardSpec(devices, [4], [1])
TupleShardSpec([ReplicatedSpec(), ReplicatedSpec()]), ShardSpec([3, 2, 1, 0], [2, 2], [1, 2])
)
# a = raf._reshard_r2s(m_x, ShardSpec(devices, [4, 4], [1, 2]))
# print(a)

print(m_x)
call_list = []
post_order_visit(
mod_before["main"].body,
Expand All @@ -68,12 +50,28 @@ def get_global_device_list(dev_type="cuda"):
mod0 = SetShardOpCallAttrs(attrs_map)(mod_before)
mod1 = ToGraphNormalForm()(mod0)
mod2 = ExpandShardOpCall()(mod1)
# print(raf._ffi.ir.AsText(mod2))
print(raf._ffi.ir.AsText(mod2))
call = relay.Call(op=mod2["main"], args=[_make_argument(x) for x in (m_x, m_y)])
result = _unwrap(interpret(call, mod2))
print(result)

def test_reshard_r2s():
class Model(raf.Model):
def build(self):
pass

@raf.model.trace
def forward(self, x, y):
z = raf.add(x, y)
return z
n_x = np.arange(16).reshape((4, 4))
m_x = raf.array(n_x)
m_y = raf._reshard_r2s(m_x, ShardSpec([0, 1, 2, 3], [2, 2], [1, 2]))
print(m_x)
print(m_y)


if __name__ == "__main__":
# pytest.main([__file__])
test_ShardOpCallAttrs()
# test_reshard_r2s()

0 comments on commit 2133a71

Please sign in to comment.