Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonny-Gu committed Nov 9, 2021
1 parent 61ae89b commit 5989fe1
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 117 deletions.
78 changes: 39 additions & 39 deletions python/mnm/_op/imp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,45 @@

__all__ = [
"_allgather", "_allreduce", "_broadcast", "_contrib_dropout", "_contrib_dropout_dx",
"_recv", "_reduce", "_reduce_scatter", "_send", "abs",
"adaptive_avg_pool2d", "adaptive_avg_pool2d_dx", "adaptive_max_pool2d", "adaptive_max_pool2d_dx", "add",
"add_event", "adv_index", "adv_index_dx", "all", "any",
"arange", "argmax", "argmin", "argsort", "argwhere",
"atan", "avg_pool2d", "avg_pool2d_dx", "batch_flatten", "batch_matmul",
"batch_matmul_nt", "batch_matmul_tn", "batch_matmul_tt", "batch_norm_infer", "batch_norm_train",
"batch_norm_train_dxwb", "bias_add", "broadcast_to", "broadcast_to_like", "cast",
"cast_like", "ceil", "clip", "clip_dx", "collapse_sum_like",
"compiler_begin", "compiler_end", "concatenate", "concatenate_dx", "conv2d",
"conv2d_dw", "conv2d_dx", "conv2d_transpose", "conv2d_transpose_dw", "conv2d_transpose_dx",
"copy", "cos", "cross_entropy", "cross_entropy_dpred", "cross_entropy_dtrue",
"cumsum", "dense", "device_copy", "divide", "embedding",
"embedding_dx", "equal", "erf", "erf_dx", "exp",
"expand_dims", "floor", "floor_divide", "full", "full_like",
"gather", "gather_dx", "gather_nd", "gather_nd_dx", "gelu",
"gelu_dx", "get_kept_dims", "get_reduce_axis", "get_valid_counts", "greater",
"greater_equal", "l2norm", "layer_norm", "layer_norm_dx", "left_shift",
"less", "less_equal", "log", "log2", "log_softmax",
"log_softmax_dx", "logical_and", "logical_not", "matmul", "matmul_nt",
"matmul_tn", "matmul_tt", "max", "max_pool2d", "max_pool2d_dx",
"maximum", "mean", "mean_dx", "mesh_grid", "min",
"minimum", "mod", "multiply", "ndarray_size", "negative",
"nll_loss", "nll_loss_dpred", "nll_loss_dtrue", "non_max_suppression", "not_equal",
"numel", "one_hot", "ones", "ones_like", "pad",
"power", "prod", "prod_dx", "relu", "relu_dx",
"repeat", "repeat_dx", "reshape", "resize2d", "resize2d_dx",
"reverse", "reverse_sequence", "right_shift", "roi_align", "roi_align_dx",
"round", "rsqrt", "scatter", "scatter_dx", "sequence_mask",
"set_stream", "sgd", "shape", "shape_as_tensor", "sigmoid",
"sigmoid_dx", "sign", "sin", "size", "smooth_l1_loss",
"smooth_l1_loss_dpred", "smooth_l1_loss_dtrue", "softmax", "softmax_dx", "sort",
"split", "sqrt", "sqrt_dx", "squeeze", "stack",
"stream_barrier", "stream_sync", "strided_slice", "strided_slice_dx", "subtract",
"sum", "sum_dx", "swap_axis", "take", "take_dx",
"tanh", "tanh_dx", "threefry_generate", "threefry_split", "threshold",
"threshold_dx", "topk", "transpose", "transpose_dx", "trunc",
"upper_bound_argwhere", "vm_alloc_storage", "vm_alloc_tensor", "vm_free", "vm_infer_type",
"vm_invoke_op", "vm_set_shape", "wait_event", "where", "zeros",
"zeros_like",
"_recv", "_reduce", "_reduce_scatter", "_reshard", "_reshard_r2s",
"_reshard_s2r", "_send", "abs", "adaptive_avg_pool2d", "adaptive_avg_pool2d_dx",
"adaptive_max_pool2d", "adaptive_max_pool2d_dx", "add", "add_event", "adv_index",
"adv_index_dx", "all", "any", "arange", "argmax",
"argmin", "argsort", "argwhere", "atan", "avg_pool2d",
"avg_pool2d_dx", "batch_flatten", "batch_matmul", "batch_matmul_nt", "batch_matmul_tn",
"batch_matmul_tt", "batch_norm_infer", "batch_norm_train", "batch_norm_train_dxwb", "bias_add",
"broadcast_to", "broadcast_to_like", "cast", "cast_like", "ceil",
"clip", "clip_dx", "collapse_sum_like", "compiler_begin", "compiler_end",
"concatenate", "concatenate_dx", "conv2d", "conv2d_dw", "conv2d_dx",
"conv2d_transpose", "conv2d_transpose_dw", "conv2d_transpose_dx", "copy", "cos",
"cross_entropy", "cross_entropy_dpred", "cross_entropy_dtrue", "cumsum", "dense",
"device_copy", "divide", "embedding", "embedding_dx", "equal",
"erf", "erf_dx", "exp", "expand_dims", "floor",
"floor_divide", "full", "full_like", "gather", "gather_dx",
"gather_nd", "gather_nd_dx", "gelu", "gelu_dx", "get_kept_dims",
"get_reduce_axis", "get_valid_counts", "greater", "greater_equal", "l2norm",
"layer_norm", "layer_norm_dx", "left_shift", "less", "less_equal",
"log", "log2", "log_softmax", "log_softmax_dx", "logical_and",
"logical_not", "matmul", "matmul_nt", "matmul_tn", "matmul_tt",
"max", "max_pool2d", "max_pool2d_dx", "maximum", "mean",
"mean_dx", "mesh_grid", "min", "minimum", "mod",
"multiply", "ndarray_size", "negative", "nll_loss", "nll_loss_dpred",
"nll_loss_dtrue", "non_max_suppression", "not_equal", "numel", "one_hot",
"ones", "ones_like", "pad", "power", "prod",
"prod_dx", "relu", "relu_dx", "repeat", "repeat_dx",
"reshape", "resize2d", "resize2d_dx", "reverse", "reverse_sequence",
"right_shift", "roi_align", "roi_align_dx", "round", "rsqrt",
"scatter", "scatter_dx", "sequence_mask", "set_stream", "sgd",
"shape", "shape_as_tensor", "sigmoid", "sigmoid_dx", "sign",
"sin", "size", "smooth_l1_loss", "smooth_l1_loss_dpred", "smooth_l1_loss_dtrue",
"softmax", "softmax_dx", "sort", "split", "sqrt",
"sqrt_dx", "squeeze", "stack", "stream_barrier", "stream_sync",
"strided_slice", "strided_slice_dx", "subtract", "sum", "sum_dx",
"swap_axis", "take", "take_dx", "tanh", "tanh_dx",
"threefry_generate", "threefry_split", "threshold", "threshold_dx", "topk",
"transpose", "transpose_dx", "trunc", "upper_bound_argwhere", "vm_alloc_storage",
"vm_alloc_tensor", "vm_free", "vm_infer_type", "vm_invoke_op", "vm_set_shape",
"wait_event", "where", "zeros", "zeros_like",
]

