Skip to content

Commit

Permalink
implement
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonny-Gu committed Oct 26, 2021
1 parent cf306f7 commit 11d08d2
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 47 deletions.
7 changes: 3 additions & 4 deletions python/mnm/_op/imp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,9 @@ def _reduce_scatter(x):
return imp_utils.ret(ffi._reduce_scatter(x))

@set_module("mnm")
def _reshard(x, spec):
x = imp_utils.to_tensor(x)
spec = imp_utils.to_any(spec)
return imp_utils.ret(ffi._reshard(x, spec))
def _reshard(x):
x = imp_utils.to_any(x)
return imp_utils.ret(ffi._reshard(x))

@set_module("mnm")
def _reshard_r2s(x, spec):
Expand Down
7 changes: 3 additions & 4 deletions python/mnm/_op/sym.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,9 @@ def _reduce_scatter(x):
x = sym_utils.to_tensor_tuple(x)
return Symbol.from_expr(ffi._reduce_scatter(x))

def _reshard(x, spec):
x = sym_utils.to_tensor(x)
spec = sym_utils.to_any(spec)
return Symbol.from_expr(ffi._reshard(x, spec))
def _reshard(x):
x = sym_utils.to_any(x)
return Symbol.from_expr(ffi._reshard(x))

def _reshard_r2s(x, spec):
x = sym_utils.to_tensor(x)
Expand Down
2 changes: 1 addition & 1 deletion python/mnm/_tvm_op/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
def compute_reshard_r2s(attr, inputs, output_type):
# pylint: disable=unused-argument
x = inputs[0]
return [_topi.strided_slice(x, [0, 0], [1, 1])]
return [_topi.strided_slice(x, attr.begin, attr.end)]

_reg.register_injective_schedule("mnm.op.tvm._reshard_r2s")
3 changes: 2 additions & 1 deletion python/mnm/distributed/sharding/shardspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from mnm._core.core_utils import register_node
from mnm._ffi.sharding import _make
from mnm._lib import Object
from mnm._core.value import Value

@register_node("mnm.sharding.BaseShardSpec")
class BaseShardSpec(Object):
class BaseShardSpec(Value):
"""Base type of Sharding Specifications"""

@register_node("mnm.sharding.ReplicatedSpec")
Expand Down
42 changes: 28 additions & 14 deletions python/mnm/distributed/sharding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from mnm._ffi.op import GetOp
from mnm._lib import _register_func, relay
from mnm.distributed.sharding.shardspec import ReplicatedSpec, ShardSpec, TupleShardSpec
from mnm._core.value import Value
from mnm import distributed as dist
from tvm.relay import Call

Expand All @@ -30,8 +31,6 @@ def get_dist_devices():
"""
return dist.get_context().dist_devices

_expansion_patterns = {}

def always_apply(call: relay.Call):
"""Always apply this pattern to expand op call"""
return True
Expand All @@ -45,14 +44,21 @@ def expand_when(cond, priority=1):
A function answering this expansion pattern is eligible under particular conditions
(e.g. with particular sharding specifications)
"""
if not hasattr(expand_when, "counter"):
expand_when.counter = 0
if not hasattr(expand_when, "patterns"):
expand_when.patterns = {}

def decorator(pyfunc):
if not hasattr(pyfunc, "op_names"):
raise ValueError("Must register expansion pattern first")
for op_name in pyfunc.op_names:
op = GetOp(op_name) if op_name != "_fallback" else "_fallback"
if op not in _expansion_patterns:
_expansion_patterns[op] = PriorityQueue()
_expansion_patterns[op].put((-priority, cond, pyfunc))
if op not in expand_when.patterns:
expand_when.patterns[op] = PriorityQueue()
print((-priority, expand_when.counter, cond, pyfunc))
expand_when.patterns[op].put((-priority, expand_when.counter, cond, pyfunc))
expand_when.counter += 1
return pyfunc
return decorator

