From 8e9e83e7e4d119d9013dfbc24a9a2c30c1610718 Mon Sep 17 00:00:00 2001 From: NekoDaemon Date: Tue, 28 Jun 2022 08:31:11 +0000 Subject: [PATCH] impl --- include/raf/sharding.h | 16 +++++++++------- python/raf/distributed/sharding/shardspec.py | 4 ++-- src/impl/sharding.cc | 7 ++++--- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/include/raf/sharding.h b/include/raf/sharding.h index 7d3c1753..a9be6be3 100644 --- a/include/raf/sharding.h +++ b/include/raf/sharding.h @@ -26,11 +26,12 @@ class BaseShardSpec : public Value { RAF_OBJECT_REF(BaseShardSpec, Value, BaseShardSpecObj); }; -class ShardSpecObj : public BaseShardSpecObj { +class ShardSpecObj final : public BaseShardSpecObj { public: - Integer ndim_; - Integer nshard_; - Integer ngroup_; + bool mutable_; + int64_t ndim_; + int64_t nshard_; + int64_t ngroup_; Array ranks; Array logic_shape; Array logic_index_; @@ -40,6 +41,7 @@ class ShardSpecObj : public BaseShardSpecObj { Array subgroup_index_; void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("mutable", &mutable_); v->Visit("ndim", &ndim_); v->Visit("nshard", &nshard_); v->Visit("ngroup", &ngroup_); @@ -56,9 +58,9 @@ class ShardSpecObj : public BaseShardSpecObj { RAF_FINAL_OBJECT(ShardSpecObj, BaseShardSpecObj); }; -class ShardSpec : public BaseShardSpec { +class ShardSpec final : public BaseShardSpec { public: - static ShardSpec make(Array ranks, Array phy_shape, Array subgroup_shape); + static ShardSpec make(Array ranks, Array phy_shape, Array subgroup_shape, bool mutable_); static int64_t GetRankIdx(Array ranks); RAF_OBJECT_REF(ShardSpec, BaseShardSpec, ShardSpecObj); }; @@ -70,7 +72,7 @@ class UnsetShardSpecObj final : public BaseShardSpecObj { RAF_FINAL_OBJECT(UnsetShardSpecObj, BaseShardSpecObj); }; -class UnsetShardSpec : public BaseShardSpec { +class UnsetShardSpec final : public BaseShardSpec { public: static UnsetShardSpec make() { auto n = make_object(); diff --git a/python/raf/distributed/sharding/shardspec.py b/python/raf/distributed/sharding/shardspec.py index 0b064d50..314543ae 100644 --- a/python/raf/distributed/sharding/shardspec.py +++ b/python/raf/distributed/sharding/shardspec.py @@ -13,8 +13,8 @@ class BaseShardSpec(Value): class ShardSpec(BaseShardSpec): """Sharding Specifications""" - def __init__(self, ranks, phy_shape, subgroup_shape): - self.__init_handle_by_constructor__(_make.ShardSpec, ranks, phy_shape, subgroup_shape) + def __init__(self, ranks, phy_shape, subgroup_shape, mutable): + self.__init_handle_by_constructor__(_make.ShardSpec, ranks, phy_shape, subgroup_shape, mutable) @register_node("raf.sharding.UnsetShardSpec") diff --git a/src/impl/sharding.cc b/src/impl/sharding.cc index a7d550c8..8c5ba00f 100644 --- a/src/impl/sharding.cc +++ b/src/impl/sharding.cc @@ -37,7 +37,7 @@ int64_t ShardSpec::GetRankIdx(Array ranks) { return -1; } -ShardSpec ShardSpec::make(Array ranks, Array phy_shape, Array subgroup_shape) { +ShardSpec ShardSpec::make(Array ranks, Array phy_shape, Array subgroup_shape, bool mutable_) { CHECK_EQ(phy_shape.size(), subgroup_shape.size()); auto ndim = phy_shape.size(); auto subgroup_index = std::vector(ndim); @@ -61,6 +61,7 @@ ShardSpec ShardSpec::make(Array ranks, Array phy_shape, Array< } auto spec = make_object(); + spec->mutable_ = mutable_; spec->ndim_ = ndim; spec->nshard_ = nshard; spec->ngroup_ = ngroup; @@ -148,7 +149,7 @@ using tvm::runtime::ObjectRef; std::string PrintAllocTable(const ObjectRef& ref) { size_t dev_idx = 0; const auto spec = Downcast(ref); - const auto ndim = spec->ndim_->value; + const auto ndim = spec->ndim_; std::stringstream ss; @@ -186,7 +187,7 @@ RAF_REGISTER_GLOBAL("raf.sharding.PrintAllocTable").set_body_typed(PrintAllocTab TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { auto r = Downcast(ref); - auto ndim = r->ndim_->value; + auto ndim = r->ndim_; if (r->nshard_ == 1) { p->stream << "ShardSpec(Mirrored)"; } else {