@set_module("mnm")
Expand Down
78 changes: 39 additions & 39 deletions python/mnm/_op/sym.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,45 +9,45 @@

__all__ = [
"_allgather", "_allreduce", "_broadcast", "_contrib_dropout", "_contrib_dropout_dx",
"_recv", "_reduce", "_reduce_scatter", "_send", "abs",
"adaptive_avg_pool2d", "adaptive_avg_pool2d_dx", "adaptive_max_pool2d", "adaptive_max_pool2d_dx", "add",
"add_event", "adv_index", "adv_index_dx", "all", "any",
"arange", "argmax", "argmin", "argsort", "argwhere",
"atan", "avg_pool2d", "avg_pool2d_dx", "batch_flatten", "batch_matmul",
"batch_matmul_nt", "batch_matmul_tn", "batch_matmul_tt", "batch_norm_infer", "batch_norm_train",
"batch_norm_train_dxwb", "bias_add", "broadcast_to", "broadcast_to_like", "cast",
"cast_like", "ceil", "clip", "clip_dx", "collapse_sum_like",
"compiler_begin", "compiler_end", "concatenate", "concatenate_dx", "conv2d",
"conv2d_dw", "conv2d_dx", "conv2d_transpose", "conv2d_transpose_dw", "conv2d_transpose_dx",
"copy", "cos", "cross_entropy", "cross_entropy_dpred", "cross_entropy_dtrue",
"cumsum", "dense", "device_copy", "divide", "embedding",
"embedding_dx", "equal", "erf", "erf_dx", "exp",
"expand_dims", "floor", "floor_divide", "full", "full_like",
"gather", "gather_dx", "gather_nd", "gather_nd_dx", "gelu",
"gelu_dx", "get_kept_dims", "get_reduce_axis", "get_valid_counts", "greater",
"greater_equal", "l2norm", "layer_norm", "layer_norm_dx", "left_shift",
"less", "less_equal", "log", "log2", "log_softmax",
"log_softmax_dx", "logical_and", "logical_not", "matmul", "matmul_nt",
"matmul_tn", "matmul_tt", "max", "max_pool2d", "max_pool2d_dx",
"maximum", "mean", "mean_dx", "mesh_grid", "min",
"minimum", "mod", "multiply", "ndarray_size", "negative",
"nll_loss", "nll_loss_dpred", "nll_loss_dtrue", "non_max_suppression", "not_equal",
"numel", "one_hot", "ones", "ones_like", "pad",
"power", "prod", "prod_dx", "relu", "relu_dx",
"repeat", "repeat_dx", "reshape", "resize2d", "resize2d_dx",
"reverse", "reverse_sequence", "right_shift", "roi_align", "roi_align_dx",
"round", "rsqrt", "scatter", "scatter_dx", "sequence_mask",
"set_stream", "sgd", "shape", "shape_as_tensor", "sigmoid",
"sigmoid_dx", "sign", "sin", "size", "smooth_l1_loss",
"smooth_l1_loss_dpred", "smooth_l1_loss_dtrue", "softmax", "softmax_dx", "sort",
"split", "sqrt", "sqrt_dx", "squeeze", "stack",
"stream_barrier", "stream_sync", "strided_slice", "strided_slice_dx", "subtract",
"sum", "sum_dx", "swap_axis", "take", "take_dx",
"tanh", "tanh_dx", "threefry_generate", "threefry_split", "threshold",
"threshold_dx", "topk", "transpose", "transpose_dx", "trunc",
"upper_bound_argwhere", "vm_alloc_storage", "vm_alloc_tensor", "vm_free", "vm_infer_type",
"vm_invoke_op", "vm_set_shape", "wait_event", "where", "zeros",
"zeros_like",
"_recv", "_reduce", "_reduce_scatter", "_reshard", "_reshard_r2s",
"_reshard_s2r", "_send", "abs", "adaptive_avg_pool2d", "adaptive_avg_pool2d_dx",
"adaptive_max_pool2d", "adaptive_max_pool2d_dx", "add", "add_event", "adv_index",
"adv_index_dx", "all", "any", "arange", "argmax",
"argmin", "argsort", "argwhere", "atan", "avg_pool2d",
"avg_pool2d_dx", "batch_flatten", "batch_matmul", "batch_matmul_nt", "batch_matmul_tn",
"batch_matmul_tt", "batch_norm_infer", "batch_norm_train", "batch_norm_train_dxwb", "bias_add",
"broadcast_to", "broadcast_to_like", "cast", "cast_like", "ceil",
"clip", "clip_dx", "collapse_sum_like", "compiler_begin", "compiler_end",
"concatenate", "concatenate_dx", "conv2d", "conv2d_dw", "conv2d_dx",
"conv2d_transpose", "conv2d_transpose_dw", "conv2d_transpose_dx", "copy", "cos",
"cross_entropy", "cross_entropy_dpred", "cross_entropy_dtrue", "cumsum", "dense",
"device_copy", "divide", "embedding", "embedding_dx", "equal",
"erf", "erf_dx", "exp", "expand_dims", "floor",
"floor_divide", "full", "full_like", "gather", "gather_dx",
"gather_nd", "gather_nd_dx", "gelu", "gelu_dx", "get_kept_dims",
"get_reduce_axis", "get_valid_counts", "greater", "greater_equal", "l2norm",
"layer_norm", "layer_norm_dx", "left_shift", "less", "less_equal",
"log", "log2", "log_softmax", "log_softmax_dx", "logical_and",
"logical_not", "matmul", "matmul_nt", "matmul_tn", "matmul_tt",
"max", "max_pool2d", "max_pool2d_dx", "maximum", "mean",
"mean_dx", "mesh_grid", "min", "minimum", "mod",
"multiply", "ndarray_size", "negative", "nll_loss", "nll_loss_dpred",
"nll_loss_dtrue", "non_max_suppression", "not_equal", "numel", "one_hot",
"ones", "ones_like", "pad", "power", "prod",
"prod_dx", "relu", "relu_dx", "repeat", "repeat_dx",
"reshape", "resize2d", "resize2d_dx", "reverse", "reverse_sequence",
"right_shift", "roi_align", "roi_align_dx", "round", "rsqrt",
"scatter", "scatter_dx", "sequence_mask", "set_stream", "sgd",
"shape", "shape_as_tensor", "sigmoid", "sigmoid_dx", "sign",
"sin", "size", "smooth_l1_loss", "smooth_l1_loss_dpred", "smooth_l1_loss_dtrue",
"softmax", "softmax_dx", "sort", "split", "sqrt",
"sqrt_dx", "squeeze", "stack", "stream_barrier", "stream_sync",
"strided_slice", "strided_slice_dx", "subtract", "sum", "sum_dx",
"swap_axis", "take", "take_dx", "tanh", "tanh_dx",
"threefry_generate", "threefry_split", "threshold", "threshold_dx", "topk",
"transpose", "transpose_dx", "trunc", "upper_bound_argwhere", "vm_alloc_storage",
"vm_alloc_tensor", "vm_free", "vm_infer_type", "vm_invoke_op", "vm_set_shape",
"wait_event", "where", "zeros", "zeros_like",
]

