Skip to content

Commit

Permalink
extend shardspec family
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonny-Gu committed Aug 21, 2021
1 parent 38d46bd commit 4a50d9d
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 52 deletions.
101 changes: 66 additions & 35 deletions include/mnm/sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,59 +13,85 @@ namespace sharding {
using namespace mnm::ir;
using namespace mnm::value;

/* Sharding Specifications */
class ShardSpecObj : public Object {
/* BaseShardSpec */
class BaseShardSpecObj : public Object {
public:
bool immutable;
bool replicated;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("immutable", &immutable);
}
static constexpr const uint32_t _type_index = ir::TypeIndex::kDynamic;
static constexpr const char* _type_key = "mnm.sharding.BaseShardSpec";
MNM_BASE_OBJECT(BaseShardSpecObj, Object);
};

class BaseShardSpec : public ObjectRef {
public:
MNM_OBJECT_REF(BaseShardSpec, ObjectRef, BaseShardSpecObj);
};

/* ReplicatedSpec */
class ReplicatedSpecObj final : public BaseShardSpecObj {
public:
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("immutable", &immutable);
}
static constexpr const char* _type_key = "mnm.sharding.ReplicatedSpec";
MNM_FINAL_OBJECT(ReplicatedSpecObj, BaseShardSpecObj);
};

class ReplicatedSpec final : public BaseShardSpec {
public:
static ReplicatedSpec make(bool immutable);
MNM_OBJECT_REF(ReplicatedSpec, BaseShardSpec, ReplicatedSpecObj);
};

/* ShardSpec */
class ShardSpecObj final : public BaseShardSpecObj {
public:
Array<Device> assigned_devices;
Array<Integer> num_devices_on_dim;
Array<Integer> num_replicas_on_dim;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("immutable", &immutable);
v->Visit("replicated", &replicated);
v->Visit("assigned_devices", &assigned_devices);
v->Visit("num_devices_on_dim", &num_devices_on_dim);
v->Visit("num_replicas_on_dim", &num_replicas_on_dim);
}

static constexpr const uint32_t _type_index = tvm::TypeIndex::kDynamic;
static constexpr const char* _type_key = "mnm.sharding.ShardSpec";

MNM_FINAL_OBJECT(ShardSpecObj, Object);
MNM_FINAL_OBJECT(ShardSpecObj, BaseShardSpecObj);
};

