Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonny-Gu committed Mar 16, 2022
1 parent 8f590b8 commit bd00a6c
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 85 deletions.
23 changes: 11 additions & 12 deletions include/raf/sharding.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class ShardSpecObj final : public BaseShardSpecObj {
Array<Device> assigned_devices;
Array<Integer> grid_shape;
Array<Integer> subgroup_sizes;
Array<Integer> _subgroup_idx; // consider put it in ShardLocalContext
Array<Integer> _subgroup_idx; // consider put it in ShardLocalContext

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("immutable", &immutable);
v->Visit("assigned_devices", &assigned_devices);
Expand All @@ -68,10 +68,8 @@ class ShardSpecObj final : public BaseShardSpecObj {

class ShardSpec final : public BaseShardSpec {
public:
static ShardSpec make(bool immutable,
Array<Device> assigned_devices,
Array<Integer> partition_shape,
Array<Integer> subgroup_sizes);
static ShardSpec make(bool immutable, Array<Device> assigned_devices,
Array<Integer> partition_shape, Array<Integer> subgroup_sizes);
RAF_OBJECT_REF(ShardSpec, BaseShardSpec, ShardSpecObj);
};

Expand All @@ -89,19 +87,20 @@ class TupleShardSpecObj final : public BaseShardSpecObj {

class TupleShardSpec final : public BaseShardSpec {
public:
static TupleShardSpec make(bool immutable,
Array<BaseShardSpec> tuple_elem);
static TupleShardSpec make(bool immutable, Array<BaseShardSpec> tuple_elem);
RAF_OBJECT_REF(TupleShardSpec, BaseShardSpec, TupleShardSpecObj);
};

struct ShardOpCallAttrs : public tvm::AttrsNode<ShardOpCallAttrs> {
static Attrs make(BaseShardSpec shard_in, BaseShardSpec shard_out);
BaseShardSpec shard_in, shard_out;
TVM_DECLARE_ATTRS(ShardOpCallAttrs, "raf.attrs.ShardOpCallAttrs") {
TVM_ATTR_FIELD(shard_in).set_default(NullValue<BaseShardSpec>())
.describe("Sharding Specifications of inputs");
TVM_ATTR_FIELD(shard_out).set_default(NullValue<BaseShardSpec>())
.describe("Sharding Specifications of outputs");
TVM_ATTR_FIELD(shard_in)
.set_default(NullValue<BaseShardSpec>())
.describe("Sharding Specifications of inputs");
TVM_ATTR_FIELD(shard_out)
.set_default(NullValue<BaseShardSpec>())
.describe("Sharding Specifications of outputs");
}
};

Expand Down
4 changes: 3 additions & 1 deletion python/raf/_tvm_op/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@

_topi = _tvm.topi # pylint: disable=invalid-name,no-member


@register_compute("raf.op.tvm._reshard_r2s")
def compute_reshard_r2s(attr, inputs, output_type):
# pylint: disable=unused-argument
x = inputs[0]
return [_topi.strided_slice(x, attr.begin, attr.end)]

_reg.register_injective_schedule("raf.op.tvm._reshard_r2s")

_reg.register_injective_schedule("raf.op.tvm._reshard_r2s")
9 changes: 7 additions & 2 deletions python/raf/distributed/sharding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""RAF sharding system"""
from raf._ffi.sharding._make import ShardOpCallAttrs
from .shardspec import BaseShardSpec, ReplicatedSpec, TupleShardSpec, ShardSpec
from .utils import get_dist_devices, expand_when, always_apply, register_expansion_pattern, \
extract_shardOpCall
from .utils import (
get_dist_devices,
expand_when,
always_apply,
register_expansion_pattern,
extract_shardOpCall,
)
23 changes: 12 additions & 11 deletions python/raf/distributed/sharding/shardspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@
from raf._lib import Object
from raf._core.value import Value


@register_node("raf.sharding.BaseShardSpec")
class BaseShardSpec(Value):
"""Base type of Sharding Specifications"""


@register_node("raf.sharding.ReplicatedSpec")
class ReplicatedSpec(BaseShardSpec):
"""Annotation denoting every node has a copy of the data"""

def __init__(self, immutable=False):
self.__init_handle_by_constructor__(_make.ReplicatedSpec, immutable)


@register_node("raf.sharding.TupleShardSpec")
class TupleShardSpec(BaseShardSpec):
"""Annotation of a tuple that will usually be used
when having multiple input or output tensors"""
when having multiple input or output tensors"""

def __init__(self, tuple_elem, immutable=False):
assert isinstance(tuple_elem, list)
self.__init_handle_by_constructor__(_make.TupleShardSpec, immutable, tuple_elem)
Expand All @@ -29,16 +34,12 @@ def __getitem__(self, index: int):
def __len__(self):
return len(self.tuple_elem)


@register_node("raf.sharding.ShardSpec")
class ShardSpec(BaseShardSpec):
"""Generic annotation of Sharding Specifications"""
def __init__(self,
devices_in_grid,
grid_shape,
subgroup_sizes,
immutable=False):
self.__init_handle_by_constructor__(_make.ShardSpec,
immutable,
devices_in_grid,
grid_shape,
subgroup_sizes)

def __init__(self, devices_in_grid, grid_shape, subgroup_sizes, immutable=False):
self.__init_handle_by_constructor__(
_make.ShardSpec, immutable, devices_in_grid, grid_shape, subgroup_sizes
)
46 changes: 35 additions & 11 deletions python/raf/distributed/sharding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
7: "kTuple",
8: "kOpaque",
}
#TODO: this pattern map is replicated mulitple times in source code

# TODO: this pattern map is replicated mulitple times in source code


def get_dist_devices():
"""Return all available devices in the cluster as a list of Device Objects.
Expand All @@ -31,10 +32,12 @@ def get_dist_devices():
"""
return dist.get_context().dist_devices


def always_apply(call: relay.Call):
"""Always apply this pattern to expand op call"""
return True


def expand_when(cond, priority=1):
"""Specify the priority and the condition when this expansion pattern should be used.
Expand All @@ -59,8 +62,10 @@ def decorator(pyfunc):
expand_when.patterns[op].put((-priority, expand_when.counter, cond, pyfunc))
expand_when.counter += 1
return pyfunc

return decorator


def register_expansion_pattern(op_name):
"""Register an expansion pattern that converts a full-sized op into a partitioned-size op
Expand All @@ -76,15 +81,19 @@ def decorator(pyfunc):
@functools.wraps(pyfunc)
def new_pyfunc(call: relay.Call):
return pyfunc(call)

setattr(new_pyfunc, "op_names", op_names)
return new_pyfunc

return decorator


def extract_shardOpCall(call):
"""Return some frequently-used object attributes as a tuple"""
assert isinstance(call, relay.Call)
return (call.op, call.args, call.attrs.shard_in, call.attrs.shard_out)


@_register_func("raf.sharding._match_expansion_pattern")
def expand_shardOpCall(call: relay.Call):
"""Match an eligible expansion pattern and return expanded IR expr"""
Expand All @@ -95,47 +104,62 @@ def expand_shardOpCall(call: relay.Call):
break
return irgen(call)

@expand_when(lambda call: isinstance(call.attrs.shard_in, ReplicatedSpec) and \
isinstance(call.attrs.shard_out, ShardSpec), priority=1)

@expand_when(
lambda call: isinstance(call.attrs.shard_in, ReplicatedSpec)
and isinstance(call.attrs.shard_out, ShardSpec),
priority=1,
)
@register_expansion_pattern("raf.op._reshard")
def reshard_replicated_to_sharded(call: relay.Call):
"""_reshard -> _reshard_r2s (strided_slice)"""
_, args, _, sout = extract_shardOpCall(call)
spec = Value.as_const_expr(sout)
return relay.Call(GetOp("raf.op._reshard_r2s"), [args[0], spec])

@expand_when(lambda call: isinstance(call.attrs.shard_in, ShardSpec) and \
isinstance(call.attrs.shard_out, ReplicatedSpec), priority=1)

@expand_when(
lambda call: isinstance(call.attrs.shard_in, ShardSpec)
and isinstance(call.attrs.shard_out, ReplicatedSpec),
priority=1,
)
@register_expansion_pattern("raf.op._reshard")
def reshard_sharded_to_replicated(call: relay.Call):
"""_reshard -> _reshard_s2r (allgather)"""
_, args, sin, _ = extract_shardOpCall(call)
return relay.Call(GetOp("raf.op._reshard_s2r"), [args[0], sin])


@expand_when(always_apply, priority=0)
@register_expansion_pattern("raf.op._reshard")
def reshard_mismatch(call: relay.Call):
"""_reshard -> <error>"""
raise NotImplementedError("Unable to process the given sharding specifications")


@expand_when(always_apply)
@register_expansion_pattern(["raf.op.add", "raf.op.subtract"])
def add_or_sub(call: relay.Call):
"""add/sub -> (reshard) add/sub"""
op, args, sin, sout = extract_shardOpCall(call)
if not sin[0] == sin[1] == sout:
args = [relay.Call(GetOp("raf.op._reshard"), [args[i]], ShardOpCallAttrs(sin[i], sout))
for i in (0, 1)] + args[2:]
args = [
relay.Call(GetOp("raf.op._reshard"), [args[i]], ShardOpCallAttrs(sin[i], sout))
for i in (0, 1)
] + args[2:]
return relay.Call(op, args)


@expand_when(always_apply)
@register_expansion_pattern("_fallback")
def fallback_reshard_to_replicated(call: relay.Call):
"""Gather partitioned tensors for op without matched patterns"""
op, args, attrs = call.op, call.args, call.attrs
if len(args) != 1 or \
isinstance(attrs.shard_in, TupleShardSpec) or \
isinstance(attrs.shard_out, TupleShardSpec):
if (
len(args) != 1
or isinstance(attrs.shard_in, TupleShardSpec)
or isinstance(attrs.shard_out, TupleShardSpec)
):
raise NotImplementedError("Currently coverting multiple args is not supported")
new_attrs = ShardOpCallAttrs(attrs.shard_in, ReplicatedSpec())
new_args = [relay.Call(GetOp("raf.op._reshard"), args, new_attrs)]
Expand Down
2 changes: 1 addition & 1 deletion python/raf/hybrid/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def find_invoker_name(namespace) -> str:
float: Value.as_const_expr,
bool: Value.as_const_expr,
np.ndarray: Value.as_const_expr,
NDArray: lambda x: x._ndarray__handle # pylint: disable=protected-access
NDArray: lambda x: x._ndarray__handle, # pylint: disable=protected-access
}


Expand Down
35 changes: 13 additions & 22 deletions src/impl/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ ReplicatedSpec ReplicatedSpec::make(bool immutable) {
return ReplicatedSpec(n);
}

ShardSpec ShardSpec::make(bool immutable,
Array<Device> assigned_devices,
Array<Integer> partition_shape,
Array<Integer> subgroup_sizes) {
ShardSpec ShardSpec::make(bool immutable, Array<Device> assigned_devices,
Array<Integer> partition_shape, Array<Integer> subgroup_sizes) {
auto ndim = partition_shape.size();
CHECK_EQ(ndim, subgroup_sizes.size());
auto n = make_object<ShardSpecObj>();
Expand All @@ -49,7 +47,7 @@ ShardSpec ShardSpec::make(bool immutable,
device_rank = i;
break;
}
} // perhaps it is improper to calculate runtime data here
} // perhaps it is improper to calculate runtime data here

for (int64_t i = ndim - 1; i >= 0; --i) {
grid_shape[i] = partition_shape[i]->value / subgroup_sizes[i]->value;
Expand All @@ -61,13 +59,13 @@ ShardSpec ShardSpec::make(bool immutable,
n->assigned_devices = std::move(assigned_devices);
n->grid_shape = Array<Integer>(grid_shape.begin(), grid_shape.end());
n->subgroup_sizes = std::move(subgroup_sizes);
n->_subgroup_idx = (device_rank != -1) ? Array<Integer>(_subgroup_idx.begin(), _subgroup_idx.end()) :
NullValue<Array<Integer>>();
n->_subgroup_idx = (device_rank != -1)
? Array<Integer>(_subgroup_idx.begin(), _subgroup_idx.end())
: NullValue<Array<Integer>>();
return ShardSpec(n);
}

TupleShardSpec TupleShardSpec::make(bool immutable,
Array<BaseShardSpec> tuple_elem) {
TupleShardSpec TupleShardSpec::make(bool immutable, Array<BaseShardSpec> tuple_elem) {
auto n = make_object<TupleShardSpecObj>();
n->immutable = immutable;
n->tuple_elem = tuple_elem;
Expand All @@ -90,7 +88,7 @@ void Reshard_R2S(const CallValues& call) {
if (spec->_subgroup_idx.defined()) {
for (int64_t i = 0; i < x->ndim; ++i) {
auto grid_dim_size = spec->grid_shape[i]->value;
CHECK_EQ(x->shape[i] % grid_dim_size , 0) << "Currently automaic padding is unsupported.";
CHECK_EQ(x->shape[i] % grid_dim_size, 0) << "Currently automaic padding is unsupported.";
shape[i] /= grid_dim_size;
}
call->out = TensorValue::Assemble(/*dev=*/x->device,
Expand Down Expand Up @@ -170,17 +168,14 @@ void PrintAllocTable(const ObjectRef& ref, ReprPrinter* p) {
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ReplicatedSpecObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto r = Downcast<ReplicatedSpec>(ref);
p->stream << "ReplicatedSpec"
<< (r->immutable ? "(Immut)" : "");
p->stream << "ReplicatedSpec" << (r->immutable ? "(Immut)" : "");
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShardSpecObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto r = Downcast<ShardSpec>(ref);
auto ndim = r->grid_shape.size();
p->stream << "ShardSpec("
<< (r->immutable ? "Immut " : "")
<< "[";
p->stream << "ShardSpec(" << (r->immutable ? "Immut " : "") << "[";
for (size_t i = 0; i < ndim; ++i) {
auto grid_dim_size = r->grid_shape[i]->value;
auto subgroup_size = r->subgroup_sizes[i]->value;
Expand All @@ -194,17 +189,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleShardSpecObj>([](const ObjectRef& ref, ReprPrinter* p) {
auto r = Downcast<TupleShardSpec>(ref);
p->stream << "TupleShardSpec"
<< (r->immutable ? "(Immut)" : "");
p->stream << "TupleShardSpec" << (r->immutable ? "(Immut)" : "");
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShardOpCallAttrs>([](const ObjectRef& ref, ReprPrinter* p) {
const auto* n = static_cast<const ShardOpCallAttrs*>(ref.get());
p->stream << "ShardOpCallAttrs("
<< "in=" << n->shard_in
<< " out=" << n->shard_out
<< ")";
<< "in=" << n->shard_in << " out=" << n->shard_out << ")";
});

TVM_REGISTER_NODE_TYPE(ShardOpCallAttrs);
Expand Down Expand Up @@ -252,8 +244,7 @@ HashKey ReshardHasher(const std::vector<Type>& param_types, const Type& y_type,
HashKey key = GenericHasher<nullptr_t>(param_types, y_type, nullptr);
auto spec = Downcast<ShardSpec>(args->spec);
for (auto i : spec->assigned_devices) {
key << i->device_id
<< i->device_type.operator int();
key << i->device_id << i->device_type.operator int();
}
for (auto i : spec->grid_shape) {
key << i->value;
Expand Down
Loading

0 comments on commit bd00a6c

Please sign in to comment.