def _allgather(x, axis):
Expand Down
78 changes: 39 additions & 39 deletions python/mnm/ir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,45 @@

__all__ = [
"_allgather", "_allreduce", "_broadcast", "_contrib_dropout", "_contrib_dropout_dx",
"_recv", "_reduce", "_reduce_scatter", "_send", "abs",
"adaptive_avg_pool2d", "adaptive_avg_pool2d_dx", "adaptive_max_pool2d", "adaptive_max_pool2d_dx", "add",
"add_event", "adv_index", "adv_index_dx", "all", "any",
"arange", "argmax", "argmin", "argsort", "argwhere",
"atan", "avg_pool2d", "avg_pool2d_dx", "batch_flatten", "batch_matmul",
"batch_matmul_nt", "batch_matmul_tn", "batch_matmul_tt", "batch_norm_infer", "batch_norm_train",
"batch_norm_train_dxwb", "bias_add", "broadcast_to", "broadcast_to_like", "cast",
"cast_like", "ceil", "clip", "clip_dx", "collapse_sum_like",
"compiler_begin", "compiler_end", "concatenate", "concatenate_dx", "conv2d",
"conv2d_dw", "conv2d_dx", "conv2d_transpose", "conv2d_transpose_dw", "conv2d_transpose_dx",
"copy", "cos", "cross_entropy", "cross_entropy_dpred", "cross_entropy_dtrue",
"cumsum", "dense", "device_copy", "divide", "embedding",
"embedding_dx", "equal", "erf", "erf_dx", "exp",
"expand_dims", "floor", "floor_divide", "full", "full_like",
"gather", "gather_dx", "gather_nd", "gather_nd_dx", "gelu",
"gelu_dx", "get_kept_dims", "get_reduce_axis", "get_valid_counts", "greater",
"greater_equal", "l2norm", "layer_norm", "layer_norm_dx", "left_shift",
"less", "less_equal", "log", "log2", "log_softmax",
"log_softmax_dx", "logical_and", "logical_not", "matmul", "matmul_nt",
"matmul_tn", "matmul_tt", "max", "max_pool2d", "max_pool2d_dx",
"maximum", "mean", "mean_dx", "mesh_grid", "min",
"minimum", "mod", "multiply", "ndarray_size", "negative",
"nll_loss", "nll_loss_dpred", "nll_loss_dtrue", "non_max_suppression", "not_equal",
"numel", "one_hot", "ones", "ones_like", "pad",
"power", "prod", "prod_dx", "relu", "relu_dx",
"repeat", "repeat_dx", "reshape", "resize2d", "resize2d_dx",
"reverse", "reverse_sequence", "right_shift", "roi_align", "roi_align_dx",
"round", "rsqrt", "scatter", "scatter_dx", "sequence_mask",
"set_stream", "sgd", "shape", "shape_as_tensor", "sigmoid",
"sigmoid_dx", "sign", "sin", "size", "smooth_l1_loss",
"smooth_l1_loss_dpred", "smooth_l1_loss_dtrue", "softmax", "softmax_dx", "sort",
"split", "sqrt", "sqrt_dx", "squeeze", "stack",
"stream_barrier", "stream_sync", "strided_slice", "strided_slice_dx", "subtract",
"sum", "sum_dx", "swap_axis", "take", "take_dx",
"tanh", "tanh_dx", "threefry_generate", "threefry_split", "threshold",
"threshold_dx", "topk", "transpose", "transpose_dx", "trunc",
"upper_bound_argwhere", "vm_alloc_storage", "vm_alloc_tensor", "vm_free", "vm_infer_type",
"vm_invoke_op", "vm_set_shape", "wait_event", "where", "zeros",
"zeros_like",
"_recv", "_reduce", "_reduce_scatter", "_reshard", "_reshard_r2s",
"_reshard_s2r", "_send", "abs", "adaptive_avg_pool2d", "adaptive_avg_pool2d_dx",
"adaptive_max_pool2d", "adaptive_max_pool2d_dx", "add", "add_event", "adv_index",
"adv_index_dx", "all", "any", "arange", "argmax",
"argmin", "argsort", "argwhere", "atan", "avg_pool2d",
"avg_pool2d_dx", "batch_flatten", "batch_matmul", "batch_matmul_nt", "batch_matmul_tn",
"batch_matmul_tt", "batch_norm_infer", "batch_norm_train", "batch_norm_train_dxwb", "bias_add",
"broadcast_to", "broadcast_to_like", "cast", "cast_like", "ceil",
"clip", "clip_dx", "collapse_sum_like", "compiler_begin", "compiler_end",
"concatenate", "concatenate_dx", "conv2d", "conv2d_dw", "conv2d_dx",
"conv2d_transpose", "conv2d_transpose_dw", "conv2d_transpose_dx", "copy", "cos",
"cross_entropy", "cross_entropy_dpred", "cross_entropy_dtrue", "cumsum", "dense",
"device_copy", "divide", "embedding", "embedding_dx", "equal",
"erf", "erf_dx", "exp", "expand_dims", "floor",
"floor_divide", "full", "full_like", "gather", "gather_dx",
"gather_nd", "gather_nd_dx", "gelu", "gelu_dx", "get_kept_dims",
"get_reduce_axis", "get_valid_counts", "greater", "greater_equal", "l2norm",
"layer_norm", "layer_norm_dx", "left_shift", "less", "less_equal",
"log", "log2", "log_softmax", "log_softmax_dx", "logical_and",
"logical_not", "matmul", "matmul_nt", "matmul_tn", "matmul_tt",
"max", "max_pool2d", "max_pool2d_dx", "maximum", "mean",
"mean_dx", "mesh_grid", "min", "minimum", "mod",
"multiply", "ndarray_size", "negative", "nll_loss", "nll_loss_dpred",
"nll_loss_dtrue", "non_max_suppression", "not_equal", "numel", "one_hot",
"ones", "ones_like", "pad", "power", "prod",
"prod_dx", "relu", "relu_dx", "repeat", "repeat_dx",
"reshape", "resize2d", "resize2d_dx", "reverse", "reverse_sequence",
"right_shift", "roi_align", "roi_align_dx", "round", "rsqrt",
"scatter", "scatter_dx", "sequence_mask", "set_stream", "sgd",
"shape", "shape_as_tensor", "sigmoid", "sigmoid_dx", "sign",
"sin", "size", "smooth_l1_loss", "smooth_l1_loss_dpred", "smooth_l1_loss_dtrue",
"softmax", "softmax_dx", "sort", "split", "sqrt",
"sqrt_dx", "squeeze", "stack", "stream_barrier", "stream_sync",
"strided_slice", "strided_slice_dx", "subtract", "sum", "sum_dx",
"swap_axis", "take", "take_dx", "tanh", "tanh_dx",
"threefry_generate", "threefry_split", "threshold", "threshold_dx", "topk",
"transpose", "transpose_dx", "trunc", "upper_bound_argwhere", "vm_alloc_storage",
"vm_alloc_tensor", "vm_free", "vm_infer_type", "vm_invoke_op", "vm_set_shape",
"wait_event", "where", "zeros", "zeros_like",
]

