From 4a50d9d72e568318be53c6abadf387d53073fb75 Mon Sep 17 00:00:00 2001 From: Tonny-Gu Date: Sat, 21 Aug 2021 08:04:06 +0000 Subject: [PATCH] extend shardspec family --- include/mnm/sharding.h | 101 ++++++++++++------ python/mnm/_ffi/sharding/_make/__init__.py | 2 + python/mnm/_ffi/sharding/_make/_internal.py | 4 + src/impl/sharding.cc | 51 +++++++-- src/pass/init_shard_op_attrs.cc | 8 +- .../pass/test_pass_init_shard_op_attrs.py | 13 +-- 6 files changed, 127 insertions(+), 52 deletions(-) diff --git a/include/mnm/sharding.h b/include/mnm/sharding.h index f52205b3..2fb886be 100644 --- a/include/mnm/sharding.h +++ b/include/mnm/sharding.h @@ -13,38 +13,65 @@ namespace sharding { using namespace mnm::ir; using namespace mnm::value; -/* Sharding Specifications */ -class ShardSpecObj : public Object { +/* BaseShardSpec */ +class BaseShardSpecObj : public Object { public: bool immutable; - bool replicated; - + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("immutable", &immutable); + } + static constexpr const uint32_t _type_index = ir::TypeIndex::kDynamic; + static constexpr const char* _type_key = "mnm.sharding.BaseShardSpec"; + MNM_BASE_OBJECT(BaseShardSpecObj, Object); +}; + +class BaseShardSpec : public ObjectRef { + public: + MNM_OBJECT_REF(BaseShardSpec, ObjectRef, BaseShardSpecObj); +}; + +/* ReplicatedSpec */ +class ReplicatedSpecObj final : public BaseShardSpecObj { + public: + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("immutable", &immutable); + } + static constexpr const char* _type_key = "mnm.sharding.ReplicatedSpec"; + MNM_FINAL_OBJECT(ReplicatedSpecObj, BaseShardSpecObj); +}; + +class ReplicatedSpec final : public BaseShardSpec { + public: + static ReplicatedSpec make(bool immutable); + MNM_OBJECT_REF(ReplicatedSpec, BaseShardSpec, ReplicatedSpecObj); +}; + +/* ShardSpec */ +class ShardSpecObj final : public BaseShardSpecObj { + public: Array assigned_devices; Array num_devices_on_dim; Array num_replicas_on_dim; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("immutable", &immutable); - v->Visit("replicated", &replicated); v->Visit("assigned_devices", &assigned_devices); v->Visit("num_devices_on_dim", &num_devices_on_dim); v->Visit("num_replicas_on_dim", &num_replicas_on_dim); } - static constexpr const uint32_t _type_index = tvm::TypeIndex::kDynamic; static constexpr const char* _type_key = "mnm.sharding.ShardSpec"; - - MNM_FINAL_OBJECT(ShardSpecObj, Object); + MNM_FINAL_OBJECT(ShardSpecObj, BaseShardSpecObj); }; -class ShardSpec : public ObjectRef { +class ShardSpec final : public BaseShardSpec { public: - static ShardSpec make(bool immutable, bool replicated, + static ShardSpec make(bool immutable, Array assigned_devices, Array num_devices_on_dim, Array num_replicas_on_dim); - const void print_alloc_table(std::ostream& ostream = std::cout) const { + void printAllocTable(std::ostream& out = std::cout) const { size_t dev_idx = 0; const auto obj = this->operator->(); const auto num_dim = obj->num_devices_on_dim.size(); @@ -52,20 +79,19 @@ class ShardSpec : public ObjectRef { std::function _print_alloc_table; _print_alloc_table = [&](int depth) { if (depth == num_dim) { - ostream << "["; + out << "["; for (size_t i = 0; i < num_dim; ++i) { auto num_devices = obj->num_devices_on_dim[i]->value; - auto num_replicas = obj->num_replicas_on_dim.defined() ? - obj->num_replicas_on_dim[i]->value : 1; + auto num_replicas = obj->num_replicas_on_dim[i]->value; if (num_devices == 1) { - ostream << ":, "; + out << ":, "; } else { auto index = indices[i] / num_replicas; - ostream << index << ", "; + out << index << ", "; } } auto dev_info = obj->assigned_devices[dev_idx++].c_str(); - ostream << "\b\b]@" << dev_info << " "; + out << "\b\b]@" << dev_info << " "; } else { auto num_devices = obj->num_devices_on_dim[depth]->value; for (size_t i = 0; i < num_devices; ++i) { @@ -77,32 +103,37 @@ class ShardSpec : public ObjectRef { _print_alloc_table(0); } - const char* c_str() const { - static thread_local char buf[2048]; - std::stringstream sstream; - const auto obj = this->operator->(); - sstream.clear(); - sstream << (obj->immutable ? "Immutable " : ""); - if (obj->replicated) { - sstream << "Replicated "; - } else { - print_alloc_table(sstream); - sstream << "\b"; - } - strncpy(buf, sstream.str().c_str(), 2048); - return buf; + MNM_OBJECT_REF(ShardSpec, BaseShardSpec, ShardSpecObj); +}; + +/* TupleShardSpec */ +class TupleShardSpecObj final : public BaseShardSpecObj { + public: + Array tuple_elem; + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("immutable", &immutable); + v->Visit("tuple_elem", &tuple_elem); } + static constexpr const char* _type_key = "mnm.sharding.TupleShardSpec"; + MNM_FINAL_OBJECT(TupleShardSpecObj, BaseShardSpecObj); +}; - MNM_OBJECT_REF(ShardSpec, ObjectRef, ShardSpecObj); +class TupleShardSpec final : public BaseShardSpec { + public: + static TupleShardSpec make(bool immutable, + Array tuple_elem); + MNM_OBJECT_REF(TupleShardSpec, BaseShardSpec, TupleShardSpecObj); }; struct ShardOpAttrs : public tvm::AttrsNode { - Array shard_out; + BaseShardSpec shard_in, shard_out; TVM_DECLARE_ATTRS(ShardOpAttrs, "mnm.attrs.ShardOpAttrs") { - TVM_ATTR_FIELD(shard_out).set_default(NullValue >()) + TVM_ATTR_FIELD(shard_in).set_default(NullValue()) + .describe("Sharding Specifications of inputs"); + TVM_ATTR_FIELD(shard_out).set_default(NullValue()) .describe("Sharding Specifications of outputs"); } }; } // namespace sharding -} // namespace mnm \ No newline at end of file +} // namespace mnm diff --git a/python/mnm/_ffi/sharding/_make/__init__.py b/python/mnm/_ffi/sharding/_make/__init__.py index e9dec18d..c7311602 100644 --- a/python/mnm/_ffi/sharding/_make/__init__.py +++ b/python/mnm/_ffi/sharding/_make/__init__.py @@ -2,4 +2,6 @@ # pylint: disable=redefined-builtin,line-too-long # pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring from __future__ import absolute_import +from ._internal import ReplicatedSpec from ._internal import ShardSpec +from ._internal import TupleShardSpec diff --git a/python/mnm/_ffi/sharding/_make/_internal.py b/python/mnm/_ffi/sharding/_make/_internal.py index 15e2f9d5..61b2d146 100644 --- a/python/mnm/_ffi/sharding/_make/_internal.py +++ b/python/mnm/_ffi/sharding/_make/_internal.py @@ -3,4 +3,8 @@ """Auto generated. Do not touch.""" from mnm._lib import _APIS # Defined in ./src/impl/sharding.cc +ReplicatedSpec = _APIS.get("mnm.sharding._make.ReplicatedSpec", None) +# Defined in ./src/impl/sharding.cc ShardSpec = _APIS.get("mnm.sharding._make.ShardSpec", None) +# Defined in ./src/impl/sharding.cc +TupleShardSpec = _APIS.get("mnm.sharding._make.TupleShardSpec", None) diff --git a/src/impl/sharding.cc b/src/impl/sharding.cc index d51efac8..efd2d911 100644 --- a/src/impl/sharding.cc +++ b/src/impl/sharding.cc @@ -15,31 +15,70 @@ namespace sharding { using namespace mnm::ir; using namespace mnm::value; -ShardSpec ShardSpec::make(bool immutable, bool replicated, +ReplicatedSpec ReplicatedSpec::make(bool immutable) { + auto n = make_object(); + n->immutable = immutable; + return ReplicatedSpec(n); +} + +ShardSpec ShardSpec::make(bool immutable, Array assigned_devices, Array num_devices_on_dim, Array num_replicas_on_dim) { - ObjectPtr n = make_object(); + auto n = make_object(); n->immutable = immutable; - n->replicated = replicated; n->assigned_devices = std::move(assigned_devices); n->num_devices_on_dim = std::move(num_devices_on_dim); n->num_replicas_on_dim = std::move(num_replicas_on_dim); return ShardSpec(n); } +TupleShardSpec TupleShardSpec::make(bool immutable, + Array tuple_elem) { + auto n = make_object(); + n->immutable = immutable; + n->tuple_elem = tuple_elem; + return TupleShardSpec(n); +} + +MNM_REGISTER_GLOBAL("mnm.sharding._make.ReplicatedSpec").set_body_typed(ReplicatedSpec::make); MNM_REGISTER_GLOBAL("mnm.sharding._make.ShardSpec").set_body_typed(ShardSpec::make); +MNM_REGISTER_GLOBAL("mnm.sharding._make.TupleShardSpec").set_body_typed(TupleShardSpec::make); + +MNM_REGISTER_OBJECT_NO_REFLECT(BaseShardSpecObj); +MNM_REGISTER_OBJECT_REFLECT(ReplicatedSpecObj); MNM_REGISTER_OBJECT_REFLECT(ShardSpecObj); +MNM_REGISTER_OBJECT_REFLECT(TupleShardSpecObj); using tvm::ReprPrinter; using tvm::runtime::ObjectRef; + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto r = Downcast(ref); + p->stream << "ReplicatedSpec(" + << (r->immutable ? "Immut" : "") + << ")"; + }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* n = static_cast(ref.get()); - p->stream << "ShardSpec(" << GetRef(n).c_str() << ")"; + auto r = Downcast(ref); + p->stream << "ShardSpec(" + << (r->immutable ? "Immut " : ""); + r.printAllocTable(p->stream); + p->stream << "\b)"; + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto r = Downcast(ref); + p->stream << "TupleShardSpec(" + << (r->immutable ? "Immut" : "") + << ")"; }); TVM_REGISTER_NODE_TYPE(ShardOpAttrs); } // namespace sharding -} // namespace mnm \ No newline at end of file +} // namespace mnm diff --git a/src/pass/init_shard_op_attrs.cc b/src/pass/init_shard_op_attrs.cc index 350929bf..5abb8eeb 100644 --- a/src/pass/init_shard_op_attrs.cc +++ b/src/pass/init_shard_op_attrs.cc @@ -24,12 +24,10 @@ class ShardAttrsInstaller : public ExprMutator { public: Expr VisitExpr_(const CallNode* node) override { const Expr& callee = node->op; - static auto default_spec = ShardSpec::make(false, true, - NullValue >(), - NullValue >(), - NullValue >()); + static auto default_spec = ReplicatedSpec::make(false); static auto default_attrs = make_object(); - default_attrs->shard_out = {default_spec}; + default_attrs->shard_in = default_spec; + default_attrs->shard_out = default_spec; if (callee->IsInstance()) { return Call(node->op, node->args, Attrs(default_attrs)); } diff --git a/tests/python/pass/test_pass_init_shard_op_attrs.py b/tests/python/pass/test_pass_init_shard_op_attrs.py index 7c39a410..0bab116f 100644 --- a/tests/python/pass/test_pass_init_shard_op_attrs.py +++ b/tests/python/pass/test_pass_init_shard_op_attrs.py @@ -14,7 +14,7 @@ def get_dist_device_array(dev_type="cuda"): if dev_type not in get_device_list(): - raise RuntimeError("Unsupported Device Type: " + dev_type) + raise RuntimeError("Non-existing Device Type: " + dev_type) dev_type_id = str2dev(dev_type).device_type dctx = dist.get_context() dev_array = [Device(dev_type_id, i) for i in range(dctx.size*16)] @@ -42,12 +42,13 @@ def forward(self, x): # mod_before = AutoDiff(record.requires_grads)(mod_before) # mod_before = InferType()(mod_before) # mod_before = GradientInputSelection()(mod_before) - mod = InitShardOpAttrs()(mod_before) - func_after = InferType()(mod)["main"] - print(func_after.astext()) + mod = InitShardOpAttrs()(mod_before)["main"] + print(mod.astext()) + #func_after = InferType()(mod)["main"] + #print(func_after.astext()) if __name__ == "__main__": # pytest.main([__file__]) test_shardOpAttrs() - # shardspec = ShardSpec(False, False, get_dist_device_array(), [4, 2], [2, 2]) - # print(shardspec) + # shardspec = ShardSpec(False, get_dist_device_array(), [8, 1], [2, 1]) + # print(shardspec) \ No newline at end of file