From 6ef4ce6db2f03877a10554e306678ad1d72c3be0 Mon Sep 17 00:00:00 2001 From: "Huang, Guangtai" Date: Fri, 25 Feb 2022 01:40:40 +0800 Subject: [PATCH 1/4] [Dyn] Make resnet with known batchsize runable on cpu (#864) * init * fix * fix --- include/raf/op_utils.h | 7 +- python/raf/_tvm_op/transform.py | 1 + python/raf/amp/type_hints.py | 1 + python/raf/testing/mlp.py | 112 ++++++++++++++++++++++ scripts/src_codegen/def_op.py | 11 ++- scripts/src_codegen/def_schema.py | 51 ++-------- src/op/declare/nn.cc | 6 +- src/op/declare/reduce.cc | 3 +- src/op/declare/transform.cc | 78 +++++++++++---- src/op/dialect/cuda/embedding.cc | 6 +- src/op/dialect/tvm/nn.cc | 10 +- src/op/dialect/tvm/reduce.cc | 3 +- src/op/dialect/tvm/transform.cc | 98 +++++-------------- src/op/from_relay/transform.cc | 4 + src/op/grad/binary.cc | 59 ++---------- src/op/grad/gemm.cc | 25 ++--- src/op/grad/grad_utils.cc | 32 +++++++ src/op/grad/grad_utils.h | 6 ++ src/op/grad/nn.cc | 17 +--- src/op/grad/reduce.cc | 4 +- src/op/grad/transform.cc | 75 +++------------ src/op/ty/nn.cc | 14 +-- src/op/ty/reduce.cc | 12 +-- src/op/ty/transform.cc | 56 +++++++---- src/pass/type_infer.cc | 2 + tests/python/model/test_dynamic_model.py | 84 +++++++++++++++- tests/python/model/test_model_mlp.py | 103 ++++---------------- tests/python/op/tvm/test_tvm_transform.py | 26 +++++ tests/python/op/ty/test_type_transform.py | 44 +++++++++ tests/python/pass/test_pass_auto_cast.py | 2 +- tests/python/pass/test_pass_from_relay.py | 24 +++-- 31 files changed, 540 insertions(+), 436 deletions(-) create mode 100644 python/raf/testing/mlp.py diff --git a/include/raf/op_utils.h b/include/raf/op_utils.h index ed25f91c..671d00ea 100644 --- a/include/raf/op_utils.h +++ b/include/raf/op_utils.h @@ -121,7 +121,7 @@ inline bool IsInOpSet(const Expr& op, const OpSet& op_set) { inline bool IsReshapeOp(const Op& op) { static std::unordered_set reshape_ops{ Op::Get("raf.op.reshape"), Op::Get("raf.op.expand_dims"), Op::Get("raf.op.squeeze"), - Op::Get("raf.op.batch_flatten")}; + Op::Get("raf.op.batch_flatten"), Op::Get("raf.op.reshape_like")}; return IsInOpSet(op, reshape_ops); } @@ -179,8 +179,9 @@ inline Array GetShapeExprFromValue(const Value& value) { ICHECK(value.defined()); Array shape; if (auto ttv = value.as()) { - auto ndim = ttv->type->shape.size(); - for (size_t i = 0; i < ndim; ++i) { + auto ndim = ttv->type->shape[0].as(); + ICHECK(ndim) << "Expected IntImm, but got " << ttv->type->shape[0]->GetTypeKey(); + for (size_t i = 0; i < ndim->value; ++i) { shape.push_back(Any()); } } else { diff --git a/python/raf/_tvm_op/transform.py b/python/raf/_tvm_op/transform.py index 74e825ac..3ab45b21 100644 --- a/python/raf/_tvm_op/transform.py +++ b/python/raf/_tvm_op/transform.py @@ -129,6 +129,7 @@ def fcompute(*args): _reg.register_injective_schedule("raf.op.tvm.batch_flatten") _reg.register_injective_schedule("raf.op.tvm.arange") _reg.register_injective_schedule("raf.op.tvm.strided_slice") +_reg.register_reduce_schedule("raf.op.tvm.collapse_sum_like") @register_compute("raf.op.tvm.take_dx") diff --git a/python/raf/amp/type_hints.py b/python/raf/amp/type_hints.py index 73fa294d..e703ecf4 100644 --- a/python/raf/amp/type_hints.py +++ b/python/raf/amp/type_hints.py @@ -203,6 +203,7 @@ def _gen(args, ret_type, amp_dtype): register_op_cast_rule("raf.op.trunc", infer_cast(1)) register_op_cast_rule("raf.op.mesh_grid", infer_cast(2)) register_op_cast_rule("raf.op.reshape", infer_cast(1)) +register_op_cast_rule("raf.op.reshape_like", infer_cast(1)) register_op_cast_rule("raf.op.resize2d", infer_cast(1)) register_op_cast_rule("raf.op.ndarray_size", infer_cast(1)) register_op_cast_rule("raf.op.transpose", infer_cast(1)) diff --git a/python/raf/testing/mlp.py b/python/raf/testing/mlp.py new file mode 100644 index 00000000..2adc3331 --- /dev/null +++ b/python/raf/testing/mlp.py @@ -0,0 +1,112 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""MLP model""" +# pylint: disable=protected-access, attribute-defined-outside-init, too-many-locals +# pylint: disable=missing-class-docstring, too-many-arguments, missing-function-docstring +import torch.nn as nn +import torch.nn.functional as F + +import raf +from raf.model import Linear +from .common import check, randn_torch, t2m_param, one_hot_torch +from .utils import get_param, set_param + + +class TorchMlp(nn.Module): # pylint: disable=abstract-method + def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2): + super(TorchMlp, self).__init__() + self.fc1 = nn.Linear(num_inputs, num_hiddens1) + self.fc2 = nn.Linear(num_hiddens1, num_hiddens2) + self.fc3 = nn.Linear(num_hiddens2, num_outputs) + + def forward_infer(self, x): + y = self.fc1(x) + y = F.relu(y) + y = self.fc2(y) + y = F.relu(y) + y = self.fc3(y) + return y + + def forward(self, x, y_true=None): # pylint: disable=arguments-differ + y = self.forward_infer(x) + if self.training: + y_pred = F.log_softmax(y, dim=-1) + loss = F.nll_loss(y_pred, y_true) + return loss + return y + + +class RAFMlp(raf.Model): + # pylint: disable=attribute-defined-outside-init + def build(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2): + self.fc1 = Linear(num_inputs, num_hiddens1) + self.fc2 = Linear(num_hiddens1, num_hiddens2) + self.fc3 = Linear(num_hiddens2, num_outputs) + + @raf.model.trace + def forward_infer(self, x): + y = self.fc1(x) + y = raf.relu(y) + y = self.fc2(y) + y = raf.relu(y) + y = self.fc3(y) + return y + + @raf.model.trace + def forward(self, x, y_true): + y = self.forward_infer(x) + y_pred = raf.log_softmax(y) + loss = raf.nll_loss(y_true, y_pred) + return loss + + +def _param_map(t_model): + """maps from m_model parameter name to t_model parameter value""" + res = { + "fc1.w": t_model.fc1.weight, + "fc1.b": t_model.fc1.bias, + "fc2.w": t_model.fc2.weight, + "fc2.b": t_model.fc2.bias, + "fc3.w": t_model.fc3.weight, + "fc3.b": t_model.fc3.bias, + } + return res + + +def _init(m_model, t_model, device="cpu"): + """initialize meta model with parameters of torch model""" + # pylint: disable=no-member, line-too-long, too-many-statements + for m_name, t_w in _param_map(t_model).items(): + set_param(m_model, m_name, t2m_param(t_w, device=device)) + + +def check_params(m_model, t_model, atol=1e-4, rtol=1e-4): + """check the parameters of m_model and t_model""" + # pylint: disable=no-member, line-too-long, too-many-statements + for m_name, t_w in _param_map(t_model).items(): + m_w = get_param(m_model, m_name) + check(m_w, t_w, atol=atol, rtol=rtol) + + +def get_model(config, train=True): + """get MLP model""" + m_model = RAFMlp(*config) + t_model = TorchMlp(*config) + _init(m_model, t_model) + if train: + m_model.train_mode() + t_model.train() + else: + m_model.infer_mode() + t_model.eval() + return m_model, t_model + + +def get_input(config, batch_size=1, device="cpu", train=True): + """get MLP input""" + m_x, t_x = randn_torch([batch_size, config[0]], device=device, requires_grad=True) + if not train: + return [(m_x,), (t_x,)] + m_y, t_y = one_hot_torch(batch_size, num_classes=config[1], device=device) + return [(m_x, m_y), (t_x, t_y)] diff --git a/scripts/src_codegen/def_op.py b/scripts/src_codegen/def_op.py index c544044e..59336b25 100644 --- a/scripts/src_codegen/def_op.py +++ b/scripts/src_codegen/def_op.py @@ -96,12 +96,12 @@ Op(name="cross_entropy_dpred", schema_name="loss"), Op(name="cross_entropy_dtrue", schema_name="loss"), Op(name="reshape", schema_name="reshape"), + Op(name="reshape_like", schema_name="binary_like"), Op(name="resize2d", schema_name="resize2d"), Op(name="resize2d_dx", schema_name="resize2d_dx"), Op(name="ndarray_size", schema_name="unary"), Op(name="transpose", schema_name="transpose"), - Op(name="transpose_dx", schema_name="transpose_dx"), - Op(name="collapse_sum_like", schema_name="collapse_like"), + Op(name="transpose_dx", schema_name="transpose"), Op(name="sum", schema_name="sum"), Op(name="sum_dx", schema_name="sum_dx"), Op(name="cumsum", schema_name="cumsum"), @@ -135,8 +135,9 @@ Op(name="sequence_mask", schema_name="sequence_mask"), Op(name="reverse_sequence", schema_name="reverse_sequence"), Op(name="reverse", schema_name="reverse"), - Op(name="broadcast_to", schema_name="broadcast_to"), - Op(name="broadcast_to_like", schema_name="broadcast_to_like"), + Op(name="broadcast_to", schema_name="binary_to"), + Op(name="broadcast_to_like", schema_name="binary_like"), + Op(name="collapse_sum_like", schema_name="binary_like"), Op(name="concatenate", schema_name="concatenate"), Op(name="squeeze", schema_name="squeeze"), Op(name="stack", schema_name="stack"), @@ -159,7 +160,7 @@ Op(name="fuse_tensor", schema_name="fuse_tensor"), Op(name="defuse_tensor", schema_name="defuse_tensor"), Op(name="cast", schema_name="cast"), - Op(name="cast_like", schema_name="cast_like"), + Op(name="cast_like", schema_name="binary_like"), Op(name="gather", schema_name="gather"), Op(name="gather_dx", schema_name="gather_dx"), Op(name="gather_nd", schema_name="gather_nd"), diff --git a/scripts/src_codegen/def_schema.py b/scripts/src_codegen/def_schema.py index 36bca32f..45ee0f3d 100644 --- a/scripts/src_codegen/def_schema.py +++ b/scripts/src_codegen/def_schema.py @@ -164,7 +164,7 @@ Arg(name="x_or_w", cxx_type="value::BaseTensorValue"), Arg(name="y", cxx_type=OptionalTensor), Arg(name="dy", cxx_type="value::BaseTensorValue"), - Arg(name="shape", cxx_type=OptionalIntArray, cxx_normalizer="IntArray"), + Arg(name="shape", cxx_type="value::Value"), Arg(name="stride", cxx_type="std::vector", cxx_normalizer="IntTuple"), Arg(name="padding", cxx_type="std::vector", cxx_normalizer="IntTuple"), Arg(name="dilation", cxx_type="std::vector", cxx_normalizer="IntTuple"), @@ -174,7 +174,7 @@ Arg(name="x_or_w", cxx_type="value::BaseTensorValue"), Arg(name="y", cxx_type=OptionalTensor), Arg(name="dy", cxx_type="value::BaseTensorValue"), - Arg(name="shape", cxx_type=OptionalIntArray, cxx_normalizer="IntArray"), + Arg(name="shape", cxx_type="value::Value"), Arg(name="stride", cxx_type="std::vector", cxx_normalizer="IntTuple"), Arg(name="padding", cxx_type="std::vector", cxx_normalizer="IntTuple"), Arg(name="output_padding", cxx_type="std::vector", cxx_normalizer="IntTuple"), @@ -224,7 +224,7 @@ "nn.h::embedding_dx": [ Arg(name="dy", cxx_type="value::BaseTensorValue"), Arg(name="indices", cxx_type="value::BaseTensorValue"), - Arg(name="num_weight", cxx_type="std::vector", cxx_normalizer="IntTuple"), + Arg(name="num_weight", cxx_type="value::Value"), ], "nn.h::repeat": [ Arg(name="x", cxx_type="value::BaseTensorValue"), @@ -249,14 +249,6 @@ Arg(name="seq_axis", cxx_type="int", cxx_default=1), Arg(name="batch_axis", cxx_type="int", cxx_default=0), ], - "nn.h::broadcast_to": [ - Arg(name="x", cxx_type="value::BaseTensorValue"), - Arg(name="shape", cxx_type="std::vector", cxx_normalizer="IntTuple"), - ], - "nn.h::broadcast_to_like": [ - Arg(name="x", cxx_type="value::BaseTensorValue"), - Arg(name="broadcast_type", cxx_type="value::BaseTensorValue"), - ], "nn.h::concatenate": [ Arg(name="x", cxx_type="std::vector", cxx_normalizer="TensorTuple"), Arg(name="axis", cxx_type="int", cxx_default=0), @@ -364,10 +356,14 @@ Arg(name="y", cxx_type="value::BaseTensorValue"), Arg(name="dy", cxx_type="value::BaseTensorValue"), ], - "likes.h::collapse_like": [ + "likes.h::binary_to": [ Arg(name="x", cxx_type="value::BaseTensorValue"), Arg(name="shape", cxx_type="std::vector", cxx_normalizer="IntTuple"), ], + "likes.h::binary_like": [ + Arg(name="x", cxx_type="value::BaseTensorValue"), + Arg(name="like_type", cxx_type="value::BaseTensorValue"), + ], "likes.h::reshape": [ Arg(name="x", cxx_type="value::BaseTensorValue"), Arg(name="shape", cxx_type="value::Value"), @@ -436,6 +432,7 @@ ], "reduce.h::mean_dx": [ Arg(name="dy", cxx_type="value::BaseTensorValue"), + Arg(name="shape", cxx_type="value::Value"), Arg( name="axis", cxx_type="std::vector", @@ -443,13 +440,6 @@ cxx_default="{}", py_default=(), ), - Arg( - name="x_shape", - cxx_type="std::vector", - cxx_normalizer="IntTuple", - cxx_default="{}", - py_default="None", - ), Arg(name="keepdims", cxx_type="bool", cxx_default=False), Arg(name="exclude", cxx_type="bool", cxx_default=False), ], @@ -504,23 +494,6 @@ py_default="None", ), ], - "transform.h::transpose_dx": [ - Arg(name="dy", cxx_type="value::BaseTensorValue"), - Arg( - name="axes", - cxx_type="std::vector", - cxx_normalizer="IntTuple", - cxx_default="{}", - py_default="None", - ), - Arg( - name="primal_shape", - cxx_type="std::vector", - cxx_normalizer="IntTuple", - cxx_default="{}", - py_default="None", - ), - ], "transform.h::swap_axis": [ Arg(name="x", cxx_type="value::BaseTensorValue"), Arg(name="axis1", cxx_type="int"), @@ -545,10 +518,6 @@ Arg(name="data", cxx_type="value::BaseTensorValue"), Arg(name="dtype", cxx_type="std::string"), ], - "transform.h::cast_like": [ - Arg(name="data", cxx_type="value::BaseTensorValue"), - Arg(name="dtype_like", cxx_type="value::BaseTensorValue"), - ], "transform.h::strided_slice": [ Arg(name="x", cxx_type="value::BaseTensorValue"), Arg(name="begin", cxx_type="value::Value"), @@ -607,7 +576,7 @@ ], "transform.h::strided_slice_dx": [ Arg(name="dy", cxx_type="value::BaseTensorValue"), - Arg(name="primal_shape", cxx_type="std::vector", cxx_normalizer="IntTuple"), + Arg(name="shape", cxx_type="value::Value"), Arg(name="begin", cxx_type="std::vector", cxx_normalizer="IntTuple"), Arg(name="end", cxx_type="std::vector", cxx_normalizer="IntTuple"), Arg( diff --git a/src/op/declare/nn.cc b/src/op/declare/nn.cc index 3c32028f..760a08d5 100644 --- a/src/op/declare/nn.cc +++ b/src/op/declare/nn.cc @@ -310,9 +310,10 @@ void Conv2dDxw(const CallValues& call) { CHECK(args != nullptr); CHECK(args->shape.defined()); const DLTensor* x_or_w = args->x_or_w; + std::vector shape = GetShapeVecFromValue(args->shape); call->out = TensorValue::Assemble(/*dev=*/x_or_w->device, /*dtype=*/x_or_w->dtype, - /*shape=*/args->shape.value()); + /*shape=*/shape); call->device = x_or_w->device; } @@ -324,9 +325,10 @@ void Conv2dTransposeDxw(const CallValues& call) { CHECK(args != nullptr); CHECK(args->shape.defined()); const DLTensor* x_or_w = args->x_or_w; + std::vector shape = GetShapeVecFromValue(args->shape); call->out = TensorValue::Assemble(/*dev=*/x_or_w->device, /*dtype=*/x_or_w->dtype, - /*shape=*/args->shape.value()); + /*shape=*/shape); call->device = x_or_w->device; } diff --git a/src/op/declare/reduce.cc b/src/op/declare/reduce.cc index c9b9625c..6d7df3e2 100644 --- a/src/op/declare/reduce.cc +++ b/src/op/declare/reduce.cc @@ -9,6 +9,7 @@ */ #include #include "raf/op.h" +#include "raf/op_utils.h" #include "raf/tensor.h" #include "../schema/reduce.h" namespace raf { @@ -96,7 +97,7 @@ void MeanDxDecl(const CallValues& call) { const auto* args = call->args.as(); CHECK(args != nullptr); DLTensor* dy = args->dy; - std::vector shape = args->x_shape; + std::vector shape = GetShapeVecFromValue(args->shape); call->device = dy->device; call->out = TensorValue::Assemble(/*dev=*/dy->device, /*dtype=*/dy->dtype, diff --git a/src/op/declare/transform.cc b/src/op/declare/transform.cc index 487ff263..b76b7b42 100644 --- a/src/op/declare/transform.cc +++ b/src/op/declare/transform.cc @@ -161,6 +161,22 @@ RAF_OP_DECLARE("raf.op.reshape", [](const CallValues& call) { throw; }); +RAF_OP_DECLARE("raf.op.reshape_like", [](const CallValues& call) { + const auto* args = call->args.as(); + CHECK(args != nullptr); + DLTensor* x = args->x; + DLTensor* like_type = args->like_type; + std::vector shape(like_type->shape, like_type->shape + like_type->ndim); + call->device = x->device; + call->callee = ir::NullValue(); + CHECK(IsCompact(*x)) + << "NotImplementedError: for now we only support reshape on contiguous tensor"; + int64_t origin = std::accumulate(x->shape, x->shape + x->ndim, 1LL, std::multiplies()); + int64_t reshaped = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies()); + CHECK_EQ(origin, reshaped) << "ValueError: Number of elements mismatch after reshaping!"; + call->out = Downcast(args->x).CreateView(shape); +}); + RAF_OP_DECLARE("raf.op.resize2d", [](const CallValues& call) { const auto* args = call->args.as(); CHECK(args != nullptr); @@ -267,9 +283,10 @@ RAF_OP_DECLARE("raf.op.embedding_dx", [](const CallValues& call) { const auto* args = call->args.as(); CHECK(args != nullptr); DLTensor* dy = args->dy; + std::vector shape = GetShapeVecFromValue(args->num_weight); call->out = TensorValue::Assemble(/*dev=*/dy->device, /*dtype=*/dy->dtype, - /*shape=*/args->num_weight); + /*shape=*/shape); call->device = dy->device; }); @@ -410,7 +427,7 @@ RAF_OP_DECLARE("raf.op.strided_slice_dx", [](const CallValues& call) { std::vector stride_vec(num_axis, 1); if (IsCompact(*data)) { call->device = data->device; - std::vector shape = args->primal_shape; + std::vector shape = GetShapeVecFromValue(args->shape); call->out = TensorValue::Assemble(/*dev=*/data->device, /*dtype=*/data->dtype, /*shape=*/shape); @@ -462,7 +479,7 @@ RAF_OP_DECLARE("raf.op.reverse_sequence", [](const CallValues& call) { }); RAF_OP_DECLARE("raf.op.broadcast_to", [](const CallValues& call) { - const auto* args = call->args.as(); + const auto* args = call->args.as(); DLTensor* x = args->x; std::vector shape = args->shape; call->out = TensorValue::Assemble(/*dev=*/x->device, @@ -527,13 +544,31 @@ RAF_OP_DECLARE("raf.op.transpose", [](const CallValues& call) { }); RAF_OP_DECLARE("raf.op.transpose_dx", [](const CallValues& call) { - const auto* args = call->args.as(); + const auto* args = call->args.as(); CHECK(args != nullptr); - const DLTensor* dy = args->dy; - std::vector shape = args->primal_shape; + std::vector axes(args->axes.size(), -1); + const DLTensor* dy = args->x; + int64_t* ishape = dy->shape; + int ndim = dy->ndim; + + std::vector oshape(ndim, -1); + if (axes.size() != 0) { + for (size_t i = 0; i < ndim; ++i) { + axes[args->axes[i]] = i; + } + CHECK_EQ(ndim, axes.size()); + for (size_t i = 0; i < ndim; ++i) { + int axis = axes[i] >= 0 ? axes[i] : axes[i] + ndim; + oshape[i] = ishape[axis]; + } + } else { + for (int i = 0; i < ndim; ++i) { + oshape[i] = ishape[ndim - i - 1]; + } + } call->out = TensorValue::Assemble(/*dev=*/dy->device, /*dtype=*/dy->dtype, - /*shape=*/shape); + /*shape=*/oshape); call->device = dy->device; }); @@ -572,17 +607,21 @@ RAF_OP_DECLARE("raf.op.swap_axis", [](const CallValues& call) { call->device = x->device; }); -RAF_OP_DECLARE("raf.op.broadcast_to_like", [](const CallValues& call) { - const auto* args = call->args.as(); +void BinaryShapeLike(const CallValues& call) { + const auto* args = call->args.as(); CHECK(args != nullptr); DLTensor* x = args->x; - DLTensor* broadcast_type = args->broadcast_type; - std::vector shape(broadcast_type->shape, broadcast_type->shape + broadcast_type->ndim); + DLTensor* like_type = args->like_type; + std::vector shape(like_type->shape, like_type->shape + like_type->ndim); call->out = TensorValue::Assemble(/*dev=*/x->device, - /*dtype=*/broadcast_type->dtype, + /*dtype=*/x->dtype, /*shape=*/shape); call->device = x->device; -}); +} + +RAF_OP_DECLARE("raf.op.broadcast_to_like", BinaryShapeLike); + +RAF_OP_DECLARE("raf.op.collapse_sum_like", BinaryShapeLike); RAF_OP_DECLARE("raf.op.stack", [](const CallValues& call) { const auto* args = call->args.as(); @@ -821,14 +860,15 @@ RAF_OP_DECLARE("raf.op.cast", [](const CallValues& call) { }); RAF_OP_DECLARE("raf.op.cast_like", [](const CallValues& call) { - const auto* args = call->args.as(); + const auto* args = call->args.as(); CHECK(args != nullptr); - DLTensor* dtype_like = args->dtype_like; - std::vector shape(dtype_like->shape, dtype_like->shape + dtype_like->ndim); - call->out = TensorValue::Assemble(/*dev=*/dtype_like->device, - /*dtype=*/dtype_like->dtype, + DLTensor* x = args->x; + DLTensor* like_type = args->like_type; + std::vector shape(x->shape, x->shape + x->ndim); + call->out = TensorValue::Assemble(/*dev=*/x->device, + /*dtype=*/like_type->dtype, /*shape=*/shape); - call->device = dtype_like->device; + call->device = x->device; }); RAF_OP_DECLARE("raf.op.gather", [](const CallValues& call) { diff --git a/src/op/dialect/cuda/embedding.cc b/src/op/dialect/cuda/embedding.cc index 38d3dcbf..9875858a 100644 --- a/src/op/dialect/cuda/embedding.cc +++ b/src/op/dialect/cuda/embedding.cc @@ -8,6 +8,7 @@ * \brief embedding_dx cuda backend */ #include "raf/op.h" +#include "raf/op_utils.h" #include "raf/device_api.h" #include "../../schema/nn.h" #include "./kernels/kernel_util.cuh" @@ -27,8 +28,9 @@ class EmbeddingDxImpl : public raf::op::OpEnv { static auto op = ir::Op::Get("raf.op.embedding_dx"); auto args = cv->args.as(); n_out_elements_ = 1; - for (int i = 0; i < args->num_weight.size(); ++i) { - n_out_elements_ *= args->num_weight[i]; + std::vector num_weight = GetShapeVecFromValue(args->num_weight); + for (int i = 0; i < num_weight.size(); ++i) { + n_out_elements_ *= num_weight[i]; } this->arg_indices = { fschema_index[op]("dy"), diff --git a/src/op/dialect/tvm/nn.cc b/src/op/dialect/tvm/nn.cc index aa2a68d8..b0ff077b 100644 --- a/src/op/dialect/tvm/nn.cc +++ b/src/op/dialect/tvm/nn.cc @@ -193,10 +193,9 @@ Attrs ConvDxwSchema2Attrs(const ConvDxwArgs* args) { attrs->dilation.push_back(IntImm(tvm::runtime::DataType::Int(64), dilation[i])); } // FIXME: (workaround) we use kernel size to store the shape of X (for dx) or W (for dw) - CHECK(args->shape.defined()); - auto shape = args->shape.value(); + auto shape = GetShapeExprFromValue(args->shape); for (int i = 0; i < shape.size(); ++i) { - attrs->kernel_size.push_back(IntImm(tvm::runtime::DataType::Int(32), shape[i]->value)); + attrs->kernel_size.push_back(shape[i]); } attrs->groups = args->groups; attrs->channels = NullValue(); @@ -267,10 +266,9 @@ Attrs ConvTransposeDxwSchema2Attrs(const ConvTransposeDxwArgs* args) { attrs->dilation.push_back(IntImm(tvm::runtime::DataType::Int(64), dilation[i])); } // FIXME: (workaround) we use kernel size to store the shape of X (for dx) or W (for dw) - CHECK(args->shape.defined()); - auto shape = args->shape.value(); + auto shape = GetShapeExprFromValue(args->shape); for (int i = 0; i < shape.size(); ++i) { - attrs->kernel_size.push_back(IntImm(tvm::runtime::DataType::Int(32), shape[i]->value)); + attrs->kernel_size.push_back(shape[i]); } attrs->groups = args->groups; attrs->channels = NullValue(); diff --git a/src/op/dialect/tvm/reduce.cc b/src/op/dialect/tvm/reduce.cc index 5aaaa79d..4b8d2c95 100644 --- a/src/op/dialect/tvm/reduce.cc +++ b/src/op/dialect/tvm/reduce.cc @@ -223,7 +223,7 @@ std::vector MeanDxSchemaArgNames(const op::CallValues& call) { Attrs MeanDxSchema2Attrs(const MeanDxArgs* args) { auto attrs = make_object(); - std::vector shape = args->x_shape; + std::vector shape = GetShapeVecFromValue(args->shape); auto ndim = shape.size(); for (int64_t s : shape) { attrs->shape.push_back(s); @@ -246,7 +246,6 @@ Attrs MeanDxSchema2Attrs(const MeanDxArgs* args) { HashKey MeanDxHasher(const std::vector& param_types, const Type& ret_type, const MeanDxArgs* args) { HashKey key = GenericHasher(param_types, ret_type, nullptr); - key << args->x_shape; key << args->axis; key << args->keepdims; key << args->exclude; diff --git a/src/op/dialect/tvm/transform.cc b/src/op/dialect/tvm/transform.cc index fd47f633..d5e17595 100644 --- a/src/op/dialect/tvm/transform.cc +++ b/src/op/dialect/tvm/transform.cc @@ -248,23 +248,15 @@ std::vector EmbeddingDxSchemaArgNames(const op::CallValues& call) { Attrs EmbeddingDxSchema2Attrs(const EmbeddingDxArgs* args) { auto attrs = make_object(); - for (auto v : args->num_weight) { + auto num_weight = GetShapeVecFromValue(args->num_weight); + for (auto v : num_weight) { attrs->dims.push_back(Integer(v)); } return Attrs(attrs); } -HashKey EmbeddingDxHasher(const std::vector& param_types, const Type& y_type, - const EmbeddingDxArgs* args) { - HashKey key = GenericHasher(param_types, y_type, nullptr); - for (auto v : args->num_weight) { - key << v; - } - return key; -} - RAF_TVM(embedding_dx, EmbeddingDx, EmbeddingDxArgs, EmbeddingDxSchema2Args, - EmbeddingDxSchemaArgNames, EmbeddingDxSchema2Attrs, EmbeddingDxHasher, kOpaque); + EmbeddingDxSchemaArgNames, EmbeddingDxSchema2Attrs, GenericHasher, kOpaque); std::vector SequenceMaskSchema2Args(const SequenceMaskArgs* args) { return {args->x, args->sequence_length}; @@ -343,27 +335,16 @@ RAF_TVM(reverse_sequence, ReverseSequence, ReverseSequenceArgs, ReverseSequenceS ReverseSequenceSchemaArgNames, ReverseSequenceSchema2Attrs, ReverseSequenceHasher, kInjective); -std::vector BroadcastToSchema2Args(const BroadcastToArgs* args) { +std::vector BinaryToSchema2Args(const BinaryToArgs* args) { return {args->x}; } -std::vector BroadcastToSchemaArgNames(const op::CallValues& call) { +std::vector BinaryToSchemaArgNames(const op::CallValues& call) { return {"x"}; } -Attrs BroadcastToSchema2Attrs(const BroadcastToArgs* args) { - auto attrs = make_object(); - std::vector shape; - shape.reserve(args->shape.size()); - for (size_t i = 0; i < args->shape.size(); ++i) { - shape.emplace_back(IntImm(ir::DataType::Int(32), args->shape[i])); - } - attrs->shape = Array(shape.begin(), shape.end()); - return Attrs(attrs); -} - -RAF_TVM(broadcast_to, BroadcastTo, BroadcastToArgs, BroadcastToSchema2Args, - BroadcastToSchemaArgNames, BroadcastToSchema2Attrs, GenericHasher, kBroadcast); +RAF_TVM(broadcast_to, BroadcastTo, BinaryToArgs, BinaryToSchema2Args, BinaryToSchemaArgNames, + GenericAttrs, GenericHasher, kBroadcast); std::vector TransposeSchema2Args(const TransposeArgs* args) { return {args->x}; @@ -394,50 +375,19 @@ HashKey TransposeHasher(const std::vector& param_types, const Type& y_type RAF_TVM(transpose, Transpose, TransposeArgs, TransposeSchema2Args, TransposeSchemaArgNames, TransposeSchema2Attrs, TransposeHasher, kInjective); -std::vector TransposeDxSchema2Args(const TransposeDxArgs* args) { - return {args->dy}; -} - -std::vector TransposeDxSchemaArgNames(const op::CallValues& call) { - return {"dy"}; -} - -Attrs TransposeDxSchema2Attrs(const TransposeDxArgs* args) { - auto attrs = make_object(); - std::vector axes; - axes.reserve(args->axes.size()); - for (size_t i = 0; i < args->axes.size(); ++i) { - axes.emplace_back(args->axes[i]); - } - attrs->axes = Array(axes.begin(), axes.end()); - return Attrs(attrs); -} - -HashKey TransposeDxHasher(const std::vector& param_types, const Type& y_type, - const TransposeDxArgs* args) { - HashKey key = GenericHasher(param_types, y_type, nullptr); - key << args->axes; - return key; -} - -RAF_TVM(transpose_dx, TransposeDx, TransposeDxArgs, TransposeDxSchema2Args, - TransposeDxSchemaArgNames, TransposeDxSchema2Attrs, TransposeDxHasher, kInjective); - -std::vector BroadcastToLikeSchema2Args(const BroadcastToLikeArgs* args) { - return {args->x, args->broadcast_type}; -} +RAF_TVM(transpose_dx, TransposeDx, TransposeArgs, TransposeSchema2Args, TransposeSchemaArgNames, + TransposeSchema2Attrs, TransposeHasher, kInjective); -std::vector BroadcastToLikeSchemaArgNames(const op::CallValues& call) { - return {"x", "broadcast_type"}; +std::vector BinaryLikeSchema2Args(const BinaryLikeArgs* args) { + return {args->x, args->like_type}; } -Attrs BroadcastToLikeSchema2Attrs(const BroadcastToLikeArgs* args) { - auto attrs = make_object(); - return Attrs(attrs); +std::vector BinaryLikeSchemaArgNames(const op::CallValues& call) { + return {"x", "like_type"}; } -RAF_TVM(broadcast_to_like, BroadcastToLike, BroadcastToLikeArgs, BroadcastToLikeSchema2Args, - BroadcastToLikeSchemaArgNames, BroadcastToLikeSchema2Attrs, GenericHasher, kBroadcast); +RAF_TVM(broadcast_to_like, BroadcastToLike, BinaryLikeArgs, BinaryLikeSchema2Args, + BinaryLikeSchemaArgNames, GenericAttrs, GenericHasher, kBroadcast); std::vector SplitSchema2Args(const SplitArgs* args) { return {args->x}; @@ -689,15 +639,7 @@ HashKey CastHasher(const std::vector& param_types, const Type& y_type, con RAF_TVM(cast, Cast, CastArgs, CastSchema2Args, CastSchemaArgNames, CastSchema2Attrs, CastHasher, kElemWise); -std::vector CastLikeSchema2Args(const CastLikeArgs* args) { - return {args->data, args->dtype_like}; -} - -std::vector CastLikeSchemaArgNames(const op::CallValues& call) { - return {"data", "dtype_like"}; -} - -RAF_TVM(cast_like, CastLike, CastLikeArgs, CastLikeSchema2Args, CastLikeSchemaArgNames, +RAF_TVM(cast_like, CastLike, BinaryLikeArgs, BinaryLikeSchema2Args, BinaryLikeSchemaArgNames, GenericAttrs, GenericHasher, kElemWise); std::vector GatherSchema2Args(const GatherArgs* args) { @@ -1079,8 +1021,9 @@ Attrs StridedSliceDxSchema2Attrs(const StridedSliceDxArgs* args) { end.emplace_back(args->end[i]); strides.emplace_back(args->strides[i]); } - for (int i = 0; i < args->primal_shape.size(); ++i) { - primal_shape.emplace_back(args->primal_shape[i]); + std::vector shape = GetShapeVecFromValue(args->shape); + for (int i = 0; i < shape.size(); ++i) { + primal_shape.emplace_back(shape[i]); } attrs->primal_shape = Array(primal_shape.begin(), primal_shape.end()); attrs->begin = Array(begin.begin(), begin.end()); @@ -1195,6 +1138,9 @@ HashKey CumsumHasher(const std::vector& param_types, const Type& ret_type, RAF_TVM(cumsum, Cumsum, CumsumArgs, CumsumSchema2Args, CumsumSchemaArgNames, CumsumSchema2Attrs, CumsumHasher, kOpaque); +RAF_TVM(collapse_sum_like, CollapseSumLike, BinaryLikeArgs, BinaryLikeSchema2Args, + BinaryLikeSchemaArgNames, GenericAttrs, GenericHasher, kCommReduce); + } // namespace tvm_dialect } // namespace op } // namespace raf diff --git a/src/op/from_relay/transform.cc b/src/op/from_relay/transform.cc index dd5563f8..fc669fe0 100644 --- a/src/op/from_relay/transform.cc +++ b/src/op/from_relay/transform.cc @@ -83,6 +83,10 @@ RAF_OP_FROM_RELAY("broadcast_to", "raf.op.broadcast_to", RAF_GENERIC_ATTR_OP_FROM_RELAY("broadcast_to_like", "raf.op.broadcast_to_like"); +RAF_GENERIC_ATTR_OP_FROM_RELAY("collapse_sum_like", "raf.op.collapse_sum_like"); + +RAF_GENERIC_ATTR_OP_FROM_RELAY("reshape_like", "raf.op.reshape_like"); + RAF_OP_FROM_RELAY("transpose", "raf.op.transpose", [&](const Attrs& attrs, const Array& args, const VarValueMap& val_map) { Array raf_args = args; diff --git a/src/op/grad/binary.cc b/src/op/grad/binary.cc index b1621400..31c952cf 100644 --- a/src/op/grad/binary.cc +++ b/src/op/grad/binary.cc @@ -26,17 +26,7 @@ Array AddGrad(const Expr& orig_call, const Array orig_args, const Va CHECK_GE(call->args.size(), 2); const Expr& x1 = call->args[0]; const Expr& x2 = call->args[1]; - - auto f = [&dy](const Expr& x) { - static auto collapse_axis = Op::Get("raf.op.get_reduce_axis"); - static auto collapse_keep = Op::Get("raf.op.get_kept_dims"); - static auto sum = Op::Get("raf.op.sum"); - Call axes = Call(collapse_axis, {dy, x}); - Call keep = Call(collapse_keep, {dy, x}); - return Call(sum, {dy, axes, keep, MakeConstant(BoolValue::make(false))}); - }; - - return {f(x1), f(x2)}; + return {GetCollapseSumLike(dy, x1), GetCollapseSumLike(dy, x2)}; } RAF_OP_GRAD("raf.op.add", AddGrad); @@ -48,15 +38,6 @@ Array SubGrad(const Expr& orig_call, const Array orig_args, const Va const Expr& x1 = call->args[0]; const Expr& x2 = call->args[1]; - auto f = [&dy](const Expr& x) { - static auto collapse_axis = Op::Get("raf.op.get_reduce_axis"); - static auto collapse_keep = Op::Get("raf.op.get_kept_dims"); - static auto sum = Op::Get("raf.op.sum"); - Call axes = Call(collapse_axis, {dy, x}); - Call keep = Call(collapse_keep, {dy, x}); - return Call(sum, {dy, axes, keep, MakeConstant(BoolValue::make(false))}); - }; - auto fs = [&dy](const Expr& x) { static auto collapse_axis = Op::Get("raf.op.get_reduce_axis"); static auto collapse_keep = Op::Get("raf.op.get_kept_dims"); @@ -67,7 +48,7 @@ Array SubGrad(const Expr& orig_call, const Array orig_args, const Va Call value = Call(sum, {dy, axes, keep, MakeConstant(BoolValue::make(false))}); return Call(neg, {value}); }; - return {f(x1), fs(x2)}; + return {GetCollapseSumLike(dy, x1), fs(x2)}; } RAF_OP_GRAD("raf.op.subtract", SubGrad); @@ -112,17 +93,8 @@ Array MulGrad(const Expr& orig_call, const Array orig_args, const Va CHECK_GE(call->args.size(), 2); const Expr& x1 = call->args[0]; const Expr& x2 = call->args[1]; - - auto f = [](const Expr& dx, const Expr& x) { - static auto collapse_axis = Op::Get("raf.op.get_reduce_axis"); - static auto collapse_keep = Op::Get("raf.op.get_kept_dims"); - static auto sum = Op::Get("raf.op.sum"); - Call axes = Call(collapse_axis, {dx, x}); - Call keep = Call(collapse_keep, {dx, x}); - return Call(sum, {dx, axes, keep, MakeConstant(BoolValue::make(false))}); - }; - - return {f(Call(op_multiply, {dy, x2}), x1), f(Call(op_multiply, {dy, x1}), x2)}; + return {GetCollapseSumLike(Call(op_multiply, {dy, x2}), x1), + GetCollapseSumLike(Call(op_multiply, {dy, x1}), x2)}; } RAF_OP_GRAD("raf.op.multiply", MulGrad); @@ -143,16 +115,8 @@ Array PowGrad(const Expr& orig_call, const Array orig_args, const Va Call x1_log = Call(op_log, {x1}); Call dx2 = Call(op_multiply, {y1, x1_log}); - auto f = [](const Expr& dx, const Expr& x) { - static auto collapse_axis = Op::Get("raf.op.get_reduce_axis"); - static auto collapse_keep = Op::Get("raf.op.get_kept_dims"); - static auto sum = Op::Get("raf.op.sum"); - Call axes = Call(collapse_axis, {dx, x}); - Call keep = Call(collapse_keep, {dx, x}); - return Call(sum, {dx, axes, keep, MakeConstant(BoolValue::make(false))}); - }; - - return {f(Call(op_multiply, {dy, dx1}), x1), f(Call(op_multiply, {dy, dx2}), x2)}; + return {GetCollapseSumLike(Call(op_multiply, {dy, dx1}), x1), + GetCollapseSumLike(Call(op_multiply, {dy, dx2}), x2)}; } RAF_OP_GRAD("raf.op.power", PowGrad); @@ -171,16 +135,7 @@ Array DivGrad(const Expr& orig_call, const Array orig_args, const Va dx2 = Call(op_multiply, {dx2, Call(op_divide, {x1, x2})}); dx2 = Call(op_divide, {dx2, x2}); - auto f = [](const Expr& dx, const Expr& x) { - static auto collapse_axis = Op::Get("raf.op.get_reduce_axis"); - static auto collapse_keep = Op::Get("raf.op.get_kept_dims"); - static auto sum = Op::Get("raf.op.sum"); - Call axes = Call(collapse_axis, {dx, x}); - Call keep = Call(collapse_keep, {dx, x}); - return Call(sum, {dx, axes, keep, MakeConstant(BoolValue::make(false))}); - }; - - return {f(dx1, x1), f(dx2, x2)}; + return {GetCollapseSumLike(dx1, x1), GetCollapseSumLike(dx2, x2)}; } RAF_OP_GRAD("raf.op.divide", DivGrad); diff --git a/src/op/grad/gemm.cc b/src/op/grad/gemm.cc index ca977f19..2d86e406 100644 --- a/src/op/grad/gemm.cc +++ b/src/op/grad/gemm.cc @@ -77,37 +77,28 @@ Array BatchMatmulGradImpl(const Expr& orig_call, const Array orig_ar const Expr& a = call->args[0]; const Expr& b = call->args[1]; - auto f = [](const Expr& dx, const Expr& x) { - static auto collapse_axis = Op::Get("raf.op.get_reduce_axis"); - static auto collapse_keep = Op::Get("raf.op.get_kept_dims"); - static auto sum = Op::Get("raf.op.sum"); - Call axes = Call(collapse_axis, {dx, x}); - Call keep = Call(collapse_keep, {dx, x}); - return Call(sum, {dx, axes, keep, MakeConstant(BoolValue::make(false))}); - }; - if (!transpose_a) { if (!transpose_b) { return { - f(Call(op_nt, {dy, b}), a), - f(Call(op_tn, {a, dy}), b), + GetCollapseSumLike(Call(op_nt, {dy, b}), a), + GetCollapseSumLike(Call(op_tn, {a, dy}), b), }; } else { return { - f(Call(op_nn, {dy, b}), a), - f(Call(op_tn, {dy, a}), b), + GetCollapseSumLike(Call(op_nn, {dy, b}), a), + GetCollapseSumLike(Call(op_tn, {dy, a}), b), }; } } else { if (!transpose_b) { return { - f(Call(op_nt, {b, dy}), a), - f(Call(op_nn, {a, dy}), b), + GetCollapseSumLike(Call(op_nt, {b, dy}), a), + GetCollapseSumLike(Call(op_nn, {a, dy}), b), }; } else { return { - f(Call(op_tt, {b, dy}), a), - f(Call(op_tt, {dy, a}), b), + GetCollapseSumLike(Call(op_tt, {b, dy}), a), + GetCollapseSumLike(Call(op_tt, {dy, a}), b), }; } } diff --git a/src/op/grad/grad_utils.cc b/src/op/grad/grad_utils.cc index 19f8495e..f02b8aed 100644 --- a/src/op/grad/grad_utils.cc +++ b/src/op/grad/grad_utils.cc @@ -15,6 +15,38 @@ namespace grad { using namespace raf::ir; +Expr GetShape(const Expr& expr) { + static auto op_shape = Op::Get("raf.op.shape"); + static auto op_size = Op::Get("raf.op.shape_as_tensor"); + if (expr->checked_type_.defined() && tvm::relay::IsDynamic(expr->checked_type_)) { + return Call(op_size, {expr}); + } + return Call(op_shape, {expr}); +} + +Expr GetReshapeLike(const Expr& x, const Expr& like_type) { + static auto reshape_like = Op::Get("raf.op.reshape_like"); + static auto reshape = Op::Get("raf.op.reshape"); + static auto shape = Op::Get("raf.op.shape"); + if (like_type->checked_type_.defined() && tvm::relay::IsDynamic(like_type->checked_type_)) { + return {Call(reshape_like, {x, like_type})}; + } + return {Call(reshape, {x, Call(shape, {like_type})})}; +} + +Expr GetCollapseSumLike(const Expr& x, const Expr& like_type) { + static auto collapse_axis = Op::Get("raf.op.get_reduce_axis"); + static auto collapse_keep = Op::Get("raf.op.get_kept_dims"); + static auto sum = Op::Get("raf.op.sum"); + static auto collapse_sum_like = Op::Get("raf.op.collapse_sum_like"); + if (like_type->checked_type_.defined() && tvm::relay::IsDynamic(like_type->checked_type_)) { + return Call(collapse_sum_like, {x, like_type}); + } + Call axes = Call(collapse_axis, {x, like_type}); + Call keep = Call(collapse_keep, {x, like_type}); + return Call(sum, {x, axes, keep, MakeConstant(value::BoolValue::make(false))}); +}; + Array AsTupleExpr(const Expr& expr, int numel) { if (const auto* tuple = expr.as()) { Array result; diff --git a/src/op/grad/grad_utils.h b/src/op/grad/grad_utils.h index 3ee19f41..72818181 100644 --- a/src/op/grad/grad_utils.h +++ b/src/op/grad/grad_utils.h @@ -15,6 +15,12 @@ namespace raf { namespace op { namespace grad { +Expr GetShape(const Expr& expr); + +Expr GetReshapeLike(const Expr& x, const Expr& like_type); + +ir::Expr GetCollapseSumLike(const ir::Expr& x, const ir::Expr& like_type); + ir::Array AsTupleExpr(const ir::Expr& expr, int numel); template diff --git a/src/op/grad/nn.cc b/src/op/grad/nn.cc index 443e6e4e..d7da12f1 100644 --- a/src/op/grad/nn.cc +++ b/src/op/grad/nn.cc @@ -18,16 +18,9 @@ namespace grad { using namespace raf::ir; using namespace raf::value; -Expr Shape(const Expr& expr) { - static auto op_shape = Op::Get("raf.op.shape"); - return Call(op_shape, {expr}); -} - Array BiasAddGrad(const Expr& orig_call, const Array orig_args, const Var& y, const Expr& dy) { using namespace raf::value; - static auto reshape = Op::Get("raf.op.reshape"); - static auto shape = Op::Get("raf.op.shape"); static auto sum = Op::Get("raf.op.sum"); const CallNode* call = orig_call.as(); const Expr& x = call->args[0]; @@ -35,7 +28,7 @@ Array BiasAddGrad(const Expr& orig_call, const Array orig_args, cons const Expr& axis = call->args[2]; Expr keep_dims = MakeConstant(ScalarValue::make((int64_t)0)); Expr exclude = MakeConstant(ScalarValue::make(true)); - return {Call(reshape, {dy, Call(shape, {x})}), Call(sum, {dy, axis, keep_dims, exclude})}; + return {GetReshapeLike(dy, x), Call(sum, {dy, axis, keep_dims, exclude})}; } RAF_OP_GRAD("raf.op.bias_add", BiasAddGrad); @@ -133,8 +126,8 @@ Array Conv2dGrad(const Expr& orig_call, const Array orig_args, const } // dx: w, y, dy, shape(x), stride, padding, dilation, groups // dw: x, y, dy, shape(w), stride, padding, dilation, groups - return {Call(op_dx, {w, y, dy, Shape(x), stride, padding, dilation, groups}), - Call(op_dw, {x, y, dy, Shape(w), stride, padding, dilation, groups})}; + return {Call(op_dx, {w, y, dy, GetShape(x), stride, padding, dilation, groups}), + Call(op_dw, {x, y, dy, GetShape(w), stride, padding, dilation, groups})}; } RAF_OP_GRAD("raf.op.conv2d", Conv2dGrad); @@ -161,8 +154,8 @@ Array Conv2dTransGrad(const Expr& orig_call, const Array orig_args, } // dx: w, y, dy, shape(x), stride, padding, dilation, groups // dw: x, y, dy, shape(w), stride, padding, dilation, groups - return {Call(op_dx, {w, y, dy, Shape(x), stride, padding, output_padding, dilation, groups}), - Call(op_dw, {x, y, dy, Shape(w), stride, padding, output_padding, dilation, groups})}; + return {Call(op_dx, {w, y, dy, GetShape(x), stride, padding, output_padding, dilation, groups}), + Call(op_dw, {x, y, dy, GetShape(w), stride, padding, output_padding, dilation, groups})}; } RAF_OP_GRAD("raf.op.conv2d_transpose", Conv2dTransGrad); diff --git a/src/op/grad/reduce.cc b/src/op/grad/reduce.cc index e52ee125..557d9df6 100644 --- a/src/op/grad/reduce.cc +++ b/src/op/grad/reduce.cc @@ -18,14 +18,12 @@ using namespace raf::ir; Array MeanGrad(const Expr& orig_call, const Array orig_args, const Expr& y, const Expr& dy) { static auto mean_dx = Op::Get("raf.op.mean_dx"); - static auto shape = Op::Get("raf.op.shape"); const CallNode* call = orig_call.as(); CHECK_GE(call->args.size(), 3); const Expr& axis = call->args[1]; const Expr& keepdims = call->args[2]; const Expr& exclude = call->args[3]; - const Expr& x_shape = Call(shape, {call->args[0]}); - return {Call(mean_dx, {dy, axis, x_shape, keepdims, exclude})}; + return {Call(mean_dx, {dy, GetShape(call->args[0]), axis, keepdims, exclude})}; } RAF_OP_GRAD("raf.op.mean", MeanGrad); diff --git a/src/op/grad/transform.cc b/src/op/grad/transform.cc index a6cdfc2e..99e6c542 100644 --- a/src/op/grad/transform.cc +++ b/src/op/grad/transform.cc @@ -43,25 +43,23 @@ Array AdvIndexGrad(const Expr& orig_call, const Array orig_args, con RAF_OP_GRAD("raf.op.adv_index", AdvIndexGrad); -Array BatchFlattenGrad(const Expr& orig_call, const Array orig_args, const Var& y, - const Expr& dy) { - static auto reshape = Op::Get("raf.op.reshape"); - static auto shape = Op::Get("raf.op.shape"); +Array ReshapeOpGrad(const Expr& orig_call, const Array orig_args, const Var& y, + const Expr& dy) { const CallNode* call = orig_call.as(); - return {Call(reshape, {dy, Call(shape, {call->args[0]})})}; + CHECK(call != nullptr); + const Expr& x = call->args[0]; + return {GetReshapeLike(dy, x)}; } -RAF_OP_GRAD("raf.op.batch_flatten", BatchFlattenGrad); +RAF_OP_GRAD("raf.op.batch_flatten", ReshapeOpGrad); Array TransposeGrad(const Expr& orig_call, const Array orig_args, const Var& y, const Expr& dy) { static auto transpose_dx = Op::Get("raf.op.transpose_dx"); - static auto shape = Op::Get("raf.op.shape"); const CallNode* call = orig_call.as(); CHECK(call != nullptr); const Expr& axes = call->args[1]; - const Expr& primal_shape = Call(shape, {call->args[0]}); - return {Call(transpose_dx, {dy, axes, primal_shape})}; + return {Call(transpose_dx, {dy, axes})}; } RAF_OP_GRAD("raf.op.transpose", TransposeGrad); @@ -96,16 +94,7 @@ Array BroadcastToGrad(const Expr& orig_call, const Array orig_args, const CallNode* call = orig_call.as(); CHECK(call != nullptr); const Expr& x = call->args[0]; - auto f = [&dy](const Expr& x) { - static auto collapse_axis = Op::Get("raf.op.get_reduce_axis"); - static auto collapse_keep = Op::Get("raf.op.get_kept_dims"); - static auto sum = Op::Get("raf.op.sum"); - Call axes = Call(collapse_axis, {dy, x}); - Call keep = Call(collapse_keep, {dy, x}); - return Call(sum, {dy, axes, keep, MakeConstant(value::BoolValue::make(false))}); - }; - - return {f(x)}; + return {GetCollapseSumLike(dy, x)}; } RAF_OP_GRAD("raf.op.broadcast_to", BroadcastToGrad); @@ -228,35 +217,11 @@ Array ClipGrad(const Expr& orig_call, const Array orig_args, const V RAF_OP_GRAD("raf.op.clip", ClipGrad); -Array ExpandDimsGrad(const Expr& orig_call, const Array orig_args, const Var& y, - const Expr& dy) { - static auto reshape = Op::Get("raf.op.reshape"); - static auto shape = Op::Get("raf.op.shape"); - const CallNode* call = orig_call.as(); - return {Call(reshape, {dy, Call(shape, {call->args[0]})})}; -} +RAF_OP_GRAD("raf.op.expand_dims", ReshapeOpGrad); -RAF_OP_GRAD("raf.op.expand_dims", ExpandDimsGrad); - -Array ReshapeGrad(const Expr& orig_call, const Array orig_args, const Var& y, - const Expr& dy) { - static auto reshape = Op::Get("raf.op.reshape"); - static auto shape = Op::Get("raf.op.shape"); - const CallNode* call = orig_call.as(); - return {Call(reshape, {dy, Call(shape, {call->args[0]})})}; -} - -RAF_OP_GRAD("raf.op.reshape", ReshapeGrad); - -Array SqueezeGrad(const Expr& orig_call, const Array orig_args, const Var& y, - const Expr& dy) { - static auto reshape = Op::Get("raf.op.reshape"); - static auto shape = Op::Get("raf.op.shape"); - const CallNode* call = orig_call.as(); - return {Call(reshape, {dy, Call(shape, {call->args[0]})})}; -} +RAF_OP_GRAD("raf.op.reshape", ReshapeOpGrad); -RAF_OP_GRAD("raf.op.squeeze", SqueezeGrad); +RAF_OP_GRAD("raf.op.squeeze", ReshapeOpGrad); Array TakeGrad(const Expr& orig_call, const Array orig_args, const Var& y, const Expr& dy) { @@ -275,13 +240,11 @@ RAF_OP_GRAD("raf.op.take", TakeGrad); Array EmbeddingGrad(const Expr& orig_call, const Array orig_args, const Var& y, const Expr& dy) { static auto op_dx = Op::Get("raf.op.embedding_dx"); - static auto shape = Op::Get("raf.op.shape"); const CallNode* call = orig_call.as(); CHECK_EQ(call->args.size(), 2); const Expr& x = call->args[0]; const Expr& indices = call->args[1]; - const Expr& x_shape = Call(shape, {x}); - return {Call(op_dx, {dy, indices, x_shape})}; + return {Call(op_dx, {dy, indices, GetShape(x)})}; } RAF_OP_GRAD("raf.op.embedding", EmbeddingGrad); @@ -351,15 +314,13 @@ RAF_OP_GRAD("raf.op.full_like", NoGrads<1>); Array StridedSliceGrad(const Expr& orig_call, const Array orig_args, const Var& y, const Expr& dy) { static auto op_slice_dx = Op::Get("raf.op.strided_slice_dx"); - static auto shape = Op::Get("raf.op.shape"); const CallNode* call = orig_call.as(); CHECK(call != nullptr); const Expr& begin = call->args[1]; const Expr& end = call->args[2]; const Expr& strides = call->args[3]; const Expr& mode = call->args[4]; - const Expr& primal_shape = Call(shape, {call->args[0]}); - return {Call(op_slice_dx, {dy, primal_shape, begin, end, strides, mode})}; + return {Call(op_slice_dx, {dy, GetShape(call->args[0]), begin, end, strides, mode})}; } RAF_OP_GRAD("raf.op.strided_slice", StridedSliceGrad); @@ -377,15 +338,7 @@ Array WhereGrad(const Expr& orig_call, const Array orig_args, const const Expr& dx1 = Call(where, {cond, dy, zero}); const Expr& dx2 = Call(where, {cond, zero, dy}); - auto f = [](const Expr& dx, const Expr& x) { - static auto collapse_axis = Op::Get("raf.op.get_reduce_axis"); - static auto collapse_keep = Op::Get("raf.op.get_kept_dims"); - static auto sum = Op::Get("raf.op.sum"); - Call axes = Call(collapse_axis, {dx, x}); - Call keep = Call(collapse_keep, {dx, x}); - return Call(sum, {dx, axes, keep, MakeConstant(value::BoolValue::make(false))}); - }; - return {NullValue(), f(dx1, x1), f(dx2, x2)}; + return {NullValue(), GetCollapseSumLike(dx1, x1), GetCollapseSumLike(dx2, x2)}; } RAF_OP_GRAD("raf.op.where", WhereGrad); diff --git a/src/op/ty/nn.cc b/src/op/ty/nn.cc index b44d57ca..120c7522 100644 --- a/src/op/ty/nn.cc +++ b/src/op/ty/nn.cc @@ -148,12 +148,9 @@ Type Conv2DDxwInfer(const CallValues& value) { const auto* args = value->args.as(); CHECK(args != nullptr); TensorType dy = Downcast(GetType(args->dy)); - Array res; if (args->shape.defined()) { - for (auto value : args->shape.value()) { - res.push_back(Integer(value->value)); - } - return TensorType(res, dy->dtype); + Array shape = GetShapeExprFromValue(args->shape); + return TensorType(shape, dy->dtype); } else { return IncompleteType(tvm::kType); } @@ -166,12 +163,9 @@ Type Conv2DTransposeDxwInfer(const CallValues& value) { const auto* args = value->args.as(); CHECK(args != nullptr); TensorType dy = Downcast(GetType(args->dy)); - Array res; if (args->shape.defined()) { - for (auto value : args->shape.value()) { - res.push_back(Integer(value->value)); - } - return TensorType(res, dy->dtype); + Array shape = GetShapeExprFromValue(args->shape); + return TensorType(shape, dy->dtype); } else { return IncompleteType(tvm::kType); } diff --git a/src/op/ty/reduce.cc b/src/op/ty/reduce.cc index cc588948..456d2cfa 100644 --- a/src/op/ty/reduce.cc +++ b/src/op/ty/reduce.cc @@ -7,14 +7,15 @@ * \file src/op/ty/reduce.cc * \brief Typing of reduction operators */ -#include +#include +#include #include +#include +#include "raf/op_utils.h" #include "raf/type.h" #include "../schema/likes.h" #include "../schema/reduce.h" #include "./utils.h" -#include -#include namespace raf { namespace op { @@ -168,10 +169,7 @@ Type MeanDxInfer(const CallValues& value) { const auto* args = value->args.as(); CHECK(args != nullptr); TensorType dy = Downcast(GetType(args->dy)); - Array oshape; - for (int s : args->x_shape) { - oshape.push_back(PrimExpr(s)); - } + Array oshape = GetShapeExprFromValue(args->shape); return TensorType(oshape, dy->dtype); } diff --git a/src/op/ty/transform.cc b/src/op/ty/transform.cc index efd0555d..adb78b24 100644 --- a/src/op/ty/transform.cc +++ b/src/op/ty/transform.cc @@ -113,12 +113,26 @@ Type TransposeInfer(const CallValues& value) { RAF_OP_TYPE("raf.op.transpose", "Transpose", TransposeInfer); Type TransposeDxInfer(const CallValues& value) { - const auto* args = value->args.as(); + const auto* args = value->args.as(); CHECK(args != nullptr); - TensorType dy = Downcast(GetType(args->dy)); + std::vector axes(args->axes.size(), -1); + TensorType dy = Downcast(GetType(args->x)); + size_t ndim = dy->shape.size(); + Array oshape; - for (auto dim : args->primal_shape) { - oshape.push_back(IntImm(DataType::Int(32), dim)); + if (axes.size() != 0) { + for (size_t i = 0; i < ndim; ++i) { + axes[args->axes[i]] = i; + } + CHECK_EQ(axes.size(), ndim); + for (size_t i = 0; i < ndim; ++i) { + int64_t axis = axes[i] < 0 ? axes[i] + ndim : axes[i]; + oshape.push_back(dy->shape[axis]); + } + } else { + for (size_t i = 0; i < ndim; ++i) { + oshape.push_back(dy->shape[ndim - i - 1]); + } } return TensorType(oshape, dy->dtype); } @@ -308,10 +322,7 @@ Type EmbeddingDxInfer(const CallValues& value) { const auto* args = value->args.as(); CHECK(args != nullptr); TensorType dy = Downcast(GetType(args->dy)); - std::vector shape; - for (auto val : args->num_weight) { - shape.push_back(Integer(val)); - } + auto shape = GetShapeExprFromValue(args->num_weight); return TensorType(shape, dy->dtype); } @@ -459,10 +470,11 @@ Type CastInfer(const CallValues& value) { RAF_OP_TYPE("raf.op.cast", "Cast", CastInfer); Type CastLikeInfer(const CallValues& value) { - const auto* args = value->args.as(); + const auto* args = value->args.as(); CHECK(args != nullptr); - TensorType dtype_like = Downcast(GetType(args->dtype_like)); - return dtype_like; + TensorType x = Downcast(GetType(args->x)); + TensorType like_type = Downcast(GetType(args->like_type)); + return TensorType(x->shape, like_type->dtype); } RAF_OP_TYPE("raf.op.cast_like", "CastLike", CastLikeInfer); @@ -533,7 +545,7 @@ Type ReverseSequenceInfer(const CallValues& value) { RAF_OP_TYPE("raf.op.reverse_sequence", "ReverseSequence", ReverseSequenceInfer); Type BroadcastToInfer(const CallValues& value) { - const auto* args = value->args.as(); + const auto* args = value->args.as(); CHECK(args != nullptr); std::vector shape = args->shape; Array oshape; @@ -546,14 +558,19 @@ Type BroadcastToInfer(const CallValues& value) { RAF_OP_TYPE("raf.op.broadcast_to", "BroadcastTo", BroadcastToInfer); -Type BroadcastToLikeInfer(const CallValues& value) { - const auto* args = value->args.as(); +Type BinaryShapeLikeInfer(const CallValues& value) { + const auto* args = value->args.as(); CHECK(args != nullptr); - TensorType broadcast_type = Downcast(GetType(args->broadcast_type)); - return broadcast_type; + TensorType x = Downcast(GetType(args->x)); + TensorType like_type = Downcast(GetType(args->like_type)); + return TensorType(like_type->shape, x->dtype); } -RAF_OP_TYPE("raf.op.broadcast_to_like", "BroadcastToLike", BroadcastToLikeInfer); +RAF_OP_TYPE("raf.op.broadcast_to_like", "BroadcastToLike", BinaryShapeLikeInfer); + +RAF_OP_TYPE("raf.op.collapse_sum_like", "CollapseSumLike", BinaryShapeLikeInfer); + +RAF_OP_TYPE("raf.op.reshape_like", "ReshapeLike", BinaryShapeLikeInfer); Type RepeatInfer(const CallValues& value) { const auto* args = value->args.as(); @@ -782,10 +799,7 @@ Type StridedSliceDxInfer(const CallValues& value) { const auto* args = value->args.as(); CHECK(args != nullptr); TensorType dy = Downcast(GetType(args->dy)); - Array oshape; - for (auto dim : args->primal_shape) { - oshape.push_back(IntImm(DataType::Int(32), dim)); - } + Array oshape = GetShapeExprFromValue(args->shape); return TensorType(oshape, dy->dtype); } diff --git a/src/pass/type_infer.cc b/src/pass/type_infer.cc index 15f00b2b..fa4b1829 100644 --- a/src/pass/type_infer.cc +++ b/src/pass/type_infer.cc @@ -116,6 +116,8 @@ class TypeInferencer : public ExprMutator { // can be either types or tensor values, depends on whether // they have already been evaluated/constant-folded. // Therefore it is essential to deal with both cases in their declare functions. + + // TODO(@hgt312): refactor concatenate_dx be a base op and only use types static std::unordered_set shape_list{ "raf.op.shape", "raf.op.get_reduce_axis", "raf.op.get_kept_dims", "raf.op.concatenate_dx"}; if (opn && shape_list.count(opn->name)) { diff --git a/tests/python/model/test_dynamic_model.py b/tests/python/model/test_dynamic_model.py index 85e4e02c..032567c6 100644 --- a/tests/python/model/test_dynamic_model.py +++ b/tests/python/model/test_dynamic_model.py @@ -3,9 +3,19 @@ import numpy as np import pytest +import torch + import raf from tvm import relay -from raf.testing import get_testable_devices, check, run_vm_model, get_vm_executor, resnet +from raf.testing import ( + randn_torch, + get_testable_devices, + check, + run_vm_model, + get_vm_executor, + resnet, + mlp, +) from raf._core.ndarray import Symbol from raf.model.trace import _get_func_inputs @@ -64,10 +74,9 @@ def forward(self, x): @pytest.mark.parametrize("device", ["cpu"]) @pytest.mark.parametrize("fuse", [True, False]) -def test_resnet(device, fuse): +def test_resnet_forward(device, fuse): # pylint: disable=invalid-name, protected-access m_model, _ = resnet.get_model([1, 1, 1, 1], False) - m_model.infer_mode() m_model.to(device=device) x_ty = relay.TensorType((relay.Any(), 3, 224, 224)) @@ -85,5 +94,74 @@ def test_resnet(device, fuse): check(m_res, v_res) +@pytest.mark.parametrize("device", ["cpu"]) +@pytest.mark.parametrize("fuse", [True, False]) +def test_resnet_backward(device, fuse): + # pylint: disable=invalid-name, protected-access, too-many-locals + m_model, t_model = resnet.get_model([1, 1, 1, 1]) + m_model.to(device=device) + t_model.to(device=device) + m_optimizer = raf.optim.sgd.with_sgd(learning_rate=0.1, momentum=0.01)(m_model) + t_optimizer = torch.optim.SGD(t_model.parameters(), lr=0.1, momentum=0.01) + + x_ty = relay.TensorType((relay.Any(), 3, 224, 224)) + x = Symbol.make_var("x", x_ty) + yhat_ty = relay.TensorType((relay.Any(),)) + yhat = Symbol.make_var("yhat", yhat_ty) + dy_ty = relay.TensorType(()) + dy = Symbol.make_var("dy", dy_ty) + record = m_optimizer._internal(dy, x, yhat) + mod = record.mod + + m_dy, t_dy = randn_torch((), device=device, requires_grad=True) + m_in, t_in = resnet.get_input(batch_size=1, device=device) + vm = get_vm_executor(mod, device, 2, not fuse) + inputs = _get_func_inputs(record, (m_dy, *m_in), {}, get_handle=False) + m_loss = vm(*inputs)[0][0] + + t_optimizer.zero_grad() + t_loss = t_model(*t_in) + t_loss.backward(t_dy) + t_optimizer.step() + + check(m_loss, t_loss, atol=1e-3, rtol=1e-3) + resnet.check_params(m_model, t_model, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("config", [(784, 10, 256, 256)]) +@pytest.mark.parametrize("device", ["cpu"]) +@pytest.mark.parametrize("fuse", [True, False]) +def test_mlp(config, device, fuse): + # pylint: disable=invalid-name, protected-access, too-many-locals + m_model, t_model = mlp.get_model(config) + m_model.to(device=device) + t_model.to(device=device) + m_optimizer = raf.optim.sgd.with_sgd(learning_rate=0.1, momentum=0.01)(m_model) + t_optimizer = torch.optim.SGD(t_model.parameters(), lr=0.1, momentum=0.01) + + x_ty = relay.TensorType((relay.Any(), config[0])) + x = Symbol.make_var("x", x_ty) + yhat_ty = relay.TensorType((relay.Any(),)) + yhat = Symbol.make_var("yhat", yhat_ty) + dy_ty = relay.TensorType(()) + dy = Symbol.make_var("dy", dy_ty) + record = m_optimizer._internal(dy, x, yhat) + mod = record.mod + + m_dy, t_dy = randn_torch((), device=device, requires_grad=True) + m_in, t_in = mlp.get_input(config, batch_size=1, device=device) + vm = get_vm_executor(mod, device, 2, not fuse) + inputs = _get_func_inputs(record, (m_dy, *m_in), {}, get_handle=False) + m_loss = vm(*inputs)[0] + + t_optimizer.zero_grad() + t_loss = t_model(*t_in) + t_loss.backward(t_dy) + t_optimizer.step() + + check(m_loss, t_loss, atol=1e-4, rtol=1e-4) + mlp.check_params(m_model, t_model) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/model/test_model_mlp.py b/tests/python/model/test_model_mlp.py index 26b3f278..a09b3bdf 100644 --- a/tests/python/model/test_model_mlp.py +++ b/tests/python/model/test_model_mlp.py @@ -2,90 +2,29 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -import torch.nn as nn -import torch.nn.functional as F +import torch import raf -from raf.model import Linear -from raf.testing import check, one_hot_torch, randn_torch, t2m_param - - -class TorchMlp(nn.Module): # pylint: disable=abstract-method - def __init__(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2): - super(TorchMlp, self).__init__() - self.fc1 = nn.Linear(num_inputs, num_hiddens1) - self.fc2 = nn.Linear(num_hiddens1, num_hiddens2) - self.fc3 = nn.Linear(num_hiddens2, num_outputs) - - def forward(self, x): # pylint: disable=arguments-differ - x = self.fc1(x) - x = F.relu(x) - x = self.fc2(x) - x = F.relu(x) - x = self.fc3(x) - return x - - -class RAFMlp(raf.Model): - # pylint: disable=attribute-defined-outside-init - def build(self, num_inputs, num_outputs, num_hiddens1, num_hiddens2): - self.fc1 = Linear(num_inputs, num_hiddens1) - self.fc2 = Linear(num_hiddens1, num_hiddens2) - self.fc3 = Linear(num_hiddens2, num_outputs) - - @raf.model.trace - def forward(self, x): - x = self.fc1(x) - x = raf.relu(x) - x = self.fc2(x) - x = raf.relu(x) - x = self.fc3(x) - return x - - -@pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled") -@pytest.mark.parametrize( - "config", - [ - (784, 10, 256, 256), - (512, 64, 128, 128), - (4, 2, 3, 3), - ], -) -@pytest.mark.parametrize("is_train", [True, False]) -def test_mlp(config, is_train): - m_model = RAFMlp(*config) - m_model.to(device="cuda") - t_model = TorchMlp(*config) - t_model.to(device="cuda") - m_model.fc1.w = t2m_param(t_model.fc1.weight) - m_model.fc1.b = t2m_param(t_model.fc1.bias) - m_model.fc2.w = t2m_param(t_model.fc2.weight) - m_model.fc2.b = t2m_param(t_model.fc2.bias) - m_model.fc3.w = t2m_param(t_model.fc3.weight) - m_model.fc3.b = t2m_param(t_model.fc3.bias) - - m_x, t_x = randn_torch((1, config[0]), requires_grad=is_train, device="cuda") - m_y, t_y = one_hot_torch(batch_size=1, num_classes=config[-1]) - if is_train: - m_model.train_mode() - t_model.train() - else: - m_model.infer_mode() - t_model.eval() - m_y = m_model(m_x) - t_y = t_model(t_x) - if is_train: - m_dy, t_dy = randn_torch(m_y.shape, std=m_y.numpy().std() * 0.0001, device="cuda") - t_y.backward(t_dy) - m_y.backward(m_dy) - check(m_model.fc1.w.grad, t_model.fc1.weight.grad, rtol=1e-4, atol=1e-4) - check(m_model.fc1.b.grad, t_model.fc1.bias.grad, rtol=1e-4, atol=1e-4) - check(m_model.fc2.w.grad, t_model.fc2.weight.grad, rtol=1e-4, atol=1e-4) - check(m_model.fc2.b.grad, t_model.fc2.bias.grad, rtol=1e-4, atol=1e-4) - check(m_model.fc3.w.grad, t_model.fc3.weight.grad, rtol=1e-4, atol=1e-4) - check(m_model.fc3.b.grad, t_model.fc3.bias.grad, rtol=1e-4, atol=1e-4) - check(m_y, t_y, rtol=1e-4, atol=1e-4) +from raf.testing import mlp, check, randn_torch, run_vm_model + + +@pytest.mark.parametrize("config", [(784, 10, 256, 256)]) +@pytest.mark.parametrize("device", ["cpu"]) +def test_mlp(config, device): + m_model, t_model = mlp.get_model(config) + m_model.to(device=device) + t_model.to(device=device) + m_optimizer = raf.optim.sgd.with_sgd(learning_rate=0.1, momentum=0.01)(m_model) + t_optimizer = torch.optim.SGD(t_model.parameters(), lr=0.1, momentum=0.01) + m_dy, t_dy = randn_torch((), device=device, requires_grad=False) + m_in, t_in = mlp.get_input(config, batch_size=1, device=device) + m_loss = run_vm_model(m_optimizer, device, [m_dy, *m_in])[0] + t_optimizer.zero_grad() + t_loss = t_model(*t_in) + t_loss.backward(t_dy) + t_optimizer.step() + check(m_loss, t_loss, atol=1e-4, rtol=1e-4) + mlp.check_params(m_model, t_model, atol=1e-4, rtol=1e-4) if __name__ == "__main__": diff --git a/tests/python/op/tvm/test_tvm_transform.py b/tests/python/op/tvm/test_tvm_transform.py index d7123f2f..f1d5ba43 100644 --- a/tests/python/op/tvm/test_tvm_transform.py +++ b/tests/python/op/tvm/test_tvm_transform.py @@ -280,6 +280,32 @@ def test_broadcast_to_like(shape, device): check(m_x.grad, np.ones(shape[0], dtype="float32") * (n_dy.size / n_x.size)) +@pytest.mark.parametrize("device", get_testable_devices()) +@pytest.mark.parametrize("shape", [[[4, 5], [3, 4, 5]], [[4, 2, 2], [3, 4, 2, 2]]]) +def test_collapse_sum_like(shape, device): + model = TestModel(raf._op.sym.collapse_sum_like) + m_x, n_x = randn(shape[1], device=device) + m_like_type, _ = randn(shape[0], device=device) + m_y = model(m_x, m_like_type) + v_y = run_vm_model(model, device, [m_x, m_like_type]) + n_y = np.sum(n_x, 0) + check(m_y, n_y) + check(v_y, n_y) + + +@pytest.mark.parametrize("device", get_testable_devices()) +@pytest.mark.parametrize("shape", [[[4, 5], [5, 2, 2]], [[4, 2, 2], [2, 8]]]) +def test_reshape_like(shape, device): + model = TestModel(raf._op.sym.reshape_like) + m_x, n_x = randn(shape[1], device=device) + m_like_type, _ = randn(shape[0], device=device) + m_y = model(m_x, m_like_type) + v_y = run_vm_model(model, device, [m_x, m_like_type]) + n_y = np.reshape(n_x, m_y.shape) + check(m_y, n_y) + check(v_y, n_y) + + @pytest.mark.parametrize("device", get_testable_devices()) @pytest.mark.parametrize("shape", [[10, 20, 30]]) @pytest.mark.parametrize("axis", [0, 1]) diff --git a/tests/python/op/ty/test_type_transform.py b/tests/python/op/ty/test_type_transform.py index 08495d1b..4292e285 100644 --- a/tests/python/op/ty/test_type_transform.py +++ b/tests/python/op/ty/test_type_transform.py @@ -222,6 +222,50 @@ def forward(self, x, broadcast_type): # pylint: disable=no-self-use check_type(m_func, expected_type) +@pytest.mark.parametrize("shape", [[[1, 4, 1], [1, 4, 1]], [[4, 1, 1], [3, 4, 2, 2]]]) +def test_collapse_sum_like(shape): + class CollapseSumLikeModel(raf.Model): + def build(self): + pass + + @raf.model.trace + def forward(self, x, like_type): # pylint: disable=no-self-use + return raf.collapse_sum_like(x, like_type) + + model = CollapseSumLikeModel() + m_x, _ = randn(shape[0]) + like_type, _ = randn(shape[1]) + m_func = model._internal(m_x, like_type).mod["main"] + m_func = run_infer_type(m_func) + x_ty = TensorType(shape[0]) + like_ty = TensorType(shape[1]) + y_ty = TensorType(shape[1]) + expected_type = FuncType([x_ty, like_ty], y_ty) + check_type(m_func, expected_type) + + +@pytest.mark.parametrize("shape", [[[1, 4, 1], [2, 2]], [[4, 1, 1], [1, 2, 2]]]) +def test_reshape_like(shape): + class ReshapeLikeModel(raf.Model): + def build(self): + pass + + @raf.model.trace + def forward(self, x, like_type): # pylint: disable=no-self-use + return raf.reshape_like(x, like_type) + + model = ReshapeLikeModel() + m_x, _ = randn(shape[0]) + like_type, _ = randn(shape[1]) + m_func = model._internal(m_x, like_type).mod["main"] + m_func = run_infer_type(m_func) + x_ty = TensorType(shape[0]) + like_ty = TensorType(shape[1]) + y_ty = TensorType(shape[1]) + expected_type = FuncType([x_ty, like_ty], y_ty) + check_type(m_func, expected_type) + + @pytest.mark.parametrize("dtype", ["float32"]) @pytest.mark.parametrize( "shape", diff --git a/tests/python/pass/test_pass_auto_cast.py b/tests/python/pass/test_pass_auto_cast.py index fc4b1c0c..43d8545e 100644 --- a/tests/python/pass/test_pass_auto_cast.py +++ b/tests/python/pass/test_pass_auto_cast.py @@ -358,7 +358,7 @@ def build(self): @raf.model.trace def forward(self, x): - return raf._op.sym.mean_dx(x) + return raf._op.sym.mean_dx(x, ()) m_x, _ = randn((), dtype="float32", device=device) model = Model() diff --git a/tests/python/pass/test_pass_from_relay.py b/tests/python/pass/test_pass_from_relay.py index ed77396f..f589679c 100644 --- a/tests/python/pass/test_pass_from_relay.py +++ b/tests/python/pass/test_pass_from_relay.py @@ -404,27 +404,31 @@ def forward(self, x): check_from_relay(model, r_func, [m_x]) -@pytest.mark.xfail(reason="broadcast_to_like with static shape will be simplified to broadcast_to") -@pytest.mark.parametrize("shape", [[[1, 4, 1], [1, 2, 4, 1]]]) +@pytest.mark.xfail(reason="binary shape like ops with static shape will be simplified") +@pytest.mark.parametrize("op", ["broadcast_to_like", "collapse_sum_like", "reshape_like"]) +@pytest.mark.parametrize("shape", [[[1, 4, 1], [1, 4, 1]]]) @pytest.mark.parametrize("dtype", ["float32"]) -def test_broadcast_to_like(shape, dtype): - class BroadcastToLike(raf.Model): +def test_binary_shape_like(op, shape, dtype): + m_op = getattr(raf._op.sym, op) + r_op = getattr(_relay, op) + + class BinaryShapeLike(raf.Model): def build(self): pass @raf.model.trace - def forward(self, x, broadcast_type): # pylint: disable=no-self-use - return raf.broadcast_to_like(x, broadcast_type) + def forward(self, x, like_type): # pylint: disable=no-self-use + return m_op(x, like_type) - model = BroadcastToLike() + model = BinaryShapeLike() m_x, _ = randn(shape[0], dtype=dtype) - broadcast_type, _ = randn(shape[1], dtype=dtype) + like_type, _ = randn(shape[1], dtype=dtype) r_x = _relay.var("x", shape=shape[0]) r_b = _relay.var("b", shape=shape[1]) - r_func = _relay.Function(params=[r_x, r_b], body=_relay.broadcast_to_like(r_x, r_b)) + r_func = _relay.Function(params=[r_x, r_b], body=r_op(r_x, r_b)) - check_from_relay(model, r_func, [m_x, broadcast_type]) + check_from_relay(model, r_func, [m_x, like_type]) @pytest.mark.parametrize("shape", [[(2, 2), (1, 0)], [(2, 2), None]]) From 171342cecf8a5837466ffab45e2738a51680569b Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 24 Feb 2022 14:42:16 -0800 Subject: [PATCH 2/4] [Refactor] Change CI Meta -> RAF (#878) --- .github/workflows/ci_lint.yml | 2 +- .github/workflows/ci_unit_test.yml | 4 ++-- .github/workflows/deploy_docker.yml | 2 +- ci/batch/README.md | 6 +++--- ci/batch/backup-ccache.sh | 2 +- docker/batch/entry.sh | 2 +- docker/push.sh | 2 +- docs/wiki/3_dev_guide/Memory-Pool.md | 4 ++-- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/workflows/ci_lint.yml b/.github/workflows/ci_lint.yml index d4c78aca..51549940 100644 --- a/.github/workflows/ci_lint.yml +++ b/.github/workflows/ci_lint.yml @@ -21,7 +21,7 @@ jobs: if: github.repository == 'meta-project/meta' runs-on: self-hosted container: - image: metaprojdev/meta:ci_cpu-v0.18 + image: metaprojdev/raf:ci_cpu-v0.18 steps: - name: Checkout repository uses: actions/checkout@v2 diff --git a/.github/workflows/ci_unit_test.yml b/.github/workflows/ci_unit_test.yml index 76e25110..5279e685 100644 --- a/.github/workflows/ci_unit_test.yml +++ b/.github/workflows/ci_unit_test.yml @@ -20,8 +20,8 @@ jobs: if: github.repository == 'meta-project/meta' runs-on: self-hosted outputs: - cpu_image: "metaprojdev/meta:ci_cpu-v0.18" - gpu_image: "metaprojdev/meta:ci_gpu-v0.20" + cpu_image: "metaprojdev/raf:ci_cpu-v0.18" + gpu_image: "metaprojdev/raf:ci_gpu-v0.20" skip_ci: ${{ steps.job_info.outputs.skip_ci }} ref: ${{ steps.job_info.outputs.ref }} repo: ${{ steps.job_info.outputs.repo }} diff --git a/.github/workflows/deploy_docker.yml b/.github/workflows/deploy_docker.yml index 40445fe8..f9cf33fd 100644 --- a/.github/workflows/deploy_docker.yml +++ b/.github/workflows/deploy_docker.yml @@ -35,7 +35,7 @@ jobs: with: context: docker file: docker/Dockerfile.${{ github.event.inputs.type }} - tags: metaprojdev/meta:${{ github.event.inputs.type }}-${{ github.event.inputs.tag }} + tags: metaprojdev/raf:${{ github.event.inputs.type }}-${{ github.event.inputs.tag }} push: true - name: Image digest run: echo ${{ steps.docker_build.outputs.digest }} diff --git a/ci/batch/README.md b/ci/batch/README.md index e072358f..d5dca777 100644 --- a/ci/batch/README.md +++ b/ci/batch/README.md @@ -167,7 +167,7 @@ AWS Batch has to be properly configured to make the above flow working as expect "type": "container", "parameters": {}, "containerProperties": { - "image": "metaprojdev/meta:ci_gpu-v0.20", + "image": "metaprojdev/raf:ci_gpu-v0.20", "command": [], "jobRoleArn": ***, "executionRoleArn": ***, @@ -210,7 +210,7 @@ AWS Batch has to be properly configured to make the above flow working as expect "type": "container", "parameters": {}, "containerProperties": { - "image": "metaprojdev/meta:ci_cpu-v0.18", + "image": "metaprojdev/raf:ci_cpu-v0.18", "command": [], "jobRoleArn": ***, "executionRoleArn": ***, @@ -249,7 +249,7 @@ AWS Batch has to be properly configured to make the above flow working as expect "type": "container", "parameters": {}, "containerProperties": { - "image": "metaprojdev/meta:ci_gpu-v0.20", + "image": "metaprojdev/raf:ci_gpu-v0.20", "command": [], "jobRoleArn": ***, "executionRoleArn": ***, diff --git a/ci/batch/backup-ccache.sh b/ci/batch/backup-ccache.sh index 16fbc5fe..021918bb 100644 --- a/ci/batch/backup-ccache.sh +++ b/ci/batch/backup-ccache.sh @@ -10,7 +10,7 @@ MODE=$1 # upload or download PLATFORM=$2 # CPU, GPU, or multi-GPU TAG=$3 # e.g., refs/heads/main, pr-7 -S3_BUCKET="ci-meta" +S3_BUCKET="ci-raf" S3_FOLDER=`echo cache-${TAG} | sed 's/\//_/g'` S3_PATH="s3://$S3_BUCKET/$S3_FOLDER" diff --git a/docker/batch/entry.sh b/docker/batch/entry.sh index a50cf7b2..a9d61048 100644 --- a/docker/batch/entry.sh +++ b/docker/batch/entry.sh @@ -13,7 +13,7 @@ SOURCE_REF=$1 REPO=$2 COMMAND=$3 SAVE_OUTPUT=$4 -REMOTE_FOLDER=$5 # e.g., s3://ci-meta/pr-7 +REMOTE_FOLDER=$5 # e.g., s3://ci-raf/pr-7 echo "Job Info" echo "-------------------------------------" diff --git a/docker/push.sh b/docker/push.sh index d711ed6b..860ec6e4 100755 --- a/docker/push.sh +++ b/docker/push.sh @@ -31,7 +31,7 @@ PASSWORD="$1" shift 1 LOCAL_IMAGE_NAME=raf.${CONTAINER_TYPE}:latest -REMOTE_IMAGE_NAME=${DOCKER_HUB_ACCOUNT}/meta:${CONTAINER_TYPE}-${VERSION} +REMOTE_IMAGE_NAME=${DOCKER_HUB_ACCOUNT}/raf:${CONTAINER_TYPE}-${VERSION} echo "Login docker hub" docker login -u ${DOCKER_HUB_ACCOUNT} -p ${PASSWORD} diff --git a/docs/wiki/3_dev_guide/Memory-Pool.md b/docs/wiki/3_dev_guide/Memory-Pool.md index 9d326900..0739276a 100644 --- a/docs/wiki/3_dev_guide/Memory-Pool.md +++ b/docs/wiki/3_dev_guide/Memory-Pool.md @@ -7,7 +7,7 @@ This document introduces the Memory Pool of RAF. ## Strategies -Currently, there are two types of memory pool in meta: (1) no_pool, (2) page_unit_pool. +Currently, there are two types of memory pool in RAF: (1) no_pool, (2) page_unit_pool. By default, we choose page_unit_pool as our memory pool, which could bring down the running time by almost 50% for rn50/vgg/etc compared with no_pool. The memory usage of these two strategies are similar. Here is an experiment on ResNet50 with Tesla T4 (15109MB) @@ -115,4 +115,4 @@ Then you can create the Pool Class that derived from `raf::memory_pool::MemoryPo Remember to register your pool in the cpp file you created, the code should be like: `RAF_REGISTER_GLOBAL("raf.memory_pool._make.your_pool").set_body_typed(YourPool::make);` -After re-make meta, you can enable your pool by calling `InitPool(contxt, pool_name)`. +After re-make RAF, you can enable your pool by calling `InitPool(contxt, pool_name)`. From a5c99c13d5a1c6b2b2c3594f08ddebaeb473a624 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 25 Feb 2022 18:37:28 -0800 Subject: [PATCH 3/4] [CI] Miner improvements (#879) --- .github/workflows/ci_unit_test.yml | 31 ++++++++++++++++++++++++++++++ README.md | 10 +++------- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci_unit_test.yml b/.github/workflows/ci_unit_test.yml index 5279e685..0db71890 100644 --- a/.github/workflows/ci_unit_test.yml +++ b/.github/workflows/ci_unit_test.yml @@ -205,3 +205,34 @@ jobs: --command "bash ./ci/batch/cli.sh config_cmake GPU 75 && bash ./ci/batch/cli.sh compile build multi-GPU ${{ needs.check_status.outputs.job_tag }} && bash ./ci/batch/cli.sh unit_test multi-GPU" + + update_ci_badge: + needs: [test_on_cpu, test_on_gpu, test_on_multi_gpu] + if: github.repository == 'meta-project/meta' + runs-on: self-hosted + steps: + - uses: haya14busa/action-workflow_run-status@v1 + - name: Checkout repository + # No need to checkout submodules because we only need to get the HEAD commit hash. + - name: Generate CI badge + id: badge + run: | + # env vars are unavailable in job.if so we have to implement it here. + if [ "${{ needs.check_status.outputs.pr }}" != "" ]; then + echo "No need to update badge for PR CI. Skip." + exit 0 + fi + head_commit=$(git rev-parse --short HEAD) + echo "::set-output name=gist_id::630a36600930c8d68e6b15f16333b532" + echo "::set-output name=message::${head_commit}" + - name: Update CI badge + # Intentionally fail this step with empty gist_id. + uses: schneegans/dynamic-badges-action@v1.1.0 + continue-on-error: true + with: + auth: ${{ secrets.DEPLOY_ACCESS_TOKEN }} + gistID: ${{ steps.badge.outputs.gist_id }} + filename: raf-ci-badge-last-pass.json + label: CI-Last-Success + message: ${{ steps.badge.outputs.message }} + color: blue diff --git a/README.md b/README.md index 998f702c..1f65ea7d 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,10 @@ -RAF -=== - -[![CI-Lint](https://github.com/meta-project/meta/actions/workflows/ci_lint.yml/badge.svg)](https://github.com/meta-project/meta/actions/workflows/ci_lint.yml) -[![CI-UnitTest](https://github.com/meta-project/meta/actions/workflows/ci_unit_test.yml/badge.svg)](https://github.com/meta-project/meta/actions/workflows/ci_unit_test.yml) +RAF: RAF Accelerates deep learning Frameworks +============================================= +![CI-Lass-Pass](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/aire-meta-bot/630a36600930c8d68e6b15f16333b532/raw/raf-ci-badge-last-pass.json) Please refer to our [wiki](docs/wiki) for more information. - - From df572d227f9d201f32b474ea6750d954da81328b Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 25 Feb 2022 19:27:27 -0800 Subject: [PATCH 4/4] Update ci_unit_test.yml --- .github/workflows/ci_unit_test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci_unit_test.yml b/.github/workflows/ci_unit_test.yml index 0db71890..159d2db5 100644 --- a/.github/workflows/ci_unit_test.yml +++ b/.github/workflows/ci_unit_test.yml @@ -214,6 +214,7 @@ jobs: - uses: haya14busa/action-workflow_run-status@v1 - name: Checkout repository # No need to checkout submodules because we only need to get the HEAD commit hash. + uses: actions/checkout@v2 - name: Generate CI badge id: badge run: |