From 2133a71dba43ea94ba71e3997b86fc0c42e193cd Mon Sep 17 00:00:00 2001 From: NekoDaemon Date: Fri, 18 Mar 2022 09:19:49 +0000 Subject: [PATCH] refactor --- include/raf/sharding.h | 8 +-- python/raf/_tvm_op/__init__.py | 2 +- python/raf/distributed/sharding/__init__.py | 1 - python/raf/distributed/sharding/shardspec.py | 4 +- python/raf/distributed/sharding/utils.py | 13 +---- src/impl/sharding.cc | 56 ++++++++++---------- tests/python/pass/test_pass_sharding.py | 48 ++++++++--------- 7 files changed, 59 insertions(+), 73 deletions(-) diff --git a/include/raf/sharding.h b/include/raf/sharding.h index 15e5fc5d..e6856d66 100644 --- a/include/raf/sharding.h +++ b/include/raf/sharding.h @@ -53,8 +53,8 @@ class ShardSpecObj final : public BaseShardSpecObj { Array replicas; Array logic_shape; Array logic_index; - Array real_shape; - Array real_index; + Array phy_shape; + Array phy_index; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("immutable", &immutable); @@ -62,8 +62,8 @@ class ShardSpecObj final : public BaseShardSpecObj { v->Visit("replicas", &replicas); v->Visit("logic_shape", &logic_shape); v->Visit("logic_index", &logic_index); - v->Visit("real_shape", &real_shape); - v->Visit("real_index", &real_index); + v->Visit("phy_shape", &phy_shape); + v->Visit("phy_index", &phy_index); } static constexpr const char* _type_key = "raf.sharding.ShardSpec"; diff --git a/python/raf/_tvm_op/__init__.py b/python/raf/_tvm_op/__init__.py index 76bda77b..c24d4f6c 100644 --- a/python/raf/_tvm_op/__init__.py +++ b/python/raf/_tvm_op/__init__.py @@ -3,5 +3,5 @@ """Compute definition and schedules for TVM operators""" from . import loss, sgd, reduce, transform, broadcast, unary, nn, vision -from . import algorithm, init, random, argwhere +from . import algorithm, init, random, argwhere, sharding from . import utils diff --git a/python/raf/distributed/sharding/__init__.py b/python/raf/distributed/sharding/__init__.py index e3069654..1a8615c5 100644 --- a/python/raf/distributed/sharding/__init__.py +++ b/python/raf/distributed/sharding/__init__.py @@ -2,7 +2,6 @@ from raf._ffi.sharding._make import ShardOpCallAttrs from .shardspec import BaseShardSpec, ReplicatedSpec, TupleShardSpec, ShardSpec from .utils import ( - get_dist_devices, expand_when, always_apply, register_expansion_pattern, diff --git a/python/raf/distributed/sharding/shardspec.py b/python/raf/distributed/sharding/shardspec.py index 089eb636..01a0e4c5 100644 --- a/python/raf/distributed/sharding/shardspec.py +++ b/python/raf/distributed/sharding/shardspec.py @@ -38,5 +38,5 @@ def __len__(self): class ShardSpec(BaseShardSpec): """Annotation of Sharding Specifications""" - def __init__(self, ranks, real_shape, replicas, immutable=False): - self.__init_handle_by_constructor__(_make.ShardSpec, immutable, ranks, real_shape, replicas) + def __init__(self, ranks, phy_shape, replicas, immutable=False): + self.__init_handle_by_constructor__(_make.ShardSpec, immutable, ranks, phy_shape, replicas) diff --git a/python/raf/distributed/sharding/utils.py b/python/raf/distributed/sharding/utils.py index 645f97e7..c6a8711c 100644 --- a/python/raf/distributed/sharding/utils.py +++ b/python/raf/distributed/sharding/utils.py @@ -19,18 +19,7 @@ 7: "kTuple", 8: "kOpaque", } -# TODO: this pattern map is replicated mulitple times in source code - - -def get_dist_devices(): - """Return all available devices in the cluster as a list of Device Objects. - - Returns - ------- - ret: list - List of Device Objects - """ - return dist.get_context().dist_devices +# TODO: this pattern map is replicated multiple times in source code def always_apply(call: relay.Call): diff --git a/src/impl/sharding.cc b/src/impl/sharding.cc index d7a641a2..3c3515ca 100644 --- a/src/impl/sharding.cc +++ b/src/impl/sharding.cc @@ -33,12 +33,12 @@ ReplicatedSpec ReplicatedSpec::make(bool immutable) { return ReplicatedSpec(n); } -ShardSpec ShardSpec::make(bool immutable, Array ranks, Array real_shape, +ShardSpec ShardSpec::make(bool immutable, Array ranks, Array phy_shape, Array replicas) { - auto ndim = real_shape.size(); + auto ndim = phy_shape.size(); CHECK_EQ(ndim, replicas.size()); auto n = make_object(); - auto real_index = std::vector(ndim); + auto phy_index = std::vector(ndim); auto logic_index = std::vector(ndim); auto logic_shape = std::vector(ndim); @@ -51,22 +51,22 @@ ShardSpec ShardSpec::make(bool immutable, Array ranks, Array r } for (int64_t i = ndim - 1; i >= 0; --i) { - logic_shape[i] = real_shape[i]->value / replicas[i]->value; - real_index[i] = rank_idx % real_shape[i]->value; - logic_index[i] = real_index[i]->value / replicas[i]->value; - rank_idx /= real_shape[i]->value; + logic_shape[i] = phy_shape[i]->value / replicas[i]->value; + phy_index[i] = rank_idx % phy_shape[i]->value; + logic_index[i] = phy_index[i]->value / replicas[i]->value; + rank_idx /= phy_shape[i]->value; } n->immutable = immutable; n->ranks = std::move(ranks); n->replicas = std::move(replicas); - n->real_shape = std::move(real_shape); + n->phy_shape = std::move(phy_shape); n->logic_shape = Array(logic_shape); if (rank_idx == -1) { - n->real_index = NullValue>(); + n->phy_index = NullValue>(); n->logic_index = NullValue>(); } else { - n->real_index = Array(real_index); + n->phy_index = Array(phy_index); n->logic_index = Array(logic_index); } @@ -93,11 +93,11 @@ void Reshard_R2S(const CallValues& call) { const DLTensor* x = args->x; std::vector shape(x->shape, x->shape + x->ndim); auto spec = Downcast(args->spec); - if (spec->real_index.defined()) { + if (spec->logic_index.defined()) { for (int64_t i = 0; i < x->ndim; ++i) { - auto grid_dim_size = spec->real_shape[i]->value; - CHECK_EQ(x->shape[i] % grid_dim_size, 0) << "Currently automaic padding is unsupported."; - shape[i] /= grid_dim_size; + auto shard_dim_size = spec->logic_shape[i]->value; + CHECK_EQ(x->shape[i] % shard_dim_size, 0) << "Currently automaic padding is unsupported."; + shape[i] /= shard_dim_size; } call->out = TensorValue::Assemble(/*dev=*/x->device, /*dtype=*/x->dtype, @@ -120,12 +120,12 @@ Type Reshard_R2S_Infer(const CallValues& call) { Array dshape = data->shape; size_t ndim = dshape.size(); std::vector oshape(ndim); - CHECK(spec->real_index.defined()); + CHECK(spec->logic_index.defined()); for (int64_t i = 0; i < ndim; ++i) { - auto grid_dim_size = spec->real_shape[i]->value; + auto shard_dim_size = spec->logic_shape[i]->value; auto dim_size = Downcast(dshape[i])->value; - CHECK_EQ(dim_size % grid_dim_size, 0) << "Currently automaic padding is unsupported."; - oshape[i] = Integer(dim_size / grid_dim_size); + CHECK_EQ(dim_size % shard_dim_size, 0) << "Currently automaic padding is unsupported."; + oshape[i] = Integer(dim_size / shard_dim_size); } return TensorType(oshape, data->dtype); } @@ -148,14 +148,14 @@ using tvm::runtime::ObjectRef; void PrintAllocTable(const ObjectRef& ref, ReprPrinter* p) { /*size_t dev_idx = 0; const auto obj = Downcast(ref); - const auto num_dim = obj->real_shape.size(); + const auto num_dim = obj->phy_shape.size(); static thread_local size_t *indices = new size_t[num_dim]; std::function _print_alloc_table; _print_alloc_table = [&](int depth) { if (depth == num_dim) { p->stream << (dev_idx != 0 ? " [" : "["); for (size_t i = 0; i < num_dim; ++i) { - auto num_devices = obj->real_shape[i]->value; + auto num_devices = obj->phy_shape[i]->value; auto index = std::to_string(indices[i]); p->stream << (num_devices == 1 ? ":" : index) << (i != num_dim - 1 ? ", " : ""); @@ -163,7 +163,7 @@ void PrintAllocTable(const ObjectRef& ref, ReprPrinter* p) { auto dev_info = obj->ranks[dev_idx++].c_str(); p->stream << "]@" << dev_info; } else { - auto subgroup_num = obj->real_shape[depth]->value; + auto subgroup_num = obj->phy_shape[depth]->value; for (size_t i = 0; i < subgroup_num; ++i) { indices[depth] = i; _print_alloc_table(depth + 1); @@ -182,12 +182,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto r = Downcast(ref); - auto ndim = r->real_shape.size(); + auto ndim = r->logic_shape.size(); p->stream << "ShardSpec(" << (r->immutable ? "Immut " : "") << "["; for (size_t i = 0; i < ndim; ++i) { - auto grid_dim_size = r->real_shape[i]->value; + auto shard_dim_size = r->logic_shape[i]->value; auto subgroup_size = r->replicas[i]->value; - p->stream << (grid_dim_size == 1 ? ":" : std::to_string(grid_dim_size)) + p->stream << (shard_dim_size == 1 ? ":" : std::to_string(shard_dim_size)) << (subgroup_size == 1 ? "" : "(x" + std::to_string(subgroup_size) + ")") << (i != ndim - 1 ? ", " : ""); } @@ -235,10 +235,10 @@ Attrs ReshardSchema2Attrs(const ShardUnaryArgs* args) { const DLTensor* x = args->x; std::vector begin(x->ndim); std::vector end(x->ndim); - CHECK(spec->real_index.defined()); + CHECK(spec->logic_index.defined()); for (int i = 0; i < x->ndim; ++i) { - auto idx = spec->real_index[i]->value; - auto size = spec->real_shape[i]->value; + auto idx = spec->logic_index[i]->value; + auto size = spec->logic_shape[i]->value; begin[i] = Integer((x->shape[i] / size) * idx); end[i] = Integer((x->shape[i] / size) * (idx + 1)); } @@ -251,7 +251,7 @@ HashKey ReshardHasher(const std::vector& param_types, const Type& y_type, const ShardUnaryArgs* args) { HashKey key = GenericHasher(param_types, y_type, nullptr); auto spec = Downcast(args->spec); - for (auto array : {spec->ranks, spec->real_shape, spec->replicas}) { + for (auto array : {spec->ranks, spec->phy_shape, spec->replicas}) { for (auto i : array) { key << i->value; } diff --git a/tests/python/pass/test_pass_sharding.py b/tests/python/pass/test_pass_sharding.py index 4151cf59..cbe1bafc 100644 --- a/tests/python/pass/test_pass_sharding.py +++ b/tests/python/pass/test_pass_sharding.py @@ -1,6 +1,7 @@ # pylint: disable=missing-function-docstring, missing-class-docstring, invalid-name, protected-access import pytest import raf +import numpy as np from raf._core.core_utils import str2dev from raf._core.executor import interpret from raf.distributed.sharding import ( @@ -11,12 +12,9 @@ ShardOpCallAttrs, ) from raf._ffi.pass_ import SetShardOpCallAttrs, ToGraphNormalForm, ExpandShardOpCall, InferType -from raf._ffi.device import Device from raf._lib import relay -from raf.distributed.sharding.utils import get_dist_devices from raf.testing import randn from raf.hybrid.hybrid import _make_argument, _unwrap -from raf import distributed as dist from tvm.relay.analysis.analysis import post_order_visit @@ -31,33 +29,17 @@ def forward(self, x, y): return z model = Model() - # m_x, _ = randn((128, 128)) - # m_y, _ = randn((128, 128)) - m_x = raf.array([1, 2, 3, 4]) - m_y = raf.array([0, 0, 0, 0]) + m_x = raf.array(np.arange(16, dtype="float").reshape((4, 4))) + m_y = raf.array(np.zeros(16, dtype="float").reshape((4, 4))) record = model._internal(m_x, m_y) mod_before = record.mod mod_before = InferType()(mod_before) - def get_global_device_list(dev_type="cuda"): - dev_type_id = str2dev(dev_type).device_type - dctx = dist.get_context() - local_id = 6 - local_id -= 1 - dev_array = ( - [Device(dev_type_id, i) for i in range(1, local_id)] - + [dctx.local_device] - + [Device(dev_type_id, i) for i in range(local_id, 16)] - ) - return dev_array - - # devices = get_global_device_list() - devices = get_dist_devices() attrs = ShardOpCallAttrs( - TupleShardSpec([ReplicatedSpec(), ReplicatedSpec()]), ShardSpec(devices, [4], [1]) + TupleShardSpec([ReplicatedSpec(), ReplicatedSpec()]), ShardSpec([3, 2, 1, 0], [2, 2], [1, 2]) ) - # a = raf._reshard_r2s(m_x, ShardSpec(devices, [4, 4], [1, 2])) - # print(a) + + print(m_x) call_list = [] post_order_visit( mod_before["main"].body, @@ -68,12 +50,28 @@ def get_global_device_list(dev_type="cuda"): mod0 = SetShardOpCallAttrs(attrs_map)(mod_before) mod1 = ToGraphNormalForm()(mod0) mod2 = ExpandShardOpCall()(mod1) - # print(raf._ffi.ir.AsText(mod2)) + print(raf._ffi.ir.AsText(mod2)) call = relay.Call(op=mod2["main"], args=[_make_argument(x) for x in (m_x, m_y)]) result = _unwrap(interpret(call, mod2)) print(result) +def test_reshard_r2s(): + class Model(raf.Model): + def build(self): + pass + + @raf.model.trace + def forward(self, x, y): + z = raf.add(x, y) + return z + n_x = np.arange(16).reshape((4, 4)) + m_x = raf.array(n_x) + m_y = raf._reshard_r2s(m_x, ShardSpec([0, 1, 2, 3], [2, 2], [1, 2])) + print(m_x) + print(m_y) + if __name__ == "__main__": # pytest.main([__file__]) test_ShardOpCallAttrs() + # test_reshard_r2s()