Skip to content

Commit

Permalink
impl
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonny-Gu committed Jun 28, 2022
1 parent 5394633 commit 8e9e83e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
16 changes: 9 additions & 7 deletions include/raf/sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ class BaseShardSpec : public Value {
RAF_OBJECT_REF(BaseShardSpec, Value, BaseShardSpecObj);
};

class ShardSpecObj : public BaseShardSpecObj {
class ShardSpecObj final : public BaseShardSpecObj {
public:
Integer ndim_;
Integer nshard_;
Integer ngroup_;
bool mutable_;
int64_t ndim_;
int64_t nshard_;
int64_t ngroup_;
Array<Integer> ranks;
Array<Integer> logic_shape;
Array<Integer> logic_index_;
Expand All @@ -40,6 +41,7 @@ class ShardSpecObj : public BaseShardSpecObj {
Array<Integer> subgroup_index_;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("mutable", &mutable_);
v->Visit("ndim", &ndim_);
v->Visit("nshard", &nshard_);
v->Visit("ngroup", &ngroup_);
Expand All @@ -56,9 +58,9 @@ class ShardSpecObj : public BaseShardSpecObj {
RAF_FINAL_OBJECT(ShardSpecObj, BaseShardSpecObj);
};

class ShardSpec : public BaseShardSpec {
class ShardSpec final : public BaseShardSpec {
public:
static ShardSpec make(Array<Integer> ranks, Array<Integer> phy_shape, Array<Integer> subgroup_shape);
static ShardSpec make(Array<Integer> ranks, Array<Integer> phy_shape, Array<Integer> subgroup_shape, bool mutable_);
static int64_t GetRankIdx(Array<Integer> ranks);
RAF_OBJECT_REF(ShardSpec, BaseShardSpec, ShardSpecObj);
};
Expand All @@ -70,7 +72,7 @@ class UnsetShardSpecObj final : public BaseShardSpecObj {
RAF_FINAL_OBJECT(UnsetShardSpecObj, BaseShardSpecObj);
};

class UnsetShardSpec : public BaseShardSpec {
class UnsetShardSpec final : public BaseShardSpec {
public:
static UnsetShardSpec make() {
auto n = make_object<UnsetShardSpecObj>();
Expand Down
4 changes: 2 additions & 2 deletions python/raf/distributed/sharding/shardspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ class BaseShardSpec(Value):
class ShardSpec(BaseShardSpec):
"""Sharding Specifications"""

def __init__(self, ranks, phy_shape, subgroup_shape):
self.__init_handle_by_constructor__(_make.ShardSpec, ranks, phy_shape, subgroup_shape)
def __init__(self, ranks, phy_shape, subgroup_shape, mutable):
self.__init_handle_by_constructor__(_make.ShardSpec, ranks, phy_shape, subgroup_shape, mutable)


@register_node("raf.sharding.UnsetShardSpec")
Expand Down
7 changes: 4 additions & 3 deletions src/impl/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ int64_t ShardSpec::GetRankIdx(Array<Integer> ranks) {
return -1;
}

ShardSpec ShardSpec::make(Array<Integer> ranks, Array<Integer> phy_shape, Array<Integer> subgroup_shape) {
ShardSpec ShardSpec::make(Array<Integer> ranks, Array<Integer> phy_shape, Array<Integer> subgroup_shape, bool mutable_) {
CHECK_EQ(phy_shape.size(), subgroup_shape.size());
auto ndim = phy_shape.size();
auto subgroup_index = std::vector<Integer>(ndim);
Expand All @@ -61,6 +61,7 @@ ShardSpec ShardSpec::make(Array<Integer> ranks, Array<Integer> phy_shape, Array<
}

auto spec = make_object<ShardSpecObj>();
spec->mutable_ = mutable_;
spec->ndim_ = ndim;
spec->nshard_ = nshard;
spec->ngroup_ = ngroup;
Expand Down Expand Up @@ -148,7 +149,7 @@ using tvm::runtime::ObjectRef;
std::string PrintAllocTable(const ObjectRef& ref) {
size_t dev_idx = 0;
const auto spec = Downcast<ShardSpec>(ref);
const auto ndim = spec->ndim_->value;
const auto ndim = spec->ndim_;

std::stringstream ss;

Expand Down Expand Up @@ -186,7 +187,7 @@ RAF_REGISTER_GLOBAL("raf.sharding.PrintAllocTable").set_body_typed(PrintAllocTab
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShardSpecObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto r = Downcast<ShardSpec>(ref);
auto ndim = r->ndim_->value;
auto ndim = r->ndim_;
if (r->nshard_ == 1) {
p->stream << "ShardSpec(Mirrored)";
} else {
Expand Down

0 comments on commit 8e9e83e

Please sign in to comment.