class ShardSpec : public ObjectRef {
class ShardSpec final : public BaseShardSpec {
public:
static ShardSpec make(bool immutable, bool replicated,
static ShardSpec make(bool immutable,
Array<Device> assigned_devices,
Array<Integer> num_devices_on_dim,
Array<Integer> num_replicas_on_dim);

const void print_alloc_table(std::ostream& ostream = std::cout) const {
void printAllocTable(std::ostream& out = std::cout) const {
size_t dev_idx = 0;
const auto obj = this->operator->();
const auto num_dim = obj->num_devices_on_dim.size();
static thread_local size_t *indices = new size_t[num_dim];
std::function<void(int)> _print_alloc_table;
_print_alloc_table = [&](int depth) {
if (depth == num_dim) {
ostream << "[";
out << "[";
for (size_t i = 0; i < num_dim; ++i) {
auto num_devices = obj->num_devices_on_dim[i]->value;
auto num_replicas = obj->num_replicas_on_dim.defined() ?
obj->num_replicas_on_dim[i]->value : 1;
auto num_replicas = obj->num_replicas_on_dim[i]->value;
if (num_devices == 1) {
ostream << ":, ";
out << ":, ";
} else {
auto index = indices[i] / num_replicas;
ostream << index << ", ";
out << index << ", ";
}
}
auto dev_info = obj->assigned_devices[dev_idx++].c_str();
ostream << "\b\b]@" << dev_info << " ";
out << "\b\b]@" << dev_info << " ";
} else {
auto num_devices = obj->num_devices_on_dim[depth]->value;
for (size_t i = 0; i < num_devices; ++i) {
Expand All @@ -77,32 +103,37 @@ class ShardSpec : public ObjectRef {
_print_alloc_table(0);
}

const char* c_str() const {
static thread_local char buf[2048];
std::stringstream sstream;
const auto obj = this->operator->();
sstream.clear();
sstream << (obj->immutable ? "Immutable " : "");
if (obj->replicated) {
sstream << "Replicated ";
} else {
print_alloc_table(sstream);
sstream << "\b";
}
strncpy(buf, sstream.str().c_str(), 2048);
return buf;
MNM_OBJECT_REF(ShardSpec, BaseShardSpec, ShardSpecObj);
};

/* TupleShardSpec */
class TupleShardSpecObj final : public BaseShardSpecObj {
public:
Array<BaseShardSpec> tuple_elem;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("immutable", &immutable);
v->Visit("tuple_elem", &tuple_elem);
}
static constexpr const char* _type_key = "mnm.sharding.TupleShardSpec";
MNM_FINAL_OBJECT(TupleShardSpecObj, BaseShardSpecObj);
};

MNM_OBJECT_REF(ShardSpec, ObjectRef, ShardSpecObj);
class TupleShardSpec final : public BaseShardSpec {
public:
static TupleShardSpec make(bool immutable,
Array<BaseShardSpec> tuple_elem);
MNM_OBJECT_REF(TupleShardSpec, BaseShardSpec, TupleShardSpecObj);
};

struct ShardOpAttrs : public tvm::AttrsNode<ShardOpAttrs> {
Array<ShardSpec> shard_out;
BaseShardSpec shard_in, shard_out;
TVM_DECLARE_ATTRS(ShardOpAttrs, "mnm.attrs.ShardOpAttrs") {
TVM_ATTR_FIELD(shard_out).set_default(NullValue<Array<ShardSpec> >())
TVM_ATTR_FIELD(shard_in).set_default(NullValue<BaseShardSpec>())
.describe("Sharding Specifications of inputs");
TVM_ATTR_FIELD(shard_out).set_default(NullValue<BaseShardSpec>())
.describe("Sharding Specifications of outputs");
}
};

} // namespace sharding
} // namespace mnm
} // namespace mnm
2 changes: 2 additions & 0 deletions python/mnm/_ffi/sharding/_make/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
# pylint: disable=redefined-builtin,line-too-long
# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
from __future__ import absolute_import
from ._internal import ReplicatedSpec
from ._internal import ShardSpec
from ._internal import TupleShardSpec
4 changes: 4 additions & 0 deletions python/mnm/_ffi/sharding/_make/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,8 @@
"""Auto generated. Do not touch."""
from mnm._lib import _APIS
# Defined in ./src/impl/sharding.cc
ReplicatedSpec = _APIS.get("mnm.sharding._make.ReplicatedSpec", None)
# Defined in ./src/impl/sharding.cc
ShardSpec = _APIS.get("mnm.sharding._make.ShardSpec", None)
# Defined in ./src/impl/sharding.cc
TupleShardSpec = _APIS.get("mnm.sharding._make.TupleShardSpec", None)
51 changes: 45 additions & 6 deletions src/impl/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,70 @@ namespace sharding {
using namespace mnm::ir;
using namespace mnm::value;

ShardSpec ShardSpec::make(bool immutable, bool replicated,
ReplicatedSpec ReplicatedSpec::make(bool immutable) {
auto n = make_object<ReplicatedSpecObj>();
n->immutable = immutable;
return ReplicatedSpec(n);
}

ShardSpec ShardSpec::make(bool immutable,
Array<Device> assigned_devices,
Array<Integer> num_devices_on_dim,
Array<Integer> num_replicas_on_dim) {
ObjectPtr<ShardSpecObj> n = make_object<ShardSpecObj>();
auto n = make_object<ShardSpecObj>();
n->immutable = immutable;
n->replicated = replicated;
n->assigned_devices = std::move(assigned_devices);
n->num_devices_on_dim = std::move(num_devices_on_dim);
n->num_replicas_on_dim = std::move(num_replicas_on_dim);
return ShardSpec(n);
}

TupleShardSpec TupleShardSpec::make(bool immutable,
Array<BaseShardSpec> tuple_elem) {
auto n = make_object<TupleShardSpecObj>();
n->immutable = immutable;
n->tuple_elem = tuple_elem;
return TupleShardSpec(n);
}

MNM_REGISTER_GLOBAL("mnm.sharding._make.ReplicatedSpec").set_body_typed(ReplicatedSpec::make);
MNM_REGISTER_GLOBAL("mnm.sharding._make.ShardSpec").set_body_typed(ShardSpec::make);
MNM_REGISTER_GLOBAL("mnm.sharding._make.TupleShardSpec").set_body_typed(TupleShardSpec::make);

MNM_REGISTER_OBJECT_NO_REFLECT(BaseShardSpecObj);
MNM_REGISTER_OBJECT_REFLECT(ReplicatedSpecObj);
MNM_REGISTER_OBJECT_REFLECT(ShardSpecObj);
MNM_REGISTER_OBJECT_REFLECT(TupleShardSpecObj);

using tvm::ReprPrinter;
using tvm::runtime::ObjectRef;

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ReplicatedSpecObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto r = Downcast<ReplicatedSpec>(ref);
p->stream << "ReplicatedSpec("
<< (r->immutable ? "Immut" : "")
<< ")";
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShardSpecObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto* n = static_cast<const ShardSpecObj*>(ref.get());
p->stream << "ShardSpec(" << GetRef<ShardSpec>(n).c_str() << ")";
auto r = Downcast<ShardSpec>(ref);
p->stream << "ShardSpec("
<< (r->immutable ? "Immut " : "");
r.printAllocTable(p->stream);
p->stream << "\b)";
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleShardSpecObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto r = Downcast<TupleShardSpec>(ref);
p->stream << "TupleShardSpec("
<< (r->immutable ? "Immut" : "")
<< ")";
});

TVM_REGISTER_NODE_TYPE(ShardOpAttrs);

} // namespace sharding
} // namespace mnm
} // namespace mnm
8 changes: 3 additions & 5 deletions src/pass/init_shard_op_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@ class ShardAttrsInstaller : public ExprMutator {
public:
Expr VisitExpr_(const CallNode* node) override {
const Expr& callee = node->op;
static auto default_spec = ShardSpec::make(false, true,
NullValue<Array<Device> >(),
NullValue<Array<Integer> >(),
NullValue<Array<Integer> >());
static auto default_spec = ReplicatedSpec::make(false);
static auto default_attrs = make_object<ShardOpAttrs>();
default_attrs->shard_out = {default_spec};
default_attrs->shard_in = default_spec;
default_attrs->shard_out = default_spec;
if (callee->IsInstance<OpNode>()) {
return Call(node->op, node->args, Attrs(default_attrs));
}
Expand Down
13 changes: 7 additions & 6 deletions tests/python/pass/test_pass_init_shard_op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

def get_dist_device_array(dev_type="cuda"):
if dev_type not in get_device_list():
raise RuntimeError("Unsupported Device Type: " + dev_type)
raise RuntimeError("Non-existing Device Type: " + dev_type)
dev_type_id = str2dev(dev_type).device_type
dctx = dist.get_context()
dev_array = [Device(dev_type_id, i) for i in range(dctx.size*16)]
Expand Down Expand Up @@ -42,12 +42,13 @@ def forward(self, x):
# mod_before = AutoDiff(record.requires_grads)(mod_before)
# mod_before = InferType()(mod_before)
# mod_before = GradientInputSelection()(mod_before)
mod = InitShardOpAttrs()(mod_before)
func_after = InferType()(mod)["main"]
print(func_after.astext())
mod = InitShardOpAttrs()(mod_before)["main"]
print(mod.astext())
#func_after = InferType()(mod)["main"]
#print(func_after.astext())

if __name__ == "__main__":
# pytest.main([__file__])
test_shardOpAttrs()
# shardspec = ShardSpec(False, False, get_dist_device_array(), [4, 2], [2, 2])
# print(shardspec)
# shardspec = ShardSpec(False, get_dist_device_array(), [8, 1], [2, 1])
# print(shardspec)

0 comments on commit 4a50d9d

Please sign in to comment.