def _allgather(x, axis, attrs=None):
Expand Down
40 changes: 40 additions & 0 deletions src/op/regs/regs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,13 @@ Attrs Sgd(const TVMArgs& values, GradTape* tapes) {
return Attrs(attrs);
}

Attrs ShardUnary(const TVMArgs& values, GradTape* tapes) {
MNM_PRELUDE(schema::ShardUnaryArgs, 2); // NOLINT(whitespace/line_length)
MNM_TAPE(0, ffi2schema::Tensor, x);
MNM_TAPE(1, ffi2schema::ArrayLike, spec);
return Attrs(attrs);
}

Attrs Size(const TVMArgs& values, GradTape* tapes) {
MNM_PRELUDE(schema::SizeArgs, 2); // NOLINT(whitespace/line_length)
MNM_TAPE(0, ffi2schema::Tensor, x);
Expand Down Expand Up @@ -4044,6 +4051,13 @@ Array<Expr> Sgd(const TVMArgs& values) {
MNM_RET();
}

Array<Expr> ShardUnary(const TVMArgs& values) {
MNM_PRELUDE(2);
MNM_ARG(0, ffi2expr::Tensor, x);
MNM_ARG(1, ffi2expr::ArrayLike, spec);
MNM_RET();
}

Array<Expr> Size(const TVMArgs& values) {
MNM_PRELUDE(2);
MNM_ARG(0, ffi2expr::Tensor, x);
Expand Down Expand Up @@ -4326,6 +4340,11 @@ MNM_REGISTER_GLOBAL("mnm.op.sym._recv").set_body(MNM_SYMBOLIC_API(_recv, 4, Recv
MNM_REGISTER_GLOBAL("mnm.op.sym._reduce").set_body(MNM_SYMBOLIC_API(_reduce, 3, CommReduce));
MNM_REGISTER_GLOBAL("mnm.op.sym._reduce_scatter")
.set_body(MNM_SYMBOLIC_API(_reduce_scatter, 2, ReduceScatter));
MNM_REGISTER_GLOBAL("mnm.op.sym._reshard").set_body(MNM_SYMBOLIC_API(_reshard, 1, Unary));
MNM_REGISTER_GLOBAL("mnm.op.sym._reshard_r2s")
.set_body(MNM_SYMBOLIC_API(_reshard_r2s, 2, ShardUnary));
MNM_REGISTER_GLOBAL("mnm.op.sym._reshard_s2r")
.set_body(MNM_SYMBOLIC_API(_reshard_s2r, 2, ShardUnary));
MNM_REGISTER_GLOBAL("mnm.op.sym._send").set_body(MNM_SYMBOLIC_API(_send, 3, Send));
MNM_REGISTER_GLOBAL("mnm.op.sym.abs").set_body(MNM_SYMBOLIC_API(abs, 1, Unary));
MNM_REGISTER_GLOBAL("mnm.op.sym.adaptive_avg_pool2d")
Expand Down Expand Up @@ -5448,6 +5467,14 @@ Attrs Sgd(const Array<Value>& values) {
return Attrs(attrs);
}

template <const char* op_name>
Attrs ShardUnary(const Array<Value>& values) {
MNM_PRELUDE(2, 2, schema::ShardUnaryArgs);
MNM_REQUIRED(0, value2schema::Tensor, x);
MNM_REQUIRED(1, value2schema::ArrayLike, spec);
return Attrs(attrs);
}

template <const char* op_name>
Attrs Size(const Array<Value>& values) {
MNM_PRELUDE(1, 2, schema::SizeArgs);
Expand Down Expand Up @@ -7233,6 +7260,18 @@ int Sgd(const std::string& field) {
return -1;
}

template <const char* op_name>
int ShardUnary(const std::string& field) {
if (field == "x") {
return 0;
}
if (field == "spec") {
return 1;
}
LOG(WARNING) << "Cannot find " << field << " in the schema of op " << op_name;
return -1;
}

template <const char* op_name>
int Size(const std::string& field) {
if (field == "x") {
Expand Down Expand Up @@ -8580,6 +8619,7 @@ MNM_REGISTER_OBJECT_REFLECT(SequenceMaskArgs);
MNM_REGISTER_OBJECT_REFLECT(SetShapeArgs);
MNM_REGISTER_OBJECT_REFLECT(SetStreamArgs);
MNM_REGISTER_OBJECT_REFLECT(SgdArgs);
MNM_REGISTER_OBJECT_REFLECT(ShardUnaryArgs);
MNM_REGISTER_OBJECT_REFLECT(SizeArgs);
MNM_REGISTER_OBJECT_REFLECT(SoftmaxArgs);
MNM_REGISTER_OBJECT_REFLECT(SoftmaxDxArgs);
Expand Down

0 comments on commit 5989fe1

Please sign in to comment.