Expand Down Expand Up @@ -84,9 +90,9 @@ def extract_shardOpCall(call):
def expand_shardOpCall(call: relay.Call):
"""Match an eligible expansion pattern and return expanded IR expr"""
print("expand: ", call, call.attrs)
patterns = _expansion_patterns[call.op if call.op in _expansion_patterns else "_fallback"]
patterns = expand_when.patterns[call.op if call.op in expand_when.patterns else "_fallback"]
for pattern in patterns.queue:
_, cond, irgen = pattern
_, _, cond, irgen = pattern
print(cond(call))
if cond(call):
break
Expand All @@ -95,10 +101,19 @@ def expand_shardOpCall(call: relay.Call):
@expand_when(lambda call: isinstance(call.attrs.shard_in, ReplicatedSpec) and \
isinstance(call.attrs.shard_out, ShardSpec), priority=1)
@register_expansion_pattern("mnm.op._reshard")
def reshard_to_strided_slice(call: relay.Call):
"""_reshard -> strided_slice"""
# GetOp("mnm.op.strided_slice")
return call
def reshard_replicated_to_sharded(call: relay.Call):
"""_reshard -> _reshard_r2s (strided_slice)"""
_, args, _, sout = extract_shardOpCall(call)
spec = Value.as_const_expr(sout)
return relay.Call(GetOp("mnm.op._reshard_r2s"), [args[0], spec])

@expand_when(lambda call: isinstance(call.attrs.shard_in, ShardSpec) and \
isinstance(call.attrs.shard_out, ReplicatedSpec), priority=1)
@register_expansion_pattern("mnm.op._reshard")
def reshard_sharded_to_replicated(call: relay.Call):
"""_reshard -> _reshard_s2r (allgather)"""
_, args, sin, _ = extract_shardOpCall(call)
return relay.Call(GetOp("mnm.op._reshard_s2r"), [args[0], sin])

@expand_when(always_apply, priority=0)
@register_expansion_pattern("mnm.op._reshard")
Expand All @@ -112,9 +127,8 @@ def add_or_sub(call: relay.Call):
"""add/sub -> (reshard) add/sub"""
op, args, sin, sout = extract_shardOpCall(call)
if not sin[0] == sin[1] == sout:
args = [relay.Call(GetOp("mnm.op._reshard"),
[args[i]],
ShardOpAttrs(sin[i], sout)) for i in (0, 1)] + args[2:]
args = [relay.Call(GetOp("mnm.op._reshard"), [args[i]], ShardOpAttrs(sin[i], sout))
for i in (0, 1)] + args[2:]
return relay.Call(op, args)

@expand_when(always_apply)
Expand Down
7 changes: 3 additions & 4 deletions python/mnm/ir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,10 @@ def _reduce_scatter(x, attrs=None):
x = op_utils.to_tensor_tuple(x)
return relay.Call(op, [x], attrs)

def _reshard(x, spec, attrs=None):
def _reshard(x, attrs=None):
op = GetOp("mnm.op._reshard")
x = op_utils.to_tensor(x)
spec = op_utils.to_any(spec)
return relay.Call(op, [x, spec], attrs)
x = op_utils.to_any(x)
return relay.Call(op, [x], attrs)

def _reshard_r2s(x, spec, attrs=None):
op = GetOp("mnm.op._reshard_r2s")
Expand Down
3 changes: 2 additions & 1 deletion scripts/src_codegen/def_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,9 @@
Op(name="roi_align", schema_name="roi_align"),
Op(name="roi_align_dx", schema_name="roi_align_dx"),
# Sharding ops
Op(name="_reshard", schema_name="shard_unary"),
Op(name="_reshard", schema_name="unary"),
Op(name="_reshard_r2s", schema_name="shard_unary"),
Op(name="_reshard_s2r", schema_name="shard_unary"),
# Stream ops
Op(name="set_stream", schema_name="set_stream"),
Op(name="add_event", schema_name="event"),
Expand Down
32 changes: 23 additions & 9 deletions src/impl/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "../op/schema/ufunc.h"
#include "../op/schema/sharding.h"
#include "../op/dialect/tvm/tvm_utils.h"
#include "../op/dialect/tvm/tvm_attrs.h"
#include <string>

namespace mnm {
Expand Down Expand Up @@ -80,6 +81,7 @@ Attrs ShardOpAttrs::make(BaseShardSpec shard_in, BaseShardSpec shard_out) {
return Attrs(attrs);
}

/*
void GetSliceRange(const CallValues& call) {
const auto* args = call->args.as<ShardUnaryArgs>();
CHECK(args != nullptr);
Expand All @@ -105,7 +107,7 @@ void GetSliceRange(const CallValues& call) {
call->out = ir::NullValue<Value>();
}
call->callee = ir::NullValue<OpValue>();
}
}*/

void Reshard_R2S(const CallValues& call) {
const auto* args = call->args.as<ShardUnaryArgs>();
Expand All @@ -115,9 +117,9 @@ void Reshard_R2S(const CallValues& call) {
auto spec = Downcast<ShardSpec>(args->spec);
if (spec->_subgroup_idx.defined()) {
for (int64_t i = 0; i < x->ndim; ++i) {
auto num_subgroup = spec->grid_shape[i]->value;
CHECK_EQ(x->shape[i] % num_subgroup , 0) << "Currently automaic padding is unsupported.";
shape[i] /= num_subgroup;
auto grid_dim_size = spec->grid_shape[i]->value;
CHECK_EQ(x->shape[i] % grid_dim_size , 0) << "Currently automaic padding is unsupported.";
shape[i] /= grid_dim_size;
}
call->out = TensorValue::Assemble(/*dev=*/x->device,
/*dtype=*/x->dtype,
Expand All @@ -142,10 +144,10 @@ Type Reshard_R2S_Infer(const CallValues& call) {
std::vector<PrimExpr> oshape(ndim);
CHECK(spec->_subgroup_idx.defined());
for (int64_t i = 0; i < ndim; ++i) {
auto num_subgroup = spec->grid_shape[i]->value;
auto grid_dim_size = spec->grid_shape[i]->value;
auto dim_size = Downcast<IntImm>(dshape[i])->value;
CHECK_EQ(dim_size % num_subgroup, 0) << "Currently automaic padding is unsupported.";
oshape[i] = Integer(dim_size / num_subgroup);
CHECK_EQ(dim_size % grid_dim_size, 0) << "Currently automaic padding is unsupported.";
oshape[i] = Integer(dim_size / grid_dim_size);
}
return TensorType(oshape, data->dtype);
}
Expand Down Expand Up @@ -258,8 +260,20 @@ std::vector<std::string> ReshardSchemaArgNames(const op::CallValues& call) {
}

Attrs ReshardSchema2Attrs(const ShardUnaryArgs* args) {
auto attrs = make_object<ShardUnaryAttrs>();
attrs->spec = Downcast<ShardSpec>(args->spec);
auto attrs = make_object<StridedSliceAttrs>();
auto spec = Downcast<ShardSpec>(args->spec);
const DLTensor* x = args->x;
std::vector<Integer> begin(x->ndim);
std::vector<Integer> end(x->ndim);
CHECK(spec->_subgroup_idx.defined());
for (int i = 0; i < x->ndim; ++i) {
auto idx = spec->_subgroup_idx[i]->value;
auto size = spec->grid_shape[i]->value;
begin[i] = Integer((x->shape[i] / size) * idx);
end[i] = Integer((x->shape[i] / size) * (idx + 1));
}
attrs->begin = Array<Integer>(begin);
attrs->end = Array<Integer>(end);
return Attrs(attrs);
}

Expand Down
12 changes: 5 additions & 7 deletions src/op/regs/regs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1389,10 +1389,8 @@ MNM_REGISTER_GLOBAL("mnm.op.imp._reduce_scatter").set_body([](TVMArgs args, TVMR
});

MNM_REGISTER_GLOBAL("mnm.op.imp._reshard").set_body([](TVMArgs args, TVMRetValue* ret) {
MNM_PRELUDE(_reshard, 2, ffi2schema::ShardUnary,
schema::ShardUnaryArgs); // NOLINT(whitespace/line_length)
MNM_SET_ENV(vpack->x[0], schema2value::Tensor(schema->x));
MNM_SET_ENV(vpack->x[1], schema2value::ArrayLike(schema->spec));
MNM_PRELUDE(_reshard, 1, ffi2schema::Unary, schema::UnaryArgs); // NOLINT(whitespace/line_length)
MNM_SET_ENV(vpack->x[0], schema2value::ArrayLike(schema->x));
MNM_SET_ENV(vpack->y, value);
*ret = MNM_RET();
});
Expand Down Expand Up @@ -4269,7 +4267,7 @@ MNM_REGISTER_GLOBAL("mnm.op.sym._recv").set_body(MNM_SYMBOLIC_API(_recv, 4, Recv
MNM_REGISTER_GLOBAL("mnm.op.sym._reduce").set_body(MNM_SYMBOLIC_API(_reduce, 3, CommReduce));
MNM_REGISTER_GLOBAL("mnm.op.sym._reduce_scatter")
.set_body(MNM_SYMBOLIC_API(_reduce_scatter, 1, ReduceScatter));
MNM_REGISTER_GLOBAL("mnm.op.sym._reshard").set_body(MNM_SYMBOLIC_API(_reshard, 2, ShardUnary));
MNM_REGISTER_GLOBAL("mnm.op.sym._reshard").set_body(MNM_SYMBOLIC_API(_reshard, 1, Unary));
MNM_REGISTER_GLOBAL("mnm.op.sym._reshard_r2s")
.set_body(MNM_SYMBOLIC_API(_reshard_r2s, 2, ShardUnary));
MNM_REGISTER_GLOBAL("mnm.op.sym._send").set_body(MNM_SYMBOLIC_API(_send, 3, Send));
Expand Down Expand Up @@ -7669,9 +7667,9 @@ MNM_BIND_SCHEMA("mnm.op._reduce_scatter", names::_reduce_scatter,
MNM_BIND_SCHEMA_FIELD_INDEX("mnm.op._reduce_scatter", names::_reduce_scatter,
schema_field_idx::ReduceScatter); // NOLINT(whitespace/line_length)
MNM_BIND_SCHEMA("mnm.op._reshard", names::_reshard,
value2schema::ShardUnary); // NOLINT(whitespace/line_length)
value2schema::Unary); // NOLINT(whitespace/line_length)
MNM_BIND_SCHEMA_FIELD_INDEX("mnm.op._reshard", names::_reshard,
schema_field_idx::ShardUnary); // NOLINT(whitespace/line_length)
schema_field_idx::Unary); // NOLINT(whitespace/line_length)
MNM_BIND_SCHEMA("mnm.op._reshard_r2s", names::_reshard_r2s,
value2schema::ShardUnary); // NOLINT(whitespace/line_length)
MNM_BIND_SCHEMA_FIELD_INDEX("mnm.op._reshard_r2s", names::_reshard_r2s,
Expand Down
4 changes: 2 additions & 2 deletions tests/python/pass/test_pass_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def get_global_device_list(dev_type="cuda"):
devices = get_global_device_list()
attrs = ShardOpAttrs(TupleShardSpec([ReplicatedSpec(), ReplicatedSpec()]),
ShardSpec(devices, [4, 4], [2, 2]))
mnm._reshard_r2s(m_x, ShardSpec(devices, [4, 4], [2, 2]))
return
#a = mnm._reshard_r2s(m_x, ShardSpec(devices, [4, 4], [1, 2]))
#print(a)
call_list = []
post_order_visit(mod_before["main"].body,
lambda op: call_list.append(op) if isinstance(op, relay.Call) else None)
Expand Down

0 comments on commit 11d08d2

Please sign in to comment.