From a14d4935b9edecd33f720e7427c34c161bd5bb13 Mon Sep 17 00:00:00 2001 From: Zhen Jia <53954057+zhen-jia@users.noreply.github.com> Date: Fri, 22 Apr 2022 18:42:23 -0700 Subject: [PATCH 01/37] [Dist][Pass] ZeRO optimization (#23) * group ops * add tests * lint * lint * lint * init * group allgather * lint * update * license header * black format * lint * lint * lint * fix bug for nccl <2100 * update * address comments --- include/raf/dist_config.h | 2 + include/raf/pass.h | 6 + python/raf/optim/data_parallel.py | 6 +- python/raf/optim/lans.py | 3 - src/common/shape_utils.h | 22 ++ src/impl/vm/compiler.cc | 9 +- src/op/dialect/cuda/lans.cc | 61 ++-- src/op/dialect/nccl/communication_utils.h | 1 + src/op/ty/transform.cc | 2 +- src/pass/group_allgather.cc | 306 ++++++++++++++++++ src/pass/partition_gradient.cc | 118 +++++-- src/pass/type_infer.cc | 1 - tests/python/optim/test_lans.py | 4 +- tests/python/optim/test_sgd.py | 4 +- .../python/pass/test_pass_estimate_memory.py | 2 +- .../python/pass/test_pass_group_allgather.py | 102 ++++++ .../pass/test_pass_partition_gradient.py | 65 ++-- 17 files changed, 621 insertions(+), 93 deletions(-) create mode 100644 src/pass/group_allgather.cc create mode 100644 tests/python/pass/test_pass_group_allgather.py diff --git a/include/raf/dist_config.h b/include/raf/dist_config.h index 682db080..56e99ed3 100644 --- a/include/raf/dist_config.h +++ b/include/raf/dist_config.h @@ -22,12 +22,14 @@ class DistConfigObj : public ir::Object { int zero_opt_level = 0; int auto_dp_profiling_start_iter = 2; int auto_dp_profiling_end_iter = 4; + int64_t group_bucket_size = 5000000000; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("enable_data_parallel", &enable_data_parallel); v->Visit("zero_opt_level", &zero_opt_level); v->Visit("auto_dp_profiling_start_iter", &auto_dp_profiling_start_iter); v->Visit("auto_dp_profiling_end_iter", &auto_dp_profiling_end_iter); + v->Visit("group_bucket_size", &group_bucket_size); } public: diff --git a/include/raf/pass.h b/include/raf/pass.h index 48451b9b..3fa10b56 100644 --- a/include/raf/pass.h +++ b/include/raf/pass.h @@ -399,6 +399,12 @@ Pass IOSStreamSchedule(); Pass Deduplicate(int forward_steps, bool consider_type, bool must_dominate, ir::Optional salt); +/*! + * \brief This pass works in ANF and group allgather operators for ZeRO. + * \return The created pass. + */ +Pass GroupAllgather(); + // Helper functions /*! diff --git a/python/raf/optim/data_parallel.py b/python/raf/optim/data_parallel.py index 02bc0893..cc71236a 100644 --- a/python/raf/optim/data_parallel.py +++ b/python/raf/optim/data_parallel.py @@ -49,7 +49,11 @@ def forward(self, *args, **kwargs): # passes.append(AutoDataParallel()) if dcfg.zero_opt_level > 0: passes.append(InferType()) - passes.append(PartitionGradient(dcfg.zero_opt_level, comm.size, comm.rank)) + passes.append( + PartitionGradient( + dcfg.zero_opt_level, comm.size, comm.rank, dcfg.group_bucket_size + ) + ) record = self.model._internal(*args, **kwargs) mod = record.mod diff --git a/python/raf/optim/lans.py b/python/raf/optim/lans.py index 47fddfc7..38b89c9e 100644 --- a/python/raf/optim/lans.py +++ b/python/raf/optim/lans.py @@ -252,9 +252,6 @@ def forward(self, dy, *args, **kwargs): if "float" not in w.dtype: continue - if self.dtype != "float32": - dxi = _op.cast(dxi, "float32") - g_list.append(dxi) x_list.append(w) m_list.append(m) diff --git a/src/common/shape_utils.h b/src/common/shape_utils.h index 83da2ff0..0c9102ee 100644 --- a/src/common/shape_utils.h +++ b/src/common/shape_utils.h @@ -214,6 +214,28 @@ inline int64_t BytesCompactType(const Type& type) { throw; } +inline int64_t GetElementNum(const Expr& var) { + int64_t n; + CHECK(var->checked_type_.defined()); + if (var->checked_type().as()) { + n = 0; + for (auto field : Downcast(var)->fields) { + int64_t fn = GetElementNum(field); + n += fn; + } + } else { + n = 1; + auto var_type = var->checked_type().as(); + CHECK(var_type != nullptr); + for (int i = 0; i < var_type->shape.size(); ++i) { + PrimExpr k = var_type->shape[i]; + int64_t k_v = k.as()->value; + n *= k_v; + } + } + return n; +} + } // namespace shape_utils } // namespace common } // namespace raf diff --git a/src/impl/vm/compiler.cc b/src/impl/vm/compiler.cc index a5ae7a8e..e5e14d12 100644 --- a/src/impl/vm/compiler.cc +++ b/src/impl/vm/compiler.cc @@ -842,13 +842,18 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const DeviceMap& device tvm::With dctx((*it).second); pass::PassContext pass_ctx = pass::PassContext::Current(); tvm::With ctx(pass_ctx); - + auto dcfg = DistConfig::Global(); + auto device_t = (*it).second.device_type(); Array pass_seqs; // optimization passes that work on ANF pass_seqs.push_back(pass::GradInputSelect()); pass_seqs.push_back(pass::InlineLet()); pass_seqs.push_back(pass::DeadCodeElimination()); + // enable group all gather for ZeRO. + if (dcfg->zero_opt_level > 1 && dcfg->group_bucket_size > 1 && device_t == DevType::kCUDA()) { + pass_seqs.push_back(pass::GroupAllgather()); + } bool enable_stream_schedule = true; if (!pass_ctx->GetConfig("raf.vm.optimize.anf_only", Bool(false)).value()) { @@ -865,7 +870,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const DeviceMap& device pass_seqs.push_back(pass::EraseType()); // optimization passes that transform BBNF into ANF - if ((*it).second.device_type() == DevType::kCUDA()) { + if (device_t == DevType::kCUDA()) { if (DistConfig::Global()->enable_data_parallel) { // The current design of EnforceSync assumes ops are executed on multiple CUDA streams: // all computation ops are executed on a computation stream, and all communication diff --git a/src/op/dialect/cuda/lans.cc b/src/op/dialect/cuda/lans.cc index 5f460c86..c5d3d0ed 100644 --- a/src/op/dialect/cuda/lans.cc +++ b/src/op/dialect/cuda/lans.cc @@ -80,6 +80,11 @@ class LansImpl : public raf::op::OpEnv { RequestWorkspace(¶m_norm_tensor_, cv->device, 4 * param_group_n_); RequestWorkspace(&update_m_norm_, cv->device, 4 * param_group_n_); RequestWorkspace(&q_norm_tensor_, cv->device, 4 * param_group_n_); + + static auto cuda_device_api = DeviceAPI::Get(DevType::kCUDA()); + compute_stream_ = cuda_device_api->GetStream(); + cpu_ctx_.device_type = kDLCPU; + cpu_ctx_.device_id = 0; } void Execute(const CallValues& cv) override { @@ -90,19 +95,14 @@ class LansImpl : public raf::op::OpEnv { } void Execute(const std::vector& inputs, Value output) override { - static auto cuda_device_api = DeviceAPI::Get(DevType::kCUDA()); - void* compute_stream = cuda_device_api->GetStream(); TupleValue tuple = ir::Downcast(inputs[0]); DLTensor* t0 = ir::Downcast(tuple->fields[0]); CHECK(t0->dtype.code == kDLFloat); CHECK((t0->dtype.bits == 32) || (t0->dtype.bits == 16)); - DLDevice cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; auto* tstep = inputs[1].as(); tensor::Tensor step_tensor = tstep->tensor; CHECK(step_tensor->ndim == 0); - tvm::runtime::NDArray step_array = step_tensor.CopyTo(cpu_ctx); + tvm::runtime::NDArray step_array = step_tensor.CopyTo(cpu_ctx_); float fstep = reinterpret_cast(step_array->data)[0]; int step = (int)fstep; float bias_correction1 = 1.0f; @@ -116,35 +116,26 @@ class LansImpl : public raf::op::OpEnv { beta3 = 1 - beta1_; } - switch (t0->dtype.bits) { - case 32: { - std::vector tlist; - for (int i = 0; i < param_group_n_; ++i) { - DLTensor* tensor = ir::Downcast(tuple->fields[i]); - tlist.push_back(static_cast(tensor->data)); - } - tlist.push_back(static_cast(q_tensor_buf_)); - for (int i = 1; i < numels_.size(); ++i) { - tlist.push_back(static_cast(q_tensor_buf_) + numels_[i - 1]); - } - for (int i = param_group_n_; i < tuple->fields.size(); ++i) { - DLTensor* tensor = ir::Downcast(tuple->fields[i]); - tlist.push_back(static_cast(tensor->data)); - } - multi_tensor_lans_cuda( - CHUNK_SIZE, tlist, learning_rate_, beta1_, beta2_, eps_, bias_correction_, - bias_correction1, bias_correction2, beta3, weight_decay_, grad_averaging_, mode_, - normalize_grad_, numels_, compute_stream, static_cast(output_per_tensor_), - static_cast(grad_norm_tensor_), static_cast(param_norm_tensor_), - static_cast(update_m_norm_), static_cast(q_norm_tensor_), - max_chunks_per_tensor_); - break; - } - default: { - LOG(FATAL) << "Unsupported dtype: " << DType(t0->dtype).c_str(); - throw; - } + std::vector tlist; + for (int i = 0; i < param_group_n_; ++i) { + DLTensor* tensor = ir::Downcast(tuple->fields[i]); + tlist.push_back(static_cast(tensor->data)); + } + tlist.push_back(static_cast(q_tensor_buf_)); + for (int i = 1; i < numels_.size(); ++i) { + tlist.push_back(static_cast(q_tensor_buf_) + numels_[i - 1]); + } + for (int i = param_group_n_; i < tuple->fields.size(); ++i) { + DLTensor* tensor = ir::Downcast(tuple->fields[i]); + tlist.push_back(static_cast(tensor->data)); } + multi_tensor_lans_cuda( + CHUNK_SIZE, tlist, learning_rate_, beta1_, beta2_, eps_, bias_correction_, bias_correction1, + bias_correction2, beta3, weight_decay_, grad_averaging_, mode_, normalize_grad_, numels_, + compute_stream_, static_cast(output_per_tensor_), + static_cast(grad_norm_tensor_), static_cast(param_norm_tensor_), + static_cast(update_m_norm_), static_cast(q_norm_tensor_), + max_chunks_per_tensor_); } std::string name() const override { @@ -174,6 +165,8 @@ class LansImpl : public raf::op::OpEnv { void* q_norm_tensor_; int max_chunks_per_tensor_; void* q_tensor_buf_; + void* compute_stream_; + DLDevice cpu_ctx_; }; RAF_REGISTER_DIALECT_OP(cuda, lans, 20); diff --git a/src/op/dialect/nccl/communication_utils.h b/src/op/dialect/nccl/communication_utils.h index 1ff8ba5d..d1f7d2bf 100644 --- a/src/op/dialect/nccl/communication_utils.h +++ b/src/op/dialect/nccl/communication_utils.h @@ -33,6 +33,7 @@ inline DType::operator ncclDataType_t() const { switch (code) { case kDLInt: if (bits == 8) return ncclInt8; + if (bits == 64) return ncclInt64; break; case kDLUInt: if (bits == 8) return ncclUint8; diff --git a/src/op/ty/transform.cc b/src/op/ty/transform.cc index d3065da6..3065138a 100644 --- a/src/op/ty/transform.cc +++ b/src/op/ty/transform.cc @@ -712,7 +712,6 @@ Type StridedSliceInfer(const CallValues& value) { auto dshape = data->shape; int64_t num_axis = dshape.size(); - Array begin = GetShapeExprFromValue(args->begin); Array end = GetShapeExprFromValue(args->end); auto is_any = [](PrimExpr expr) { return expr->IsInstance(); }; @@ -812,6 +811,7 @@ Type StridedSliceInfer(const CallValues& value) { slice_range = end_v - begin_v; step = stride_v; } + CHECK_NE(step, 0) << "step can not be zero "; oshape[i] = Integer((slice_range + step - 1) / step); } diff --git a/src/pass/group_allgather.cc b/src/pass/group_allgather.cc new file mode 100644 index 00000000..e6eb970f --- /dev/null +++ b/src/pass/group_allgather.cc @@ -0,0 +1,306 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +/*! + * \file group_allgather.cc + * \brief ZeRO optimzed graph, this pass group the cast, if there is, and allgather ops. + */ +#include "raf/pass.h" +#include "./common.h" +#include "./let_list.h" +#include "../common/shape_utils.h" +#include "raf/op_utils.h" +#include "raf/dist_config.h" +#include "raf/communicator.h" +#include + +namespace raf { +namespace pass { +namespace group_comm { + +using namespace raf::op; +using raf::distributed::DistConfig; +using namespace raf::distributed::communicator; +class CommGrouper : public ExprMutator { + public: + CommGrouper(const Function& func) : func_(func) { + auto dcfg = DistConfig::Global(); + auto comm = GetGlobalCommunicator(); + local_rank_ = comm->local_rank; + bucket_size_ = dcfg->group_bucket_size; + auto ell = ExplicitLetList::make(func->body); + auto ret = ell->exprs.back().as(); + ret_var_ = ell->vars.back(); + for (int i = 2; i < ret->fields.size(); ++i) { + params_.Set(Downcast(ret->fields[i]), Expr()); + } + scopes_.emplace_back(new LetList); + } + + Function Group() { + if (params_.empty()) { + return func_; + } + Array func_params{func_->params}; + auto new_body = this->Mutate(func_->body); + return Function(func_params, new_body, func_->ret_type, {}, func_->attrs); + } + + Expr VisitExpr_(const LetNode* node) { + static Op zeros_op = Op::Get("raf.op.zeros"); + static Op group_cast = Op::Get("raf.op.group_cast"); + static Op group_allgather = Op::Get("raf.op._group_allgather"); + + scopes_.emplace_back(new LetList); + auto scope = scopes_.back().get(); + Expr body; + do { + auto curr_var = node->var; + auto value = VisitExpr(node->value); + + bool comm_node = false; + + Nodes re_nodes = MatchParamUpdateWithAllGather(node); + + Var update_var = re_nodes.update_var; + if (update_var.defined()) { + comm_node = true; + auto slice_node = re_nodes.slice_node; + auto gather_node = re_nodes.gather_node; + auto cast_node = re_nodes.cast_node; + auto add_node = re_nodes.add_node; + + auto gather_var = gather_node->var; + int64_t size = common::shape_utils::GetElementNum(update_var); + auto var_type = gather_var->checked_type_.as(); + CHECK(var_type != nullptr); + Var zero_input; + + Expr allgather_input; + Expr allgather_output; + Call cast_call; + Call slice_call; + + Call gather_call = Downcast(gather_node->value); + Call add_call = Downcast(add_node->value); + + if (cast_node) { + cast_allgather_ = true; + cast_call = Downcast(cast_node->value); + allgather_input = cast_call->args[0]; + } else { + allgather_input = gather_call->args[0]; + } + + if (slice_node) { + slice_call = Downcast(slice_node->value); + slice_dic_[allgather_inputs_.size()] = slice_call->args[2]; + zero_input = scope->Push( + Call(zeros_op, + {MakeConstant(ArrayToIntTuple(var_type->shape)), + MakeConstant(StringValue::make(DLDataType2String(var_type->dtype))), + MakeConstant(StringValue::make("cuda(" + std::to_string(local_rank_) + ")"))})); + allgather_output = zero_input; + } else { + allgather_output = add_call->args[2]; + } + + if (curr_size_ + size < bucket_size_) { + curr_size_ += size; + allgather_inputs_.push_back(allgather_input); + allgather_outputs_.push_back(allgather_output); + update_params_.push_back(update_var); + } else { + curr_size_ = size; + Var gather_input; + if (cast_node) { + auto cast_input = scope->Push(Tuple(allgather_inputs_)); + gather_input = scope->Push(Call(group_cast, {cast_input, cast_call->args[1]})); + } else { + gather_input = scope->Push(Tuple(allgather_inputs_)); + } + auto gather_output = scope->Push(Tuple(allgather_outputs_)); + auto output = scope->Push( + Call(group_allgather, {gather_input, gather_call->args[1], gather_output})); + for (int i = 0; i < allgather_inputs_.size(); ++i) { + auto out_tensor = scope->Push(TupleGetItem(output, i)); + if (slice_dic_.count(i)) { + out_tensor = scope->Push( + Call(slice_op_, + {out_tensor, MakeConstant(TupleValue::make({ScalarValue::make(0)})), + slice_dic_[i], MakeConstant(TupleValue::make({ScalarValue::make(1)}))})); + } + params_.Set(update_params_[i], out_tensor); + } + + allgather_inputs_ = {allgather_input}; + allgather_outputs_ = {allgather_output}; + if (slice_node) { + slice_dic_.clear(); + slice_dic_[allgather_inputs_.size() - 1] = slice_call->args[2]; + } + update_params_ = {update_var}; + } + node = add_node; + } else if (curr_var == ret_var_) { + comm_node = true; + if (allgather_inputs_.size() > 1) { + auto gather_input = scope->Push(Tuple(allgather_inputs_)); + auto gather_output = scope->Push(Tuple(allgather_outputs_)); + if (cast_allgather_) { + gather_input = scope->Push( + Call(group_cast, {gather_input, MakeConstant(StringValue::make("float16"))})); + } + auto output = scope->Push(Call( + group_allgather, {gather_input, MakeConstant(ScalarValue::make(0)), gather_output})); + for (int i = 0; i < allgather_inputs_.size(); ++i) { + auto out_tensor = scope->Push(TupleGetItem(output, i)); + if (slice_dic_.count(i)) { + out_tensor = scope->Push( + Call(slice_op_, + {out_tensor, MakeConstant(TupleValue::make({ScalarValue::make(0)})), + slice_dic_[i], MakeConstant(TupleValue::make({ScalarValue::make(1)}))})); + } + params_.Set(update_params_[i], out_tensor); + } + } + Array tuple; + auto ret_value = value.as(); + tuple.push_back(ret_value->fields[0]); + tuple.push_back(ret_value->fields[1]); + for (int j = 2; j < ret_value->fields.size(); ++j) { + auto key = Downcast(ret_value->fields[j]); + if (params_[key].defined()) { + tuple.push_back(params_[key]); + } else { + tuple.push_back(key); + } + } + scope->Push(curr_var, Tuple(tuple)); + } + if (comm_node == false) { + scope->Push(curr_var, value); + } + body = node->body; + node = body.as(); + + } while (node); + auto ret = scopes_.back()->Get(this->Mutate(body)); + scopes_.pop_back(); + return ret; + } + + private: + struct Nodes { + Var update_var; + const LetNode* cast_node; + const LetNode* slice_node; + const LetNode* gather_node; + const LetNode* add_node; + }; + // TODO @zhen-jia we will have an ANF-based pattern matching mechanism in the future. + inline Nodes MatchParamUpdateWithAllGather(const LetNode* node) { + const LetNode* gather_node = nullptr; + const LetNode* cast_node = nullptr; + const LetNode* visit_node = nullptr; + if (IsOp(node, cast_op_)) { + // Matching cast -> allgather -> update parameter throguht in place update add + cast_node = node; + gather_node = node->body.as(); + if (IsOp(gather_node, allgather_op_)) { + visit_node = gather_node->body.as(); + } + } else if (IsOp(node, allgather_op_)) { + // Matching allgather -> update parameter throguht in place update add + gather_node = node; + visit_node = node->body.as(); + } + if (visit_node) { + auto result_nodes = FindUpdateVar(visit_node); + Var update_var = result_nodes.update_var; + auto slice_node = result_nodes.slice_node; + auto add_node = result_nodes.add_node; + return Nodes{update_var, cast_node, slice_node, gather_node, add_node}; + } + return Nodes{NullValue(), nullptr, nullptr, nullptr, nullptr}; + } + + inline Nodes FindUpdateVar(const LetNode* node) { + if (IsOp(node, slice_op_)) { + // Machinig stride_slice followed by add node + auto add_node = node->body.as(); + if (IsOp(add_node, add_op_)) { + auto update_var = add_node->var; + if (params_.count(update_var)) { + return Nodes{update_var, nullptr, node, nullptr, add_node}; + } + } + } else { + // no slice op, only matching add node + if (IsOp(node, add_op_)) { + auto update_var = node->var; + if (params_.count(update_var)) { + return Nodes{update_var, nullptr, nullptr, nullptr, node}; + } + } + } + return Nodes{NullValue(), nullptr, nullptr, nullptr, nullptr}; + } + + inline bool IsOp(const LetNode* node, Op op) { + if (node->value.as()) { + auto call = Downcast(node->value); + auto opn = Downcast(call->op); + if (opn == op) { + return true; + } + } + return false; + } + + /*! \brief The target function. */ + Function func_; + /*! \brief The scope stack of the let list. */ + std::vector> scopes_; + /*! \brief The parameters of the target function. */ + Map params_; + /*! \brief The inputs of allgather. */ + std::vector allgather_inputs_; + /*! \brief The outputs of allgather. */ + std::vector allgather_outputs_; + /*! \brief The parameters need to be updated. */ + std::vector update_params_; + /*! \brief Track the tensors that need to be sliced. */ + std::unordered_map slice_dic_; + /*! \brief Group bucket size. */ + size_t bucket_size_; + /*! \brief The current bucket size for the group. */ + size_t curr_size_ = 0; + /*! \brief whether has cast op before allgather. */ + bool cast_allgather_ = false; + /*! \brief The return var. */ + Var ret_var_; + /*! \brief Local rank. */ + int local_rank_; + // ops using in this pass + Op add_op_ = Op::Get("raf.op.add"); + Op allgather_op_ = Op::Get("raf.op._allgather"); + Op slice_op_ = Op::Get("raf.op.strided_slice"); + Op cast_op_ = Op::Get("raf.op.cast"); +}; +} // namespace group_comm + +Pass GroupAllgather() { + TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return group_comm::CommGrouper(f).Group(); }; + auto group_allgather_pass = CreateRAFFunctionPass(pass_func, 0, "GroupAllgather", {}); + + return RAFSequential({InferType(), group_allgather_pass, InferType()}, "GroupAllgather"); +} // namespace pass + +RAF_REGISTER_GLOBAL("raf.pass_.GroupAllgather").set_body_typed(GroupAllgather); + +} // namespace pass +} // namespace raf diff --git a/src/pass/partition_gradient.cc b/src/pass/partition_gradient.cc index 61b28e51..0f48c3b2 100644 --- a/src/pass/partition_gradient.cc +++ b/src/pass/partition_gradient.cc @@ -24,8 +24,8 @@ namespace partition_gradient { class GradientPartitioner : public ExprMutator { public: - GradientPartitioner(int opt_level, int n_part, const Function& func) - : opt_level_(opt_level), n_part_(n_part), func_(func) { + GradientPartitioner(int opt_level, int n_part, int64_t bucket_size, const Function& func) + : opt_level_(opt_level), n_part_(n_part), bucket_size_(bucket_size), func_(func) { // Build the var to expr map for the ANF. Map var_to_expr; auto ell = ExplicitLetList::make(func->body); @@ -50,12 +50,13 @@ class GradientPartitioner : public ExprMutator { grad_tuple_var_ = Downcast(tuple->fields[tgi->index]); grads = var_to_expr[grad_tuple_var_]; } - - for (auto field : Downcast(grads)->fields) { + auto grad_fields = Downcast(grads)->fields; + for (auto field : grad_fields) { CHECK(field->IsInstance()) << "Expected a var in the gradient tuple, but got " << field->GetTypeKey(); grads_.Set(Downcast(field), Expr()); } + last_all_reduce_ = Downcast(grad_fields[grad_fields.size() - 1]); scopes_.emplace_back(new LetList); } @@ -84,8 +85,7 @@ class GradientPartitioner : public ExprMutator { if (grads_.count(curr_var) > 0) { // The curr_var is a complete gradient. CHECK(!grads_[curr_var].defined()); - auto grad_var = SliceGrad(scope, curr_var, value, opt_level_); - grads_.Set(curr_var, grad_var); + SliceGrad(scope, curr_var, value, opt_level_); } else if (curr_var == grad_tuple_var_) { // Replace gradients with sliced ones. Array fields; @@ -212,20 +212,34 @@ class GradientPartitioner : public ExprMutator { * let %4 = TupleGetItem(%3, rank); * TODO(comaniac): Add %rank to the function argument if rank_ is unknown. * - * The desired IR for ZeRO-2 is: + * The desired IR for ZeRO-2 is if bucket_size_ < 2: * // if NCCL version is >= 2.10 * let %1 = op(%0); // A backward op to generate gradient * let %2 = pad(%1, ...); // %1 is the complete local gradient - * let %3 = split(%2, ...); + * let %3 = Tuple(%2); * let %4 = reduce_scatter(%3, avg); * // else NCCL version is < 2.10 * let %1 = op(%0); // A backward op to generate gradient * let %2 = pad(%1, ...); // %1 is the complete local gradient - * let %3 = split(%2, ...); + * let %3 = Tuple(%2); * let %4 = reduce_scatter(%3, sum); * let %5 = divide(%4, ...) - */ - Var SliceGrad(LetList* scope, const Var& var, const Expr& value, int opt_level) { + * The desired IR for ZeRO-2 is if bucket_size_ > 2, which means group reduce_scatter: + * // if NCCL version is >= 2.10 + * let %1 = op(%0); // A backward op to generate gradient + * let %2 = pad(%1, ...); // %1 is the complete local gradient + * let %3 = Tuple(%2); + * let %4 = group_reduce_scatter(%3, avg); + * // else NCCL version is < 2.10 + * let %1 = op(%0); // A backward op to generate gradient + * let %2 = pad(%1, ...); // %1 is the complete local gradient + * let %3 = Tuple(%2); + * let %4 = group_reduce_scatter(%3, sum); + * let %5 = %4.0 + * let %6 = divide(%5, ...) + +*/ + void SliceGrad(LetList* scope, const Var& var, const Expr& value, int opt_level) { static const Op& split_op = Op::Get("raf.op.split"); static const Op& reduce_scatter_op = Op::Get("raf.op._reduce_scatter"); Expr allreduce_expr, divide_expr; @@ -235,11 +249,12 @@ class GradientPartitioner : public ExprMutator { // no need to apply ZeRO-2. opt_level = 1; } - auto grad_var = var; + Var grad_var; if (opt_level > 1) { // ZeRO-2: Replace the AllReduce with ReduceScatter. auto first_arg = Downcast(GetNArg(allreduce_expr, 0)); // The 1st arg of allreduce is a tuple of tensors. + Constant compute = Downcast(GetNArg(allreduce_expr, 1)); auto arg_tuple = Downcast(var_to_expr_[first_arg]); CHECK_EQ(arg_tuple->fields.size(), 1U) << "Not supported yet"; @@ -251,28 +266,70 @@ class GradientPartitioner : public ExprMutator { } else { grad_var = Downcast(arg_tuple->fields[0]); } + + int64_t size = common::shape_utils::GetElementNum(grad_var); + grad_var = GenPadCall(scope, grad_var); + if (bucket_size_ < 2) { + // Do not group redcue_scatter + grad_var = scope->Push(Tuple({grad_var})); + auto reduce_scatter_var = scope->Push(Call(reduce_scatter_op, {grad_var, compute})); + if (divide_expr.defined()) { + // update the divide op args + auto divide_call = divide_expr.as(); + reduce_scatter_var = + scope->Push(Call(divide_call->op, {reduce_scatter_var, divide_call->args[1]})); + } + grads_.Set(var, reduce_scatter_var); + } else { + if (var == last_all_reduce_) { + scatter_input_.push_back(grad_var); + scatter_var_.push_back(var); + divide_expr_.push_back(divide_expr); + IssueGroupScatter(scope, compute); + return; + } + if (curr_size_ + size < bucket_size_) { + scatter_input_.push_back(grad_var); + scatter_var_.push_back(var); + divide_expr_.push_back(divide_expr); + curr_size_ += size; + } else { + IssueGroupScatter(scope, compute); + scatter_input_.push_back(grad_var); + scatter_var_.push_back(var); + divide_expr_.push_back(divide_expr); + curr_size_ = size; + } + } } else { // ZeRO-1: Keep AllReduce (or the backward op if data parallel is disabled). scope->Push(var, value); + grad_var = GenPadCall(scope, var); + grad_var = scope->Push(Call(split_op, {grad_var, MakeConstant(ScalarValue::make(n_part_)), + MakeConstant(ScalarValue::make(0))})); + auto replace_var = scope->Push(TupleGetItem(grad_var, rank_)); + grads_.Set(var, replace_var); } + } - grad_var = GenPadCall(scope, grad_var); - grad_var = scope->Push(Call(split_op, {grad_var, MakeConstant(ScalarValue::make(n_part_)), - MakeConstant(ScalarValue::make(0))})); + void IssueGroupScatter(LetList* scope, Constant compute) { + static const Op& group_reduce_scatter = Op::Get("raf.op._group_reduce_scatter"); + auto inputs = scope->Push(Tuple(scatter_input_)); - if (opt_level > 1) { - auto compute = Downcast(GetNArg(allreduce_expr, 1)); - auto reduce_scatter_var = scope->Push(Call(reduce_scatter_op, {grad_var, compute})); + auto scatter_out = scope->Push(Call(group_reduce_scatter, {inputs, compute})); + for (int i = 0; i < scatter_var_.size(); ++i) { + auto update_var = scope->Push(TupleGetItem(scatter_out, i)); + auto divide_expr = divide_expr_[i]; if (divide_expr.defined()) { // update the divide op args auto divide_call = divide_expr.as(); - return scope->Push(Call(divide_call->op, {reduce_scatter_var, divide_call->args[1]})); + update_var = scope->Push(Call(divide_call->op, {update_var, divide_call->args[1]})); } - return reduce_scatter_var; + grads_.Set(scatter_var_[i], update_var); } - return scope->Push(TupleGetItem(grad_var, rank_)); + scatter_input_ = {}; + scatter_var_ = {}; } - /*! \brief The scope stack of the let list. */ std::vector> scopes_; /*! \brief The optimization level (ZeRO-n). */ @@ -289,14 +346,27 @@ class GradientPartitioner : public ExprMutator { Var grad_tuple_var_; /*! \brief Mapping from let-binding var to the expression. */ Map var_to_expr_; + /*! \brief The bucket size for group scatter. */ + int64_t bucket_size_; + /*! \brief The current bucket size for group scatter. */ + int64_t curr_size_ = 0; + /*! \brief The last all reduce in graph. */ + Var last_all_reduce_; + /*! \brief The inputs for group scatter. */ + std::vector scatter_input_; + /*! \brief The group scatter var. */ + std::vector scatter_var_; + /*! \brief Divide expr after allreduce for NCCL version < 2.10. */ + std::vector divide_expr_; }; } // namespace partition_gradient -Pass PartitionGradient(int opt_level, int n_part, int rank) { +Pass PartitionGradient(int opt_level, int n_part, int rank, int64_t bucket_size) { TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return partition_gradient::GradientPartitioner(opt_level, n_part, f).Partition(rank); + return partition_gradient::GradientPartitioner(opt_level, n_part, bucket_size, f) + .Partition(rank); }; auto partition_gradient = CreateRAFFunctionPass(pass_func, 0, "PartitionGradientFunc", {}); return RAFSequential({partition_gradient, EraseType(), DeadCodeElimination()}, diff --git a/src/pass/type_infer.cc b/src/pass/type_infer.cc index fa4b1829..81723e46 100644 --- a/src/pass/type_infer.cc +++ b/src/pass/type_infer.cc @@ -89,7 +89,6 @@ class TypeInferencer : public ExprMutator { Expr VisitExpr_(const CallNode* call) override { static const Op& invoke_op = Op::Get("raf.op.vm.invoke_op"); const OpNode* opn = call->op.as(); - if (opn && GetRef(opn) == invoke_op) { // Since invoke_op use the second argument (input tuple) to invoke // the first argument (op or closure), we need to update the var map diff --git a/tests/python/optim/test_lans.py b/tests/python/optim/test_lans.py index 43e8f58c..91699b67 100644 --- a/tests/python/optim/test_lans.py +++ b/tests/python/optim/test_lans.py @@ -218,6 +218,7 @@ class MockConfig: def __init__(self): self.enable_data_parallel = True self.zero_opt_level = 2 + self.group_bucket_size = 50000000 mock_get_config.return_value = MockConfig() @@ -259,7 +260,8 @@ def __init__(self): # Verify IR. This model has 7 parameters and 9 gradients # (gradients for input data and ytrure are useless). - assert text.count("raf.op._reduce_scatter") == 9, text + # The 9 _reduce_scatters are grouped. So only 1 _group_reduce_scatter. + assert text.count("raf.op._group_reduce_scatter") == 1, text assert text.count("raf.op._allgather") == 7, text assert text.count("raf.op.strided_slice") == 7, text diff --git a/tests/python/optim/test_sgd.py b/tests/python/optim/test_sgd.py index 437d3a16..eb160ada 100644 --- a/tests/python/optim/test_sgd.py +++ b/tests/python/optim/test_sgd.py @@ -253,6 +253,7 @@ class MockConfig: def __init__(self): self.enable_data_parallel = True self.zero_opt_level = 2 + self.group_bucket_size = 50000000 mock_get_config.return_value = MockConfig() @@ -295,7 +296,8 @@ def __init__(self): # Verify IR. This model has 8 parameters and 10 gradients # (gradients for input data and ytrure are useless). - assert text.count("raf.op._reduce_scatter") == 10, text + # The 10 _reduce_scatters are grouped. So only 1 _group_reduce_scatter. + assert text.count("raf.op._group_reduce_scatter") == 1, text assert text.count("raf.op._allgather") == 8, text assert text.count("raf.op.strided_slice") == 8, text diff --git a/tests/python/pass/test_pass_estimate_memory.py b/tests/python/pass/test_pass_estimate_memory.py index df58fae7..0c80ce68 100644 --- a/tests/python/pass/test_pass_estimate_memory.py +++ b/tests/python/pass/test_pass_estimate_memory.py @@ -27,7 +27,7 @@ def verify_memory(mod, device, expected_trace, disable_fusion=True, include_para for (name, mem), expected in zip(trace, expected_trace): assert name != "unknown" if isinstance(expected, tuple): # The expected memory could be a range. - assert expected[0] < mem < expected[1] + assert expected[0] <= mem < expected[1] else: check(mem, expected) diff --git a/tests/python/pass/test_pass_group_allgather.py b/tests/python/pass/test_pass_group_allgather.py new file mode 100644 index 00000000..115692db --- /dev/null +++ b/tests/python/pass/test_pass_group_allgather.py @@ -0,0 +1,102 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=unused-variable,protected-access +from unittest.mock import patch +import pytest + +import raf +from raf.model import Conv2d, Linear, BatchNorm +from raf.testing import one_hot_torch, randn +from raf._ffi import pass_ + + +class RAFTest(raf.Model): + # pylint: disable=attribute-defined-outside-init + def build(self, input_shape=28, num_classes=10): + self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=5, padding=2, bias=False) + self.bn1 = BatchNorm(6) + self.linear1 = Linear((input_shape // 2) ** 2 * 6, num_classes) + + # pylint: enable=attribute-defined-outside-init + + @raf.model.trace + def forward(self, x, y_true): + y_pred = self.forward_infer(x) + y_pred = raf.log_softmax(y_pred) + loss = raf.nll_loss(y_true=y_true, y_pred=y_pred) + return loss + + @raf.model.trace + def forward_infer(self, x): + out = self.bn1(self.conv1(x)) + out = raf.sigmoid(out) + out = raf.avg_pool2d(out, (2, 2), (2, 2)) + out = raf.batch_flatten(out) + out = self.linear1(out) + return out + + +def optimize(mod): + mod = pass_.ToGraphNormalForm()(mod) + mod = pass_.ToBasicBlockNormalForm()(mod) + mod = pass_.ToANormalForm()(mod) + mod = pass_.InferType()(mod) + mod = pass_.GroupAllgather()(mod) + return mod + + +def lower(model, args): + mod = model._internal(*args).mod + return optimize(mod) + + +@pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled") +@patch("raf.distributed.get_communicator") +@patch("raf.distributed.get_config") +def test_group(mock_get_config, mock_get_comm): + # pylint: disable=too-many-locals, protected-access + # Mock the context to let with_lans generate the desired IR. + class MockConfig: + def __init__(self): + self.enable_data_parallel = True + self.zero_opt_level = 2 + self.group_bucket_size = 5000000000 + + mock_get_config.return_value = MockConfig() + + class MockComm: + def __init__(self): + self.size = 4 + self.local_rank = 0 + self.rank = 3 + + mock_get_comm.return_value = MockComm() + shape, n_classes = 28, 10 + batch_size = 7 + m_model = RAFTest(shape, 10) + m_model.train_mode() + m_optimizer = raf.optim.lans.with_lans()(m_model) + + device = "cuda" + m_x, _ = randn([batch_size, 3, shape, shape], requires_grad=True, device=device) + m_dy, _ = randn((), device=device, requires_grad=False) + m_ytrue, _ = one_hot_torch(batch_size=batch_size, num_classes=n_classes, device=device) + args = [m_dy, m_x, m_ytrue] + + record = m_optimizer._internal(*args) + mod = record.mod + + func = lower(m_optimizer, [*args])["main"] + text = raf.ir.AsText(func) + ## Verify IR. This model has 7 parameters and 9 gradients + ## There should be only 1 group GroupAllgather that perform + ## group operations for 9 gradients. + assert text.count("raf.op._group_allgather") == 1, text + assert text.count("raf.op.strided_slice") == 7, text + ## Using "zeros(" to exclude "zeros_like" op. + assert text.count("raf.op.zeros(") == 7, text + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/pass/test_pass_partition_gradient.py b/tests/python/pass/test_pass_partition_gradient.py index 5aa576f4..36904ebc 100644 --- a/tests/python/pass/test_pass_partition_gradient.py +++ b/tests/python/pass/test_pass_partition_gradient.py @@ -2,9 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 # pylint: disable=too-many-arguments, protected-access, attribute-defined-outside-init -# pylint: disable=too-many-locals +# pylint: disable=too-many-locals, too-many-branches import re - from unittest.mock import patch import pytest @@ -17,14 +16,17 @@ def verify_ir(opt_level, ad_model, args, rank_size, rank, n_grad, n_pad): + bucket_size = 5000000000 record = ad_model._internal(*args) mod = record.mod mod = InferType()(mod) - mod = PartitionGradient(opt_level, rank_size, rank)(mod) + mod = PartitionGradient(opt_level, rank_size, rank, bucket_size)(mod) mod = InferType()(mod) text = raf.ir.AsText(mod) assert text.count("raf.op.pad") == n_pad - assert text.count("raf.op.split") == n_grad + if opt_level == 1: + assert text.count("raf.op.split") == n_grad + nccl_version = raf.build.with_nccl() if opt_level == 1: @@ -33,44 +35,59 @@ def verify_ir(opt_level, ad_model, args, rank_size, rank, n_grad, n_pad): # let %x_2 = %x_1.3; # ... # let %a10 = (..., %x_2, ...); - slice_grad_regex = fr"let %x_(\d+) = %x_\d+\.{rank};" + grad_regex = rf"let %x_(\d+) = %x_\d+\.{rank};" elif opt_level == 2: - # ZeRO-2 uses reduce_scatter to slice gradients. - assert text.count("raf.op._reduce_scatter") == n_grad, text + # ZeRO-2 uses group_reduce_scatter to slice gradients. + assert text.count("raf.op._group_reduce_scatter") == 1, text # Gradients will be sliced as follows in ZeRO-2: # # if NCCL version is >= 2.10 - # let %x_1 = raf.op.split(%a0, 4); - # let %x_2 = raf.op._recuce_scatter(%x_1, avg); + # let %x = (%x_0, %x1, ...); + # let %y = raf.op._group_recuce_scatter(%x, avg); + # let %y_0 = %y.0 + # let %y_1 = %y.1 # ... - # let %a10 = (..., %x_2, ...); + # let %a10 = (..., %y_0, %y_1, ...); # # else NCCL version is < 2.10 - # let %x_1 = raf.op.split(%a0, 4); - # let %x_2 = raf.op._recuce_scatter(%x_1, sum); - # let %x_3 = raf.op.divide(%x_2, ...) + # let %x = (%x_0, %x1, ...); + # let %y = raf.op._group_recuce_scatter(%x, avg); + # let %y_0 = %y.0 + # let %y_3 = raf.op.divide(%y_0, ...) # ... - # let %a10 = (..., %x_3, ...); + # let %a10 = (..., %y_3, ...); if nccl_version >= 21000: - slice_grad_regex = fr"let %x_(\d+) = raf.op._reduce_scatter.+" + grad_regex = rf"let %x_(\d+) = raf.op._group_reduce_scatter.+" else: - slice_grad_regex = fr"let %x_(\d+) = raf.op.divide.+" + grad_regex = rf"let %x_(\d+) = raf.op.divide.+" else: assert False, "Unsupported opt_level %d" % opt_level # Verify that the output gradient tuple contains all sliced gradients. verify_grad_tuple = False split_grads = set() + find_grad = False for line in text.split("\n"): - tokens = re.search(slice_grad_regex, line) + tokens = re.search(grad_regex, line) if tokens: - split_grads.add(f"%x_{tokens.group(1)}") + grad_var = f"%x_{tokens.group(1)}" + split_grads.add(grad_var) + find_grad = True continue - - tokens = re.search(r"let .+ = \((.+)\);", line) - if tokens: - if all([g in split_grads for g in tokens.group(1).replace(" ", "").split(",")]): - verify_grad_tuple = True - break + if find_grad: + if opt_level == 2 and nccl_version >= 2100: + pattern = rf"let %x_(\d+) = {grad_var}.+;" + tokens = re.search(pattern, line) + if tokens: + if grad_var in split_grads: + split_grads = set() + split_grads.add(f"%x_{tokens.group(1)}") + continue + + tokens = re.search(r"let .+ = \((.+)\);", line) + if tokens: + if all([g in split_grads for g in tokens.group(1).replace(" ", "").split(",")]): + verify_grad_tuple = True + break assert verify_grad_tuple if raf.build.with_cuda(): From 65268b4d327225370295bffc1b996bb364cb0f74 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Sat, 23 Apr 2022 16:27:17 +0800 Subject: [PATCH 02/37] [Fix] Support MPI w/o NCCL (#24) * [Fix] Support MPI w/o NCCL * fix * Update CMakeLists.txt Co-authored-by: Huang, Guangtai * fix Co-authored-by: Huang, Guangtai --- CMakeLists.txt | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 697a30fc..2f39531f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -160,17 +160,22 @@ else() endif() if (${RAF_USE_NCCL} STREQUAL "OFF") - set(RAF_DISTRIBUTED_SOURCE_FILES "") + set(RAF_NCCL_SOURCE_FILES "") else () set(RAF_CXX_FLAGS ${RAF_CXX_FLAGS} -DRAF_USE_NCCL) - file(GLOB_RECURSE RAF_DISTRIBUTED_SOURCE_FILES - ${CMAKE_CURRENT_LIST_DIR}/src/distributed/cuda/*.cc + file(GLOB_RECURSE RAF_NCCL_SOURCE_FILES + ${CMAKE_CURRENT_LIST_DIR}/src/distributed/cuda/nccl*.cc ${CMAKE_CURRENT_LIST_DIR}/src/op/dialect/nccl/*.cc ) endif() -if (${RAF_USE_MPI} STREQUAL "ON") +if (${RAF_USE_MPI} STREQUAL "OFF") + set(RAF_MPI_SOURCE_FILES "") +else () set(RAF_CXX_FLAGS ${RAF_CXX_FLAGS} -DRAF_USE_MPI) + file(GLOB_RECURSE RAF_MPI_SOURCE_FILES + ${CMAKE_CURRENT_LIST_DIR}/src/distributed/cuda/mpi*.cc + ) endif() set(RAF_SOURCE_FILES @@ -179,7 +184,8 @@ set(RAF_SOURCE_FILES ${RAF_CUDNN_SOURCE_FILES} ${RAF_CUBLAS_SOURCE_FILES} ${RAF_CUTLASS_SOURCE_FILES} - ${RAF_DISTRIBUTED_SOURCE_FILES} + ${RAF_MPI_SOURCE_FILES} + ${RAF_NCCL_SOURCE_FILES} ) add_library(raf_objs OBJECT ${RAF_SOURCE_FILES}) From e807f7513ae318e6ddd9534cc1a184ddaf6a4e53 Mon Sep 17 00:00:00 2001 From: "Huang, Guangtai" Date: Sat, 23 Apr 2022 18:28:17 +0800 Subject: [PATCH 03/37] Update communicator.py --- python/raf/distributed/communicator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/raf/distributed/communicator.py b/python/raf/distributed/communicator.py index b321e424..d7566b2e 100644 --- a/python/raf/distributed/communicator.py +++ b/python/raf/distributed/communicator.py @@ -65,8 +65,6 @@ def dumps(self): "rank", "local_size", "local_rank", - "world_size", - "world_rank", ] return {attr: getattr(self, attr) for attr in attr_keys} From 675d2800b00f73f01857fd7be8cd41814a57aad0 Mon Sep 17 00:00:00 2001 From: AIREMetaBot <100344401+aire-meta-bot@users.noreply.github.com> Date: Tue, 26 Apr 2022 00:03:28 +0800 Subject: [PATCH 04/37] [TVM] Update Submodule (#26) --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 0b9bcf0e..d2db9cb0 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 0b9bcf0e7ab5c898d80ff8057dca022979e5e797 +Subproject commit d2db9cb0d839e32778f461b77e59f6418282a511 From 72d7be65d3ec2fb6813d411b071afdde0e5f1e24 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 2 May 2022 01:24:13 +0800 Subject: [PATCH 05/37] [InferType] Fix closure type error (#28) * [InferType] Fix closure type error * fix * fix * fix * fix --- src/pass/type_infer.cc | 161 ++++++++++++++-------- tests/python/pass/test_pass_infer_type.py | 52 +++++++ 2 files changed, 157 insertions(+), 56 deletions(-) diff --git a/src/pass/type_infer.cc b/src/pass/type_infer.cc index 81723e46..85350096 100644 --- a/src/pass/type_infer.cc +++ b/src/pass/type_infer.cc @@ -51,14 +51,19 @@ class TypeInferencer : public ExprMutator { } Expr VisitExpr_(const VarNode* op) override { - if (!op->checked_type_.defined()) { - if (op->type_annotation.defined()) { - op->checked_type_ = op->type_annotation; + auto var = GetRef(op); + if (closure_param_map_.count(var) > 0) { + // Use the updated closure parame var. + var = closure_param_map_[var]; + } + if (!var->checked_type_.defined()) { + if (var->type_annotation.defined()) { + var->checked_type_ = var->type_annotation; } else { - op->checked_type_ = IncompleteType(kType); + var->checked_type_ = IncompleteType(kType); } } - return GetRef(op); + return var; } Expr VisitExpr_(const GlobalVarNode* op) override { @@ -71,7 +76,7 @@ class TypeInferencer : public ExprMutator { return std::move(GetRef(op)); } - CallValues SchemaToValue(Array args, const OpNode* op) { + CallValues SchemaToValue(Array args, const Op op) { CallValues call_values = CallValues::make(); Array arg_values; for (const auto& arg : args) { @@ -81,8 +86,8 @@ class TypeInferencer : public ExprMutator { arg_values.push_back(GetValue(arg)); } } - call_values->args = GetOpAttr(GetRef(op), "FRAFSchema")(arg_values); - call_values->callee = OpValue::make(GetRef(op)); + call_values->args = GetOpAttr(op, "FRAFSchema")(arg_values); + call_values->callee = OpValue::make(op); return call_values; } @@ -120,7 +125,7 @@ class TypeInferencer : public ExprMutator { static std::unordered_set shape_list{ "raf.op.shape", "raf.op.get_reduce_axis", "raf.op.get_kept_dims", "raf.op.concatenate_dx"}; if (opn && shape_list.count(opn->name)) { - CallValues call_values = SchemaToValue(args, opn); + CallValues call_values = SchemaToValue(args, GetRef(opn)); declare_op[GetRef(opn)](call_values); if (call_values->out.defined()) { Expr re = ir::MakeConstant(call_values->out); @@ -132,44 +137,52 @@ class TypeInferencer : public ExprMutator { UpdateFuncParamVarMap(fn, call->args); } - Expr op = VisitExpr(call->op); - Call ret = Call(op, args, call->attrs, call->type_args); - if (const FunctionNode* fn = ret->op.as()) { - ret->checked_type_ = InferClosure(ret); - } else if (const GlobalVarNode* gvn = ret->op.as()) { - ret->op->checked_type_ = - Unify(gvn->checked_type(), mod_->Lookup(GetRef(gvn))->checked_type()); - ret->checked_type_ = InferClosure(ret); - } else if (const OpNode* opn = ret->op.as()) { - ret->checked_type_ = InferPrimitive(ret, opn); - } else if (ret->op.as() || ret->op.as()) { - // handle recursive func call when op is a var node - if (op->checked_type()->IsInstance()) { - ret->checked_type_ = IncompleteType(kType); - } else { - // The var node can be a result of the output type of a func call. A var node - // here is valid if it points to a function. Check that the type is a FuncType - // and the args of the Call match the type of the FuncType. If yes, return the - // FuncType's ret_type. - if (const auto* var_node = ret->op.as()) { - VisitPrimitiveClosureFromCallerArgs(var_node, call->args); - } - const FuncTypeNode* fty_node = ret->op->checked_type_.as(); - CHECK(fty_node); - for (size_t i = 0; i < fty_node->arg_types.size(); i++) { - ret->args[i]->checked_type_ = Unify(fty_node->arg_types[i], ret->args[i]->checked_type()); + Call ret = Call(call->op, args, call->attrs, call->type_args); + if (const FunctionNode* fn = call->op.as()) { + auto ret_type = InferClosure(ret, GetRef(fn)); + ret = Call(VisitExpr(call->op), args, call->attrs, call->type_args); + ret->checked_type_ = ret_type; + } else if (const GlobalVarNode* gvn = call->op.as()) { + auto fn = Downcast(mod_->Lookup(GetRef(gvn))); + auto ret_type = InferClosure(ret, fn); + ret = Call(VisitExpr(call->op), args, call->attrs, call->type_args); + ret->op->checked_type_ = Unify(gvn->checked_type(), fn->checked_type()); + ret->checked_type_ = ret_type; + } else { + Expr op = VisitExpr(call->op); + ret = Call(op, args, call->attrs, call->type_args); + if (const OpNode* opn = ret->op.as()) { + ret->checked_type_ = InferPrimitive(ret, GetRef(opn)); + } else if (ret->op.as() || ret->op.as()) { + // handle recursive func call when op is a var node + if (op->checked_type()->IsInstance()) { + ret->checked_type_ = IncompleteType(kType); + } else { + // The var node can be a result of the output type of a func call. A var node + // here is valid if it points to a function. Check that the type is a FuncType + // and the args of the Call match the type of the FuncType. If yes, return the + // FuncType's ret_type. + if (const auto* var_node = ret->op.as()) { + VisitPrimitiveClosureFromCallerArgs(var_node, call->args); + } + const FuncTypeNode* fty_node = ret->op->checked_type_.as(); + CHECK(fty_node); + for (size_t i = 0; i < fty_node->arg_types.size(); i++) { + ret->args[i]->checked_type_ = + Unify(fty_node->arg_types[i], ret->args[i]->checked_type()); + } + ret->checked_type_ = fty_node->ret_type; } - ret->checked_type_ = fty_node->ret_type; + } else if (const auto* ftn = op->checked_type().as()) { + ret->checked_type_ = ftn->ret_type; + } else { + LOG(FATAL) << "Invalid op type: " << call->op->GetTypeKey(); } - } else if (const auto* ftn = op->checked_type().as()) { - ret->checked_type_ = ftn->ret_type; - } else { - LOG(FATAL) << "Invalid op type: " << call->op->GetTypeKey(); } return ret; } - Type InferPrimitive(const Call& call, const OpNode* op) { + Type InferPrimitive(const Call& call, const Op op) { // Only type inference from leaf to root is supported. // Thus incomplete inputs will not be inferred from outputs. // Instead, the incompleteness propogates. @@ -194,14 +207,42 @@ class TypeInferencer : public ExprMutator { } } - Type InferClosure(const Call& call) { + Type InferClosure(const Call& call, const Function& fn) { // TODO(@hzfan): perform template param deduction to eliminate type_params - FuncType fty = Downcast(call->op->checked_type()); - CHECK_EQ(call->args.size(), fty->arg_types.size()); + bool update_closure = false; + Array new_params; for (size_t i = 0; i < call->args.size(); ++i) { - Unify(call->args[i]->checked_type(), fty->arg_types[i]); + try { + // Try to unify caller type and param type. + Unify(call->args[i]->checked_type(), fn->params[i]->type_annotation); + new_params.push_back(MakeVar(fn->params[i]->name_hint(), fn->params[i]->type_annotation)); + } catch (const dmlc::Error& e) { + // If caller type and closure parameter type are inconsistent and this is the first caller, + // update the closure parameter type; othewise throw an error. + CHECK(visited_closures_.find(fn) == visited_closures_.end()) + << "The following closure is called more than once " + << "but callers have inconsistent types:" << std::endl + << raf::ir::AsText(fn) << std::endl + << e.what(); + update_closure = true; + new_params.push_back(MakeVar(fn->params[i]->name_hint(), call->args[i]->checked_type())); + } + } + + Function new_fn = fn; + if (update_closure) { + // If param types have to be updated, create a new closure with updated param types. + // Note that in this case we also have to mutate the closure body to use the updated + // param vars, so the closure body cannot be visited in advance. + for (size_t i = 0; i < new_params.size(); ++i) { + closure_param_map_[fn->params[i]] = new_params[i]; + } + new_fn = WithFields(fn, new_params); + UpdateFuncParamVarMap(new_fn.as(), call->args); } - return fty->ret_type; + new_fn = Downcast(VisitExpr(new_fn)); + visited_closures_[fn] = new_fn; + return Downcast(new_fn->checked_type())->ret_type; } void UpdateFuncParamVarMap(const FunctionNode* fn, const Array& args) { @@ -352,25 +393,29 @@ class TypeInferencer : public ExprMutator { } Expr VisitExpr_(const FunctionNode* op) override { - if (visited_.count(GetRef(op))) { - if (!op->checked_type_.defined()) { - op->checked_type_ = IncompleteType(kType); + auto fn = GetRef(op); + if (visited_closures_.count(fn) > 0) { + fn = visited_closures_[fn]; + } + if (visited_.count(fn)) { + if (!fn->checked_type_.defined()) { + fn->checked_type_ = IncompleteType(kType); } - return GetRef(op); + return fn; } - visited_.insert(GetRef(op)); + visited_.insert(fn); Array params; Array param_types; - for (const auto& p : op->params) { + for (const auto& p : fn->params) { Var param = Downcast(VisitExpr(p)); params.push_back(param); param_types.push_back(param->checked_type()); } - Expr body = VisitExpr(op->body); + Expr body = VisitExpr(fn->body); Type ret_type = - op->ret_type.defined() ? Unify(body->checked_type(), op->ret_type) : body->checked_type(); - Function func(params, body, ret_type, op->type_params, op->attrs); - func->checked_type_ = FuncType(param_types, ret_type, op->type_params, {}); + fn->ret_type.defined() ? Unify(body->checked_type(), fn->ret_type) : body->checked_type(); + Function func(params, body, ret_type, fn->type_params, fn->attrs); + func->checked_type_ = FuncType(param_types, ret_type, fn->type_params, {}); return func; } @@ -380,6 +425,10 @@ class TypeInferencer : public ExprMutator { * E.g. Let %a = %b; Let %c = some_op(%a). The var_value_map_ will map %b to some_op. */ std::unordered_map var_value_map_; + /*! \brief Mapping from original closures to visited ones (may have type-updated params). */ + std::unordered_map visited_closures_; + /*! \brief Mapping from original closure params to type-updated ones. */ + std::unordered_map closure_param_map_; /*! \brief Track visited Expr to avoid indefinite recursion in IR with recursive functions */ std::unordered_set visited_; }; diff --git a/tests/python/pass/test_pass_infer_type.py b/tests/python/pass/test_pass_infer_type.py index 9a0576df..1638b5cd 100644 --- a/tests/python/pass/test_pass_infer_type.py +++ b/tests/python/pass/test_pass_infer_type.py @@ -358,6 +358,58 @@ def forward(self, x): mod = raf._ffi.pass_.InferType()(mod) +@pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled") +def test_closure_param_type_update(): + shape = (10, 10) + + class Model(raf.Model): + def build(self): + pass + + @raf.model.trace + def forward(self, x): + out = raf._contrib_dropout(x) # pylint: disable=no-member + out = raf.reshape(out[0], [100]) + out = raf.cast(out, "float16") + return out + + model = Model() + m_x, _ = randn(shape, dtype="float32") + mod = model._internal(m_x).mod + with raf.Device("cuda"): + mod = raf._ffi.pass_.ToGraphNormalForm()(mod) + mod = raf._ffi.pass_.ToBasicBlockNormalForm()(mod) + mod = raf._ffi.pass_.FuseTVM()(mod) + mod = raf._ffi.pass_.DispatchDialect()(mod) + mod = raf._ffi.pass_.EraseType()(mod) + mod = raf._ffi.pass_.ToANormalForm()(mod) + mod = raf._ffi.pass_.InlinePrimitives()(mod) + mod = raf._ffi.pass_.InferType()(mod) + + # def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(100), float16] { + # let %x1 = raf.op.cudnn._contrib_dropout(%x, float64(0.5), nullptr) + # /* ty=(Tensor[(10, 10), float32], float32, uint8, Tensor[(13), uint8]) */; + # %2 = fn (%p0: (Tensor[(10, 10), float32], float32, uint8, Tensor[(13), uint8]), + # %p1: (int32,), %p2_v2: int64, Primitive=1, Dialect="tvm") + # -> Tensor[(100), float16] { + # %0 = %p0_v2.0; + # %1 = raf.op.tvm.reshape(%0, %p1_v2, bool(0)) /* ty=Tensor[(100), float32] */; + # raf.op.tvm.cast(%1, %p2_v2) /* ty=Tensor[(100), float16] */ + # }; + # let %x3 = %2(%x1, TupleValue([int32(100)]), str"float16") /* ty=Tensor[(100), float16] */; + # %x3 + # } + hit_count = 0 + for line in raf.ir.AsText(mod).split("\n"): + if line.find("raf.op.cudnn._contrib_dropout") != -1: + hit_count += 1 + assert line.find("ty=(Tensor[(10, 10), float32], float32, uint8") != -1 + elif line.find("fn") != -1: + hit_count += 1 + assert line.find("Tensor[(10, 10), float32], float32, uint8") != -1 + assert hit_count == 2 + + def test_multi_functions(): # Create a symbolic model and run it class Add(raf.Model): From 6ff9fc5d527735798778686c894807b221d344b0 Mon Sep 17 00:00:00 2001 From: AIREMetaBot <100344401+aire-meta-bot@users.noreply.github.com> Date: Tue, 3 May 2022 01:07:00 +0800 Subject: [PATCH 06/37] [TVM] Update Submodule (#29) Co-authored-by: SubmoduleUpdaterBot --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index d2db9cb0..b6b0bafd 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit d2db9cb0d839e32778f461b77e59f6418282a511 +Subproject commit b6b0bafdef15bb5491c38770668ddf73ddd02af2 From e097fadac32eafb4b60321cb4d324b6b13d8d190 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 4 May 2022 04:39:10 +0800 Subject: [PATCH 07/37] [Bugfix] Fix bug in #28 (#31) * [Bugfix] Fix bug in #28 * format * format * fix * fix --- src/pass/type_infer.cc | 34 +++++++++----- tests/python/pass/test_pass_infer_type.py | 57 ++++++++++++++--------- 2 files changed, 57 insertions(+), 34 deletions(-) diff --git a/src/pass/type_infer.cc b/src/pass/type_infer.cc index 85350096..dc2d56f7 100644 --- a/src/pass/type_infer.cc +++ b/src/pass/type_infer.cc @@ -210,39 +210,49 @@ class TypeInferencer : public ExprMutator { Type InferClosure(const Call& call, const Function& fn) { // TODO(@hzfan): perform template param deduction to eliminate type_params bool update_closure = false; + Function curr_fn = fn; + if (visited_closures_.count(fn) > 0) { + curr_fn = visited_closures_[fn]; + } + Array new_params; for (size_t i = 0; i < call->args.size(); ++i) { try { // Try to unify caller type and param type. - Unify(call->args[i]->checked_type(), fn->params[i]->type_annotation); - new_params.push_back(MakeVar(fn->params[i]->name_hint(), fn->params[i]->type_annotation)); + Unify(call->args[i]->checked_type(), curr_fn->params[i]->type_annotation); + new_params.push_back( + MakeVar(curr_fn->params[i]->name_hint(), curr_fn->params[i]->type_annotation)); } catch (const dmlc::Error& e) { // If caller type and closure parameter type are inconsistent and this is the first caller, // update the closure parameter type; othewise throw an error. - CHECK(visited_closures_.find(fn) == visited_closures_.end()) + CHECK(visited_closures_.find(curr_fn) == visited_closures_.end()) << "The following closure is called more than once " << "but callers have inconsistent types:" << std::endl - << raf::ir::AsText(fn) << std::endl + << raf::ir::AsText(curr_fn) << std::endl << e.what(); update_closure = true; - new_params.push_back(MakeVar(fn->params[i]->name_hint(), call->args[i]->checked_type())); + new_params.push_back( + MakeVar(curr_fn->params[i]->name_hint(), call->args[i]->checked_type())); } } - Function new_fn = fn; if (update_closure) { // If param types have to be updated, create a new closure with updated param types. // Note that in this case we also have to mutate the closure body to use the updated // param vars, so the closure body cannot be visited in advance. for (size_t i = 0; i < new_params.size(); ++i) { - closure_param_map_[fn->params[i]] = new_params[i]; + closure_param_map_[curr_fn->params[i]] = new_params[i]; } - new_fn = WithFields(fn, new_params); - UpdateFuncParamVarMap(new_fn.as(), call->args); + curr_fn = WithFields(curr_fn, new_params); + UpdateFuncParamVarMap(curr_fn.as(), call->args); } - new_fn = Downcast(VisitExpr(new_fn)); - visited_closures_[fn] = new_fn; - return Downcast(new_fn->checked_type())->ret_type; + curr_fn = Downcast(VisitExpr(curr_fn)); + + // Mark both the original and updated closure as visited because they are not allowed + // to be updated anymore. + visited_closures_[fn] = curr_fn; + visited_closures_[curr_fn] = curr_fn; + return Downcast(curr_fn->checked_type())->ret_type; } void UpdateFuncParamVarMap(const FunctionNode* fn, const Array& args) { diff --git a/tests/python/pass/test_pass_infer_type.py b/tests/python/pass/test_pass_infer_type.py index 1638b5cd..8ba03237 100644 --- a/tests/python/pass/test_pass_infer_type.py +++ b/tests/python/pass/test_pass_infer_type.py @@ -368,9 +368,20 @@ def build(self): @raf.model.trace def forward(self, x): + # First closure out = raf._contrib_dropout(x) # pylint: disable=no-member out = raf.reshape(out[0], [100]) out = raf.cast(out, "float16") + + # Second closure + out = raf._contrib_dropout(out) # pylint: disable=no-member + out = raf.cast(out[0], "float32") + out = raf.reshape(out, shape) + + # Should reuse the first closure + out = raf._contrib_dropout(out) # pylint: disable=no-member + out = raf.reshape(out[0], [100]) + out = raf.cast(out, "float16") return out model = Model() @@ -386,28 +397,30 @@ def forward(self, x): mod = raf._ffi.pass_.InlinePrimitives()(mod) mod = raf._ffi.pass_.InferType()(mod) - # def @main(%x: Tensor[(10, 10), float32]) -> Tensor[(100), float16] { - # let %x1 = raf.op.cudnn._contrib_dropout(%x, float64(0.5), nullptr) - # /* ty=(Tensor[(10, 10), float32], float32, uint8, Tensor[(13), uint8]) */; - # %2 = fn (%p0: (Tensor[(10, 10), float32], float32, uint8, Tensor[(13), uint8]), - # %p1: (int32,), %p2_v2: int64, Primitive=1, Dialect="tvm") - # -> Tensor[(100), float16] { - # %0 = %p0_v2.0; - # %1 = raf.op.tvm.reshape(%0, %p1_v2, bool(0)) /* ty=Tensor[(100), float32] */; - # raf.op.tvm.cast(%1, %p2_v2) /* ty=Tensor[(100), float16] */ - # }; - # let %x3 = %2(%x1, TupleValue([int32(100)]), str"float16") /* ty=Tensor[(100), float16] */; - # %x3 - # } - hit_count = 0 - for line in raf.ir.AsText(mod).split("\n"): - if line.find("raf.op.cudnn._contrib_dropout") != -1: - hit_count += 1 - assert line.find("ty=(Tensor[(10, 10), float32], float32, uint8") != -1 - elif line.find("fn") != -1: - hit_count += 1 - assert line.find("Tensor[(10, 10), float32], float32, uint8") != -1 - assert hit_count == 2 + class ResultChecker(relay.ExprVisitor): + def __init__(self): + super(ResultChecker, self).__init__() + self.cudnn_dropout_cnt = 0 + + def visit_let(self, let): + call_op = let.value.op + if ( + not isinstance(call_op, relay.Function) + and call_op.name == "raf.op.cudnn._contrib_dropout" + ): + # Make sure 3 dropouts are dispatched to cudnn. + self.cudnn_dropout_cnt += 1 + super().visit_let(let) + + def visit_function(self, fn): + # The mask in CuDNN dropout has 0-dim so the closure param should be updated. + dropout_mask_type = fn.params[0].checked_type.fields[1] + assert len(dropout_mask_type.shape) == 0 + super().visit_function(fn) + + checker = ResultChecker() + checker.visit(mod["main"].body) + assert checker.cudnn_dropout_cnt == 3 def test_multi_functions(): From e353ad7f9dbb13abadcdc7a388ee638d071fdc43 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 4 May 2022 06:28:53 +0800 Subject: [PATCH 08/37] [VM] Thread safe error message handling (#30) --- include/raf/op.h | 8 +++- src/impl/op.cc | 48 +++++++++++-------- src/op/dialect/cutlass/conv.cc | 8 ++-- src/op/dialect/cutlass/cutlass_fusion.cc | 2 +- src/op/dialect/cutlass/gemm.cc | 8 ++-- src/op/dialect/tvm/tvm_utils.h | 4 +- .../python/pass/test_pass_estimate_memory.py | 2 +- 7 files changed, 47 insertions(+), 33 deletions(-) diff --git a/include/raf/op.h b/include/raf/op.h index ea1ecfdd..a30eecfa 100644 --- a/include/raf/op.h +++ b/include/raf/op.h @@ -32,8 +32,6 @@ class Requests; namespace raf { namespace op { -extern std::vector dispatch_error_msgs; - class CallValuesNode : public ir::Object { public: mutable value::Value callee; @@ -79,6 +77,12 @@ class OpEnv { */ virtual void Execute(const std::vector& inputs, value::Value output) = 0; + /*! \brief Whether this OpEnv is valid. */ + std::vector error_msgs; + bool HasError() { + return !error_msgs.empty(); + } + void RequestWorkspace(void** dest, const Device& device, int64_t nbytes); void RequestStream(void** dest, const Device& device, int tag_idx); void RequestDistributed(void** dest, const std::string& name, const value::Value rank_list); diff --git a/src/impl/op.cc b/src/impl/op.cc index 76f9d83b..7fecd374 100644 --- a/src/impl/op.cc +++ b/src/impl/op.cc @@ -32,8 +32,6 @@ using namespace raf::value; using executor::Executor; using requests::Requests; -std::vector dispatch_error_msgs; - CallValues CallValues::make(value::Value callee, ir::Attrs args) { ObjectPtr n = make_object(); n->callee = std::move(callee); @@ -117,26 +115,31 @@ const OpEnvMaker* OpEnvMaker::Get(const std::string& op_name) { return TRegistry::Get()->Find(op_name); } -std::shared_ptr OpEnvMaker::Make(const std::string& op_name, const CallValues& call) { +OpEnvPtr OpEnvMaker::Make(const std::string& op_name, const CallValues& call) { auto maker = OpEnvMaker::Get(op_name); CHECK(maker) << "Cannot find an OpEnvMaker registered to " << op_name; auto env = (*maker)(call); - return std::shared_ptr(env); + return OpEnvPtr(env); } // Implementation : helper functions -std::shared_ptr DispatchSingleOp(const CallValues& call) { - dispatch_error_msgs.clear(); +OpEnvPtr DispatchSingleOp(const CallValues& call) { + std::vector dispatch_error_msgs; + Op op = Downcast(call->callee)->op; std::string skip_dialect = ""; // Try dispatch directly auto maker = OpEnvMaker::Get(op->name); if (maker != nullptr) { - auto env = std::shared_ptr((*maker)(call)); - if (env != nullptr) { + auto env = OpEnvPtr((*maker)(call)); + if (env && !env->HasError()) { DLOG(INFO) << "Dispatch to " << op->name; return env; + } else if (env) { + for (auto msg : env->error_msgs) { + dispatch_error_msgs.push_back(msg); + } } } if (IsDialectOp(op)) { @@ -153,9 +156,17 @@ std::shared_ptr DispatchSingleOp(const CallValues& call) { } auto dialect_op = Op::Get(entry.dialect_op); dialect_op->op_type = op->op_type; - if (auto env = OpEnvMaker::Make(dialect_op->name, call)) { - DLOG(INFO) << "Dispatch to " << dialect_op->name; - return env; + auto maker = OpEnvMaker::Get(dialect_op->name); + if (maker != nullptr) { + auto env = OpEnvPtr((*maker)(call)); + if (env && !env->HasError()) { + DLOG(INFO) << "Dispatch to " << dialect_op->name; + return env; + } else if (env) { + for (auto msg : env->error_msgs) { + dispatch_error_msgs.push_back(msg); + } + } } } @@ -165,12 +176,10 @@ std::shared_ptr DispatchSingleOp(const CallValues& call) { ss << "\n" << msg; } LOG(FATAL) << ss.str(); - dispatch_error_msgs.clear(); return nullptr; } -std::shared_ptr DispatchFusedOp(const CallValues& call) { - dispatch_error_msgs.clear(); +OpEnvPtr DispatchFusedOp(const CallValues& call) { auto clo = Downcast(call->callee); auto func = clo->func; ICHECK(func->HasNonzeroAttr(attr::kPrimitive)) @@ -181,20 +190,21 @@ std::shared_ptr DispatchFusedOp(const CallValues& call) { std::ostringstream os; os << "raf.op." << dialect.value() << "._fused_op"; auto op_env = OpEnvMaker::Make(os.str(), call); - if (op_env == nullptr && !dispatch_error_msgs.empty()) { + if (!op_env || op_env->HasError()) { std::stringstream ss; ss << "Failed to dispatch fused op:"; - for (auto msg : dispatch_error_msgs) { - ss << "\n\t" << msg; + if (op_env) { + for (auto msg : op_env->error_msgs) { + ss << "\n\t" << msg; + } } ss << "\nName: " << os.str() << "\n" << ir::AsText(func); LOG(FATAL) << ss.str(); - dispatch_error_msgs.clear(); } return op_env; } -std::shared_ptr Dispatch(const CallValues& call) { +OpEnvPtr Dispatch(const CallValues& call) { if (call->callee.as()) { return DispatchSingleOp(call); } else if (call->callee.as()) { diff --git a/src/op/dialect/cutlass/conv.cc b/src/op/dialect/cutlass/conv.cc index f5a91edf..33a22ab0 100644 --- a/src/op/dialect/cutlass/conv.cc +++ b/src/op/dialect/cutlass/conv.cc @@ -102,16 +102,16 @@ OpEnv* CutlassConv2dOpEnv::make(const CallValues& cv) { if (!matched_pattern || !valid) { std::stringstream ss; ss << "[CUTLASS] Cannot JIT: matched pattern? " << matched_pattern << ", valid? " << valid; - dispatch_error_msgs.push_back(ss.str()); - return nullptr; + op_env->error_msgs.push_back(ss.str()); + return op_env.release(); } try { op_env->Init(cv); } catch (const dmlc::Error& e) { std::stringstream ss; ss << "[CUTLASS] Failed to JIT: " << e.what(); - dispatch_error_msgs.push_back(ss.str()); - return nullptr; + op_env->error_msgs.push_back(ss.str()); + return op_env.release(); } return op_env.release(); } diff --git a/src/op/dialect/cutlass/cutlass_fusion.cc b/src/op/dialect/cutlass/cutlass_fusion.cc index 23fe996a..c1cce5ce 100644 --- a/src/op/dialect/cutlass/cutlass_fusion.cc +++ b/src/op/dialect/cutlass/cutlass_fusion.cc @@ -75,7 +75,7 @@ OpEnv* FusedFuncBuild(const op::CallValues& call) { OpEnv* env = nullptr; auto fmake_tune = [&env, &call](FMaker maker) { env = maker(call); - if (env) { + if (!env->HasError()) { Tune(call, env); } }; diff --git a/src/op/dialect/cutlass/gemm.cc b/src/op/dialect/cutlass/gemm.cc index 73ef4c12..7cf99830 100644 --- a/src/op/dialect/cutlass/gemm.cc +++ b/src/op/dialect/cutlass/gemm.cc @@ -191,16 +191,16 @@ OpEnv* CutlassMatmulOpEnv::make(const CallValues& cv) { if (!matched_pattern || !valid) { std::stringstream ss; ss << "[CUTLASS] Cannot JIT: matched pattern? " << matched_pattern << ", valid? " << valid; - dispatch_error_msgs.push_back(ss.str()); - return nullptr; + op_env->error_msgs.push_back(ss.str()); + return op_env.release(); } try { op_env->Init(cv); } catch (const dmlc::Error& e) { std::stringstream ss; ss << "[CUTLASS] Failed to JIT: " << e.what(); - dispatch_error_msgs.push_back(ss.str()); - return nullptr; + op_env->error_msgs.push_back(ss.str()); + return op_env.release(); } return op_env.release(); } diff --git a/src/op/dialect/tvm/tvm_utils.h b/src/op/dialect/tvm/tvm_utils.h index 6b7d0ef5..da14b0dc 100644 --- a/src/op/dialect/tvm/tvm_utils.h +++ b/src/op/dialect/tvm/tvm_utils.h @@ -277,9 +277,9 @@ extern MetaPersistCache CacheLoweredFunc; std::stringstream ss; \ ss << "[TVM] Failed to JIT: " << env->env_name << ": " << e.what(); \ auto msg = ss.str(); \ - dispatch_error_msgs.push_back(msg); \ + env->error_msgs.push_back(msg); \ DLOG(WARNING) << msg; \ - return nullptr; \ + return env; \ } \ } \ return env; \ diff --git a/tests/python/pass/test_pass_estimate_memory.py b/tests/python/pass/test_pass_estimate_memory.py index 0c80ce68..1cbfb7a0 100644 --- a/tests/python/pass/test_pass_estimate_memory.py +++ b/tests/python/pass/test_pass_estimate_memory.py @@ -27,7 +27,7 @@ def verify_memory(mod, device, expected_trace, disable_fusion=True, include_para for (name, mem), expected in zip(trace, expected_trace): assert name != "unknown" if isinstance(expected, tuple): # The expected memory could be a range. - assert expected[0] <= mem < expected[1] + assert expected[0] <= mem <= expected[1], "{expected[0]} <= {mem} <= {expected[1]}" else: check(mem, expected) From d82ea4e47cf0ab0240547e48d2ee79b162d7caea Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 5 May 2022 09:16:51 +0800 Subject: [PATCH 09/37] [VM] Cache fused func JIT (#32) * [VM] Cache fused func JIT * use tvm AsText * fix --- src/op/dialect/cutlass/conv.cc | 8 +-- src/op/dialect/cutlass/conv_utils.cc | 12 ++-- src/op/dialect/cutlass/conv_utils.h | 8 +-- src/op/dialect/cutlass/cutlass_fusion.cc | 73 +++++++++++++++++++----- src/op/dialect/cutlass/cutlass_utils.cc | 2 +- src/op/dialect/cutlass/cutlass_utils.h | 6 +- src/op/dialect/cutlass/gemm.cc | 8 +-- src/op/dialect/cutlass/gemm_utils.cc | 12 ++-- src/op/dialect/cutlass/gemm_utils.h | 8 +-- src/op/dialect/tvm/tvm_fusion.cc | 60 ++++++++++++++----- 10 files changed, 137 insertions(+), 60 deletions(-) diff --git a/src/op/dialect/cutlass/conv.cc b/src/op/dialect/cutlass/conv.cc index 33a22ab0..9830c12c 100644 --- a/src/op/dialect/cutlass/conv.cc +++ b/src/op/dialect/cutlass/conv.cc @@ -96,14 +96,14 @@ void CutlassConv2dOpEnv::Init(const CallValues& cv) { } OpEnv* CutlassConv2dOpEnv::make(const CallValues& cv) { - std::unique_ptr op_env(std::make_unique(cv)); + CutlassConv2dOpEnv* op_env = new CutlassConv2dOpEnv(cv); auto matched_pattern = op_env->Pattern(cv); auto valid = op_env->IsValid(cv); if (!matched_pattern || !valid) { std::stringstream ss; ss << "[CUTLASS] Cannot JIT: matched pattern? " << matched_pattern << ", valid? " << valid; op_env->error_msgs.push_back(ss.str()); - return op_env.release(); + return op_env; } try { op_env->Init(cv); @@ -111,9 +111,9 @@ OpEnv* CutlassConv2dOpEnv::make(const CallValues& cv) { std::stringstream ss; ss << "[CUTLASS] Failed to JIT: " << e.what(); op_env->error_msgs.push_back(ss.str()); - return op_env.release(); + return op_env; } - return op_env.release(); + return op_env; } void CutlassConv2dOpEnv::Execute(const std::vector& inputs, Value output) { diff --git a/src/op/dialect/cutlass/conv_utils.cc b/src/op/dialect/cutlass/conv_utils.cc index 2153d660..9a763654 100644 --- a/src/op/dialect/cutlass/conv_utils.cc +++ b/src/op/dialect/cutlass/conv_utils.cc @@ -25,7 +25,7 @@ void CutlassConvOpEnv::InitConvOperation( int Q = (W + 2 * pad_w - ((S - 1) * dilation_w + 1)) / stride_w + 1; - functional_key_ = std::make_unique( + functional_key_ = std::make_shared( provider_, ConvKind::kFprop, element_A, layout_A, element_B, layout_B, element_C, layout_A, element_accumulator, element_compute, epilogue_math_op); @@ -35,7 +35,7 @@ void CutlassConvOpEnv::InitConvOperation( CHECK(!operators_it->second.empty()); preference_key_ = - std::make_unique(compute_capability(), IteratorAlgorithmID::kOptimized); + std::make_shared(compute_capability(), IteratorAlgorithmID::kOptimized); Operation const* operation = find_conv2d_operation(operators_it, *preference_key_, preferred_name); @@ -74,7 +74,7 @@ void CutlassConvOpEnv::InitConvOperation( }; } -std::vector> CutlassConvOpEnv::ListTunableConfigs() { +std::vector> CutlassConvOpEnv::ListTunableConfigs() { // Tunable configuration: kernel_name std::vector kernel_names; auto operators_it = SingletonExt::get().operation_table.conv2d_operations.find(*functional_key_); @@ -94,14 +94,14 @@ std::vector> CutlassConvOpEnv::ListTunableConfigs } } } - std::vector> rets; + std::vector> rets; for (const auto& name : kernel_names) { - rets.push_back(std::make_unique(name)); + rets.push_back(std::make_shared(name)); } return rets; } -void CutlassConvOpEnv::SetTunableConfig(const std::unique_ptr& tunable) { +void CutlassConvOpEnv::SetTunableConfig(const std::shared_ptr& tunable) { tunable_ = *static_cast(tunable.get()); } diff --git a/src/op/dialect/cutlass/conv_utils.h b/src/op/dialect/cutlass/conv_utils.h index 541d5d9d..7d60fe74 100644 --- a/src/op/dialect/cutlass/conv_utils.h +++ b/src/op/dialect/cutlass/conv_utils.h @@ -35,9 +35,9 @@ class CutlassConvOpEnv : public CutlassOpEnv { explicit CutlassConvOpEnv(const CallValues& call) : CutlassOpEnv(call) { } - std::vector> ListTunableConfigs() override; + std::vector> ListTunableConfigs() override; - void SetTunableConfig(const std::unique_ptr& tunable) override; + void SetTunableConfig(const std::shared_ptr& tunable) override; /*! * \brief Initialize a convolution operator @@ -85,9 +85,9 @@ class CutlassConvOpEnv : public CutlassOpEnv { /*! \brief Convolution operator arguments */ ConvArguments arguments_; /*! \brief Conv functional key */ - std::unique_ptr functional_key_; + std::shared_ptr functional_key_; /*! \brief Conv functional key */ - std::unique_ptr preference_key_; + std::shared_ptr preference_key_; /*! \brief Tunable configuration for cutlass conv */ ConvTunableConfig tunable_; }; diff --git a/src/op/dialect/cutlass/cutlass_fusion.cc b/src/op/dialect/cutlass/cutlass_fusion.cc index c1cce5ce..3aec8cec 100644 --- a/src/op/dialect/cutlass/cutlass_fusion.cc +++ b/src/op/dialect/cutlass/cutlass_fusion.cc @@ -9,6 +9,7 @@ */ #include +#include "raf/cache.h" #include "raf/value.h" #include "raf/profiler.h" #include "raf/registry.h" @@ -30,23 +31,69 @@ using namespace raf::ir; using namespace raf::value; using raf::registry::TypedPackedFunc; +/*! \brief The persist cache entry of the best CUTLASS tuned config. */ +class CUTLASSConfigCacheEntry { + public: + explicit CUTLASSConfigCacheEntry() { + } + + CUTLASSConfigCacheEntry(std::shared_ptr config) : config_(config) { + } + + std::shared_ptr GetConfig() { + return config_; + } + + static CUTLASSConfigCacheEntry Load(const std::string path) { + // Not support deserialization yet. + throw; + } + + bool Save(const std::string& path) { + // Not support serialization yet. + return false; + } + + private: + /*! \brief The tunable config. */ + std::shared_ptr config_; +}; + +MetaPersistCache CacheConfig("cutlass_fusion_config"); + +HashKey HashFusedFunc(const Function& func) { + HashKey key; + key << tvm::AsText(func, true); + return key; +} + OpEnv* Tune(const op::CallValues& call, OpEnv* op_env) { CutlassOpEnv* env = static_cast(op_env); - std::vector> tunable = env->ListTunableConfigs(); - const int number = 10, repeat = 1, min_repeat_ms = 0; - std::unique_ptr best; - double min_time = std::numeric_limits::max(); - for (auto& i : tunable) { - env->SetTunableConfig(i); - env->Init(call); - Array result = TimeEvaluator(TypedPackedFunc([&]() { env->Execute(call); }), - call->device, number, repeat, min_repeat_ms)(); - CHECK_EQ(result.size(), 1U); - if (result[0]->value < min_time) { - min_time = result[0]->value; - best = std::move(i); + auto key = HashFusedFunc(Downcast(call->callee)->func); + std::shared_ptr best; + + if (const auto* compiled = CacheConfig.Get(key.byte_vector)) { + CUTLASSConfigCacheEntry entry = *compiled; + best = entry.GetConfig(); + } else { + std::vector> tunable = env->ListTunableConfigs(); + const int number = 10, repeat = 1, min_repeat_ms = 0; + double min_time = std::numeric_limits::max(); + for (auto& config : tunable) { + env->SetTunableConfig(config); + env->Init(call); + Array result = + TimeEvaluator(TypedPackedFunc([&]() { env->Execute(call); }), call->device, + number, repeat, min_repeat_ms)(); + CHECK_EQ(result.size(), 1U); + if (result[0]->value < min_time) { + min_time = result[0]->value; + best = config; + } } + CacheConfig.Set(key.byte_vector, CUTLASSConfigCacheEntry(best)); } + env->SetTunableConfig(best); env->Init(call); return env; diff --git a/src/op/dialect/cutlass/cutlass_utils.cc b/src/op/dialect/cutlass/cutlass_utils.cc index 844f7f2b..6337e6f8 100644 --- a/src/op/dialect/cutlass/cutlass_utils.cc +++ b/src/op/dialect/cutlass/cutlass_utils.cc @@ -46,7 +46,7 @@ void CutlassOpEnv::RequestWorkspace(void** dest, const Device& device, int64_t n *dest = workspace_mem_->data; } -std::ostream& operator<<(std::ostream& stream, const std::unique_ptr& config) { +std::ostream& operator<<(std::ostream& stream, const std::shared_ptr& config) { config->AsText(stream); return stream; } diff --git a/src/op/dialect/cutlass/cutlass_utils.h b/src/op/dialect/cutlass/cutlass_utils.h index 3b87c316..5abf6e6e 100644 --- a/src/op/dialect/cutlass/cutlass_utils.h +++ b/src/op/dialect/cutlass/cutlass_utils.h @@ -53,10 +53,10 @@ class CutlassOpEnv : public raf::op::OpEnv { void RequestWorkspace(void** dest, const Device& device, int64_t nbytes); /*! \brief Set tunable configuration */ - virtual void SetTunableConfig(const std::unique_ptr& tunable) = 0; + virtual void SetTunableConfig(const std::shared_ptr& tunable) = 0; /*! \brief List all possible configs */ - virtual std::vector> ListTunableConfigs() = 0; + virtual std::vector> ListTunableConfigs() = 0; /*! \brief Initialize with default configuration */ virtual void Init(const CallValues& call) = 0; @@ -111,7 +111,7 @@ struct TunableConfig { std::string kernel_name; }; -std::ostream& operator<<(std::ostream& stream, const std::unique_ptr& config); +std::ostream& operator<<(std::ostream& stream, const std::shared_ptr& config); std::ostream& operator<<(std::ostream& stream, const SplitKMode& mode); diff --git a/src/op/dialect/cutlass/gemm.cc b/src/op/dialect/cutlass/gemm.cc index 7cf99830..bd89c8f9 100644 --- a/src/op/dialect/cutlass/gemm.cc +++ b/src/op/dialect/cutlass/gemm.cc @@ -185,14 +185,14 @@ void CutlassMatmulOpEnv::Init(const CallValues& cv) { } OpEnv* CutlassMatmulOpEnv::make(const CallValues& cv) { - std::unique_ptr op_env(std::make_unique(cv)); + CutlassMatmulOpEnv* op_env = new CutlassMatmulOpEnv(cv); auto matched_pattern = op_env->Pattern(cv); auto valid = op_env->IsValid(cv); if (!matched_pattern || !valid) { std::stringstream ss; ss << "[CUTLASS] Cannot JIT: matched pattern? " << matched_pattern << ", valid? " << valid; op_env->error_msgs.push_back(ss.str()); - return op_env.release(); + return op_env; } try { op_env->Init(cv); @@ -200,9 +200,9 @@ OpEnv* CutlassMatmulOpEnv::make(const CallValues& cv) { std::stringstream ss; ss << "[CUTLASS] Failed to JIT: " << e.what(); op_env->error_msgs.push_back(ss.str()); - return op_env.release(); + return op_env; } - return op_env.release(); + return op_env; } void CutlassMatmulOpEnv::Execute(const std::vector& inputs, Value output) { diff --git a/src/op/dialect/cutlass/gemm_utils.cc b/src/op/dialect/cutlass/gemm_utils.cc index dd969f38..ea255557 100644 --- a/src/op/dialect/cutlass/gemm_utils.cc +++ b/src/op/dialect/cutlass/gemm_utils.cc @@ -23,7 +23,7 @@ void CutlassGemmOpEnv::InitGemmOperation( int ldd, int batch_count, int64_t batch_stride_A, int64_t batch_stride_B, int64_t batch_stride_C, int64_t batch_stride_D, EpilogueKindExt epilogue_math_op, const std::string& preferred_name) { - functional_key_ = std::make_unique( + functional_key_ = std::make_shared( provider_, GemmKind::kUniversal, element_compute, element_scalar, element_A, layout_A, ComplexTransform::kNone, element_B, layout_B, ComplexTransform::kNone, element_C, epilogue_math_op); @@ -51,7 +51,7 @@ void CutlassGemmOpEnv::InitGemmOperation( ptr_D_check, ldd, 0, kMaximumAlignmentSize); // Find the best kernel in descending order of preference. - preference_key_ = std::make_unique(compute_capability(), alignment); + preference_key_ = std::make_shared(compute_capability(), alignment); Operation const* operation = find_gemm_operation(operators_it, *preference_key_, preferred_name); CHECK(operation); operation_ = operation; @@ -84,7 +84,7 @@ void CutlassGemmOpEnv::InitGemmOperation( batch_stride_D}; } -std::vector> CutlassGemmOpEnv::ListTunableConfigs() { +std::vector> CutlassGemmOpEnv::ListTunableConfigs() { // Tunable configuration: split_k_slices // Split axis k into 1 slice (no slicing) or 4 slices const static std::vector split_k_slices = {1, 2, 4, 8}; @@ -112,18 +112,18 @@ std::vector> CutlassGemmOpEnv::ListTunableConfigs } } } - std::vector> rets; + std::vector> rets; for (const auto& name : kernel_names) { for (const auto& i_split_k_slices : split_k_slices) { for (const auto& i_split_k_mode : split_k_mode) { - rets.push_back(std::make_unique(name, i_split_k_mode, i_split_k_slices)); + rets.push_back(std::make_shared(name, i_split_k_mode, i_split_k_slices)); } } } return rets; } -void CutlassGemmOpEnv::SetTunableConfig(const std::unique_ptr& tunable) { +void CutlassGemmOpEnv::SetTunableConfig(const std::shared_ptr& tunable) { tunable_ = *static_cast(tunable.get()); } diff --git a/src/op/dialect/cutlass/gemm_utils.h b/src/op/dialect/cutlass/gemm_utils.h index 65272b49..96db7c64 100644 --- a/src/op/dialect/cutlass/gemm_utils.h +++ b/src/op/dialect/cutlass/gemm_utils.h @@ -43,9 +43,9 @@ class CutlassGemmOpEnv : public CutlassOpEnv { explicit CutlassGemmOpEnv(const CallValues& call) : CutlassOpEnv(call) { } - std::vector> ListTunableConfigs() override; + std::vector> ListTunableConfigs() override; - void SetTunableConfig(const std::unique_ptr& tunable) override; + void SetTunableConfig(const std::shared_ptr& tunable) override; /*! * \brief Initialize a gemm operator @@ -94,9 +94,9 @@ class CutlassGemmOpEnv : public CutlassOpEnv { /*! \brief Gemm operator arguments */ GemmUniversalArguments arguments_; /*! \brief Gemm functional key */ - std::unique_ptr functional_key_; + std::shared_ptr functional_key_; /*! \brief Gemm functional key */ - std::unique_ptr preference_key_; + std::shared_ptr preference_key_; /*! \brief Tunable configuration for cutlass gemm */ GemmTunableConfig tunable_; }; diff --git a/src/op/dialect/tvm/tvm_fusion.cc b/src/op/dialect/tvm/tvm_fusion.cc index 5443f4bb..bc9edb6b 100644 --- a/src/op/dialect/tvm/tvm_fusion.cc +++ b/src/op/dialect/tvm/tvm_fusion.cc @@ -173,9 +173,9 @@ class Cast2TVMDialect : public ExprMutator { * \brief Converter from raf style (all inputs are arguments) to * tvm style (inputs are explicitly marked as arguments or attrs) */ -class Meta2TVM : public ExprMutator { +class RAF2TVM : public ExprMutator { public: - Meta2TVM(const CallValues& call, const DevType& dev_type) + RAF2TVM(const CallValues& call, const DevType& dev_type) : func_(Downcast(call->callee)->func), call_values_getter_(call), device_type_(dev_type) { @@ -251,26 +251,56 @@ class Meta2TVM : public ExprMutator { DevType device_type_; }; +HashKey HashFusedFunc(const Function& func) { + HashKey key; + key << tvm::AsText(func, true); + return key; +} + OpEnv* FusedFuncBuild(const op::CallValues& call) { tvm::relay::tec::TECompiler te_compiler; auto env = std::make_unique(); Device dev = call->device; + + // Determine cache + MetaPersistCache* cache; + if (dev.device_type() == DevType::kCPU()) { + cache = &CacheBuildCpu; + } else if (dev.device_type() == DevType::kCUDA()) { + cache = &CacheBuildCuda; + } else { + LOG(FATAL) << "NotImplementedError: device is not supported " << dev.device_type().c_str(); + throw; + } + tvm::Target target = dev.tvm_target(); CHECK(dev.device_type() == DevType::kCPU() || dev.device_type() == DevType::kCUDA()) << "NotImplementedError: target is not supported " << dev.device_type().c_str(); - Meta2TVM meta_to_tvm(call, dev.device_type()); - Function func = Downcast(meta_to_tvm()); - // TODO(@hzfan): add cache for raf - te_compiler->Clear(); - env->env_name = TruncateName(GetUniqueName(meta_to_tvm.func_name)); - try { - env->f = te_compiler->JIT(tvm::relay::tec::CCacheKey(func, target)); - } catch (const dmlc::Error& e) { - if (!AllowJitFailure()) { - LOG(FATAL) << "Failed to build a fused op " << env->env_name << ": " << e.what(); + RAF2TVM raf_to_tvm(call, dev.device_type()); + Function func = Downcast(raf_to_tvm()); + env->env_name = TruncateName(GetUniqueName(raf_to_tvm.func_name)); + + auto key = HashFusedFunc(Downcast(call->callee)->func); + TVMModuleCacheEntry entry; + if (const auto* compiled = cache->Get(key.byte_vector)) { + entry = *compiled; + } else { + te_compiler->Clear(); + try { + auto cached_key = tvm::relay::tec::CCacheKey(func, target); + auto cached_func = te_compiler->Lower(cached_key, [](String name) { return name; }); + auto mod = tvm::build(cached_func->funcs, cached_key->target, Target(nullptr)); + entry = TVMModuleCacheEntry(mod, cached_func->prim_fn_var->name_hint); + cache->Set(key.byte_vector, entry); + } catch (const dmlc::Error& e) { + if (!AllowJitFailure()) { + LOG(FATAL) << "Failed to build a fused op " << env->env_name << ": " << e.what(); + } } } - env->arg_indices = meta_to_tvm.arg_indices; + + env->f = entry.GetFunction(); + env->arg_indices = raf_to_tvm.arg_indices; Array args = GetListArgs(call->args); for (const int& i : env->arg_indices) { GetDLTensor(args[i], &env->inputs); @@ -299,8 +329,8 @@ float CalcFuncGFLOPS(const op::CallValues& call, const Array& param_types, new_call->out = call->out; new_call->device = call->device; - Meta2TVM meta_to_tvm(new_call, device.device_type()); - Function tvm_func = Downcast(meta_to_tvm()); + RAF2TVM raf_to_tvm(new_call, device.device_type()); + Function tvm_func = Downcast(raf_to_tvm()); tvm::Target target = device.tvm_target(); auto cache_key = tvm::relay::tec::CCacheKey(tvm_func, target); From f87d8b6c4edc7bcb74ea46336487ea9d7e9096e2 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 6 May 2022 02:07:32 +0800 Subject: [PATCH 10/37] [Remat] Do not remat non-deterministic ops (#33) --- include/raf/op_utils.h | 6 ++++++ src/pass/rematerialization.cc | 14 +++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/include/raf/op_utils.h b/include/raf/op_utils.h index 671d00ea..096f1eaf 100644 --- a/include/raf/op_utils.h +++ b/include/raf/op_utils.h @@ -125,6 +125,12 @@ inline bool IsReshapeOp(const Op& op) { return IsInOpSet(op, reshape_ops); } +inline bool IsNonDeterministicOp(const Op& op) { + static std::unordered_set non_deterministic_ops{ + Op::Get("raf.op._contrib_dropout")}; + return IsInOpSet(op, non_deterministic_ops); +} + inline bool IsMemcpyOp(const Expr& op) { static OpSet memcpy_ops = { Op::Get("raf.op.fuse_tensor"), diff --git a/src/pass/rematerialization.cc b/src/pass/rematerialization.cc index ab04a50a..ec4cce49 100644 --- a/src/pass/rematerialization.cc +++ b/src/pass/rematerialization.cc @@ -920,7 +920,19 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor { float compute_cost = 0.0f; int64_t ws_size = 0; - if (profiler_) { + + // Get the call node op if applicable. + Op op; + if (auto call_node = exprs[i].as()) { + if (auto op_node = call_node->op.as()) { + op = GetRef(op_node); + } + } + + if (op.defined() && IsNonDeterministicOp(op)) { + // Non-deterministic ops cannot be recomputed + compute_cost = std::numeric_limits::max(); + } else if (profiler_) { // Try to profile the op auto exec_time_and_ws_size = profiler_->ProfileOp(exprs[i]); // Default is to repeat once, so we take the first element From 067afed72201ac7612d4a68d6ae2d2f56c377355 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 6 May 2022 07:10:43 +0800 Subject: [PATCH 11/37] [Bugfix] Cache fusion JIT / serialization (#35) --- src/impl/ir_ext.cc | 2 +- src/impl/serialization.cc | 6 +++--- src/op/dialect/cutlass/cutlass_fusion.cc | 2 +- src/op/dialect/tvm/tvm_fusion.cc | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/impl/ir_ext.cc b/src/impl/ir_ext.cc index 5c01d2f0..5b423deb 100644 --- a/src/impl/ir_ext.cc +++ b/src/impl/ir_ext.cc @@ -71,7 +71,7 @@ ObjectRef ConstantExtractValue(RelayConstant _node) { return node->value; } -Var MakeVar_(Id vid, Type type_annotation, Var may_share = Var()) { +Var MakeVar_(Id vid, Type type_annotation, Var may_share) { ObjectPtr n = make_object(); n->vid = std::move(vid); n->type_annotation = std::move(type_annotation); diff --git a/src/impl/serialization.cc b/src/impl/serialization.cc index 61f40682..c9e5e64f 100644 --- a/src/impl/serialization.cc +++ b/src/impl/serialization.cc @@ -212,18 +212,18 @@ Value DeserializeValue(dmlc::Stream* strm) { } case kOpValue: { strm->Read(&str); - Op op = Downcast(tvm::LoadJSON(str)); + Op op = Downcast(serialization::LoadJSON(str)); return OpValue::make(op); } case kClosureValue: { strm->Read(&str); - auto func = Downcast(tvm::LoadJSON(str)); + auto func = Downcast(serialization::LoadJSON(str)); uint64_t cnt; std::unordered_map env; strm->Read(&cnt); for (uint64_t i = 0; i < cnt; ++i) { strm->Read(&str); - Var var = Downcast(tvm::LoadJSON(str)); + Var var = Downcast(serialization::LoadJSON(str)); Value val = DeserializeValue(strm); env.emplace(var, val); } diff --git a/src/op/dialect/cutlass/cutlass_fusion.cc b/src/op/dialect/cutlass/cutlass_fusion.cc index 3aec8cec..9e1651bc 100644 --- a/src/op/dialect/cutlass/cutlass_fusion.cc +++ b/src/op/dialect/cutlass/cutlass_fusion.cc @@ -63,7 +63,7 @@ MetaPersistCache CacheConfig("cutlass_fusion_config"); HashKey HashFusedFunc(const Function& func) { HashKey key; - key << tvm::AsText(func, true); + key << raf::ir::AsText(func, true); return key; } diff --git a/src/op/dialect/tvm/tvm_fusion.cc b/src/op/dialect/tvm/tvm_fusion.cc index bc9edb6b..43dd6fe5 100644 --- a/src/op/dialect/tvm/tvm_fusion.cc +++ b/src/op/dialect/tvm/tvm_fusion.cc @@ -253,7 +253,7 @@ class RAF2TVM : public ExprMutator { HashKey HashFusedFunc(const Function& func) { HashKey key; - key << tvm::AsText(func, true); + key << raf::ir::AsText(func, true); return key; } From 5ac70e6912b8f22fe4f17d8787bcec180737f0b8 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 6 May 2022 08:45:40 +0800 Subject: [PATCH 12/37] [Op] Dropout workspace (#37) --- src/op/dialect/cudnn/dropout.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/op/dialect/cudnn/dropout.cc b/src/op/dialect/cudnn/dropout.cc index 0895de1b..537b46bd 100644 --- a/src/op/dialect/cudnn/dropout.cc +++ b/src/op/dialect/cudnn/dropout.cc @@ -144,7 +144,7 @@ class DropoutImplementedByCUDNNDropoutBackward : public raf::op::OpEnv { float dropout; size_t stateSizeInBytes; size_t reserveSpaceSizeInBytes; - std::shared_ptr states; + void* states; explicit DropoutImplementedByCUDNNDropoutBackward(const CallValues& cv) { this->arg_indices = {/*dy=*/0, /*reserve_space=*/1}; @@ -159,9 +159,7 @@ class DropoutImplementedByCUDNNDropoutBackward : public raf::op::OpEnv { CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropoutDesc)); CUDNN_CALL( cudnnDropoutGetStatesSize(CUDNNThreadEntry::ThreadLocal()->handle, &stateSizeInBytes)); - states = Memory::Alloc(cv->device, stateSizeInBytes); - CUDNN_CALL(cudnnSetDropoutDescriptor(dropoutDesc, CUDNNThreadEntry::ThreadLocal()->handle, - dropout, states->data, stateSizeInBytes, 0)); + RequestWorkspace(&states, cv->device, stateSizeInBytes); } public: @@ -182,6 +180,8 @@ class DropoutImplementedByCUDNNDropoutBackward : public raf::op::OpEnv { DLTensor* dy = args->dy; DLTensor* reserve_space = args->reserve_space; + CUDNN_CALL(cudnnRestoreDropoutDescriptor(dropoutDesc, CUDNNThreadEntry::ThreadLocal()->handle, + dropout, states, stateSizeInBytes, 0)); CUDNN_CALL(cudnnDropoutBackward(CUDNNThreadEntry::ThreadLocal()->handle, dropoutDesc, dydesc, dy->data, dxdesc, dx->data, reserve_space->data, reserveSpaceSizeInBytes)); @@ -193,6 +193,8 @@ class DropoutImplementedByCUDNNDropoutBackward : public raf::op::OpEnv { DLTensor* dy = inputs[0]; DLTensor* reserve_space = inputs[1]; + CUDNN_CALL(cudnnRestoreDropoutDescriptor(dropoutDesc, CUDNNThreadEntry::ThreadLocal()->handle, + dropout, states, stateSizeInBytes, 0)); CUDNN_CALL(cudnnDropoutBackward(CUDNNThreadEntry::ThreadLocal()->handle, dropoutDesc, dydesc, dy->data, dxdesc, dx->data, reserve_space->data, reserveSpaceSizeInBytes)); From 75168d123efcfabe2132331fb48f8b14ccf298bd Mon Sep 17 00:00:00 2001 From: Jie Wang Date: Fri, 6 May 2022 16:54:51 -0700 Subject: [PATCH 13/37] [Bugfix] Fix bug in LANS optimizer with float16 model (#38) Co-authored-by: Jie Wang --- python/raf/optim/lans.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/raf/optim/lans.py b/python/raf/optim/lans.py index 38b89c9e..b84f6ca1 100644 --- a/python/raf/optim/lans.py +++ b/python/raf/optim/lans.py @@ -295,9 +295,12 @@ def forward(self, dy, *args, **kwargs): ) next_w = _op.add(new_weight, self.zero, out=p) else: - # LANS inplace upates the weight - # So the new weight is just the input weight - next_w = new_w + if self.dtype != "float32": + next_w = _op.add(new_w, self.zero, out=p) + else: + # LANS inplace upates the weight + # So the new weight is just the input weight + next_w = new_w trace_mutate_attr(param_model, name.split(".")[-1], next_w) trace_mutate_attr(self, f"{name}.m", next_m) From 249db5f7b64522fd5912c95b5eb285eb6b8d10ff Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 9 May 2022 12:51:59 -0700 Subject: [PATCH 14/37] [Op] Fix dropout (#39) * wip * wip * done * test * fone * lint * fix * comment --- include/raf/op_utils.h | 2 +- python/raf/_tvm_op/nn.py | 5 +- src/op/declare/nn.cc | 13 +-- src/op/dialect/cudnn/cudnn_utils.h | 2 - src/op/dialect/cudnn/dropout.cc | 156 ++++++++++++++++++------- src/op/grad/nn.cc | 2 +- src/op/ty/nn.cc | 8 +- src/pass/assign_device.cc | 17 --- src/profiler/op_profiler.cc | 6 +- tests/python/op/cudnn/test_cudnn_nn.py | 57 +++++---- 10 files changed, 154 insertions(+), 114 deletions(-) diff --git a/include/raf/op_utils.h b/include/raf/op_utils.h index 096f1eaf..69336656 100644 --- a/include/raf/op_utils.h +++ b/include/raf/op_utils.h @@ -127,7 +127,7 @@ inline bool IsReshapeOp(const Op& op) { inline bool IsNonDeterministicOp(const Op& op) { static std::unordered_set non_deterministic_ops{ - Op::Get("raf.op._contrib_dropout")}; + Op::Get("raf.op._contrib_dropout"), Op::Get("raf.op._contrib_dropout_dx")}; return IsInOpSet(op, non_deterministic_ops); } diff --git a/python/raf/_tvm_op/nn.py b/python/raf/_tvm_op/nn.py index 4dbb02fd..19a16e5b 100644 --- a/python/raf/_tvm_op/nn.py +++ b/python/raf/_tvm_op/nn.py @@ -349,8 +349,7 @@ def compute_contrib_dropout(attr, inputs, output_type): _tvm.tir.const(1 / (1 - p), "float32"), ), ) - # states and reserve_space are valid in cudnn only - states = _topi.full((), dtype="uint8", fill_value=0.0) + # reserve_space is valid in cudnn only reserve_space_shape = () if len(output_type.fields[-1].shape) > 0: # Reserve_space is not scalar type. It is dispatched from the base op @@ -360,7 +359,7 @@ def compute_contrib_dropout(attr, inputs, output_type): x_ty = _tvm.relay.TensorType(x.shape, dtype=x.dtype) reserve_space_shape = (GetDropoutReserveSpaceSizeInBytes(x_ty),) reserve_space = _topi.full(reserve_space_shape, dtype="uint8", fill_value=0.0) - return [ret, mask, states, reserve_space] + return [ret, mask, reserve_space] _reg.register_injective_schedule("raf.op.tvm._contrib_dropout") diff --git a/src/op/declare/nn.cc b/src/op/declare/nn.cc index f1d21d03..f73c2035 100644 --- a/src/op/declare/nn.cc +++ b/src/op/declare/nn.cc @@ -396,7 +396,6 @@ void ContribDropout(const CallValues& call) { CHECK(args != nullptr); const DLTensor* x = args->x; std::vector shape(x->shape, x->shape + x->ndim); - std::vector states_shape; std::vector reserve_space_shape; // The CUDNN compute generates reserve_space for backward usage. #ifdef RAF_USE_CUDA @@ -407,12 +406,6 @@ void ContribDropout(const CallValues& call) { reserve_space_shape.push_back(reserve_space_size_in_bytes->value); } #endif - if (args->in_states.defined()) { - const DLTensor* in_states = args->in_states.value(); - for (size_t i = 0; i < in_states->ndim; i++) { - states_shape.push_back(tvm::Integer(in_states->shape[i])); - } - } TensorValue output = TensorValue::Assemble(/*dev=*/x->device, /*dtype=*/x->dtype, /*shape=*/shape); @@ -425,14 +418,10 @@ void ContribDropout(const CallValues& call) { /*dtype=*/DType(DTypeCode::kFloat(), 32), /*shape=*/mask_shape); // valid for cudnn only - TensorValue out_states = TensorValue::Assemble(/*dev=*/x->device, - /*dtype=*/DType(DTypeCode::kUInt(), 8), - /*shape=*/states_shape); - // valid for cudnn only TensorValue reserve_space = TensorValue::Assemble(/*dev=*/x->device, /*dtype=*/DType(DTypeCode::kUInt(), 8), /*shape=*/reserve_space_shape); - call->out = TupleValue::make(tvm::Array({output, mask, out_states, reserve_space})); + call->out = TupleValue::make(tvm::Array({output, mask, reserve_space})); call->device = x->device; } diff --git a/src/op/dialect/cudnn/cudnn_utils.h b/src/op/dialect/cudnn/cudnn_utils.h index 17fbd7d9..58c64545 100644 --- a/src/op/dialect/cudnn/cudnn_utils.h +++ b/src/op/dialect/cudnn/cudnn_utils.h @@ -398,8 +398,6 @@ inline size_t ComputeStorageInBytes(const ir::TensorType& type) { return size; } -TensorValue GetDropoutState(double dropout, int64_t seed); - } // namespace cudnn } // namespace op } // namespace raf diff --git a/src/op/dialect/cudnn/dropout.cc b/src/op/dialect/cudnn/dropout.cc index 537b46bd..c1f4f9f3 100644 --- a/src/op/dialect/cudnn/dropout.cc +++ b/src/op/dialect/cudnn/dropout.cc @@ -7,6 +7,7 @@ * \file src/op/dialect/cudnn/dropout.cc * \brief cuDNN dropout operators. */ +#include #include "raf/ir.h" #include "raf/registry.h" #include "raf/op_utils.h" @@ -30,36 +31,81 @@ int64_t GetDropoutStateSizeInBytes() { return stateSizeInBytes; } -int64_t GetDropoutReserveSpaceSizeInBytes(TensorType x) { - size_t reserveSpaceSizeInBytes; - cudnnTensorDescriptor_t xdesc = NormalizeTensorType(x); - CUDNN_CALL(cudnnDropoutGetReserveSpaceSize(xdesc, &reserveSpaceSizeInBytes)); - CUDNN_CALL(cudnnDestroyTensorDescriptor(xdesc)); - return reserveSpaceSizeInBytes; -} +class DropoutStatePool { + public: + explicit DropoutStatePool(const Device& dev) : device(dev) { + } -TensorValue GetDropoutState(double dropout, int64_t seed) { - Device device(DevType::kCUDA(), 0); - size_t stateSizeInBytes = GetDropoutStateSizeInBytes(); - cudnnDropoutDescriptor_t dropoutDesc; - CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropoutDesc)); - std::shared_ptr memory = memory_pool::Memory::Alloc(device, stateSizeInBytes); - TensorValue state = - TensorValue::Assemble(device, DType(DTypeCode::kInt(), 8), - {static_cast(stateSizeInBytes)}, {}, memory->data, memory); - DLTensor* dlt = state; - CUDNN_CALL(cudnnSetDropoutDescriptor(dropoutDesc, CUDNNThreadEntry::ThreadLocal()->handle, - dropout, dlt->data, stateSizeInBytes, - static_cast(seed))); - CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropoutDesc)); - return state; -} + ~DropoutStatePool() { + for (auto pair : state_pool) { + if (pair.second != nullptr) { + pair.second.reset(); + } + } + } + + static std::shared_ptr Get(const Device& dev) { + static registry::PerDeviceStore* pool = + new registry::PerDeviceStore(); + std::shared_ptr& ret = pool->Get(dev); + if (ret == nullptr) { + std::lock_guard lock(pool->mutex_); + if (ret == nullptr) { + ret = std::make_shared(dev); + } + } + return ret; + } + + /*! + * \brief Get an existing dropout state buffer of the given ratio. + * If the state buffer for the ratio is not created, then this function initializes a new one. + */ + std::pair, bool> GetState(float ratio) { + bool init = false; + std::lock_guard lock(mutex); + if (state_pool.count(ratio) == 0) { + size_t stateSizeInBytes = GetDropoutStateSizeInBytes(); + state_pool[ratio] = memory_pool::Memory::Alloc(device, stateSizeInBytes); + init = true; + } + return {state_pool[ratio], init}; + } + + public: + Device device; + std::unordered_map> state_pool; + std::mutex mutex; +}; RAF_REGISTER_GLOBAL("raf.backend.cudnn.GetDropoutStateSizeInBytes") .set_body_typed(GetDropoutStateSizeInBytes); -RAF_REGISTER_GLOBAL("raf.backend.cudnn.GetDropoutState").set_body_typed(GetDropoutState); +RAF_REGISTER_GLOBAL("raf.backend.cudnn.GetDropoutState") + .set_body_typed([](const Device& dev, double dropout) { + auto state_n_init = DropoutStatePool::Get(dev)->GetState(dropout); + auto buf = state_n_init.first; + auto stateSizeInBytes = GetDropoutStateSizeInBytes(); + if (state_n_init.second) { + // If the state buffer is newly created, then initialize it. + cudnnDropoutDescriptor_t dropoutDesc; + auto seed = tvm::support::LinearCongruentialEngine::DeviceRandom(); + CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropoutDesc)); + CUDNN_CALL(cudnnSetDropoutDescriptor(dropoutDesc, CUDNNThreadEntry::ThreadLocal()->handle, + dropout, buf->data, stateSizeInBytes, seed)); + CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropoutDesc)); + } + TensorValue state = TensorValue::Assemble(dev, DType(DTypeCode::kInt(), 8), + {stateSizeInBytes}, {}, buf->data, buf); + return state; + }); RAF_REGISTER_GLOBAL("raf.backend.cudnn.GetDropoutReserveSpaceSizeInBytes") - .set_body_typed(GetDropoutReserveSpaceSizeInBytes); + .set_body_typed([](TensorType x) { + size_t reserveSpaceSizeInBytes; + cudnnTensorDescriptor_t xdesc = NormalizeTensorType(x); + CUDNN_CALL(cudnnDropoutGetReserveSpaceSize(xdesc, &reserveSpaceSizeInBytes)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(xdesc)); + return (int64_t)reserveSpaceSizeInBytes; + }); static auto fschema_index = ir::Op::GetAttrMap("FRAFSchemaFieldIndex"); @@ -73,24 +119,46 @@ class DropoutImplementedByCUDNNDropoutForward : public raf::op::OpEnv { explicit DropoutImplementedByCUDNNDropoutForward(const CallValues& cv) { auto op = Op::Get("raf.op._contrib_dropout"); - this->arg_indices = { - fschema_index[op]("x"), - fschema_index[op]("in_states"), - }; auto args = cv->args.as(); + dropout = args->p; TupleValue tv = Downcast(cv->out); DLTensor* x = args->x; DLTensor* out = tv->fields[0]; - DLTensor* state = args->in_states.value(); + void* state_data = nullptr; + + // Note that we do not put "in_states" in arg_indices because we do not expect + // in_states to be used in the VM. + this->arg_indices = {fschema_index[op]("x")}; + + bool is_first_dropout = false; + if (args->in_states.get() == nullptr) { + // If no state is provided, use the internal one. + auto state_n_init = DropoutStatePool::Get(cv->device)->GetState(dropout); + auto buf = state_n_init.first; + is_first_dropout = state_n_init.second; + state_data = buf->data; + } else { + DLTensor* state_tensor = args->in_states.value(); + state_data = state_tensor->data; + } xdesc = NormalizeTensorType(SquashTensorShape(x, {})); ydesc = NormalizeTensorType(SquashTensorShape(out, {})); - dropout = args->p; CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropoutDesc)); - CUDNN_CALL( - cudnnDropoutGetStatesSize(CUDNNThreadEntry::ThreadLocal()->handle, &stateSizeInBytes)); + stateSizeInBytes = GetDropoutStateSizeInBytes(); CUDNN_CALL(cudnnDropoutGetReserveSpaceSize(xdesc, &reserveSpaceSizeInBytes)); - CUDNN_CALL(cudnnRestoreDropoutDescriptor(dropoutDesc, CUDNNThreadEntry::ThreadLocal()->handle, - dropout, state->data, stateSizeInBytes, 0)); + + if (is_first_dropout) { + // Initial dropout desc with a certain ratio. This desc will be shared by all dropout ops + // with the same ratio. + auto seed = tvm::support::LinearCongruentialEngine::DeviceRandom(); + CUDNN_CALL(cudnnSetDropoutDescriptor(dropoutDesc, CUDNNThreadEntry::ThreadLocal()->handle, + dropout, state_data, stateSizeInBytes, seed)); + } else { + // The dropout desc has been globally initialized so we just restore it. Note that + // in this case random seend has no effect so we simply put 0. + CUDNN_CALL(cudnnRestoreDropoutDescriptor(dropoutDesc, CUDNNThreadEntry::ThreadLocal()->handle, + dropout, state_data, stateSizeInBytes, 0)); + } } public: @@ -110,7 +178,7 @@ class DropoutImplementedByCUDNNDropoutForward : public raf::op::OpEnv { TupleValue tv = Downcast(cv->out); DLTensor* x = args->x; DLTensor* out = tv->fields[0]; - DLTensor* reserve_space = tv->fields[3]; + DLTensor* reserve_space = tv->fields[2]; CUDNN_CALL(cudnnDropoutForward(CUDNNThreadEntry::ThreadLocal()->handle, dropoutDesc, xdesc, x->data, ydesc, out->data, reserve_space->data, @@ -118,11 +186,11 @@ class DropoutImplementedByCUDNNDropoutForward : public raf::op::OpEnv { } void Execute(const std::vector& inputs, Value output) { - CHECK_EQ(inputs.size(), 2); + CHECK_GE(inputs.size(), 1); TupleValue tv = Downcast(output); DLTensor* x = inputs[0]; DLTensor* out = tv->fields[0]; - DLTensor* reserve_space = tv->fields[3]; + DLTensor* reserve_space = tv->fields[2]; CUDNN_CALL(cudnnDropoutForward(CUDNNThreadEntry::ThreadLocal()->handle, dropoutDesc, xdesc, x->data, ydesc, out->data, reserve_space->data, @@ -144,22 +212,24 @@ class DropoutImplementedByCUDNNDropoutBackward : public raf::op::OpEnv { float dropout; size_t stateSizeInBytes; size_t reserveSpaceSizeInBytes; - void* states; explicit DropoutImplementedByCUDNNDropoutBackward(const CallValues& cv) { this->arg_indices = {/*dy=*/0, /*reserve_space=*/1}; auto args = cv->args.as(); + dropout = args->p; DLTensor* dx = cv->out; DLTensor* dy = args->dy; DLTensor* reserve_space = args->reserve_space; dxdesc = NormalizeTensorType(SquashTensorShape(dx, {})); dydesc = NormalizeTensorType(SquashTensorShape(dy, {})); - dropout = args->p; reserveSpaceSizeInBytes = ComputeStorageInBytes(SquashTensorShape(reserve_space, {})); CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropoutDesc)); CUDNN_CALL( cudnnDropoutGetStatesSize(CUDNNThreadEntry::ThreadLocal()->handle, &stateSizeInBytes)); - RequestWorkspace(&states, cv->device, stateSizeInBytes); + + void* state_data = DropoutStatePool::Get(cv->device)->GetState(dropout).first->data; + CUDNN_CALL(cudnnRestoreDropoutDescriptor(dropoutDesc, CUDNNThreadEntry::ThreadLocal()->handle, + dropout, state_data, stateSizeInBytes, 0)); } public: @@ -180,8 +250,6 @@ class DropoutImplementedByCUDNNDropoutBackward : public raf::op::OpEnv { DLTensor* dy = args->dy; DLTensor* reserve_space = args->reserve_space; - CUDNN_CALL(cudnnRestoreDropoutDescriptor(dropoutDesc, CUDNNThreadEntry::ThreadLocal()->handle, - dropout, states, stateSizeInBytes, 0)); CUDNN_CALL(cudnnDropoutBackward(CUDNNThreadEntry::ThreadLocal()->handle, dropoutDesc, dydesc, dy->data, dxdesc, dx->data, reserve_space->data, reserveSpaceSizeInBytes)); @@ -193,8 +261,6 @@ class DropoutImplementedByCUDNNDropoutBackward : public raf::op::OpEnv { DLTensor* dy = inputs[0]; DLTensor* reserve_space = inputs[1]; - CUDNN_CALL(cudnnRestoreDropoutDescriptor(dropoutDesc, CUDNNThreadEntry::ThreadLocal()->handle, - dropout, states, stateSizeInBytes, 0)); CUDNN_CALL(cudnnDropoutBackward(CUDNNThreadEntry::ThreadLocal()->handle, dropoutDesc, dydesc, dy->data, dxdesc, dx->data, reserve_space->data, reserveSpaceSizeInBytes)); diff --git a/src/op/grad/nn.cc b/src/op/grad/nn.cc index e1a91de2..c0f7460b 100644 --- a/src/op/grad/nn.cc +++ b/src/op/grad/nn.cc @@ -38,7 +38,7 @@ Array ContribDropoutGrad(const Expr& orig_call, const Array orig_arg const static auto dropout_dx = Op::Get("raf.op._contrib_dropout_dx"); const Expr& dy = AsTupleExpr(dout, 2)[0]; const Expr& mask = TupleGetItem(y, 1); - const Expr& reserve_space = TupleGetItem(y, 3); + const Expr& reserve_space = TupleGetItem(y, 2); const Expr& p = orig_args[1]; return {Call(dropout_dx, {dy, mask, reserve_space, p})}; } diff --git a/src/op/ty/nn.cc b/src/op/ty/nn.cc index 379a76d3..6f0e3973 100644 --- a/src/op/ty/nn.cc +++ b/src/op/ty/nn.cc @@ -312,18 +312,12 @@ Type ContribDropoutInfer(const CallValues& value) { reserve_space = TensorType(reserve_space_shape, DataType::UInt(8)); } #endif - TensorType states_ty; - if (args->in_states.defined()) { - states_ty = Downcast(GetType(args->in_states.value())); - } else { - states_ty = TensorType({}, DataType::UInt(8)); - } Array mask_shape; if (include_mask) { mask_shape = x_ty->shape; } TensorType mask_ty(mask_shape, DataType::Float(32)); - return TupleType(Array{x_ty, mask_ty, states_ty, reserve_space}); + return TupleType(Array{x_ty, mask_ty, reserve_space}); } static const auto ContribDropoutBase = ContribDropoutInfer; diff --git a/src/pass/assign_device.cc b/src/pass/assign_device.cc index 3e3254b0..b22ce50e 100644 --- a/src/pass/assign_device.cc +++ b/src/pass/assign_device.cc @@ -196,23 +196,6 @@ class DeviceAssigner : public ExprMutator { } return (*fmap[node_op->name])(node, visited_args, device_str_); } - static const Op& dropout_op = Op::Get("raf.op._contrib_dropout"); - if (Downcast(node->op) == dropout_op) { - if (device_str_ == "cpu" && node->args.size() > 2) { - tvm::Array new_args; - new_args.push_back(node->args[0]); - new_args.push_back(node->args[1]); - return Call(node->op, new_args, node->attrs); - } else if (device_str_ == "cuda" && node->args.size() < 3) { -#ifdef RAF_CXX_USE_CUDNN - tvm::Array new_args = node->args; - auto val = ir::ConstantExtractValue(Downcast(node->args[1])); - new_args.push_back(MakeConstant( - raf::op::cudnn::GetDropoutState(val.as()->value, 4458794440442597400L))); - return Call(node->op, new_args, node->attrs); -#endif - } - } return ExprMutator::VisitExpr_(node); } diff --git a/src/profiler/op_profiler.cc b/src/profiler/op_profiler.cc index 6ee706c6..6d60b04f 100644 --- a/src/profiler/op_profiler.cc +++ b/src/profiler/op_profiler.cc @@ -91,7 +91,11 @@ OpWithData::~OpWithData() { } // Free the input and output buffers. - inputs.clear(); + try { + inputs.clear(); + } catch (dmlc::Error& e) { + return; + } } OpEnvPtr OpProfiler::GetOpEnv(const Expr& op) { diff --git a/tests/python/op/cudnn/test_cudnn_nn.py b/tests/python/op/cudnn/test_cudnn_nn.py index c7d3ae8f..07eeb296 100644 --- a/tests/python/op/cudnn/test_cudnn_nn.py +++ b/tests/python/op/cudnn/test_cudnn_nn.py @@ -3,7 +3,6 @@ # pylint: disable=too-many-locals,too-many-arguments,protected-access,attribute-defined-outside-init # pylint: disable=no-self-use,no-member -import random import pytest import torch import torch.nn.functional as F @@ -267,7 +266,7 @@ def forward(self, m_x, m_m, m_v, m_w, m_b): check(m_b.grad, t_b.grad, rtol=rtol, atol=atol) -@with_dialect(["cudnn", "tvm"]) +@with_dialect(["cudnn"]) @pytest.mark.skipif(not raf.build.with_cuda(), reason="CUDA is not enabled") @pytest.mark.parametrize("dropout", [0.4, 0.6]) def test_raf_dropout(dropout): @@ -286,39 +285,47 @@ def check_dropout(x, y, dx=None, dy=None): class TestModel(raf.Model): def build(self): self.dropout = dropout - self.dropout_state = ndarray.from_tensor_value( - raf._ffi.backend.cudnn.GetDropoutState(dropout, random.getrandbits(63)) - ).to(device="cuda") @raf.model.trace def forward(self, x): - return raf._contrib_dropout(x, dropout, self.dropout_state) + return raf._contrib_dropout(x, dropout) shape, dtype = [1024, 1024], "float32" x, _ = randint(shape, low=10, high=20, dtype=dtype, device="cuda") x.requires_grad = True - model = TestModel() - state_0 = model.dropout_state.to(device="cuda") - m_y = model(x)[0] - state_1 = model.dropout_state.to(device="cuda") - check_dropout(x, m_y) - v_y = run_vm_model(model, "cuda", [x])[0] - state_2 = model.dropout_state.to(device="cuda") - check_dropout(x, v_y) - # state updates (cudnn enforce state inplace updates) - n_y = model(x)[0] - assert not np.array_equal(numpy(state_0), numpy(state_1)) - assert not np.array_equal(numpy(state_1), numpy(state_2)) - assert not np.array_equal(numpy(m_y), numpy(v_y)) - assert not np.array_equal(numpy(m_y), numpy(n_y)) + + # get the random state. Note that its values will be modified after each dropout call + state = ndarray.from_tensor_value( + raf._ffi.backend.cudnn.GetDropoutState(raf.Device("cuda"), dropout) + ) + state_0 = numpy(state) + + # forward + m_y0 = raf._contrib_dropout(x, dropout)[0] + check_dropout(x, m_y0) + + # check whether the state is updated + m_y1 = raf._contrib_dropout(x, dropout)[0] + state_1 = numpy(state) + assert not np.array_equal(state_0, state_1) + assert not np.array_equal(numpy(m_y0), numpy(m_y1)) + # reproducible - model.dropout_state = state_0 - r_y = model(x)[0] - check(m_y, r_y) + r_y = raf._contrib_dropout(x, dropout, raf.array(state_0, device="cuda"))[0] + check(m_y0, r_y) + # backward dy, _ = randn_torch(shape, dtype=dtype, device="cuda") - m_y.backward(dy) - check_dropout(x, m_y, x.grad, dy) + m_y0.backward(dy) + check_dropout(x, m_y0, x.grad, dy) + + # VM + model = TestModel() + v_y = run_vm_model(model, "cuda", [x])[0] + state_2 = numpy(state) + assert not np.array_equal(state_1, state_2) + check_dropout(x, v_y) + assert not np.array_equal(numpy(m_y1), numpy(v_y)) if __name__ == "__main__": From 21deaab50672bc335e530db5c3d5ae598cc620e1 Mon Sep 17 00:00:00 2001 From: AIREMetaBot <100344401+aire-meta-bot@users.noreply.github.com> Date: Tue, 10 May 2022 04:21:28 +0800 Subject: [PATCH 15/37] [TVM] Update Submodule 2022-05-09-13-20-01 (#40) --- tests/python/pass/test_pass_estimate_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/pass/test_pass_estimate_memory.py b/tests/python/pass/test_pass_estimate_memory.py index 1cbfb7a0..228efd92 100644 --- a/tests/python/pass/test_pass_estimate_memory.py +++ b/tests/python/pass/test_pass_estimate_memory.py @@ -27,7 +27,7 @@ def verify_memory(mod, device, expected_trace, disable_fusion=True, include_para for (name, mem), expected in zip(trace, expected_trace): assert name != "unknown" if isinstance(expected, tuple): # The expected memory could be a range. - assert expected[0] <= mem <= expected[1], "{expected[0]} <= {mem} <= {expected[1]}" + assert expected[0] <= mem <= expected[1], f"{expected[0]} <= {mem} <= {expected[1]}" else: check(mem, expected) From fac2ce094e3878483fc504c615301868e1393176 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 9 May 2022 16:51:52 -0700 Subject: [PATCH 16/37] [Pass] Fix liveness analysis with folded constant (#41) --- src/pass/liveness_analysis.cc | 21 ++++++++++++-- .../pass/test_pass_liveness_analysis.py | 28 ++++++++++++++++++- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/pass/liveness_analysis.cc b/src/pass/liveness_analysis.cc index b6032eea..4f14fd83 100644 --- a/src/pass/liveness_analysis.cc +++ b/src/pass/liveness_analysis.cc @@ -162,7 +162,6 @@ void LivenessAnalyzer::ForwardAnalyzer::VisitExpr_(const CallNode* node) { CHECK(var != nullptr) << "Expected the first argument of reshape op to be a Var, but got " << node->args[0]->GetTypeKey(); this->VisitExpr_(var); - } else { Var dummy = analyzer_->CreateTensorVar(node->checked_type()); analyzer_->Init(let_var_, dummy); @@ -176,6 +175,14 @@ void LivenessAnalyzer::ForwardAnalyzer::VisitExpr_(const TupleNode* node) { if (field.as()) { // Ignore constant fields (e.g., NoGradValue) var = Downcast(field); + } else if (auto const_node = field.as()) { + // If the constant is a tensor, it is folded by constant folding and initialized + // in this tuple. For example: `let %a2 = (%p0, %a1, tensor(5x5, float32, cuda(0)))`. + // In this case, we have to create a dummy liveness var for it, because this tensor + // might be used in the future such as `let %a3 = %a2.2`. + if (auto ttype = const_node->checked_type().as()) { + var = analyzer_->CreateTensorVar(GetRef(ttype)); + } } fields.push_back(var); } @@ -281,7 +288,17 @@ void LivenessAnalyzer::BackwardAnalyzer::VisitExpr_(const CallNode* node) { } void LivenessAnalyzer::BackwardAnalyzer::VisitExpr_(const TupleNode* node) { - analyzer_->live_[let_var_] = analyzer_->vset_[MergeLive(let_var_)]; + Array var_fields; + for (const auto& field : node->fields) { + // If the field is a constant, then its life must start from this tuple, + // so we do not merge its life with other fields. + if (field.as()) { + var_fields.push_back(Downcast(field)); + } + } + Var d1 = analyzer_->Merge(var_fields); + Var d2 = MergeLive(d1, let_var_); + analyzer_->live_[let_var_] = analyzer_->vset_[d2]; } void LivenessAnalyzer::BackwardAnalyzer::VisitExpr_(const TupleGetItemNode* node) { diff --git a/tests/python/pass/test_pass_liveness_analysis.py b/tests/python/pass/test_pass_liveness_analysis.py index 43a1abcf..7d1ec308 100644 --- a/tests/python/pass/test_pass_liveness_analysis.py +++ b/tests/python/pass/test_pass_liveness_analysis.py @@ -7,7 +7,7 @@ import raf from raf._lib import tvm, relay from raf.ir import ScopeBuilder -from raf._ffi.pass_ import InferType, LivenessAnalysis, ManifestAlloc +from raf._ffi.pass_ import InferType, LivenessAnalysis, ManifestAlloc, FoldConstant from raf.testing import randn @@ -213,6 +213,32 @@ def test_direct_assign(): verify_live_in_set(mod, expected) +def test_folded_const_in_tuple(): + shape = (5, 5) + + sb = ScopeBuilder() + p0 = raf.ir.var("p0", shape=shape) + a1 = sb.let("a1", raf.ir.op.relu(p0)) + x0 = sb.let("x0", raf.ir.op.zeros(raf.ir.const(shape), raf.ir.const("float32"))) + # x0 will be folded and we should have "let %a2 = (%p0, %a1, tensor(5x5, float32, cpu(0)))" + a2 = sb.let("a2", relay.Tuple([p0, a1, x0])) + a3 = sb.let("a3", relay.TupleGetItem(a2, 2)) + sb.ret(a3) + mod = tvm.IRModule.from_expr(relay.Function([p0], sb.get())) + mod = InferType()(mod) + mod = FoldConstant()(mod) + mod = InferType()(mod) + + expected = { + "n_0": {}, + "a1": {"param_0"}, + "a2": {"param_0", "t_0"}, + "a3": {"t_1"}, # t_1 is the folded const generated at %a2 + "n_1": {"t_1"}, + } + verify_live_in_set(mod, expected) + + def test_reshape(): shape = (10, 10) From 79c28b1b764a62c8039fc3182cce4666c48063a9 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 10 May 2022 10:27:12 -0700 Subject: [PATCH 17/37] [Remat] Fix remat with concat (#42) * [Remat] Fix remat with concat * fix --- src/pass/rematerialization.cc | 6 +++- .../pass/test_pass_rematerialization.py | 33 +++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/pass/rematerialization.cc b/src/pass/rematerialization.cc index ec4cce49..f9e7d1f5 100644 --- a/src/pass/rematerialization.cc +++ b/src/pass/rematerialization.cc @@ -592,6 +592,11 @@ class Rematerializer : public ExprMutator { static const auto reshape_op = Op::Get("raf.op.reshape"); auto tensor_infos = tensor_infos_.GetTensorInfoFromLetVar(target_let_var); + if (tensor_infos.size() > 1) { + // TODO: Handle tuple. This happens for ops taking a tuple as an argument (e.g., concat). + // Ideally we should check each element in the tuple and correct their types. + return target_let_var; + } auto curr_let_var = tensor_infos[0]->let_var; // Do nothing if the types are already match. @@ -603,7 +608,6 @@ class Rematerializer : public ExprMutator { // If not match, then the target let_var must be in a tensor type. auto target_type = target_let_var->checked_type().as(); CHECK(target_type != nullptr); - CHECK_EQ(tensor_infos.size(), 1U); // Need to generate a TupleGetItem node. if (auto tuple_type_node = curr_let_var->checked_type().as()) { diff --git a/tests/python/pass/test_pass_rematerialization.py b/tests/python/pass/test_pass_rematerialization.py index 86ffe35f..348e8a26 100644 --- a/tests/python/pass/test_pass_rematerialization.py +++ b/tests/python/pass/test_pass_rematerialization.py @@ -633,5 +633,38 @@ def expected(): verify_remat(model, args, 4.01172, expected(), (5.01172, 4.01172)) +def test_concat(): + """ + A simple test program to check whether the remat pass can handle concat. + No actual rematerialization is taking place. + """ + device = "cpu" + shape = (16, 16, 64, 64) # 4 MBs + + def get_mod(): + relu_op = raf._ffi.op.GetOp("raf.op.relu") + concat_op = raf._ffi.op.GetOp("raf.op.concatenate") + + # param: 8 MBs + p_0 = raf.ir.var("p0", shape=shape) + p_1 = raf.ir.var("p1", shape=shape) + + sb = ScopeBuilder() + a_1 = sb.let("a1", relay.Call(relu_op, [p_0])) + a_2 = sb.let("a2", relay.Call(relu_op, [p_1])) + a_3 = sb.let("a3", relay.Tuple([a_1, a_2])) + a_4 = sb.let("a4", relay.Call(concat_op, [a_3, raf.ir.const(1)])) + sb.ret(a_4) + func = relay.Function([p_0, p_1], sb.get()) + return tvm.IRModule.from_expr(func) + + m_p0, _ = randn(shape, device=device) + m_p1, _ = randn(shape, device=device) + + # Set the memory budget to be higher than the peak + # The IR should remain unchanged after the remat pass + verify_remat(get_mod(), [m_p0, m_p1], 32, get_mod()["main"], (24.00, 24.00)) + + if __name__ == "__main__": pytest.main([__file__]) From 6117b0c2c1ef08a7761c4b77bd2957dd3a0bae54 Mon Sep 17 00:00:00 2001 From: "Huang, Guangtai" Date: Wed, 11 May 2022 19:05:44 +0800 Subject: [PATCH 18/37] [Op] `strided_set` to avoid use te.gradient (#43) * scatter_strided_slice -> strided_set * test * strided_slice_dx use strided_set * address comments --- python/raf/_tvm_op/transform.py | 48 +++++----- python/raf/testing/__init__.py | 1 - python/raf/testing/topi/__init__.py | 9 -- .../topi/scatter_strided_slice_python.py | 78 ---------------- scripts/src_codegen/def_op.py | 2 +- scripts/src_codegen/def_schema.py | 27 +++--- src/op/declare/transform.cc | 36 ++++---- src/op/dialect/tvm/transform.cc | 89 +++++++++---------- src/op/ty/transform.cc | 18 ++-- tests/python/op/tvm/test_tvm_transform.py | 7 +- 10 files changed, 113 insertions(+), 202 deletions(-) delete mode 100644 python/raf/testing/topi/__init__.py delete mode 100644 python/raf/testing/topi/scatter_strided_slice_python.py diff --git a/python/raf/_tvm_op/transform.py b/python/raf/_tvm_op/transform.py index 3ca96f50..7cc6d561 100644 --- a/python/raf/_tvm_op/transform.py +++ b/python/raf/_tvm_op/transform.py @@ -1,15 +1,24 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -# pylint: disable=missing-function-docstring, undefined-loop-variable, unused-argument +# pylint: disable=missing-function-docstring, undefined-loop-variable, unused-argument, invalid-name """Compute definition and schedules for data transform operators""" +import numpy as np + from raf._tvm_op.nn import schedule_generic from .._lib import register_compute from .._lib import strategy from .._lib import tvm as _tvm # pylint: disable=unused-import from .._lib import _reg -_topi = _tvm.topi # pylint: disable=invalid-name,no-member +_topi = _tvm.topi # pylint: disable=no-member + + +def _to_const_tensor_1d(x, dtype="int32"): + x = _topi.utils.get_const_tuple(x) + x = np.array(x, dtype=dtype) + x = _topi.utils.const_vector(x) + return x @register_compute("raf.op.tvm.embedding") @@ -102,24 +111,17 @@ def fcompute(*args): return [out] -@register_compute("raf.op.tvm.scatter_strided_slice") -def scatter_strided_slice_compute(attrs, inputs, output_type): - x, src = inputs - begin, end, strides, slice_mode = attrs.begin, attrs.end, attrs.strides, attrs.slice_mode - - ones = _tvm.topi.full_like(src, _tvm.tir.const(1.0, x.dtype)) - var = _tvm.te.placeholder(shape=x.shape, dtype=x.dtype) - slices = _topi.nn.strided_slice(var, begin, end, strides, None, slice_mode) - matched_slices = _tvm.te.gradient(slices, [var], head=ones)[0] - matched_slices_value = _tvm.te.gradient(slices, [var], head=src)[0] - - out = matched_slices_value * matched_slices + x * (1 - matched_slices) - return [out] +@register_compute("raf.op.tvm.strided_set") +def strided_set_compute(attrs, inputs, output_type): + data, v = inputs + begin = _to_const_tensor_1d(attrs.begin) + end = _to_const_tensor_1d(attrs.end) + strides = _to_const_tensor_1d(attrs.strides) + return [_topi.strided_set(data, v, begin, end, strides)] _reg.register_strategy("raf.op.tvm.scatter", strategy.scatter_strategy) _reg.register_injective_schedule("raf.op.tvm.scatter_dx") -_reg.register_injective_schedule("raf.op.tvm.scatter_strided_slice") _reg.register_injective_schedule("raf.op.tvm.transpose_dx") _reg.register_injective_schedule("raf.op.tvm.transpose") _reg.register_injective_schedule("raf.op.tvm.swap_axis") @@ -145,6 +147,7 @@ def scatter_strided_slice_compute(attrs, inputs, output_type): _reg.register_injective_schedule("raf.op.tvm.batch_flatten") _reg.register_injective_schedule("raf.op.tvm.arange") _reg.register_injective_schedule("raf.op.tvm.strided_slice") +_reg.register_injective_schedule("raf.op.tvm.strided_set") _reg.register_reduce_schedule("raf.op.tvm.collapse_sum_like") @@ -187,12 +190,13 @@ def take_dx_compute(attrs, inputs, output_type): @register_compute("raf.op.tvm.strided_slice_dx") def strided_slice_dx_compute(attrs, inputs, output_type): - dy = inputs[0] - begin, end, strides, slice_mode = attrs.begin, attrs.end, attrs.strides, attrs.slice_mode - var = _tvm.te.placeholder(shape=attrs.primal_shape, dtype=dy.dtype) - out = _topi.nn.strided_slice(var, begin, end, strides, None, slice_mode) - grads = _tvm.te.gradient(out, [var], head=dy) - return grads + assert attrs.slice_mode == "end" + v = inputs[0] + data = _topi.full(attrs.primal_shape, v.dtype, 0.0) + begin = _to_const_tensor_1d(attrs.begin) + end = _to_const_tensor_1d(attrs.end) + strides = _to_const_tensor_1d(attrs.strides) + return [_topi.strided_set(data, v, begin, end, strides)] _reg.register_injective_schedule("raf.op.tvm.strided_slice_dx") diff --git a/python/raf/testing/__init__.py b/python/raf/testing/__init__.py index 26464288..a458ec48 100644 --- a/python/raf/testing/__init__.py +++ b/python/raf/testing/__init__.py @@ -5,6 +5,5 @@ from .common import * from .pt_models import * from .utils import * -from .topi import * from . import resnet from . import resnet_cifar10 diff --git a/python/raf/testing/topi/__init__.py b/python/raf/testing/topi/__init__.py deleted file mode 100644 index b58ccdb0..00000000 --- a/python/raf/testing/topi/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""TOPI Testing Util functions. - -Used to verify the correctness of operators in TOPI . -""" - -from .scatter_strided_slice_python import * diff --git a/python/raf/testing/topi/scatter_strided_slice_python.py b/python/raf/testing/topi/scatter_strided_slice_python.py deleted file mode 100644 index 5244d30b..00000000 --- a/python/raf/testing/topi/scatter_strided_slice_python.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 - -# pylint: disable=too-many-arguments - -"""scatter_strided_slice in python""" - - -def scatter_strided_slice_python(data, src, begin, end, strides, slice_mode="end", axes=None): - """Python version of scatter strided slice operator. - - Parameters - ---------- - data : numpy.ndarray - Input data - - begin : list - Beginning of the slices. - - end : list - End of the slices. - - strides : list - The stride of each slice. - - slice_mode : str, optional - The slice mode [end, size]. - end: The default slice mode, ending indices for the slice. - size: The input strides will be ignored, input end in this mode indicates - the sizeof a slice starting at the location specified by begin. If end[i] is -1, - all remaining elements in that dimension are included in the slice. - - axes : list, optional - Axes along which slicing is applied - - Returns - ------- - result : numpy.ndarray - The sliced result. - """ - strides = [] if strides is None else strides - if axes is not None: - rank = len(data.shape) - new_begin = [0] * rank - new_end = [data.shape[i] for i in range(rank)] - new_strides = [1] * rank - - for i, axis in enumerate(axes): - new_begin[axis] = begin[i] - new_end[axis] = end[i] - if len(strides) > i: - new_strides[axis] = strides[i] - - begin = new_begin - end = new_end - strides = new_strides - - slices = [] - for i in range(len(data.shape)): - new_stride = None - if slice_mode == "end" and i < len(strides): - new_stride = strides[i] - - new_begin = begin[i] if i < len(begin) else None - if i >= len(end): - new_end = None - elif slice_mode == "size": - if end[i] < 0: - new_end = None - else: - new_end = new_begin + end[i] - else: - new_end = end[i] - - slices.append(slice(new_begin, new_end, new_stride)) - - data[tuple(slices)] = src - return data diff --git a/scripts/src_codegen/def_op.py b/scripts/src_codegen/def_op.py index 62edc5ef..f792d6e3 100644 --- a/scripts/src_codegen/def_op.py +++ b/scripts/src_codegen/def_op.py @@ -132,6 +132,7 @@ Op(name="threefry_split", schema_name="threefry_split"), Op(name="strided_slice", schema_name="strided_slice"), Op(name="strided_slice_dx", schema_name="strided_slice_dx"), + Op(name="strided_set", schema_name="strided_set"), Op(name="sequence_mask", schema_name="sequence_mask"), Op(name="reverse_sequence", schema_name="reverse_sequence"), Op(name="reverse", schema_name="reverse"), @@ -148,7 +149,6 @@ Op(name="layer_norm_train", schema_name="layer_norm"), Op(name="scatter", schema_name="scatter"), Op(name="scatter_dx", schema_name="scatter_dx"), - Op(name="scatter_strided_slice", schema_name="scatter_strided_slice"), Op(name="layer_norm_dx", schema_name="layer_norm_dx"), Op(name="layer_norm_train_dx", schema_name="layer_norm_train_dx"), Op(name="concatenate_dx", schema_name="concatenate"), diff --git a/scripts/src_codegen/def_schema.py b/scripts/src_codegen/def_schema.py index 93d9a3b3..acdd5ec6 100644 --- a/scripts/src_codegen/def_schema.py +++ b/scripts/src_codegen/def_schema.py @@ -492,20 +492,6 @@ Arg(name="src", cxx_type="value::BaseTensorValue"), Arg(name="axis", cxx_type="value::Value"), ], - "transform.h::scatter_strided_slice": [ - Arg(name="x", cxx_type="value::BaseTensorValue"), - Arg(name="src", cxx_type="value::BaseTensorValue"), - Arg(name="begin", cxx_type="std::vector", cxx_normalizer="IntTuple"), - Arg(name="end", cxx_type="std::vector", cxx_normalizer="IntTuple"), - Arg( - name="strides", - cxx_type="std::vector", - cxx_normalizer="IntTuple", - cxx_default="{}", - py_default="None", - ), - Arg(name="slice_mode", cxx_type="std::string", cxx_default='"end"', py_default='"end"'), - ], "transform.h::transpose": [ Arg(name="x", cxx_type="value::BaseTensorValue"), Arg( @@ -561,6 +547,19 @@ ), Arg(name="slice_mode", cxx_type="std::string", cxx_default='"end"', py_default='"end"'), ], + "transform.h::strided_set": [ + Arg(name="data", cxx_type="value::BaseTensorValue"), + Arg(name="v", cxx_type="value::BaseTensorValue"), + Arg(name="begin", cxx_type="std::vector", cxx_normalizer="IntTuple"), + Arg(name="end", cxx_type="std::vector", cxx_normalizer="IntTuple"), + Arg( + name="strides", + cxx_type="std::vector", + cxx_normalizer="IntTuple", + cxx_default="{}", + py_default="None", + ), + ], "likes.h::sum_dx": [ Arg(name="x", cxx_type="value::BaseTensorValue"), Arg(name="dy", cxx_type="value::BaseTensorValue"), diff --git a/src/op/declare/transform.cc b/src/op/declare/transform.cc index b7de5365..14c4776c 100644 --- a/src/op/declare/transform.cc +++ b/src/op/declare/transform.cc @@ -437,6 +437,24 @@ RAF_OP_DECLARE("raf.op.strided_slice_dx", [](const CallValues& call) { throw; }); +RAF_OP_DECLARE("raf.op.strided_set", [](const CallValues& call) { + const auto* args = call->args.as(); + CHECK(args != nullptr); + DLTensor* data = args->data; + DLTensor* v = args->v; + + CHECK(!args->begin.empty()) << "strided_set received invalid begin"; + CHECK(!args->end.empty()) << "strided_set received invalid end"; + CHECK_EQ(args->begin.size(), args->end.size()) << "begin.size() != end.size()"; + CHECK_EQ(data->ndim, v->ndim); + + std::vector shape(data->shape, data->shape + data->ndim); + call->device = data->device; + call->out = TensorValue::Assemble(/*dev=*/data->device, + /*dtype=*/data->dtype, + /*shape=*/shape); +}); + RAF_OP_DECLARE("raf.op.sequence_mask", [](const CallValues& call) { const auto* args = call->args.as(); CHECK(args != nullptr); @@ -825,24 +843,6 @@ RAF_OP_DECLARE("raf.op.scatter_dx", [](const CallValues& call) { call->device = x->device; }); -RAF_OP_DECLARE("raf.op.scatter_strided_slice", [](const CallValues& call) { - const auto* args = call->args.as(); - CHECK(args != nullptr); - DLTensor* x = args->x; - DLTensor* src_tensor = args->src; - - CHECK(!args->begin.empty()) << "scatter_strided_slice received invalid begin"; - CHECK(!args->end.empty()) << "scatter_strided_slice received invalid end"; - CHECK_EQ(args->begin.size(), args->end.size()) << "begin.size() != end.size()"; - CHECK_EQ(x->ndim, src_tensor->ndim); - - std::vector shape(x->shape, x->shape + x->ndim); - call->device = x->device; - call->out = TensorValue::Assemble(/*dev=*/x->device, - /*dtype=*/x->dtype, - /*shape=*/shape); -}); - RAF_OP_DECLARE("raf.op.clip", [](const CallValues& call) { const auto* args = call->args.as(); CHECK(args != nullptr); diff --git a/src/op/dialect/tvm/transform.cc b/src/op/dialect/tvm/transform.cc index d6e53973..e9bc44a1 100644 --- a/src/op/dialect/tvm/transform.cc +++ b/src/op/dialect/tvm/transform.cc @@ -496,52 +496,6 @@ HashKey ScatterDxHasher(const std::vector& param_types, const Type& y_type RAF_TVM(scatter_dx, ScatterDx, ScatterDxArgs, ScatterDxSchema2Args, ScatterDxSchemaArgNames, ScatterDxSchema2Attrs, ScatterDxHasher, kInjective); -std::vector ScatterStridedSliceSchema2Args(const ScatterStridedSliceArgs* args) { - return {args->x, args->src}; -} - -std::vector ScatterStridedSliceSchemaArgNames(const op::CallValues& call) { - return {"x", "src"}; -} - -Attrs ScatterStridedSliceSchema2Attrs(const ScatterStridedSliceArgs* args) { - auto attrs = make_object(); - CHECK_EQ(args->begin.size(), args->end.size()); - CHECK_EQ(args->begin.size(), args->strides.size()); - std::vector begin, end, strides; - TensorType x_type = Downcast(GetType(args->x)); - int i; - for (i = 0; i < args->begin.size(); i++) { - begin.emplace_back(args->begin[i]); - end.emplace_back(args->end[i]); - strides.emplace_back(args->strides[i]); - } - for (; i < x_type->shape.size(); i++) { - begin.emplace_back(0); - end.emplace_back(x_type->shape[i].as()->value); - strides.emplace_back(1); - } - attrs->begin = Array(begin.begin(), begin.end()); - attrs->end = Array(end.begin(), end.end()); - attrs->strides = Array(strides.begin(), strides.end()); - attrs->slice_mode = args->slice_mode; - return Attrs(attrs); -} - -HashKey ScatterStridedSliceHasher(const std::vector& param_types, const Type& y_type, - const ScatterStridedSliceArgs* args) { - HashKey key = GenericHasher(param_types, y_type, nullptr); - key << args->begin; - key << args->end; - key << args->strides; - key << args->slice_mode; - return key; -} - -RAF_TVM(scatter_strided_slice, ScatterStridedSlice, ScatterStridedSliceArgs, - ScatterStridedSliceSchema2Args, ScatterStridedSliceSchemaArgNames, - ScatterStridedSliceSchema2Attrs, ScatterStridedSliceHasher, kInjective); - std::vector ConcatenateSchema2Args(const ConcatenateArgs* args) { std::vector ret; for (auto v : args->x) { @@ -1122,6 +1076,49 @@ HashKey StridedSliceDxHasher(const std::vector& param_types, const Type& y RAF_TVM(strided_slice_dx, StridedSliceDx, StridedSliceDxArgs, StridedSliceDxSchema2Args, StridedSliceDxSchemaArgNames, StridedSliceDxSchema2Attrs, StridedSliceDxHasher, kInjective); +std::vector StridedSetSchema2Args(const StridedSetArgs* args) { + return {args->data, args->v}; +} + +std::vector StridedSetSchemaArgNames(const op::CallValues& call) { + return {"data", "v"}; +} + +Attrs StridedSetSchema2Attrs(const StridedSetArgs* args) { + auto attrs = make_object(); + CHECK_EQ(args->begin.size(), args->end.size()); + CHECK_EQ(args->begin.size(), args->strides.size()); + std::vector begin, end, strides; + TensorType data_type = Downcast(GetType(args->data)); + int i; + for (i = 0; i < args->begin.size(); i++) { + begin.emplace_back(args->begin[i]); + end.emplace_back(args->end[i]); + strides.emplace_back(args->strides[i]); + } + for (; i < data_type->shape.size(); i++) { + begin.emplace_back(0); + end.emplace_back(data_type->shape[i].as()->value); + strides.emplace_back(1); + } + attrs->begin = Array(begin.begin(), begin.end()); + attrs->end = Array(end.begin(), end.end()); + attrs->strides = Array(strides.begin(), strides.end()); + return Attrs(attrs); +} + +HashKey StridedSetHasher(const std::vector& param_types, const Type& y_type, + const StridedSetArgs* args) { + HashKey key = GenericHasher(param_types, y_type, nullptr); + key << args->begin; + key << args->end; + key << args->strides; + return key; +} + +RAF_TVM(strided_set, StridedSet, StridedSetArgs, StridedSetSchema2Args, StridedSetSchemaArgNames, + StridedSetSchema2Attrs, StridedSetHasher, kInjective); + std::vector WhereSchema2Args(const WhereArgs* args) { return {args->condition, args->x, args->y}; } diff --git a/src/op/ty/transform.cc b/src/op/ty/transform.cc index 3065138a..b3b63c42 100644 --- a/src/op/ty/transform.cc +++ b/src/op/ty/transform.cc @@ -459,15 +459,6 @@ Type ScatterDxInfer(const CallValues& value) { RAF_OP_TYPE("raf.op.scatter_dx", "ScatterDx", ScatterDxInfer); -Type ScatterStridedSliceInfer(const CallValues& value) { - const auto* args = value->args.as(); - CHECK(args != nullptr); - TensorType x = Downcast(GetType(args->x)); - return x; -} - -RAF_OP_TYPE("raf.op.scatter_strided_slice", "ScaterStridedSlice", ScatterStridedSliceInfer); - Type CastInfer(const CallValues& value) { const auto* args = value->args.as(); CHECK(args != nullptr); @@ -829,6 +820,15 @@ Type StridedSliceDxInfer(const CallValues& value) { RAF_OP_TYPE("raf.op.strided_slice", "StridedSlice", StridedSliceInfer); RAF_OP_TYPE("raf.op.strided_slice_dx", "StridedSliceDx", StridedSliceDxInfer); +Type StridedSetInfer(const CallValues& value) { + const auto* args = value->args.as(); + CHECK(args != nullptr); + TensorType data = Downcast(GetType(args->data)); + return data; +} + +RAF_OP_TYPE("raf.op.strided_set", "StridedSet", StridedSetInfer); + Type SqueezeInfer(const CallValues& value) { const auto* args = value->args.as(); CHECK(args != nullptr); diff --git a/tests/python/op/tvm/test_tvm_transform.py b/tests/python/op/tvm/test_tvm_transform.py index afab6482..ec172d08 100644 --- a/tests/python/op/tvm/test_tvm_transform.py +++ b/tests/python/op/tvm/test_tvm_transform.py @@ -19,7 +19,6 @@ randint, check, run_vm_model, - scatter_strided_slice_python, ) import tvm.topi.testing as npx # pylint: disable=no-name-in-module @@ -242,7 +241,7 @@ def test_scatter(shape, axis, device): ], ) @pytest.mark.parametrize("dtype", ["float16", "float32"]) -def test_scatter_strided_slice(device, params, dtype): +def test_strided_set(device, params, dtype): # Skip float16 tests on CPU since it may not be supported and not much performance benefit. if dtype == "float16" and device == "cpu": pytest.skip("float16 is not supported on CPU") @@ -251,9 +250,9 @@ def test_scatter_strided_slice(device, params, dtype): n_slice = npx.strided_slice_python(n_x, begin, end, strides) m_src, n_src = randn(n_slice.shape, device=device, dtype=dtype) - model = TestModel(raf._op.sym.scatter_strided_slice, begin=begin, end=end, strides=strides) + model = TestModel(raf._op.sym.strided_set, begin=begin, end=end, strides=strides) m_y = model(m_x, m_src) - n_y = scatter_strided_slice_python(n_x, n_src, begin, end, strides) + n_y = npx.strided_set_python(n_x, n_src, begin, end, strides) v_y = run_vm_model(model, device, [m_x, m_src]) check(m_y, n_y) From 28d1aa6d861e11829a4c751b6a84909dcfb0925d Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 11 May 2022 16:03:40 -0700 Subject: [PATCH 19/37] [Util] Env config to dump IRs (#45) * [Util] Env config to dump IRs * fix * fix --- include/raf/cache.h | 19 ++--------------- include/raf/file.h | 34 +++++++++++++++++++++++++++++++ python/raf/optim/data_parallel.py | 12 +++++------ python/raf/optim/optim.py | 2 +- src/impl/model.cc | 4 ++-- src/impl/vm/compiler.cc | 2 +- src/pass/pass_manager.cc | 32 +++++++++++++++++++++++++++-- 7 files changed, 76 insertions(+), 29 deletions(-) create mode 100644 include/raf/file.h diff --git a/include/raf/cache.h b/include/raf/cache.h index 4a5b470d..f4ac3ee3 100644 --- a/include/raf/cache.h +++ b/include/raf/cache.h @@ -4,7 +4,7 @@ */ /*! - * \file src/op/cache.h + * \file cache.h * \brief The RAF cache. */ #pragma once @@ -12,6 +12,7 @@ #include #include #include +#include "./file.h" #include "./op.h" #include "./value.h" @@ -319,22 +320,6 @@ class MetaPersistCache : public MetaCache, public MetaCacheMetric { return path_ + "/" + std::to_string(hashed_key); } - inline void CreateDir(const std::string& path) { - if (mkdir(path.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH) == -1) { - if (errno != EEXIST) { - LOG(FATAL) << "Failed to create directory " << path << ": " << strerror(errno); - throw; - } - } - } - - inline bool DirExists(const std::string& path) { - std::ifstream ifs(path); - auto ret = ifs.good(); - ifs.close(); - return ret; - } - inline void AddMetric(const std::string name, size_t val) { metrics_[name] += val; } diff --git a/include/raf/file.h b/include/raf/file.h new file mode 100644 index 00000000..6f8be91f --- /dev/null +++ b/include/raf/file.h @@ -0,0 +1,34 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +/*! + * \file file.h + * \brief File/Directory manipulation functions. + */ +#pragma once + +#include +#include +#include +#include +#include "dmlc/logging.h" + +namespace raf { +inline void CreateDir(const std::string& path) { + if (mkdir(path.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH) == -1) { + if (errno != EEXIST) { + LOG(FATAL) << "Failed to create directory " << path << ": " << strerror(errno); + throw; + } + } +} + +inline bool DirExists(const std::string& path) { + std::ifstream ifs(path); + auto ret = ifs.good(); + ifs.close(); + return ret; +} +} // namespace raf diff --git a/python/raf/optim/data_parallel.py b/python/raf/optim/data_parallel.py index cc71236a..2fa9ece2 100644 --- a/python/raf/optim/data_parallel.py +++ b/python/raf/optim/data_parallel.py @@ -40,25 +40,25 @@ def build(self, model): @trace def forward(self, *args, **kwargs): # pylint: disable=protected-access, missing-function-docstring - passes = [] dcfg = dist.get_config() comm = dist.get_communicator() + record = self.model._internal(*args, **kwargs) + mod = record.mod + # TODO: Refactor AutoDataParallel to let it work on the IR after InlineBackward # so that it can be applied here. # if dcfg.enable_data_parallel: # passes.append(AutoDataParallel()) if dcfg.zero_opt_level > 0: + passes = [] passes.append(InferType()) passes.append( PartitionGradient( dcfg.zero_opt_level, comm.size, comm.rank, dcfg.group_bucket_size ) ) - - record = self.model._internal(*args, **kwargs) - mod = record.mod - seq = RAFSequential(passes) - mod = seq(mod) + seq = RAFSequential(passes, name="with_data_parallel") + mod = seq(mod) inputs = _get_func_inputs(record, args, kwargs) out = inline(mod["main"], inputs) y = out[0] diff --git a/python/raf/optim/optim.py b/python/raf/optim/optim.py index 9e7f30c9..3c65126e 100644 --- a/python/raf/optim/optim.py +++ b/python/raf/optim/optim.py @@ -80,7 +80,7 @@ def forward(self, dy, *args, **kwargs): # TODO: Refactor AutoDataParallel to let it work on the IR after InlineBackward. passes += [AutoDataParallel()] passes += [InferType(), FoldConstant(), DeadCodeElimination(), InlineBackward()] - seq = RAFSequential(passes) + seq = RAFSequential(passes, name="with_autodiff") mod = seq(mod) inputs = _get_func_inputs(record, args, kwargs) inputs = inputs + [get_symbol_handle(dy)] diff --git a/src/impl/model.cc b/src/impl/model.cc index 54e94ace..1ce9365f 100644 --- a/src/impl/model.cc +++ b/src/impl/model.cc @@ -58,7 +58,7 @@ ObjectRef RunModel(ir::IRModule mod, Array args) { if (!requires_grad) { // TODO(haibin): add simplify inference pass - simplify the compute of // BN, LN, Dropout, GN, etc. - raf::pass::RAFSequential seq({CanonicalizeOps(), FoldConstant()}); + raf::pass::RAFSequential seq({CanonicalizeOps(), FoldConstant()}, "interpreter_infer_optimize"); updated_mod = seq(updated_mod); func = Downcast(updated_mod->Lookup("main")); auto call_node = Call(func, args); @@ -81,7 +81,7 @@ ObjectRef RunModel(ir::IRModule mod, Array args) { // run const folding pass passes.push_back(FoldConstant()); - raf::pass::RAFSequential seq(passes); + raf::pass::RAFSequential seq(passes, "interpreter_optimize"); updated_mod = seq(updated_mod); func = Downcast(updated_mod->Lookup("main")); TupleValue result = Downcast(Interpret(Call(func, args), updated_mod)); diff --git a/src/impl/vm/compiler.cc b/src/impl/vm/compiler.cc index e5e14d12..35bd5fe0 100644 --- a/src/impl/vm/compiler.cc +++ b/src/impl/vm/compiler.cc @@ -930,7 +930,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const DeviceMap& device pass_seqs.push_back(pass::ManifestAlloc()); pass_seqs.push_back(pass::MemoryPlan()); - pass::RAFSequential seq(pass_seqs); + pass::RAFSequential seq(pass_seqs, "vm_compiler_optimize"); return seq(mod); } diff --git a/src/pass/pass_manager.cc b/src/pass/pass_manager.cc index 2e862aec..c74d6cbf 100644 --- a/src/pass/pass_manager.cc +++ b/src/pass/pass_manager.cc @@ -7,10 +7,9 @@ * \file src/pass/pass_manager.cc * \brief Infrastructure for transformation passes. */ - -#include #include +#include "raf/file.h" #include "raf/pass.h" #include "raf/pass_manager.h" #include "raf/registry.h" @@ -98,10 +97,38 @@ inline Pass GetPass(const String& pass_name) { return (*f)(); } +std::string DumpAfterPassIRToFile(std::string dump_ir_path, const IRModule& mod, size_t idx, + std::string pass_name) { + if (dump_ir_path.empty()) { + return ""; + } + // Dump the IR to the folder. + std::ofstream ofs(dump_ir_path + "/" + std::to_string(idx) + "_" + pass_name + ".txt"); + ofs << raf::ir::AsText(mod); + return dump_ir_path; +} + // TODO(zhiics): we currenlty only sequentially execute each pass in // a RAFSequential without the consideration of their orders. The phase // ordering problem needs to be handled in the future. IRModule RAFSequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const { + const char* raf_dump_ir_to = getenv("RAF_DUMP_IR_TO"); + std::string dump_ir_path = ""; + if (raf_dump_ir_to != nullptr) { + dump_ir_path = std::string(raf_dump_ir_to); + // Create parent directory if it doesn't exist. + CreateDir(dump_ir_path); + + // Create a unique sequence directory. + dump_ir_path += "/" + pass_info->name; + if (DirExists(dump_ir_path)) { + dump_ir_path += "_1"; + } + CreateDir(dump_ir_path); + DumpAfterPassIRToFile(dump_ir_path, mod, 0, "init"); + } + + size_t pass_cnt = 1; for (const Pass& pass : passes) { ICHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); @@ -111,6 +138,7 @@ IRModule RAFSequentialNode::operator()(IRModule mod, const PassContext& pass_ctx mod = GetPass(it)(std::move(mod), pass_ctx); } mod = pass(std::move(mod), pass_ctx); + DumpAfterPassIRToFile(dump_ir_path, mod, pass_cnt++, pass_info->name); } return mod; } From 704040cc6dbe5d3e36853f4800fafa6c7d3eca06 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 12 May 2022 20:25:26 -0700 Subject: [PATCH 20/37] [Op] Fix power_dx (#47) --- python/raf/testing/common.py | 13 +++++++---- src/op/grad/binary.cc | 30 ++++++++++++++++---------- tests/python/op/tvm/test_tvm_binary.py | 19 +++++++++++++++- 3 files changed, 46 insertions(+), 16 deletions(-) diff --git a/python/raf/testing/common.py b/python/raf/testing/common.py index 2de3b3ba..421c88ac 100644 --- a/python/raf/testing/common.py +++ b/python/raf/testing/common.py @@ -42,13 +42,12 @@ def get_arr_addr(arr): def numpy(x): """Helper function to convert x to numpy""" import torch - import mxnet as mx if isinstance(x, (raf.ndarray, raf._core.value.TensorValue)): return x.numpy() if isinstance(x, torch.Tensor): return x.detach().cpu().numpy() - if isinstance(x, mx.nd.NDArray): + if hasattr(x, "asnumpy"): return x.asnumpy() if np.isscalar(x): return np.array(x) @@ -56,11 +55,17 @@ def numpy(x): return x -def check(m_x, m_y, *, rtol=1e-5, atol=1e-5): +def check(m_x, m_y, *, rtol=1e-5, atol=1e-5, dump_name_when_error=None): """Helper function to check if m_x and m_y are equal""" m_x = numpy(m_x) m_y = numpy(m_y) - np.testing.assert_allclose(m_x, m_y, rtol=rtol, atol=atol) + try: + np.testing.assert_allclose(m_x, m_y, rtol=rtol, atol=atol) + except Exception as err: # pylint: disable=broad-except + if dump_name_when_error is not None: + m_x.tofile(dump_name_when_error + "_x.npy", sep=",") + m_y.tofile(dump_name_when_error + "_y.npy", sep=",") + raise Exception(err) def to_torch_dev(device_str): diff --git a/src/op/grad/binary.cc b/src/op/grad/binary.cc index 31c952cf..cc389de5 100644 --- a/src/op/grad/binary.cc +++ b/src/op/grad/binary.cc @@ -8,6 +8,7 @@ * \brief Declaration of gradients */ #include "./grad_utils.h" +#include "raf/op_utils.h" namespace raf { namespace op { @@ -102,21 +103,28 @@ RAF_OP_GRAD("raf.op.multiply", MulGrad); Array PowGrad(const Expr& orig_call, const Array orig_args, const Var& y, const Expr& dy) { static auto op_power = Op::Get("raf.op.power"); + static auto op_ones_like = Op::Get("raf.op.ones_like"); static auto op_multiply = Op::Get("raf.op.multiply"); static auto op_log = Op::Get("raf.op.log"); - static auto op_divide = Op::Get("raf.op.divide"); + static auto op_subtract = Op::Get("raf.op.subtract"); const CallNode* call = orig_call.as(); CHECK_GE(call->args.size(), 2); - const Expr& x1 = call->args[0]; - const Expr& x2 = call->args[1]; - Call y1 = Call(op_power, {x1, x2}); - Call y2 = Call(op_divide, {y1, x1}); - Call dx1 = Call(op_multiply, {x2, y2}); - Call x1_log = Call(op_log, {x1}); - Call dx2 = Call(op_multiply, {y1, x1_log}); - - return {GetCollapseSumLike(Call(op_multiply, {dy, dx1}), x1), - GetCollapseSumLike(Call(op_multiply, {dy, dx2}), x2)}; + const Expr& x = call->args[0]; + const Expr& a = call->args[1]; + + // dx = a * x^(a-1) + Call ones = Call(op_ones_like, {a}); + Call a_minus_one = Call(op_subtract, {a, ones, MakeNull(), MakeNull()}); + Call x_pow_minus_one = Call(op_power, {x, a_minus_one}); + Call dx = Call(op_multiply, {a, x_pow_minus_one}); + + // da = x^a * log(x) + Call x_pow = Call(op_power, {x, a}); + Call x_log = Call(op_log, {x}); + Call da = Call(op_multiply, {x_pow, x_log}); + + return {GetCollapseSumLike(Call(op_multiply, {dy, dx}), x), + GetCollapseSumLike(Call(op_multiply, {dy, da}), a)}; } RAF_OP_GRAD("raf.op.power", PowGrad); diff --git a/tests/python/op/tvm/test_tvm_binary.py b/tests/python/op/tvm/test_tvm_binary.py index c620ff3c..edacc4e5 100644 --- a/tests/python/op/tvm/test_tvm_binary.py +++ b/tests/python/op/tvm/test_tvm_binary.py @@ -71,7 +71,6 @@ def test_binary_ops_without_grad(ops, shape, dtype, device): [ (torch.mul, raf._op.sym.multiply), (torch.div, raf._op.sym.divide), - (torch.pow, raf._op.sym.power), (torch.add, raf._op.sym.add), (torch.sub, raf._op.sym.subtract), ], @@ -89,6 +88,24 @@ def test_binary_ops_with_grad(ops, shape, dtype, device): verify_op(m_op, [m_x1, m_x2], device, t_y, m_dy, [t_x1.grad, t_x2.grad]) +@pytest.mark.parametrize("device", get_testable_devices()) +@pytest.mark.parametrize("dtype", ["float32"]) +def test_power(dtype, device): + x1 = np.random.randn(2, 2).astype("float32") + x1[0][0] = 0 # Assign 0 to test the corner case. + t_x1 = torch.Tensor(x1).to(device) + t_x1.requires_grad = True + m_x1 = raf.array(x1, device=device) + m_x1.requires_grad = True + + m_x2, t_x2 = randn_torch((), dtype=dtype, device=device, requires_grad=True) + t_y = torch.pow(t_x1, t_x2) + m_dy, t_dy = randn_torch(t_y.shape, dtype=dtype, device=device) + t_y.backward(t_dy) + + verify_op(raf._op.sym.power, [m_x1, m_x2], device, t_y, m_dy, [t_x1.grad, t_x2.grad]) + + # logical_and only allows bool input s @pytest.mark.parametrize("device", get_testable_devices()) @pytest.mark.parametrize( From dfee6f5235876736661ca4d3b3511a559d2df69a Mon Sep 17 00:00:00 2001 From: Zach Zheng Date: Fri, 13 May 2022 09:34:54 -0700 Subject: [PATCH 21/37] [Fix] Path clearing for Razor distributed execution on GPU (#46) * [Fix] Path clearing for Razor distributed execution on GPU * Minor fix * Fix bug when rank not in rank_list * Rename threadlocal to get --- include/raf/communicator.h | 4 +- src/device_api/cuda/cuda.cc | 2 + src/distributed/common/communicator.cc | 10 ++-- src/distributed/cuda/nccl_communicator.cc | 2 + src/op/dialect/nccl/nccl.cc | 56 +++++++++++------------ 5 files changed, 37 insertions(+), 37 deletions(-) diff --git a/include/raf/communicator.h b/include/raf/communicator.h index 62fc49d4..2e1a9253 100644 --- a/include/raf/communicator.h +++ b/include/raf/communicator.h @@ -108,8 +108,8 @@ class CommunicatorPool { } static CommunicatorPool* Get() { - static CommunicatorPool* instance = dmlc::ThreadLocalStore::Get(); - return instance; + static CommunicatorPool instance; + return &instance; } Communicator GetCommunicator(const std::string& name, const Value rank_list) { diff --git a/src/device_api/cuda/cuda.cc b/src/device_api/cuda/cuda.cc index dcbf9581..b727b65a 100644 --- a/src/device_api/cuda/cuda.cc +++ b/src/device_api/cuda/cuda.cc @@ -35,6 +35,7 @@ class CUDADeviceAPI final : public DeviceAPI { } void* AllocMemory(int64_t nbytes, int64_t alignment) override { + CUDA_CALL(cudaSetDevice(device_id_)); void* ptr = nullptr; // TODO(@junrushao1994): make sure it is correct CHECK_EQ(512 % alignment, 0); @@ -49,6 +50,7 @@ class CUDADeviceAPI final : public DeviceAPI { #if CUDA_VERSION >= 11030 void SetDevice(const int dev_id) override { device_id_ = dev_id; + CUDA_CALL(cudaSetDevice(dev_id)); } static cudaMemPool_t GetCUDAMemoryPool(int dev_id) { diff --git a/src/distributed/common/communicator.cc b/src/distributed/common/communicator.cc index fcbf7bbd..abe4775c 100644 --- a/src/distributed/common/communicator.cc +++ b/src/distributed/common/communicator.cc @@ -118,15 +118,15 @@ class GlobalCommunicatorEntry { public: GlobalCommunicatorEntry() = default; - static GlobalCommunicatorEntry* ThreadLocal() { - using TLS = dmlc::ThreadLocalStore; - return TLS::Get(); + static GlobalCommunicatorEntry* Get() { + static GlobalCommunicatorEntry entry; + return &entry; } Communicator comm; }; Communicator GetGlobalCommunicator() { - auto entry = GlobalCommunicatorEntry::ThreadLocal(); + auto entry = GlobalCommunicatorEntry::Get(); if (!entry->comm.defined()) { #ifdef RAF_USE_MPI Communicator comm = Communicator::Get("mpi"); @@ -139,7 +139,7 @@ Communicator GetGlobalCommunicator() { } void SetDefaultCommunicator(std::string name) { - auto entry = GlobalCommunicatorEntry::ThreadLocal(); + auto entry = GlobalCommunicatorEntry::Get(); entry->comm = Communicator::Get(name); } diff --git a/src/distributed/cuda/nccl_communicator.cc b/src/distributed/cuda/nccl_communicator.cc index 123d8d88..62bcaf1d 100644 --- a/src/distributed/cuda/nccl_communicator.cc +++ b/src/distributed/cuda/nccl_communicator.cc @@ -112,6 +112,8 @@ NCCLCommunicator NCCLCommunicator::make(Value rank_list) { } else { // Create Sub-communicator InitSubCommunicator(obj.get(), rank_list, global_comm); + cudaSetDevice(global_comm->local_rank); + obj->parent_comm = global_comm; // sync NCCL id between ranks diff --git a/src/op/dialect/nccl/nccl.cc b/src/op/dialect/nccl/nccl.cc index b01f0c36..ad589656 100644 --- a/src/op/dialect/nccl/nccl.cc +++ b/src/op/dialect/nccl/nccl.cc @@ -13,6 +13,7 @@ #include "raf/op_utils.h" #include "raf/dist_config.h" #include "raf/nccl_communicator.h" +#include "../../../src/common/cuda_utils.h" #include "../../schema/communication.h" #include "./communication_utils.h" @@ -27,16 +28,23 @@ using stream_pool::StreamTagEnum; RAF_REGISTER_DIALECT("nccl").set_enable(DevType::kCUDA()); -class NCCLAllReduce : public raf::op::OpEnv { +class NCCLOpEnv : public raf::op::OpEnv { + protected: void* stream; void* communicator; + explicit NCCLOpEnv(const CallValues& cv) { + CUDA_CALL(cudaSetDevice(cv->device.device_id())); + } +}; + +class NCCLAllReduce : public NCCLOpEnv { void* fused_data; size_t total_size = 0; std::vector tuple_sizes; DType dtype; ncclRedOp_t compute; - explicit NCCLAllReduce(const CallValues& cv) { + explicit NCCLAllReduce(const CallValues& cv) : NCCLOpEnv(cv) { auto op = ir::Op::Get("raf.op._allreduce"); auto fschema_index = ir::Op::GetAttrMap("FRAFSchemaFieldIndex"); auto args = cv->args.as(); @@ -146,10 +154,8 @@ class NCCLAllReduce : public raf::op::OpEnv { RAF_REGISTER_DIALECT_OP(nccl, _allreduce, 10); RAF_OP_ENV_MAKER("raf.op.nccl._allreduce", NCCLAllReduce::make); -class NCCLAllGather : public raf::op::OpEnv { - void* stream; - void* communicator; - explicit NCCLAllGather(const CallValues& cv) { +class NCCLAllGather : public NCCLOpEnv { + explicit NCCLAllGather(const CallValues& cv) : NCCLOpEnv(cv) { auto op = ir::Op::Get("raf.op._allgather"); auto fschema_index = ir::Op::GetAttrMap("FRAFSchemaFieldIndex"); auto args = cv->args.as(); @@ -192,10 +198,8 @@ class NCCLAllGather : public raf::op::OpEnv { RAF_REGISTER_DIALECT_OP(nccl, _allgather, 10); RAF_OP_ENV_MAKER("raf.op.nccl._allgather", NCCLAllGather::make); -class NCCLGroupAllGather : public raf::op::OpEnv { - void* stream; - void* communicator; - explicit NCCLGroupAllGather(const CallValues& cv) { +class NCCLGroupAllGather : public NCCLOpEnv { + explicit NCCLGroupAllGather(const CallValues& cv) : NCCLOpEnv(cv) { auto op = ir::Op::Get("raf.op._group_allgather"); auto fschema_index = ir::Op::GetAttrMap("FRAFSchemaFieldIndex"); this->arg_indices = {fschema_index[op]("tensor_list")}; @@ -246,15 +250,13 @@ class NCCLGroupAllGather : public raf::op::OpEnv { RAF_REGISTER_DIALECT_OP(nccl, _group_allgather, 10); RAF_OP_ENV_MAKER("raf.op.nccl._group_allgather", NCCLGroupAllGather::make); -class NCCLReduceScatter : public raf::op::OpEnv { - void* stream; - void* communicator; +class NCCLReduceScatter : public NCCLOpEnv { void* in_buffer; size_t size_in_bytes; size_t size; ncclRedOp_t compute; - explicit NCCLReduceScatter(const CallValues& cv) { + explicit NCCLReduceScatter(const CallValues& cv) : NCCLOpEnv(cv) { auto op = ir::Op::Get("raf.op._reduce_scatter"); auto fschema_index = ir::Op::GetAttrMap("FRAFSchemaFieldIndex"); this->arg_indices = {fschema_index[op]("x")}; @@ -332,13 +334,11 @@ class NCCLReduceScatter : public raf::op::OpEnv { RAF_REGISTER_DIALECT_OP(nccl, _reduce_scatter, 10); RAF_OP_ENV_MAKER("raf.op.nccl._reduce_scatter", NCCLReduceScatter::make); -class NCCLGroupReduceScatter : public raf::op::OpEnv { - void* stream; - void* communicator; +class NCCLGroupReduceScatter : public NCCLOpEnv { std::vector sizes; ncclRedOp_t compute; - explicit NCCLGroupReduceScatter(const CallValues& cv) { + explicit NCCLGroupReduceScatter(const CallValues& cv) : NCCLOpEnv(cv) { auto op = ir::Op::Get("raf.op._group_reduce_scatter"); auto fschema_index = ir::Op::GetAttrMap("FRAFSchemaFieldIndex"); this->arg_indices = {fschema_index[op]("tensor_list")}; @@ -415,16 +415,14 @@ class NCCLGroupReduceScatter : public raf::op::OpEnv { RAF_REGISTER_DIALECT_OP(nccl, _group_reduce_scatter, 10); RAF_OP_ENV_MAKER("raf.op.nccl._group_reduce_scatter", NCCLGroupReduceScatter::make); -class NCCLBroadcast : public raf::op::OpEnv { - void* stream; - void* communicator; +class NCCLBroadcast : public NCCLOpEnv { void* fused_data; size_t total_size = 0; std::vector tuple_sizes; DType dtype; int root; - explicit NCCLBroadcast(const CallValues& cv) { + explicit NCCLBroadcast(const CallValues& cv) : NCCLOpEnv(cv) { auto op = ir::Op::Get("raf.op._broadcast"); auto fschema_index = ir::Op::GetAttrMap("FRAFSchemaFieldIndex"); this->arg_indices = {fschema_index[op]("x")}; @@ -507,12 +505,10 @@ class NCCLBroadcast : public raf::op::OpEnv { RAF_REGISTER_DIALECT_OP(nccl, _broadcast, 10); RAF_OP_ENV_MAKER("raf.op.nccl._broadcast", NCCLBroadcast::make); -class NCCLSend : public raf::op::OpEnv { - void* stream; - void* communicator; +class NCCLSend : public NCCLOpEnv { int peer; - explicit NCCLSend(const CallValues& cv) { + explicit NCCLSend(const CallValues& cv) : NCCLOpEnv(cv) { auto op = ir::Op::Get("raf.op._send"); auto fschema_index = ir::Op::GetAttrMap("FRAFSchemaFieldIndex"); this->arg_indices = {fschema_index[op]("x")}; @@ -553,14 +549,14 @@ class NCCLSend : public raf::op::OpEnv { RAF_REGISTER_DIALECT_OP(nccl, _send, 10); RAF_OP_ENV_MAKER("raf.op.nccl._send", NCCLSend::make); -class NCCLRecv : public raf::op::OpEnv { +class NCCLRecv : public NCCLOpEnv { void* stream; void* communicator; int peer; std::vector shape; DType dtype; - explicit NCCLRecv(const CallValues& cv) { + explicit NCCLRecv(const CallValues& cv) : NCCLOpEnv(cv) { RequestStream(&stream, cv->device, StreamTagEnum::CudaCommunicate()); RequestDistributed(&communicator, "nccl", NullValue()); const auto* args = cv->args.as(); @@ -598,7 +594,7 @@ class NCCLRecv : public raf::op::OpEnv { RAF_REGISTER_DIALECT_OP(nccl, _recv, 10); RAF_OP_ENV_MAKER("raf.op.nccl._recv", NCCLRecv::make); -class NCCLReduce : public raf::op::OpEnv { +class NCCLReduce : public NCCLOpEnv { void* stream; void* communicator; ncclRedOp_t compute; @@ -608,7 +604,7 @@ class NCCLReduce : public raf::op::OpEnv { std::vector tuple_sizes; void* fused_data; - explicit NCCLReduce(const CallValues& cv) { + explicit NCCLReduce(const CallValues& cv) : NCCLOpEnv(cv) { auto op = ir::Op::Get("raf.op._reduce"); auto fschema_index = ir::Op::GetAttrMap("FRAFSchemaFieldIndex"); this->arg_indices = {fschema_index[op]("x")}; From a3716b43314721dcd1c8667d95863f279dd71269 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 13 May 2022 21:44:23 -0700 Subject: [PATCH 22/37] [Flaky] Correct expected memory in estimate_memory test (#48) --- tests/python/pass/test_pass_estimate_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/pass/test_pass_estimate_memory.py b/tests/python/pass/test_pass_estimate_memory.py index 228efd92..7ec03f45 100644 --- a/tests/python/pass/test_pass_estimate_memory.py +++ b/tests/python/pass/test_pass_estimate_memory.py @@ -71,7 +71,7 @@ def get_mod(): # The memory at Conv2D should be 1 MB+workspace, but the workspace should be # freed afterward, so the following ReLU should only have 2 MBs. - verify_memory(get_mod(), "cuda", [(1, 2), 2, 1], True) + verify_memory(get_mod(), "cuda", [(1, float("inf")), 2, 1], True) if __name__ == "__main__": From 4453c32eb9a4a61f971f3fb8d6d626bd0f727c56 Mon Sep 17 00:00:00 2001 From: AIREMetaBot <100344401+aire-meta-bot@users.noreply.github.com> Date: Tue, 17 May 2022 02:45:06 +0800 Subject: [PATCH 23/37] [TVM] Update Submodule (#49) Co-authored-by: SubmoduleUpdaterBot --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index b6b0bafd..02d57bbc 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit b6b0bafdef15bb5491c38770668ddf73ddd02af2 +Subproject commit 02d57bbc062cf9bd47c03d4355ccd660ed68091a From 522be9484a7c81fca57ae9c60ffda22090c9dafc Mon Sep 17 00:00:00 2001 From: Yuan Zhou Date: Mon, 16 May 2022 13:19:40 -0700 Subject: [PATCH 24/37] [Refactor] Refactor ANF partition pass to extract some common utilities. (#50) --- src/pass/anf_partition.cc | 194 +----------------------------------- src/pass/partition_utils.cc | 145 +++++++++++++++++++++++++++ src/pass/partition_utils.h | 122 +++++++++++++++++++++++ 3 files changed, 269 insertions(+), 192 deletions(-) create mode 100644 src/pass/partition_utils.cc create mode 100644 src/pass/partition_utils.h diff --git a/src/pass/anf_partition.cc b/src/pass/anf_partition.cc index 736a2bd7..fdb94ec8 100644 --- a/src/pass/anf_partition.cc +++ b/src/pass/anf_partition.cc @@ -45,6 +45,7 @@ #include #include "./common.h" #include "./liveness_analysis.h" +#include "./partition_utils.h" namespace raf { namespace pass { @@ -57,137 +58,6 @@ using binding::BindingEntry; using binding::BindNDArray; using binding::LookupBinding; -class PartitionFunction { - public: - explicit PartitionFunction(std::string name) : func_name_(std::move(name)) { - } - - /*! - * \brief Add a non-annotated Expr to the Function, - * find out the inputs and outputs of this partition function. - * \param var The Var that the expr bind to. - * \param expr The Expr to be pushed. - */ - void Push(Var var, Expr expr) { - // Push the inputs and outputs into ins_ and outs_. - outs_.push_back(var); - // Push the Expr into ell_. - ell_.vars.push_back(var); - ell_.exprs.push_back(expr); - } - - /*! - * \brief Remove the outputs which won't be used later. - * \param analyzer LinvenessAnalyzer which has been run on the original program. - * \param next_var The next variable follows the current PartitionFunction. - * \param is_final_ret whether next_var is the final return output of the origianl program. - */ - void TrimOutputs(liveness_analysis::LivenessAnalyzer& analyzer, const Var& next_var, - bool is_final_ret) { - if (is_final_ret) { - outs_ = {next_var}; - } else if (analyzer.IsSuccess()) { - std::unordered_set live_var = - analyzer.GetLiveVars(next_var); - std::vector trimmed_outs; - for (auto out_expr : outs_) { - Var out = Downcast(out_expr); - if (analyzer.IsAlive(out, live_var)) { - trimmed_outs.push_back(out); - } - } - outs_ = trimmed_outs; - } - } - - /*! - * \brief Export the partition functions into ExplicitLetList. - * \param part_func_vars The map from old vars into func_named vars - * \return The ExplicitLetList with partition function packed. - */ - void ExportTo(ExplicitLetList* ret_ell, Map& intermediate_var_2_func_out) { - // Because anf will auto-capture the global vars, we don't need to push ins_ into params. - // If the Var inside ins_ is inside the outs_, which indicate that this input is given by - // the expr inside this function. Then replace the usage of old vars with func_named vars. - Array params_array; - - // Surround the Values in the outs_ with a TupleNode. And replace the old - // vars with part_func named vars. - CHECK_GT(outs_.size(), 0); - if (outs_.size() > 1) { - Tuple outs_tuple = Tuple(outs_); - std::string outs_var_name = func_name_ + "_outs"; - Var outs_var = MakeVar(outs_var_name, {}); - ell_.vars.push_back(outs_var); - ell_.exprs.push_back(outs_tuple); - ell_.ret = outs_var; - } else { - ell_.ret = Downcast(outs_[0]); - } - // Assemble the partition function. - Expr body = ell_.AsExpr(); - // replace the usage of intermediate tensors in previous function - // to be their function outputs. - VarSubstitutor substitutor(intermediate_var_2_func_out); - body = substitutor.Substitute(body); - // it's a closure with 0 params and outputs live intermediate variables. - /* - func() { - let %a0 = ... - let %a1 = ... - let %a2 = ... - let %func_outs = (%a0, %a1, %a2) - %func_outs - } - */ - auto func = Function({}, body, {}, {}); - - // Insert the CallNode for the function - // and TupleGetItemNode to get the outputs from the function. - Var func_var = MakeVar(func_name_, {}); - ret_ell->vars.push_back(func_var); - ret_ell->exprs.push_back(func); - - // Call the partition function - auto func_call = Call(func_var, {}, Attrs()); - std::string ret_var_name = func_var->name_hint() + "_ret"; - Var ret_var = MakeVar(ret_var_name, {}); - // let %func = func() {} - // let %func_ret = Call(%func, {}) - ret_ell->vars.push_back(ret_var); - ret_ell->exprs.push_back(func_call); - if (outs_.size() > 1) { - // get the outputs TupleNode, - // let %func_0 = %func_ret.0 - // let %func_1 = %func_ret.1 - // let %func_2 = %func_ret.2 - for (size_t i = 0; i < outs_.size(); ++i) { - int index = i; - String var_name = String(func_name_ + "_ret_" + std::to_string(index)); - TupleGetItem tgi = TupleGetItem(ret_var, index, {}); - Var tgi_var = MakeVar(var_name, {}); - ret_ell->vars.push_back(tgi_var); - ret_ell->exprs.push_back(tgi); - // placeholder ret - CHECK(!intermediate_var_2_func_out.count(Downcast(outs_[i]))) - << "Duplicated output " << outs_[i]; - intermediate_var_2_func_out.Set(Downcast(outs_[i]), tgi_var); - } - } else { - CHECK(!intermediate_var_2_func_out.count(Downcast(outs_[0]))) - << "Duplicated output " << outs_[0]; - intermediate_var_2_func_out.Set(Downcast(outs_[0]), ret_var); - } - } - - /*! \brief The function name of the partition function. */ - std::string func_name_; - /*! \brief The LetNodes to construct the partition function. */ - ExplicitLetList ell_; - /*! \brief The outputs of this partition function. */ - std::vector outs_; -}; - class Partitioner final : public ExprMutator { public: explicit Partitioner(int max_num_ops, Var boundary, liveness_analysis::LivenessAnalyzer& analyzer) @@ -258,72 +128,12 @@ class Partitioner final : public ExprMutator { liveness_analysis::LivenessAnalyzer& analyzer_; }; -/** - * \brief Find the boundary of the partitions, so to keep the rest IR unchanged. - * It tries not to include tuples/tgis in the last partitioned function, - * this helps to avoid an output of func1 being captured by func2 for nothing - * but just to return as a final result, for example, - * func(...) { - * out1 = ...; - * out2 = ...; // <- this is the boundary. - * final_res = (out1, out2); - * return final_res; - * } - * - * If we partition everything, the previous program might end-up as, - * func(...) { - * func1() { - * return out1; - * } - * out1 = func1(); - * func2() { - * out2 = ... - * return (out1, out2); // out1 is captured for nothing but serving as a final output. - * } - * final_res = func2(); - * return final_res; - * } - * - * Instead it is much easier to do analysis later on the following program, - * func(...) { - * func1() { - * return out1; - * } - * out1 = func1(); - * func2() { - * return out2; - * } - * out2 = func2(); - * final_res = (out1, out2); - * return final_res; - * } - * - * \param f the original un-partitioned function. - * \return the boundary variable defines the last partition. - */ -Var GetPartitionBoundary(Function f) { - auto ell = ExplicitLetList::make(f->body); - std::unordered_set ret_vars{ell->ret.get()}; - for (int i = ell->vars.size() - 1; i >= 0; --i) { - Var var = ell->vars[i]; - Expr value = ell->exprs[i]; - // skip - // let %var = (..., ...) or - // let %var = tuple.0 - // etc. - if (!value.as() && !value.as()) { - return var; - } - } - return ell->ret; -} - } // namespace anf_partition Pass PartitionANF(int max_num_ops) { TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - Var boundary = anf_partition::GetPartitionBoundary(f); + Var boundary = pass::GetPartitionBoundary(f); auto analyzer = liveness_analysis::LivenessAnalyzer(f); analyzer.Run(); anf_partition::Partitioner partitioner(max_num_ops, boundary, analyzer); diff --git a/src/pass/partition_utils.cc b/src/pass/partition_utils.cc new file mode 100644 index 00000000..b143cc85 --- /dev/null +++ b/src/pass/partition_utils.cc @@ -0,0 +1,145 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +/*! + * \file partition_utils.cc + * \brief Utility functions for partitioning ANF into multiple functions. An + * example can be found in the comments of anf_partition.cc. + */ + +#include "./partition_utils.h" + +namespace raf { +namespace pass { + +/** + * \brief Find the boundary of the partitions, so to keep the rest IR unchanged. + * \param f the original un-partitioned function. + * \return the boundary variable defines the last partition. + */ +Var GetPartitionBoundary(Function f) { + auto ell = ExplicitLetList::make(f->body); + std::unordered_set ret_vars{ell->ret.get()}; + for (int i = ell->vars.size() - 1; i >= 0; --i) { + Var var = ell->vars[i]; + Expr value = ell->exprs[i]; + // skip + // let %var = (..., ...) or + // let %var = tuple.0 + // etc. + if (!value.as() && !value.as()) { + return var; + } + } + return ell->ret; +} + +/*! + * \brief Remove the outputs which won't be used later. + * \param analyzer LinvenessAnalyzer which has been run on the original program. + * \param next_var The next variable follows the current PartitionFunction. + * \param is_final_ret whether next_var is the final return output of the origianl program. + */ +void PartitionFunction::TrimOutputs(liveness_analysis::LivenessAnalyzer& analyzer, + const Var& next_var, bool is_final_ret) { + if (is_final_ret) { + outs_ = {next_var}; + } else if (analyzer.IsSuccess()) { + std::unordered_set live_var = + analyzer.GetLiveVars(next_var); + std::vector trimmed_outs; + for (auto out_expr : outs_) { + Var out = Downcast(out_expr); + if (analyzer.IsAlive(out, live_var)) { + trimmed_outs.push_back(out); + } + } + outs_ = trimmed_outs; + } +} + +/*! + * \brief Export the partition functions into ExplicitLetList. + * \param part_func_vars The map from old vars into func_named vars + * \return The ExplicitLetList with partition function packed. + */ +void PartitionFunction::ExportTo(ExplicitLetList* ret_ell, + Map& intermediate_var_2_func_out) { + // Because anf will auto-capture the global vars, we don't need to push ins_ into params. + // If the Var inside ins_ is inside the outs_, which indicate that this input is given by + // the expr inside this function. Then replace the usage of old vars with func_named vars. + Array params_array; + + // Surround the Values in the outs_ with a TupleNode. And replace the old + // vars with part_func named vars. + CHECK_GT(outs_.size(), 0); + if (outs_.size() > 1) { + Tuple outs_tuple = Tuple(outs_); + std::string outs_var_name = func_name_ + "_outs"; + Var outs_var = MakeVar(outs_var_name, {}); + ell_.vars.push_back(outs_var); + ell_.exprs.push_back(outs_tuple); + ell_.ret = outs_var; + } else { + ell_.ret = Downcast(outs_[0]); + } + // Assemble the partition function. + Expr body = ell_.AsExpr(); + // replace the usage of intermediate tensors in previous function + // to be their function outputs. + VarSubstitutor substitutor(intermediate_var_2_func_out); + body = substitutor.Substitute(body); + // it's a closure with 0 params and outputs live intermediate variables. + /* + func() { + let %a0 = ... + let %a1 = ... + let %a2 = ... + let %func_outs = (%a0, %a1, %a2) + %func_outs + } + */ + auto func = Function({}, body, {}, {}); + + // Insert the CallNode for the function + // and TupleGetItemNode to get the outputs from the function. + Var func_var = MakeVar(func_name_, {}); + ret_ell->vars.push_back(func_var); + ret_ell->exprs.push_back(func); + + // Call the partition function + auto func_call = Call(func_var, {}, Attrs()); + std::string ret_var_name = func_var->name_hint() + "_ret"; + Var ret_var = MakeVar(ret_var_name, {}); + // let %func = func() {} + // let %func_ret = Call(%func, {}) + ret_ell->vars.push_back(ret_var); + ret_ell->exprs.push_back(func_call); + if (outs_.size() > 1) { + // get the outputs TupleNode, + // let %func_0 = %func_ret.0 + // let %func_1 = %func_ret.1 + // let %func_2 = %func_ret.2 + for (size_t i = 0; i < outs_.size(); ++i) { + int index = i; + String var_name = String(func_name_ + "_ret_" + std::to_string(index)); + TupleGetItem tgi = TupleGetItem(ret_var, index, {}); + Var tgi_var = MakeVar(var_name, {}); + ret_ell->vars.push_back(tgi_var); + ret_ell->exprs.push_back(tgi); + // placeholder ret + CHECK(!intermediate_var_2_func_out.count(Downcast(outs_[i]))) + << "Duplicated output " << outs_[i]; + intermediate_var_2_func_out.Set(Downcast(outs_[i]), tgi_var); + } + } else { + CHECK(!intermediate_var_2_func_out.count(Downcast(outs_[0]))) + << "Duplicated output " << outs_[0]; + intermediate_var_2_func_out.Set(Downcast(outs_[0]), ret_var); + } +} + +} // namespace pass +} // namespace raf diff --git a/src/pass/partition_utils.h b/src/pass/partition_utils.h new file mode 100644 index 00000000..a54ea657 --- /dev/null +++ b/src/pass/partition_utils.h @@ -0,0 +1,122 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +/*! + * \file partition_utils.h + * \brief Utility functions for partitioning ANF into multiple functions. An + * example can be found in the comments of anf_partition.cc. + */ + +#include "raf/op.h" +#include "raf/ir.h" +#include "raf/binding.h" +#include "raf/pass.h" +#include "raf/ir_ext.h" +#include +#include +#include +#include "./common.h" +#include "./liveness_analysis.h" + +namespace raf { +namespace pass { + +using namespace raf::ir; +using namespace raf::op; +using namespace raf::binding; +using binding::BindingEntry; +using binding::BindNDArray; +using binding::LookupBinding; + +class PartitionFunction { + public: + explicit PartitionFunction(std::string name) : func_name_(std::move(name)) { + } + + /*! + * \brief Add a non-annotated Expr to the Function, + * find out the inputs and outputs of this partition function. + * \param var The Var that the expr bind to. + * \param expr The Expr to be pushed. + */ + void Push(Var var, Expr expr) { + // Push the inputs and outputs into ins_ and outs_. + outs_.push_back(var); + // Push the Expr into ell_. + ell_.vars.push_back(var); + ell_.exprs.push_back(expr); + } + + /*! + * \brief Remove the outputs which won't be used later. + * \param analyzer LinvenessAnalyzer which has been run on the original program. + * \param next_var The next variable follows the current PartitionFunction. + * \param is_final_ret whether next_var is the final return output of the origianl program. + */ + void TrimOutputs(liveness_analysis::LivenessAnalyzer& analyzer, const Var& next_var, + bool is_final_ret); + + /*! + * \brief Export the partition functions into ExplicitLetList. + * \param part_func_vars The map from old vars into func_named vars + * \return The ExplicitLetList with partition function packed. + */ + void ExportTo(ExplicitLetList* ret_ell, Map& intermediate_var_2_func_out); + + /*! \brief The function name of the partition function. */ + std::string func_name_; + /*! \brief The LetNodes to construct the partition function. */ + ExplicitLetList ell_; + /*! \brief The outputs of this partition function. */ + std::vector outs_; +}; + +/** + * \brief Find the boundary of the partitions, so to keep the rest IR unchanged. + * It tries not to include tuples/tgis in the last partitioned function, + * this helps to avoid an output of func1 being captured by func2 for nothing + * but just to return as a final result, for example, + * func(...) { + * out1 = ...; + * out2 = ...; // <- this is the boundary. + * final_res = (out1, out2); + * return final_res; + * } + * + * If we partition everything, the previous program might end-up as, + * func(...) { + * func1() { + * return out1; + * } + * out1 = func1(); + * func2() { + * out2 = ... + * return (out1, out2); // out1 is captured for nothing but serving as a final output. + * } + * final_res = func2(); + * return final_res; + * } + * + * Instead it is much easier to do analysis later on the following program, + * func(...) { + * func1() { + * return out1; + * } + * out1 = func1(); + * func2() { + * return out2; + * } + * out2 = func2(); + * final_res = (out1, out2); + * return final_res; + * } + * + * \param f the original un-partitioned function. + * \return the boundary variable defines the last partition. + */ +Var GetPartitionBoundary(Function f); + +} // namespace pass +} // namespace raf From 6c77327dee9665a0a844637d30f2bd709cfbbcff Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 17 May 2022 12:49:18 -0700 Subject: [PATCH 25/37] [Remat] Fix reshape param (#51) * [Remat] Fix reshape param * fix --- docs/wiki/2_user_guide/Train-Model.md | 4 ++-- docs/wiki/2_user_guide/Train-PyTorch-Model.md | 4 ++-- python/raf/testing/common.py | 18 ++++++++++-------- src/pass/rematerialization.cc | 10 +++++++--- tests/python/distributed/test_data_parallel.py | 2 +- tests/python/frontend/test_mxnet_model.py | 4 ++-- tests/python/frontend/test_pytorch.py | 6 +++--- tests/python/model/test_model_lenet.py | 2 +- tests/python/model/test_resnet_50_imagenet.py | 2 +- tests/python/optim/test_lans.py | 6 +++--- tests/python/optim/test_sgd.py | 6 +++--- tests/python/pass/test_pass_group_allgather.py | 2 +- .../pass/test_pass_partition_gradient.py | 2 +- 13 files changed, 37 insertions(+), 31 deletions(-) diff --git a/docs/wiki/2_user_guide/Train-Model.md b/docs/wiki/2_user_guide/Train-Model.md index 312f3af6..6ae40e71 100644 --- a/docs/wiki/2_user_guide/Train-Model.md +++ b/docs/wiki/2_user_guide/Train-Model.md @@ -134,7 +134,7 @@ executor = None for _ in range(num_epoch): # Prepare input data, use random data as example here r_x, _ = randn_torch(input_shape, device=device, dtype="float32") - r_ytrue, _ = one_hot_torch(batch_size=batch_size, num_classes=10, device=device) + r_ytrue, _ = one_hot_torch(size=batch_size, num_classes=10, device=device) args = [dy, r_x, r_ytrue] # Initialize the VM at the first iteration. @@ -158,7 +158,7 @@ batch_size = 8 dy, _ = randn_torch((), std=0.0, mean=1.0, requires_grad=False, device=device) # dy = tensor(1.0) r_x, _ = randn_torch(input_shape, device=device, dtype="float32") -r_ytrue, _ = one_hot_torch(batch_size=batch_size, num_classes=10, device=device) +r_ytrue, _ = one_hot_torch(size=batch_size, num_classes=10, device=device) args = [dy, r_x, r_ytrue] ret = optimizer(*args) diff --git a/docs/wiki/2_user_guide/Train-PyTorch-Model.md b/docs/wiki/2_user_guide/Train-PyTorch-Model.md index 84ef4f29..b25105db 100644 --- a/docs/wiki/2_user_guide/Train-PyTorch-Model.md +++ b/docs/wiki/2_user_guide/Train-PyTorch-Model.md @@ -75,7 +75,7 @@ from raf._op import sym # prepare random data, they just provides shape and dtype info r_x, _ = randn_torch(input_shape, device=device, dtype=dtype) -r_ytrue, _ = one_hot_torch(batch_size=batch_size, num_classes=10, device=device) +r_ytrue, _ = one_hot_torch(size=batch_size, num_classes=10, device=device) out = r_model.record(r_x) y_pred = sym.log_softmax(out) @@ -130,7 +130,7 @@ executor = get_vm_executor(record.mod, device) for _ in range(num_epoch): # prepare input data, use random data as example here r_x, _ = randn_torch(input_shape, device=device, dtype=dtype) - r_ytrue, _ = one_hot_torch(batch_size=batch_size, num_classes=10, device=device) + r_ytrue, _ = one_hot_torch(size=batch_size, num_classes=10, device=device) args = [dy, r_x, r_ytrue] ret = run_vm_executor(executor, record, args, device) loss = ret[0] # ret[0][0] for some models diff --git a/python/raf/testing/common.py b/python/raf/testing/common.py index 421c88ac..b87607fb 100644 --- a/python/raf/testing/common.py +++ b/python/raf/testing/common.py @@ -143,29 +143,31 @@ def randn_mxnet( return m_x, mx_x -def one_hot_torch(batch_size, num_classes, device="cpu"): +def one_hot_torch(size, num_classes, device="cpu"): """Helper function to generate one hot tensors in raf and torch""" import torch - targets = np.random.randint(0, num_classes, size=batch_size) + size = tuple(size) if isinstance(size, (list, tuple)) else (size,) + targets = np.random.randint(0, num_classes, size=size) m_x = raf.array(targets, device=device) t_x = torch.tensor( targets, requires_grad=False, device=to_torch_dev(device) ) # pylint: disable=not-callable - assert list(m_x.shape) == [batch_size] - assert list(t_x.shape) == [batch_size] + assert tuple(m_x.shape) == size + assert tuple(t_x.shape) == size return m_x, t_x -def one_hot_mxnet(batch_size, num_classes, device="cpu"): +def one_hot_mxnet(size, num_classes, device="cpu"): """Helper function to generate one hot tensors in raf and mxnet""" import mxnet as mx - targets = np.random.randint(0, num_classes, size=batch_size) + size = tuple(size) if isinstance(size, (list, tuple)) else (size,) + targets = np.random.randint(0, num_classes, size=size) raf_x = raf.array(targets, device=device) mx_x = mx.nd.array(targets, ctx=mx.cpu()) # pylint: disable=not-callable - assert list(raf_x.shape) == [batch_size] - assert list(mx_x.shape) == [batch_size] + assert tuple(raf_x.shape) == size + assert tuple(mx_x.shape) == size return raf_x, mx_x diff --git a/src/pass/rematerialization.cc b/src/pass/rematerialization.cc index f9e7d1f5..9fa631f8 100644 --- a/src/pass/rematerialization.cc +++ b/src/pass/rematerialization.cc @@ -667,8 +667,12 @@ class Rematerializer : public ExprMutator { auto tensor_infos = tensor_infos_.GetTensorInfoFromLetVar(var); auto latest_let_var = tensor_infos[0]->let_var; + if (let_vars_.count(latest_let_var) == 0) { + // The representative let-var may has no binded expression (i.e., parameter). + return var; + } auto call_node = let_vars_[latest_let_var].as(); - CHECK(call_node != nullptr) << "Tensor " << latest_let_var + CHECK(call_node != nullptr) << "Tensor " << latest_let_var << " with alias " << var << " is not a parameter nor generated by a call node: " << raf::ir::AsText(let_vars_[latest_let_var]); @@ -933,8 +937,8 @@ class Rematerializer::TensorAnalyzer : public ExprVisitor { } } - if (op.defined() && IsNonDeterministicOp(op)) { - // Non-deterministic ops cannot be recomputed + if (op.defined() && IsNonDeterministicOp(op) && IsCollectiveOp(op)) { + // Non-deterministic and collective ops cannot be recomputed compute_cost = std::numeric_limits::max(); } else if (profiler_) { // Try to profile the op diff --git a/tests/python/distributed/test_data_parallel.py b/tests/python/distributed/test_data_parallel.py index 3f2990ea..177f8bd4 100644 --- a/tests/python/distributed/test_data_parallel.py +++ b/tests/python/distributed/test_data_parallel.py @@ -49,7 +49,7 @@ def run_model(device): m_model.train_mode() m_x, _ = randn([4, 3, 28, 28], device=device, requires_grad=True) - m_y, _ = one_hot_torch(batch_size=4, num_classes=10, device=device) + m_y, _ = one_hot_torch(size=4, num_classes=10, device=device) m_dy, _ = randn((), device=device) model_train = raf.optim.sgd.with_sgd()(m_model) diff --git a/tests/python/frontend/test_mxnet_model.py b/tests/python/frontend/test_mxnet_model.py index b586df2f..cf7da6c9 100644 --- a/tests/python/frontend/test_mxnet_model.py +++ b/tests/python/frontend/test_mxnet_model.py @@ -36,7 +36,7 @@ def test_backward_check(device, mx_model): ) mx_model[1].hybridize(static_alloc=True, static_shape=True) x, mx_x = randn_mxnet((5, 3, 224, 224), requires_grad=True, device=device) - m_ytrue, mx_ytrue = one_hot_mxnet(batch_size=5, num_classes=1000, device=device) + m_ytrue, mx_ytrue = one_hot_mxnet(size=5, num_classes=1000, device=device) raf_model = raf.frontend.from_mxnet(mx_model[1], ["x"]) out = raf_model.record(x) @@ -77,7 +77,7 @@ def test_forward_check(device, mx_model): mx_model[1].hybridize(static_alloc=True, static_shape=True) x, mx_x = randn_mxnet((5, 3, 224, 224), requires_grad=True, device=device) - m_ytrue, _ = one_hot_mxnet(batch_size=5, num_classes=1000, device=device) + m_ytrue, _ = one_hot_mxnet(size=5, num_classes=1000, device=device) raf_model = raf.frontend.from_mxnet(mx_model[1], ["x"]) raf_model.train_mode() diff --git a/tests/python/frontend/test_pytorch.py b/tests/python/frontend/test_pytorch.py index efe1b3d0..6b1e8d00 100644 --- a/tests/python/frontend/test_pytorch.py +++ b/tests/python/frontend/test_pytorch.py @@ -73,7 +73,7 @@ def test_lenet(shape_dict, mode): check(m_y, t_y, rtol=tol, atol=tol) return - m_ytrue, t_ytrue = one_hot_torch(batch_size=batch_size, num_classes=10, device=device) + m_ytrue, t_ytrue = one_hot_torch(size=batch_size, num_classes=10, device=device) m_dy, t_dy = randn_torch((), std=0.0, mean=1.0, device=device, requires_grad=False, dtype=dtype) # append loss function @@ -165,7 +165,7 @@ def test_conv_bn(shape_dict, mode, fuse): check(m_y, t_y, rtol=1e-4, atol=1e-4) return - m_ytrue, t_ytrue = one_hot_torch(batch_size=batch_size, num_classes=6 * 24 * 24, device=device) + m_ytrue, t_ytrue = one_hot_torch(size=batch_size, num_classes=6 * 24 * 24, device=device) m_dy, t_dy = randn_torch((), std=0.0, mean=1.0, device=device, requires_grad=False) # append loss function @@ -285,7 +285,7 @@ def test_mm_dropout(shape_dict, p, device, mode): check(m_y, t_y, rtol=1e-4, atol=1e-4) return - m_ytrue, t_ytrue = one_hot_torch(batch_size=batch_size, num_classes=30, device=device) + m_ytrue, t_ytrue = one_hot_torch(size=batch_size, num_classes=30, device=device) # append loss function out = m_model.record(m_x) diff --git a/tests/python/model/test_model_lenet.py b/tests/python/model/test_model_lenet.py index a93dab76..73f0e7f4 100644 --- a/tests/python/model/test_model_lenet.py +++ b/tests/python/model/test_model_lenet.py @@ -98,7 +98,7 @@ def test_lenet(config): m_model.linear3.w = t2m_param(t_model.linear3.weight) m_model.linear3.b = t2m_param(t_model.linear3.bias) m_x, t_x = randn_torch([1, 3, config[0], config[0]], requires_grad=True, device="cuda") - m_y, t_y = one_hot_torch(batch_size=1, num_classes=config[1], device="cuda") + m_y, t_y = one_hot_torch(size=1, num_classes=config[1], device="cuda") print("### Switch to training mode") m_model.train_mode() diff --git a/tests/python/model/test_resnet_50_imagenet.py b/tests/python/model/test_resnet_50_imagenet.py index 85d84594..4e54383d 100644 --- a/tests/python/model/test_resnet_50_imagenet.py +++ b/tests/python/model/test_resnet_50_imagenet.py @@ -510,7 +510,7 @@ def test_r50_v1_imagenet(): # pylint: disable=too-many-statements m_x, t_x = randn_torch( [1, 3, 224, 224], requires_grad=True, device="cuda" ) # pylint: disable=unused-variable - m_y, t_y = one_hot_torch(batch_size=1, num_classes=1000, device="cuda") + m_y, t_y = one_hot_torch(size=1, num_classes=1000, device="cuda") m_x.requires_grad = True m_model.train_mode() t_model.train() diff --git a/tests/python/optim/test_lans.py b/tests/python/optim/test_lans.py index 91699b67..af3c2003 100644 --- a/tests/python/optim/test_lans.py +++ b/tests/python/optim/test_lans.py @@ -117,7 +117,7 @@ def test_lans(config): for i in range(batch_size): t_optimizer.zero_grad() m_x, t_x = randn_torch([1, 3, config[1], config[1]], requires_grad=True, device="cuda") - m_y, t_y = one_hot_torch(batch_size=1, num_classes=config[2], device="cuda") + m_y, t_y = one_hot_torch(size=1, num_classes=config[2], device="cuda") m_loss = m_model(m_x, m_y) t_loss = t_model(t_x, t_y) m_loss.backward() @@ -192,7 +192,7 @@ def test_traced_lans(config): for i in range(batch_size): m_dy, t_dy = randn_torch((), std=0.0, mean=1.0, device=device, requires_grad=False) m_x, t_x = randn_torch([1, 3, config[1], config[1]], requires_grad=True, device=device) - m_y, t_y = one_hot_torch(batch_size=1, num_classes=config[2], device=device) + m_y, t_y = one_hot_torch(size=1, num_classes=config[2], device=device) m_loss = run_vm_model(m_optimizer, device, [m_dy, m_x, m_y]) t_optimizer.zero_grad() t_loss = t_model(t_x, t_y) @@ -238,7 +238,7 @@ def __init__(self): device = "cuda" m_x, _ = randn_torch([batch_size, 3, shape, shape], requires_grad=True, device=device) m_dy, _ = randn_torch((), std=0.0, mean=1.0, device=device, requires_grad=False) - m_ytrue, _ = one_hot_torch(batch_size=batch_size, num_classes=n_classes, device=device) + m_ytrue, _ = one_hot_torch(size=batch_size, num_classes=n_classes, device=device) args = [m_dy, m_x, m_ytrue] record = m_optimizer._internal(*args) diff --git a/tests/python/optim/test_sgd.py b/tests/python/optim/test_sgd.py index eb160ada..f1e383fb 100644 --- a/tests/python/optim/test_sgd.py +++ b/tests/python/optim/test_sgd.py @@ -104,7 +104,7 @@ def test_sgd(config): for i in range(batch_size): t_optimizer.zero_grad() m_x, t_x = randn_torch([1, 3, config[1], config[1]], requires_grad=True, device="cuda") - m_y, t_y = one_hot_torch(batch_size=1, num_classes=config[2], device="cuda") + m_y, t_y = one_hot_torch(size=1, num_classes=config[2], device="cuda") m_loss = m_model(m_x, m_y) t_loss = t_model(t_x, t_y) m_loss.backward() @@ -197,7 +197,7 @@ def test_traced_sgd(config): for i in range(batch_size): m_dy, t_dy = randn_torch((), std=0.0, mean=1.0, device=device, requires_grad=False) m_x, t_x = randn_torch([1, 3, config[1], config[1]], requires_grad=True, device=device) - m_y, t_y = one_hot_torch(batch_size=1, num_classes=config[2], device=device) + m_y, t_y = one_hot_torch(size=1, num_classes=config[2], device=device) m_loss = run_vm_model(m_optimizer, device, [m_dy, m_x, m_y]) t_optimizer.zero_grad() t_loss = t_model(t_x, t_y) @@ -273,7 +273,7 @@ def __init__(self): device = "cuda" m_x, _ = randn_torch([batch_size, 3, shape, shape], requires_grad=True, device=device) m_dy, _ = randn_torch((), std=0.0, mean=1.0, device=device, requires_grad=False) - m_ytrue, _ = one_hot_torch(batch_size=batch_size, num_classes=n_classes, device=device) + m_ytrue, _ = one_hot_torch(size=batch_size, num_classes=n_classes, device=device) args = [m_dy, m_x, m_ytrue] record = m_optimizer._internal(*args) diff --git a/tests/python/pass/test_pass_group_allgather.py b/tests/python/pass/test_pass_group_allgather.py index 115692db..bb532d04 100644 --- a/tests/python/pass/test_pass_group_allgather.py +++ b/tests/python/pass/test_pass_group_allgather.py @@ -81,7 +81,7 @@ def __init__(self): device = "cuda" m_x, _ = randn([batch_size, 3, shape, shape], requires_grad=True, device=device) m_dy, _ = randn((), device=device, requires_grad=False) - m_ytrue, _ = one_hot_torch(batch_size=batch_size, num_classes=n_classes, device=device) + m_ytrue, _ = one_hot_torch(size=batch_size, num_classes=n_classes, device=device) args = [m_dy, m_x, m_ytrue] record = m_optimizer._internal(*args) diff --git a/tests/python/pass/test_pass_partition_gradient.py b/tests/python/pass/test_pass_partition_gradient.py index 36904ebc..75120770 100644 --- a/tests/python/pass/test_pass_partition_gradient.py +++ b/tests/python/pass/test_pass_partition_gradient.py @@ -143,7 +143,7 @@ def __init__(self): ad_model = with_autodiff(model) m_x, _ = randn((batch, 3, 28, 28), dtype="float32") m_dy, _ = randn((), dtype="float32") - m_ytrue, _ = one_hot_torch(batch_size=batch, num_classes=10) + m_ytrue, _ = one_hot_torch(size=batch, num_classes=10) if batch == 8: # The gradient of conv2d_dx and nll_loss is dividable so no padding is need. From 1a71ac7dc47ddf168c46ba9691bdb3cd482533e5 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 18 May 2022 13:54:29 -0700 Subject: [PATCH 26/37] [Misc] Scope Timer (#54) * [Misc] Scope Timer * fix * fix * reset --- CMakeLists.txt | 1 + include/raf/profiler.h | 33 ----------- include/raf/scope_timer.h | 107 ++++++++++++++++++++++++++++++++++++ src/profiler/scope_timer.cc | 25 +++++++++ 4 files changed, 133 insertions(+), 33 deletions(-) create mode 100644 include/raf/scope_timer.h create mode 100644 src/profiler/scope_timer.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 2f39531f..a1edc319 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,6 +106,7 @@ file(GLOB_RECURSE RAF_CXX_SOURCE_FILES ${CMAKE_CURRENT_LIST_DIR}/src/impl/*.cc ${CMAKE_CURRENT_LIST_DIR}/src/profiler/memory_profiler.cc ${CMAKE_CURRENT_LIST_DIR}/src/profiler/op_profiler.cc + ${CMAKE_CURRENT_LIST_DIR}/src/profiler/scope_timer.cc ${CMAKE_CURRENT_LIST_DIR}/src/profiler/base/*.cc ${CMAKE_CURRENT_LIST_DIR}/src/distributed/common/*.cc ) diff --git a/include/raf/profiler.h b/include/raf/profiler.h index bbf61c77..8ab03ba5 100644 --- a/include/raf/profiler.h +++ b/include/raf/profiler.h @@ -251,38 +251,5 @@ inline void ProfilerHelper::collect() { Profiler::Get()->AddNewProfileStat(categories_, name_, start_time_, end_time_, args); } -/*! - * \brief A helper class and macro to profile the execution time of a scope (e.g., function). - * This is used for debugging purpose. For example: - * void some_func() { - * RAF_TIMED_SEC("some_func") - * // do something; - * } - * The profiled time is then the life time of the created TimeSection object, and will be - * logged to stderr. - */ -class TimedSection { - public: - explicit TimedSection(std::string name, bool in_us = false) - : name_(name), start_(ProfileStat::NowInMicrosec()), in_us_(in_us) { - } - - ~TimedSection() { - auto timed = ProfileStat::NowInMicrosec() - start_; - if (in_us_) { - LOG(INFO) << "Timed " << name_ << ": " << timed << "us"; - } else { - LOG(INFO) << "Timed " << name_ << ": " << std::setprecision(2) << timed / 1000000.0 << "s"; - } - } - - private: - std::string name_; - uint64_t start_; - bool in_us_; -}; -#define RAF_TIMED_SEC(name) raf::profiler::TimedSection timed_section(name); -#define RAF_TIMED_US(name) raf::profiler::TimedSection timed_section(name, true); - } // namespace profiler } // namespace raf diff --git a/include/raf/scope_timer.h b/include/raf/scope_timer.h new file mode 100644 index 00000000..ba129f50 --- /dev/null +++ b/include/raf/scope_timer.h @@ -0,0 +1,107 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +/*! + * \file scope_timer.h + * \brief Scope timer that times the execution time of a code scope (e.g., function). + */ +#pragma once +#include "raf/profiler.h" + +namespace raf { +namespace scope_timer { + +class ScopeTimerPool { + public: + explicit ScopeTimerPool() { + } + + ~ScopeTimerPool() { + } + + static std::shared_ptr Get() { + static registry::PerDeviceStore* pool = + new registry::PerDeviceStore(); + std::shared_ptr& ret = pool->Get(Device(DevType::kCPU(), 0)); + if (ret == nullptr) { + std::lock_guard lock(pool->mutex_); + if (ret == nullptr) { + ret = std::make_shared(); + } + } + return ret; + } + + void AddSample(std::string name, float timed) { + std::lock_guard lock(mutex); + time_pool[name].push_back(timed); + } + + void DumpReport() { + LOG(INFO) << "Scope Timer Report:"; + for (auto& kv : time_pool) { + float total = 0; + float max_sample = 0; + for (auto& t : kv.second) { + total += t; + max_sample = std::max(max_sample, t); + } + float avg = total / kv.second.size(); + LOG(INFO) << kv.first << ": " << total << "s (max " << max_sample << "s, avg " << avg + << "s) from " << kv.second.size() << " samples"; + } + } + + void Reset() { + std::lock_guard lock(mutex); + time_pool.clear(); + } + + std::vector GetSamplesByName(std::string name) { + return time_pool[name]; + } + + public: + std::unordered_map> time_pool; + std::mutex mutex; +}; + +/*! + * \brief A helper class and macro to profile the execution time of a scope (e.g., function). + * This is used for debugging/profiling purpose. For example: + * void some_func() { + * RAF_TIMED_SEC("some_func") + * // do something; + * } + * The profiled time is then the life time of the created TimeSection object, and will be + * recorded in the global scope_timer_pool and could be dumped later. + */ +class TimedSection { + public: + explicit TimedSection(std::string name, bool flush, std::shared_ptr pool) + : name_(name), flush_(flush), pool_(pool), start_(profiler::ProfileStat::NowInMicrosec()) { + } + + ~TimedSection() { + auto now = profiler::ProfileStat::NowInMicrosec(); + float timed = (now - start_) / 1e6; + pool_->AddSample(name_, timed); + if (flush_) { + LOG(INFO) << "Timed " << name_ << ": " << timed << "s"; + } + } + + private: + std::string name_; + bool flush_; + std::shared_ptr pool_; + uint64_t start_; +}; +#define RAF_TIMED(name, flush) \ + auto scope_timer_pool = raf::scope_timer::ScopeTimerPool::Get(); \ + raf::scope_timer::TimedSection timed_section(name, flush, scope_timer_pool); + +} // namespace scope_timer +} // namespace raf diff --git a/src/profiler/scope_timer.cc b/src/profiler/scope_timer.cc new file mode 100644 index 00000000..410bea7c --- /dev/null +++ b/src/profiler/scope_timer.cc @@ -0,0 +1,25 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +/*! + * \file src/profiler/scope_timer.cc + * \brief Scope timer to measure execution time of a code scope. + */ +#include "raf/registry.h" +#include "raf/scope_timer.h" + +namespace raf { +namespace scope_timer { + +RAF_REGISTER_GLOBAL("raf.scope_timer.DumpReport").set_body_typed([]() { + ScopeTimerPool::Get()->DumpReport(); +}); + +RAF_REGISTER_GLOBAL("raf.scope_timer.Reset").set_body_typed([]() { + ScopeTimerPool::Get()->Reset(); +}); + +} // namespace scope_timer +} // namespace raf From a5aaacafbb25ed217498c734497a553b231d8301 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 18 May 2022 16:30:45 -0700 Subject: [PATCH 27/37] [Autogen] add indent for schema field index if condition in regs.cc (#55) --- scripts/src_codegen/main_cxx_reg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/src_codegen/main_cxx_reg.py b/scripts/src_codegen/main_cxx_reg.py index 57125098..a36efebc 100644 --- a/scripts/src_codegen/main_cxx_reg.py +++ b/scripts/src_codegen/main_cxx_reg.py @@ -590,7 +590,7 @@ def gen_schema_field_idx(_schema): schema_name = snake_to_pascal(schema_name) args = [] for i, entry in enumerate(schema): - args.append(ARG.format(I=i, FIELD=entry.name)) + args.append(" " + ARG.format(I=i, FIELD=entry.name)) args = "\n".join(map(add_no_lint, args)) return VALUE_TO_SCHEMA.format(SCHEMA_NAME=schema_name, ARGS=args) From a6828f6e8b2248dbfbf1b68118dd2f38b7fcff9f Mon Sep 17 00:00:00 2001 From: Zhen Zhang Date: Sun, 22 May 2022 20:02:23 -0400 Subject: [PATCH 28/37] [DOC] fix import errors of the code in "user guide" (#57) * fix the import error * fixed import, undefined var, getting loss var --- docs/wiki/2_user_guide/Train-Model.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/wiki/2_user_guide/Train-Model.md b/docs/wiki/2_user_guide/Train-Model.md index 6ae40e71..50b513c9 100644 --- a/docs/wiki/2_user_guide/Train-Model.md +++ b/docs/wiki/2_user_guide/Train-Model.md @@ -9,8 +9,8 @@ The programming model of implementing a deep learning mode in RAF is basically t ```python import raf -from raf.model import Conv2D, BatchNorm -from raf.testing import randn_torch, get_vm_executor, run_vm_executor +from raf.model import Conv2d, BatchNorm, Sequential, Linear +from raf.testing import randn_torch, get_vm_executor, run_vm_executor, one_hot_torch class RAFBottleneck(raf.Model): expansion = 4 @@ -125,10 +125,11 @@ RAF offers a virtual machine (VM) runtime to execute the model training process. ```python batch_size = 8 - +input_shape = (batch_size, 3, 224, 224) dy, _ = randn_torch((), std=0.0, mean=1.0, requires_grad=False, device=device) # dy = tensor(1.0) # Training loop +num_epoch = 2 record = None executor = None for _ in range(num_epoch): @@ -143,8 +144,8 @@ for _ in range(num_epoch): executor = get_vm_executor(record.mod, device) ret = run_vm_executor(executor, record, args, device) - loss = ret[0] # ret[0][0] for some models - print("Loss:", loss) + loss = ret[0][0] # ret[0] for some models + print("Loss:", loss.numpy()) ``` One major different as PyTorch is RAF needs to initialize a virtual machine in the first iteration. The initialization involves graph level optimization and VM bytecode compilation. In addition, when running the VM executor in the first iteration, RAF performs just-in-time (JIT) compilation to code generate each kernel, so it may take a bit longer. From fc66f0df9a7d2b69a36820915914ffb0aa74582b Mon Sep 17 00:00:00 2001 From: Zhen Jia <53954057+zhen-jia@users.noreply.github.com> Date: Mon, 23 May 2022 16:47:21 -0700 Subject: [PATCH 29/37] [Bugfix] Fix bugs and clean code for LANS (#56) * fix and clean * lint --- python/raf/optim/lans.py | 8 +++++++- src/op/dialect/cuda/lans.cc | 12 ++---------- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/python/raf/optim/lans.py b/python/raf/optim/lans.py index b84f6ca1..85cf3932 100644 --- a/python/raf/optim/lans.py +++ b/python/raf/optim/lans.py @@ -1,7 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -# pylint: disable=invalid-name, missing-function-docstring, too-many-instance-attributes, too-many-locals, too-many-statements, protected-access, too-many-arguments +# pylint: disable=invalid-name, missing-function-docstring, too-many-instance-attributes, too-many-locals, too-many-statements, protected-access, too-many-arguments, too-many-branches """LANS optimizer.""" import numpy as np @@ -258,6 +258,12 @@ def forward(self, dy, *args, **kwargs): v_list.append(v) ntensor += 1 + if self.dtype != "float32": + fp32_g = _op.group_cast(g_list, "float32") + g_list = [] + for i in range(ntensor): + g_list.append(fp32_g[i]) + tensor_list = g_list + x_list + m_list + v_list output_list = _op.lans( tensor_list, diff --git a/src/op/dialect/cuda/lans.cc b/src/op/dialect/cuda/lans.cc index c5d3d0ed..1da9d18d 100644 --- a/src/op/dialect/cuda/lans.cc +++ b/src/op/dialect/cuda/lans.cc @@ -21,7 +21,6 @@ using namespace raf::value; using device_api::DeviceAPI; #define CHUNK_SIZE 65536 #define FLOAT_BYTES 4 -#define HALF_BYTES 2 class LansImpl : public raf::op::OpEnv { public: @@ -38,7 +37,7 @@ class LansImpl : public raf::op::OpEnv { DLTensor* t0 = ir::Downcast(args->tensor_list[0]); auto datatype = t0->dtype; CHECK(datatype.code == kDLFloat); - CHECK((datatype.bits == 32) || (datatype.bits == 16)); + CHECK((datatype.bits == 32)) << "LANS only takes FP32 inputs"; beta1_ = args->beta1; beta2_ = args->beta2; @@ -64,11 +63,7 @@ class LansImpl : public raf::op::OpEnv { tensor_elements += numel; } max_chunks_per_tensor_ = -1; - if (datatype.bits == 32) { - RequestWorkspace(&q_tensor_buf_, cv->device, FLOAT_BYTES * tensor_elements); - } else { - RequestWorkspace(&q_tensor_buf_, cv->device, HALF_BYTES * tensor_elements); - } + RequestWorkspace(&q_tensor_buf_, cv->device, FLOAT_BYTES * tensor_elements); for (int t = 0; t < param_group_n_; t++) { int max_chunks_this_tensor = (numels_[t] + CHUNK_SIZE - 1) / CHUNK_SIZE; if (max_chunks_this_tensor > max_chunks_per_tensor_) { @@ -96,9 +91,6 @@ class LansImpl : public raf::op::OpEnv { void Execute(const std::vector& inputs, Value output) override { TupleValue tuple = ir::Downcast(inputs[0]); - DLTensor* t0 = ir::Downcast(tuple->fields[0]); - CHECK(t0->dtype.code == kDLFloat); - CHECK((t0->dtype.bits == 32) || (t0->dtype.bits == 16)); auto* tstep = inputs[1].as(); tensor::Tensor step_tensor = tstep->tensor; CHECK(step_tensor->ndim == 0); From 37337e0eca5a0db0a52b3a8295d690458f5275ef Mon Sep 17 00:00:00 2001 From: AIREMetaBot <100344401+aire-meta-bot@users.noreply.github.com> Date: Tue, 24 May 2022 07:47:38 +0800 Subject: [PATCH 30/37] [TVM] Update Submodule 2022-05-23-13-15-12 (#58) * [TVM] Update Submodule * __VisitAttrs_ -> _tvm_VisitAttrs * fix * fix * fix Co-authored-by: SubmoduleUpdaterBot Co-authored-by: Cody Yu --- 3rdparty/tvm | 2 +- include/raf/op.h | 2 +- src/impl/ir_ext.cc | 4 ++-- src/op/grad/binary.cc | 14 ++++++++------ tests/python/op/tvm/test_tvm_binary.py | 12 ++++++++---- 5 files changed, 20 insertions(+), 14 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 02d57bbc..6247bf48 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 02d57bbc062cf9bd47c03d4355ccd660ed68091a +Subproject commit 6247bf48aaa59be9549dd8c342702c6005f16c5f diff --git a/include/raf/op.h b/include/raf/op.h index a30eecfa..8b660982 100644 --- a/include/raf/op.h +++ b/include/raf/op.h @@ -337,7 +337,7 @@ inline T GetOpAttrOrDefault(const ir::Op& op, const std::string attr_name, T def static constexpr const char* _type_key = type_key; \ RAF_FINAL_OBJECT(class_name, ::tvm::BaseAttrsNode) \ template \ - void __VisitAttrs__(FVisit& __fvisit__) { \ + void _tvm_VisitAttrs(FVisit& _tvm_fvisit) { \ } #define RAF_OP_GRAD(op_name, body) \ diff --git a/src/impl/ir_ext.cc b/src/impl/ir_ext.cc index 5b423deb..b9cd7c8c 100644 --- a/src/impl/ir_ext.cc +++ b/src/impl/ir_ext.cc @@ -126,11 +126,11 @@ std::string AsText(const ObjectRef& node, bool show_meta_data) { std::string ret = tvm::AsText(node, show_meta_data, annotate); size_t index = 0; while (true) { - index = ret.find("-114514", index); + index = ret.find("-114514i64", index); if (index == std::string::npos) { break; } - ret.replace(index, 7, ""); + ret.replace(index, 10, ""); } return ret; } diff --git a/src/op/grad/binary.cc b/src/op/grad/binary.cc index cc389de5..2e73ff15 100644 --- a/src/op/grad/binary.cc +++ b/src/op/grad/binary.cc @@ -119,12 +119,14 @@ Array PowGrad(const Expr& orig_call, const Array orig_args, const Va Call dx = Call(op_multiply, {a, x_pow_minus_one}); // da = x^a * log(x) - Call x_pow = Call(op_power, {x, a}); - Call x_log = Call(op_log, {x}); - Call da = Call(op_multiply, {x_pow, x_log}); - - return {GetCollapseSumLike(Call(op_multiply, {dy, dx}), x), - GetCollapseSumLike(Call(op_multiply, {dy, da}), a)}; + // FIXME: PyTorch defines d(x^a)/da at x = 0 and a < 0 to be -inf. This is due to + // d(x^a)/da -> -inf for fixed a as x -> +0. However, this requires either an if-branch + // or special handling at the kernel level, so we disable the gradient of "a" for a now. + // Call x_pow = Call(op_power, {x, a}); + // Call x_log = Call(op_log, {x}); + // Call da = Call(op_multiply, {x_pow, x_log}); + + return {GetCollapseSumLike(Call(op_multiply, {dy, dx}), x), ir::NullValue()}; } RAF_OP_GRAD("raf.op.power", PowGrad); diff --git a/tests/python/op/tvm/test_tvm_binary.py b/tests/python/op/tvm/test_tvm_binary.py index edacc4e5..15c46b17 100644 --- a/tests/python/op/tvm/test_tvm_binary.py +++ b/tests/python/op/tvm/test_tvm_binary.py @@ -91,19 +91,23 @@ def test_binary_ops_with_grad(ops, shape, dtype, device): @pytest.mark.parametrize("device", get_testable_devices()) @pytest.mark.parametrize("dtype", ["float32"]) def test_power(dtype, device): - x1 = np.random.randn(2, 2).astype("float32") - x1[0][0] = 0 # Assign 0 to test the corner case. + x1 = np.abs(np.random.randn(2, 2).astype("float32")) + 1e-5 + # Corner case: 0 + x1[0][0] = 0 + # Corner case: negative value + x1[0][1] = -x1[0][1] t_x1 = torch.Tensor(x1).to(device) t_x1.requires_grad = True m_x1 = raf.array(x1, device=device) m_x1.requires_grad = True - m_x2, t_x2 = randn_torch((), dtype=dtype, device=device, requires_grad=True) + m_x2, t_x2 = randn_torch((), dtype=dtype, device=device, requires_grad=False, positive=True) t_y = torch.pow(t_x1, t_x2) m_dy, t_dy = randn_torch(t_y.shape, dtype=dtype, device=device) t_y.backward(t_dy) - verify_op(raf._op.sym.power, [m_x1, m_x2], device, t_y, m_dy, [t_x1.grad, t_x2.grad]) + # Note that we do not compare the gradient of x2 because it is disabled for now. + verify_op(raf._op.sym.power, [m_x1, m_x2], device, t_y, m_dy, [t_x1.grad]) # logical_and only allows bool input s From 0639ef59673cead092b7351d7425caba8dded684 Mon Sep 17 00:00:00 2001 From: Jie Wang Date: Thu, 26 May 2022 16:58:59 -0700 Subject: [PATCH 31/37] [Tracing] Support multiple inputs (#59) Co-authored-by: Jie Wang --- python/raf/frontend/pytorch.py | 63 +++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/python/raf/frontend/pytorch.py b/python/raf/frontend/pytorch.py index 158d1925..55a1f90f 100644 --- a/python/raf/frontend/pytorch.py +++ b/python/raf/frontend/pytorch.py @@ -15,7 +15,7 @@ from ..frontend.model import FrameworkModel -def trace_model(model, input_type, input_shape): +def trace_model(model, shape_dict): """Trace PyTorch model. Parameters @@ -23,11 +23,10 @@ def trace_model(model, input_type, input_shape): model: torch.nn.Module The PyTorch module to be converted. - input_type: str - Input type. - - input_shape: Tuple[int, ...] - Input shape + shape_dict: Dict[str, + Union[Tuple[Tuple[int, ...], str], + Tuple[Tuple[int, ...], str, int]] + A map from input name to its shape, type, and maximal value (optional). Returns ------- @@ -49,8 +48,8 @@ def __init__(self, model): super().__init__() self.model = model - def forward(self, inp): - out = self.model(inp) + def forward(self, *inputs): + out = self.model(*inputs) if isinstance(out, list): ordered_outs = [out[0][key] for key in self.od_model_output_keys if key in out[0]] return tuple(ordered_outs) @@ -65,7 +64,7 @@ def dtype(self): return param.dtype return torch.float32 - def inner(model, input_type, input_shape): + def inner(model, shape_dict): """Wrap the tracing process so that we could empty PyTorch CUDA cache afterward.""" model = TraceWrapper(model) model.eval() @@ -80,16 +79,24 @@ def inner(model, input_type, input_shape): comm = dist.get_communicator() device = "cuda:" + str(comm.local_rank) - if input_type.startswith("float"): - input_data = torch.randn(input_shape, dtype=getattr(torch, input_type), device=device) - else: - assert input_type.startswith("int64"), "Unsupported input type %s" % input_type - input_data = torch.randint(10000, input_shape, device=device) + example_inputs = [] + for _, input_info in shape_dict.items(): + input_shape = input_info[0] + input_type = input_info[1] + if input_type.startswith("int64"): + max_val = 10000 if len(input_info) == 2 else input_info[2] + input_data = torch.randint(max_val + 1, input_shape, device=device) + elif input_type.startswith("float"): + input_data = torch.randn( + input_shape, dtype=getattr(torch, input_type), device=device + ) + else: + raise ValueError("Unsupported input type %s" % input_type) + example_inputs.append(input_data) with torch.no_grad(): model.to(device=device) - model(input_data) - scripted_model = torch.jit.trace(model, input_data).eval() + scripted_model = torch.jit.trace(model, tuple(example_inputs)).eval() if device.startswith("cuda"): model.to(device="cpu") @@ -97,7 +104,7 @@ def inner(model, input_type, input_shape): return scripted_model - scripted_model = inner(model, input_type, input_shape) + scripted_model = inner(model, shape_dict) if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -113,9 +120,10 @@ def from_pytorch(model, shape_dict, model_file=None, hash_file=None): model: torch.nn.Module The PyTorch module to be converted. - shape_dict: Dict[str, Tuple[Tuple[int, ...], str]] - A map from input name to its shape and type. Note that we currently only support - the model with a single input. + shape_dict: Dict[str, + Union[Tuple[Tuple[int, ...], str], + Tuple[Tuple[int, ...], str, int]] + A map from input name to its shape, type, and maximal value (optional). model_file: str The file that stores the scripted model @@ -127,11 +135,6 @@ def from_pytorch(model, shape_dict, model_file=None, hash_file=None): model: FrameworkModel The converted FrameworkModel. """ - if len(shape_dict) > 1: - raise RuntimeError( - "Do not support PyTorch model with multiple inputs (%d) yet" % len(shape_dict) - ) - input_name, (input_shape, input_type) = list(shape_dict.items())[0] if model_file is not None and hash_file is not None: model_hash = hashlib.md5(str(model).encode(encoding="UTF-8")).hexdigest() if os.path.exists(model_file) and os.path.exists(hash_file): @@ -144,14 +147,18 @@ def from_pytorch(model, shape_dict, model_file=None, hash_file=None): except: raise RuntimeError("Loading scripted model failed") else: - scripted_model = trace_model(model, input_type, input_shape) + scripted_model = trace_model(model, shape_dict) scripted_model.eval() scripted_model.save(model_file) with open(hash_file, "w") as hashf: hashf.write(model_hash) else: - scripted_model = trace_model(model, input_type, input_shape) - shape_list = [(input_name, (input_shape, input_type))] + scripted_model = trace_model(model, shape_dict) + shape_list = [] + for input_name, input_info in list(shape_dict.items()): + input_shape = input_info[0] + input_type = input_info[1] + shape_list.append((input_name, (input_shape, input_type))) relay_mod, relay_params = relay.frontend.from_pytorch(scripted_model, shape_list) meta_mod = FromRelay()(relay_mod) meta_params = OrderedDict() From 10fad16bfc01e19330b30cb10a2449a90ef9c3bc Mon Sep 17 00:00:00 2001 From: TIAN Ye Date: Tue, 31 May 2022 12:27:35 +0800 Subject: [PATCH 32/37] fix pass name (#62) --- src/pass/switch_train.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pass/switch_train.cc b/src/pass/switch_train.cc index b58cd2ab..535e0c3c 100644 --- a/src/pass/switch_train.cc +++ b/src/pass/switch_train.cc @@ -112,8 +112,8 @@ Pass SwitchTrainOp(bool to_train_op) { PassContext pc) { return switch_train::OpReplacer(f, to_train_op).Replace(); }; - auto switch_train = CreateRAFFunctionPass(pass_func, 0, "SwithTrainOp", {}); - return RAFSequential({switch_train}, "SwithTrainOp"); + auto switch_train = CreateRAFFunctionPass(pass_func, 0, "SwitchTrainOp", {}); + return RAFSequential({switch_train}, "SwitchTrainOp"); } RAF_REGISTER_GLOBAL("raf.pass_.SwitchTrainOp").set_body_typed(SwitchTrainOp); From de2f73913ce9ce94b84b6acdb6cc6a1066b266fc Mon Sep 17 00:00:00 2001 From: AIREMetaBot <100344401+aire-meta-bot@users.noreply.github.com> Date: Wed, 1 Jun 2022 01:30:32 +0800 Subject: [PATCH 33/37] [TVM] Update Submodule (#61) Co-authored-by: SubmoduleUpdaterBot --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 6247bf48..c6415d14 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 6247bf48aaa59be9549dd8c342702c6005f16c5f +Subproject commit c6415d14928d1e09f4bd3105c7a5ddf87f92166b From dbc507aaf342725e804ddbcf77bb11485b748e7c Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 31 May 2022 18:41:29 -0700 Subject: [PATCH 34/37] [Op] [TVM] Lightweight TOPI schedule tuner (#60) * [Op][TVM] Lightweight tuner * dispatch * trigger * test * test * done * fix * comment * fix * enable by default --- python/raf/_tvm_op/reduce.py | 297 ++++++++++++++++++++++++++++++++++- python/raf/_tvm_op/utils.py | 110 +++++++++++++ 2 files changed, 399 insertions(+), 8 deletions(-) diff --git a/python/raf/_tvm_op/reduce.py b/python/raf/_tvm_op/reduce.py index 17e3feb7..5a4caadb 100644 --- a/python/raf/_tvm_op/reduce.py +++ b/python/raf/_tvm_op/reduce.py @@ -3,11 +3,17 @@ # pylint: disable=missing-function-docstring, too-many-locals, unused-argument """Reduction compute definition and schedules.""" +from operator import mul +from functools import reduce + +import numpy as np + from raf._tvm_op.nn import schedule_generic from .._lib import register_compute from .._lib import generic_func from .._lib import tvm as _tvm from .._lib import _reg +from .utils import profile_schedule _topi = _tvm.topi # pylint: disable=invalid-name, no-member @@ -61,10 +67,266 @@ def schedule_sum(attrs, outs, target): return _topi.generic.schedule_injective(outs) +def _schedule_cuda_sum_long_reduce(op, sch, **kwargs): + """The helper function for scheduling sum with long reduction length for CUDA. + In this case, we want to parallelize the reduction to keep the GPU busy. This is modified + from TOPI reduce schedule for CUDA. + + Parameters + ---------- + op: tvm.Operation + The operator being scheduled. + + sch: tvm.schedule.Schedule + The working schedule. + + **kwargs: Dict[str, List[Any]] + Tunable parameters. If not presents, the default values will be used. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + # pylint: disable=invalid-name + # Setup the tunable parameter value. + num_thread = kwargs.get("num_thread", 32) + thread_x = _tvm.te.thread_axis((0, num_thread), "threadIdx.x") + + data_out = op.output(0) + + # Fuse and rfactor the reduce axis + fused_reduce = sch[data_out].fuse( + *[sch[data_out].op.reduce_axis[i] for i in range(len(sch[data_out].op.reduce_axis))] + ) + _, ki = sch[data_out].split(fused_reduce, factor=num_thread) + data_out_rf = sch.rfactor(data_out, ki) + tx = sch[data_out].op.reduce_axis[0] + sch[data_out].bind(tx, thread_x) + sch[data_out_rf].compute_at(sch[data_out], tx) + + if len(sch[data_out].op.axis) > 0: + # There are one or more axes to not reduced. Here we bind them to threads and blocks + # for parallelism. + block_x = _tvm.te.thread_axis("blockIdx.x") + thread_y = _tvm.te.thread_axis((0, num_thread), "threadIdx.y") + + # Fuse and split the axis + fused_outer = sch[data_out].fuse( + *[sch[data_out].op.axis[i] for i in range(len(sch[data_out].op.axis))] + ) + bx, outer_in = sch[data_out].split(fused_outer, factor=num_thread) + + # Bind non-reduced axes to threads and blocks + sch[data_out].bind(outer_in, thread_y) + sch[data_out].bind(bx, block_x) + sch[data_out].set_store_predicate( + _tvm.tir.all( + thread_x.equal(0), block_x * num_thread + thread_y < reduce(mul, data_out.shape) + ) + ) + else: + # All axes are reduced. + sch[data_out].set_store_predicate(thread_x.equal(0)) + return sch + + +@profile_schedule( + num_thread=[16, 32, 64], validator=lambda _, reduce_last_axis: not reduce_last_axis +) +def schedule_cuda_sum_long_reduce(outs, reduce_last_axis, **kwargs): + """Schedule sum for CUDA. This schedule targets to the sum with long reduction length. + In this case, we want to parallelize the reduction to keep the GPU busy. This is modified + from TOPI reduce schedule for CUDA. + + In addition, this schedule is tunable if the last axis is not reduced. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of reduce in the format of an array of tensors. + + reduce_last_axis: bool + A hint indicating whether the last axis is reduced. + + **kwargs: Dict[str, List[Any]] + Tunable parameters. If not presents, the default values will be used. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + # pylint: disable=unused-argument + outs = [outs] if isinstance(outs, _tvm.te.tensor.Tensor) else outs + sch = _tvm.te.create_schedule([x.op for x in outs]) + scheduled_ops = [] + + def _enable_auto_inline(sch): + def is_scheduled(stage): + # auto inline requires the attach type is AttachType.kGroupRoot + conds = [ + len(stage.relations) == 0, + stage.attach_type == 1, + stage.all_iter_vars == stage.leaf_iter_vars, + ] + if not all(conds): + return True + return False + + for stage in sch.stages: + if not stage.is_output and isinstance(stage.op, _tvm.te.ComputeOp): + if is_scheduled(stage) or len(stage.op.reduce_axis) != 0: + return False + return True + + enable_auto_inline = _enable_auto_inline(sch) + + def traverse_before_reduce(tensor): + """Internal traverse function""" + operator = tensor.op + if isinstance(operator, _tvm.te.PlaceholderOp): + return + if _topi.tag.is_injective(operator.tag): + sch[operator].compute_inline() + for inp_tensor in operator.input_tensors: + if inp_tensor.op not in scheduled_ops: + traverse_before_reduce(inp_tensor) + else: + raise RuntimeError("Unsupported operator: %s" % operator.tag) + + scheduled_ops.append(operator) + + def traverse_after_reduce(tensor): + """Internal traverse function""" + operator = tensor.op + if _topi.tag.is_broadcast(operator.tag): + if operator not in scheduled_ops: + _topi.schedule_injective_from_existing( # pylint: disable=no-member + sch, operator.output(0) + ) + for inp_tensor in operator.input_tensors: + if inp_tensor.op not in scheduled_ops: + if enable_auto_inline: + traverse_before_reduce(inp_tensor) + else: + traverse_after_reduce(inp_tensor) + elif operator.tag == "comm_reduce": + if operator not in scheduled_ops: + _schedule_cuda_sum_long_reduce(operator, sch, **kwargs) + for inp_tensor in operator.input_tensors: + if inp_tensor.op not in scheduled_ops: + traverse_before_reduce(inp_tensor) + elif isinstance(operator, _tvm.te.PlaceholderOp): + pass + else: + raise RuntimeError("Unsupported operator tag: %s" % operator.tag) + + scheduled_ops.append(operator) + + for out in outs: + traverse_after_reduce(out) + + return sch + + +@profile_schedule(num_thread=[16, 32, 64], max_block=[128, 256, 512]) +def schedule_cuda_short_reduce(outs, **kwargs): + """Schedule sum for CUDA. This schedule targets to the sum with short reduction length. + In this case, each thread is responsible for reduction. The parallelization is across + the output elements. This is modified from TOPI injective schedule for CUDA. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of injective in the format of an array of tensors. + + **kwargs: Dict[str, List[Any]] + Tunable parameters. If not presents, the default values will be used. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + # pylint: disable=invalid-name + # Tunable parameters. + num_thread = kwargs.get( + "num_thread", _tvm.target.Target.current(allow_none=False).max_num_threads + ) + max_block = kwargs.get("max_block", 256) + + def find_nearest_small_factor(num, target): + """Find the nearest factor of the given number that is smaller than the target.""" + for i in range(target, 0, -1): + if num % i == 0: + return i + # Unreachable because i=1 must hold. + return -1 + + outs = [outs] if isinstance(outs, _tvm.te.tensor.Tensor) else outs + sch = _tvm.te.create_schedule([x.op for x in outs]) + + _tvm.te.schedule.AutoInlineInjective(sch) # pylint: disable=no-member + for out in outs: + if not _topi.utils.is_empty_shape(out.shape): + fused = sch[out].fuse(*sch[out].op.axis) + + # Vectorize on fp16 data type to enable half2 for better memory bandwidth utilization. + vector_width = 2 if out.dtype == "float16" else 1 + + out_len = _topi.utils.prod(out.shape) + + try: + const_size = _topi.utils.get_const_int(out_len) + + # Adjust block and thread to make sure they are dividable so that vectorize can be + # correctly applied. + if vector_width > 1 and const_size % vector_width == 0: + remain_total_size = const_size // vector_width + cand_sizes = [0, 0] + for idx, max_size in enumerate([num_thread, max_block]): + cand_sizes[idx] = ( + max_size + if remain_total_size % max_size == 0 + else find_nearest_small_factor(remain_total_size, max_size) + ) + remain_total_size //= cand_sizes[idx] + + # If the product of candidate dividable (block * thread) is too small, + # then the performance may be worse even half2 is enabled. Note that 0.7 + # is just a heuristic ratio and may not be optimal for all workloads. + if np.prod(cand_sizes) / (max_block * num_thread) >= 0.7: + num_thread, max_block = cand_sizes + + need_block_split = const_size > max_block * num_thread * vector_width + except ValueError: + need_block_split = False + const_size = 0 + + if vector_width > 1: + fused, vec = sch[out].split(fused, vector_width) + sch[out].vectorize(vec) + + if need_block_split: + xo, xi = sch[out].split(fused, factor=num_thread * max_block) + bx, tx = sch[out].split(xi, factor=num_thread) + sch[out].reorder(bx, tx, xo) + sch[out].bind(bx, _tvm.te.thread_axis("blockIdx.x")) + sch[out].bind(tx, _tvm.te.thread_axis("threadIdx.x")) + else: + if const_size != 0 and const_size < num_thread: + bx, tx = sch[out].split(fused, factor=const_size) + else: + bx, tx = sch[out].split(fused, factor=num_thread) + sch[out].bind(tx, _tvm.te.thread_axis("threadIdx.x")) + sch[out].bind(bx, _tvm.te.thread_axis("blockIdx.x")) + return sch + + @schedule_sum.register(["cuda", "gpu"]) def schedule_sum_cuda(attrs, outs, target): # pylint: disable=unused-argument - # pylint: disable=invalid-name def get_num_elements(axes): extents = [int(iv.dom.extent) for iv in axes] n_elems = 1 @@ -72,20 +334,39 @@ def get_num_elements(axes): n_elems *= extent return n_elems + def get_sum_input(tensor): + operator = tensor.op + if len(operator.input_tensors) != 1: + return None + + input_tensor = operator.input_tensors[0] + if isinstance(input_tensor.op, _tvm.te.PlaceholderOp): + return input_tensor + return None + with target: out = outs[0] num_out_elements = get_num_elements(out.op.axis) num_reduce_elements = get_num_elements(out.op.reduce_axis) - # We want to saturate the GPU cores by parallelization. There are 2 scenarios - # 1) Reduce dimension is small - In this case, each thread is responsible for reduction. - # The parallelization is across the output elements. - # 2) Reduce dimension is large - We want to parallelize the reduction to keep the GPU busy. - # Here we fall back to TVM schedule. + # Whether the last axis is reduced. Note that axis=-1 should already be proceed in advance. + reduce_last_axis = False + input_tensor = get_sum_input(out) + if input_tensor is not None: + # Only try to analyze the workload with sum as the last op. + reduce_axis = [False for _ in range(len(input_tensor.shape))] + for axis in attrs.axis: + reduce_axis[axis.value] = True + if attrs.exclude == 1: + reduce_axis = [not axis for axis in reduce_axis] + reduce_last_axis = reduce_axis[-1] + + # We attempt to saturate the GPU cores by parallelization, so we dispatch + # the sum workloads to two schedules based on their reduction length. if num_out_elements > num_reduce_elements: - return _topi.cuda.schedule_injective(outs) + return schedule_cuda_short_reduce(outs) - return _topi.cuda.schedule_reduce(outs) + return schedule_cuda_sum_long_reduce(outs, reduce_last_axis=reduce_last_axis) _reg.register_schedule("raf.op.tvm.sum", schedule_sum) diff --git a/python/raf/_tvm_op/utils.py b/python/raf/_tvm_op/utils.py index c86d211a..0e7813a5 100644 --- a/python/raf/_tvm_op/utils.py +++ b/python/raf/_tvm_op/utils.py @@ -5,8 +5,12 @@ # pylint: disable=protected-access import os +import numpy as np + import tvm +from raf import distributed as dist + @tvm._ffi.register_func("raf._tvm_op.utils.export_library") def export_library(mod, path): @@ -45,3 +49,109 @@ def load_module(path): if not os.path.exists(path): raise RuntimeError("Module file does not exist {}".format(path)) return tvm.runtime.module.load_module(path) + + +def profile_schedule(**params): + """A lightwight tuner for TOPI schedules. It is similar to AutoTVM but very lightweight. + It can be used as follows: + + ```python + @profile_schedule(num_thread=[8, 16, 32, 64], validator=should_tune_this_workload) + def _schedule_cuda(outs, **kwargs): + num_thread = kwargs.get("num_thread", 32) # Get tuned value or default. + ... + return sch + + @schedule_sum.register(["cuda", "gpu"]) + def schedule_cuda(attrs, outs, target): + with target: + return _schedule_cuda(outs) + ``` + + The above code snippet profiles 4 schedules with different num_thread and returns + the best one. Note that "validator" is a reserved keyword, which optionally specifies + a function to check whether the workload should be tuned or not. + + Since we directly use tvm.build to compile and evaluate the schedule + without heavy RFC mechanism and reuse the random data inputs, this is lightwieght + compared to AutoTVM and auto-schedule and can be used for JIT compilation. However, + develoeprs should control the tuning space to avoid long JIT time. It is recommended + to have <10 tuning space when using this function. + """ + validator = lambda _, **kwargs: True + if "validator" in params: + validator = params["validator"] + del params["validator"] + enable = os.environ.get("RAF_JIT_TUNE", True) + comm = dist.get_communicator() + local_rank = comm.local_rank + + def _wrapper(sch_func): + def _profile(outs, **kwargs): + # If not enabled, do not pass any tunable parameters so that the schedule + # function will use the built-in default values. + if not enable or not validator(outs, **kwargs): + return sch_func(outs, **kwargs) + + outs = [outs] if isinstance(outs, tvm.te.tensor.Tensor) else outs + + # Collect arguments. + args_set = set() + + def collect_args(tensor): + operator = tensor.op + if isinstance(operator, tvm.te.PlaceholderOp): + args_set.add(tensor) + else: + for inp_tensor in operator.input_tensors: + collect_args(inp_tensor) + + for out in outs: + collect_args(out) + + # Generate random input data for profiling. + tvm_target = tvm.target.Target.current() + tvm_device = tvm.device(str(tvm_target), local_rank) + args = list(args_set) + args_data = [] + for arg in args: + shape = [s.value for s in arg.shape] + args_data.append( + tvm.nd.array(np.random.uniform(size=shape).astype(arg.dtype), tvm_device) + ) + + # Profiling + def profile_param(param_list, param_dict): + if param_list: + # One or more parameter values are not visited yet, iterating + # the values of the first undetermined parameter recursively. + best_sch_n_latency = (None, float("inf")) + key, vals = param_list[0] + for val in vals: + param_dict[key] = val + sch, latency = profile_param(param_list[1:], param_dict) + if best_sch_n_latency[0] is None or latency < best_sch_n_latency[1]: + best_sch_n_latency = (sch, latency) + assert best_sch_n_latency is not None + return best_sch_n_latency + + # All parameter values are determined and in param_dict, evaluate the schedule. + sch_kwargs = kwargs.copy() + sch_kwargs.update(param_dict) + sch = sch_func(outs, **sch_kwargs) + try: + func = tvm.build(sch, args, tvm_target) + # Run 5 times and take the median value to avoid outliers. + evaluator = func.time_evaluator(func.entry_name, tvm_device, number=5) + latency = evaluator(*args_data).median + except Exception: # pylint: disable=broad-except + latency = float("inf") + return sch, latency + + sch, _ = profile_param(list(params.items()), {}) + del args_data + return sch + + return _profile + + return _wrapper From 615342aa6888b9954012d416b2846a93be8c9ad3 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 3 Jun 2022 15:51:18 -0700 Subject: [PATCH 35/37] [Op][TVM] Miner improvements for TOPI tuner (#64) * test * a * fix * a * a * a --- python/raf/_tvm_op/reduce.py | 18 +++++++++++------- python/raf/_tvm_op/utils.py | 34 ++++++++++++++++++++++++---------- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/python/raf/_tvm_op/reduce.py b/python/raf/_tvm_op/reduce.py index 5a4caadb..47e64b95 100644 --- a/python/raf/_tvm_op/reduce.py +++ b/python/raf/_tvm_op/reduce.py @@ -13,7 +13,7 @@ from .._lib import generic_func from .._lib import tvm as _tvm from .._lib import _reg -from .utils import profile_schedule +from .utils import get_cuda_max_thread, profile_schedule _topi = _tvm.topi # pylint: disable=invalid-name, no-member @@ -89,11 +89,14 @@ def _schedule_cuda_sum_long_reduce(op, sch, **kwargs): The computation schedule for the op. """ # pylint: disable=invalid-name - # Setup the tunable parameter value. - num_thread = kwargs.get("num_thread", 32) - thread_x = _tvm.te.thread_axis((0, num_thread), "threadIdx.x") + # Whether this workload is reducing all axes. data_out = op.output(0) + all_reduce = len(sch[data_out].op.axis) == 0 + + # Setup the tunable parameter value. + num_thread = kwargs.get("num_thread", get_cuda_max_thread() if all_reduce else 32) + thread_x = _tvm.te.thread_axis((0, num_thread), "threadIdx.x") # Fuse and rfactor the reduce axis fused_reduce = sch[data_out].fuse( @@ -105,7 +108,7 @@ def _schedule_cuda_sum_long_reduce(op, sch, **kwargs): sch[data_out].bind(tx, thread_x) sch[data_out_rf].compute_at(sch[data_out], tx) - if len(sch[data_out].op.axis) > 0: + if not all_reduce: # There are one or more axes to not reduced. Here we bind them to threads and blocks # for parallelism. block_x = _tvm.te.thread_axis("blockIdx.x") @@ -132,7 +135,8 @@ def _schedule_cuda_sum_long_reduce(op, sch, **kwargs): @profile_schedule( - num_thread=[16, 32, 64], validator=lambda _, reduce_last_axis: not reduce_last_axis + num_thread=[32, 64, get_cuda_max_thread()], + validator=lambda _, reduce_last_axis: not reduce_last_axis, ) def schedule_cuda_sum_long_reduce(outs, reduce_last_axis, **kwargs): """Schedule sum for CUDA. This schedule targets to the sum with long reduction length. @@ -230,7 +234,7 @@ def traverse_after_reduce(tensor): return sch -@profile_schedule(num_thread=[16, 32, 64], max_block=[128, 256, 512]) +@profile_schedule(num_thread=[32, 64, get_cuda_max_thread()], max_block=[256, 512]) def schedule_cuda_short_reduce(outs, **kwargs): """Schedule sum for CUDA. This schedule targets to the sum with short reduction length. In this case, each thread is responsible for reduction. The parallelization is across diff --git a/python/raf/_tvm_op/utils.py b/python/raf/_tvm_op/utils.py index 0e7813a5..074fbc46 100644 --- a/python/raf/_tvm_op/utils.py +++ b/python/raf/_tvm_op/utils.py @@ -4,6 +4,7 @@ """Utilities for processing TVM ops.""" # pylint: disable=protected-access import os +from threading import Thread import numpy as np @@ -51,6 +52,11 @@ def load_module(path): return tvm.runtime.module.load_module(path) +def get_cuda_max_thread(): + """A helper function to obtain the maximum number of threads per block.""" + return tvm.target.Target("cuda").max_num_threads + + def profile_schedule(**params): """A lightwight tuner for TOPI schedules. It is similar to AutoTVM but very lightweight. It can be used as follows: @@ -82,7 +88,7 @@ def schedule_cuda(attrs, outs, target): if "validator" in params: validator = params["validator"] del params["validator"] - enable = os.environ.get("RAF_JIT_TUNE", True) + enable = bool(int(os.environ.get("RAF_JIT_TUNE", 1))) comm = dist.get_communicator() local_rank = comm.local_rank @@ -110,7 +116,7 @@ def collect_args(tensor): collect_args(out) # Generate random input data for profiling. - tvm_target = tvm.target.Target.current() + tvm_target = tvm.target.Target.current(allow_none=False) tvm_device = tvm.device(str(tvm_target), local_rank) args = list(args_set) args_data = [] @@ -120,6 +126,17 @@ def collect_args(tensor): tvm.nd.array(np.random.uniform(size=shape).astype(arg.dtype), tvm_device) ) + def _build_n_profile(sch, args, tvm_target, ret): + """Build and profile a schedule. This is supposed to be used in an isolated env.""" + try: + func = tvm.build(sch, args, tvm_target) + # Run 5 times and take the median value to avoid outliers. + evaluator = func.time_evaluator(func.entry_name, tvm_device, number=5) + ret["latency"] = evaluator(*args_data).median + except Exception as err: # pylint: disable=broad-except + ret["latency"] = float("inf") + ret["error"] = str(err) + # Profiling def profile_param(param_list, param_dict): if param_list: @@ -132,20 +149,17 @@ def profile_param(param_list, param_dict): sch, latency = profile_param(param_list[1:], param_dict) if best_sch_n_latency[0] is None or latency < best_sch_n_latency[1]: best_sch_n_latency = (sch, latency) - assert best_sch_n_latency is not None return best_sch_n_latency # All parameter values are determined and in param_dict, evaluate the schedule. sch_kwargs = kwargs.copy() sch_kwargs.update(param_dict) sch = sch_func(outs, **sch_kwargs) - try: - func = tvm.build(sch, args, tvm_target) - # Run 5 times and take the median value to avoid outliers. - evaluator = func.time_evaluator(func.entry_name, tvm_device, number=5) - latency = evaluator(*args_data).median - except Exception: # pylint: disable=broad-except - latency = float("inf") + ret = {} + thd = Thread(target=_build_n_profile, args=(sch, args, tvm_target, ret)) + thd.start() + thd.join() + latency = ret["latency"] return sch, latency sch, _ = profile_param(list(params.items()), {}) From 25372d7a9dc662f1519d73dda7f1df6251441a1b Mon Sep 17 00:00:00 2001 From: StrongSpoon <35829812+StrongSpoon@users.noreply.github.com> Date: Mon, 6 Jun 2022 18:49:21 +0800 Subject: [PATCH 36/37] [Op] [TVM] update reduce_scatter (#63) * update reduce_scatter: add new argument rank_list * remove useless comment and write a specific test case for reduce_scatter with rank_list * Update code format in function test_reduce_scatter_with_rank_list * amend an unused argument * write test cases for more computation * fixed undeclared args --- python/raf/distributed/op.py | 12 ++++- scripts/src_codegen/def_schema.py | 1 + src/op/dialect/nccl/nccl.cc | 2 +- .../test_collective_communication.py | 51 +++++++++++++++++++ 4 files changed, 63 insertions(+), 3 deletions(-) diff --git a/python/raf/distributed/op.py b/python/raf/distributed/op.py index cf16a7d5..92a5cb56 100644 --- a/python/raf/distributed/op.py +++ b/python/raf/distributed/op.py @@ -126,7 +126,7 @@ def reduce(x, root, computation="sum"): return sym._reduce(x, root, computation) -def reduce_scatter(x, computation="sum"): +def reduce_scatter(x, computation="sum", rank_list=None): """Performs reduction then scatter Parameters @@ -136,6 +136,14 @@ def reduce_scatter(x, computation="sum"): replica i receives reduction of x[i] over all replicas computation: string The reduction operation, default is sum + rank_list: [[int]] + The list of ranks to communicate. This parameter will split the ranks + (MPI / NCCL processes) into multiple groups as specified by the user, + and each rank will only communicate within the group. If the rank list + leaves empty, the ranks won't get split. Note that this operator is + collective, which means ranks, whether they are in the rank_list or not, + must invoke this along with other ranks. The rank not in the rank_list + will run in standalone mode. Returns ------- @@ -143,7 +151,7 @@ def reduce_scatter(x, computation="sum"): reduction result of x[rank] over all replicas, where rank represents rank number of the current process """ - return sym._reduce_scatter(x, computation) + return sym._reduce_scatter(x, computation, rank_list=rank_list) def group_reduce_scatter(tensor_list, computation="sum"): diff --git a/scripts/src_codegen/def_schema.py b/scripts/src_codegen/def_schema.py index acdd5ec6..ccbd687f 100644 --- a/scripts/src_codegen/def_schema.py +++ b/scripts/src_codegen/def_schema.py @@ -725,6 +725,7 @@ "communication.h::reduce_scatter": [ Arg(name="x", cxx_type="std::vector", cxx_normalizer="TensorTuple"), Arg(name="computation", cxx_type="std::string", cxx_default='"sum"', py_default='"sum"'), + Arg(name="rank_list", cxx_type="value::Value", cxx_default="nullptr"), ], "communication.h::group_reduce_scatter": [ Arg( diff --git a/src/op/dialect/nccl/nccl.cc b/src/op/dialect/nccl/nccl.cc index ad589656..fd779eae 100644 --- a/src/op/dialect/nccl/nccl.cc +++ b/src/op/dialect/nccl/nccl.cc @@ -261,8 +261,8 @@ class NCCLReduceScatter : public NCCLOpEnv { auto fschema_index = ir::Op::GetAttrMap("FRAFSchemaFieldIndex"); this->arg_indices = {fschema_index[op]("x")}; RequestStream(&stream, cv->device, StreamTagEnum::CudaCommunicate()); - RequestDistributed(&communicator, "nccl", NullValue()); auto args = cv->args.as(); + RequestDistributed(&communicator, "nccl", args->rank_list); if (args->computation.compare("sum") == 0) { compute = ncclSum; } else if (args->computation.compare("prod") == 0) { diff --git a/tests/python/distributed/test_collective_communication.py b/tests/python/distributed/test_collective_communication.py index 78561a5a..4fd2936e 100644 --- a/tests/python/distributed/test_collective_communication.py +++ b/tests/python/distributed/test_collective_communication.py @@ -335,6 +335,57 @@ def forward(self, x, y): check(m_out, n_out) +@pytest.mark.skipif(skip_dist_test(min_rank_num=4, require_exact_rank=True), reason=SKIP_REASON) +@pytest.mark.parametrize("computation", ["sum", "prod", "min", "max"]) +@pytest.mark.parametrize("rank_list", [[[0, 1], [2, 3]]]) +def test_reduce_scatter_with_rank_list(computation, rank_list): + class TestModel(raf.Model): + def build(self): + pass + + @raf.model.trace + def forward(self, x, y): + z = Symbol.make_tuple([x, y]) + out = raf.reduce_scatter(z, computation=computation, rank_list=rank_list) + return out + + if computation == "avg" and raf.build.with_nccl() < 21000: + pytest.skip("avg is not supported in NCCL < 2.10") + + model = TestModel() + total_rank, rank, local_rank = get_dist_comm_info(verbose=True) + device = f"cuda({local_rank})" + n_ones = np.ones(shape=(4, 4), dtype="float32") + n_x = n_ones * (rank + 1) + n_y = -n_ones * (rank + 1) + m_x, m_y = raf.array(n_x, device=device), raf.array(n_y, device=device) + model.to(device=device) + m_out = run_model(model, [m_x, m_y], device) + for group in rank_list: + if rank in group: + ones = np.ones(shape=(4, 4), dtype="float32") + if computation == "sum": + even_out = ones * sum(np.array(group) + 1) + odd_out = -even_out + elif computation == "prod": + even_out = ones * np.prod([(temp_rank + 1) for temp_rank in np.array(group)]) + odd_out = even_out + elif computation == "min": + even_out = ones * min(np.array(group) + 1) + odd_out = -ones * max(np.array(group) + 1) + elif computation == "max": + even_out = ones * max(np.array(group) + 1) + odd_out = -ones * min(np.array(group) + 1) + elif computation == "avg": + even_out = ones * sum(np.array(group) + 1) + even_out = even_out / total_rank + odd_out = -even_out + if rank % 2 == 0: + check(m_out, even_out) + else: + check(m_out, odd_out) + + @pytest.mark.skipif(skip_dist_test(min_rank_num=2, require_exact_rank=True), reason=SKIP_REASON) def test_send_recv(): shape = [2, 2] From c8ddbc93f8382df62315aa59720275c1f8054bc1 Mon Sep 17 00:00:00 2001 From: AIREMetaBot <100344401+aire-meta-bot@users.noreply.github.com> Date: Tue, 7 Jun 2022 05:30:11 +0800 Subject: [PATCH 37/37] [TVM] Update Submodule 2022-06-06-13-15-38 (#65) * [TVM] Update Submodule * fix alignment * checker * fix test * fix test * fix Co-authored-by: SubmoduleUpdaterBot Co-authored-by: Cody Yu --- 3rdparty/tvm | 2 +- include/raf/device.h | 3 ++- src/device_api/cuda/cuda.cc | 5 ----- src/impl/memory_pool.cc | 9 +++++++++ src/pass/manifest_alloc.cc | 4 ++-- tests/cpp/test_memory_pool.cc | 6 +++--- tests/python/pass/test_pass_manifest_alloc.py | 12 ++++++------ 7 files changed, 23 insertions(+), 18 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index c6415d14..609d6af1 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit c6415d14928d1e09f4bd3105c7a5ddf87f92166b +Subproject commit 609d6af17605d657909549e908876f4335206bd6 diff --git a/include/raf/device.h b/include/raf/device.h index 6edc881c..e7946434 100644 --- a/include/raf/device.h +++ b/include/raf/device.h @@ -11,6 +11,7 @@ #include #include "dlpack/dlpack.h" #include "tvm/runtime/c_runtime_api.h" +#include "tvm/runtime/device_api.h" #include "tvm/runtime/ndarray.h" #include "tvm/support/with.h" #include "./enum_base.h" @@ -20,7 +21,7 @@ namespace raf { using namespace raf::ir; -constexpr int64_t kDefaultMemoryAlignment = 64; +constexpr int64_t kDefaultMemoryAlignment = tvm::runtime::kAllocAlignment; class DTypeCode final : public EnumBase { public: diff --git a/src/device_api/cuda/cuda.cc b/src/device_api/cuda/cuda.cc index b727b65a..74f5fec9 100644 --- a/src/device_api/cuda/cuda.cc +++ b/src/device_api/cuda/cuda.cc @@ -37,8 +37,6 @@ class CUDADeviceAPI final : public DeviceAPI { void* AllocMemory(int64_t nbytes, int64_t alignment) override { CUDA_CALL(cudaSetDevice(device_id_)); void* ptr = nullptr; - // TODO(@junrushao1994): make sure it is correct - CHECK_EQ(512 % alignment, 0); CUDA_CALL(cudaMalloc(&ptr, nbytes)); return ptr; } @@ -78,9 +76,6 @@ class CUDADeviceAPI final : public DeviceAPI { static auto cuda_pool = GetCUDAMemoryPool(device_id_); void* ptr = nullptr; - // TODO(@junrushao1994): make sure it is correct - CHECK_EQ(512 % alignment, 0); - try { CUDA_CALL( cudaMallocFromPoolAsync(&ptr, nbytes, cuda_pool, static_cast(stream))); diff --git a/src/impl/memory_pool.cc b/src/impl/memory_pool.cc index 1a99ea33..2d9cf199 100644 --- a/src/impl/memory_pool.cc +++ b/src/impl/memory_pool.cc @@ -75,6 +75,12 @@ class MemoryPoolManager { PerDeviceStore reg; }; +inline void CheckAlignment(int64_t alignment) { + CHECK_EQ(alignment % kDefaultMemoryAlignment, 0U) + << "Requested memory with alignment " << alignment << " is not aligned to " + << kDefaultMemoryAlignment; +} + int64_t Memory::GetAllocBytes(const Device& dev, int64_t nbytes) { MemoryPoolManager* mgr = MemoryPoolManager::Get(); return mgr->GetPool(dev, "")->GetAllocBytes(nbytes); @@ -82,12 +88,14 @@ int64_t Memory::GetAllocBytes(const Device& dev, int64_t nbytes) { std::shared_ptr Memory::Alloc(const Device& dev, int64_t nbytes, int64_t alignment) { MemoryPoolManager* mgr = MemoryPoolManager::Get(); + CheckAlignment(alignment); return mgr->GetPool(dev, "")->Alloc(nbytes, alignment); } std::shared_ptr Memory::AllocAsync(const Device& dev, int64_t nbytes, void* stream, int64_t alignment) { MemoryPoolManager* mgr = MemoryPoolManager::Get(); + CheckAlignment(alignment); return mgr->GetPool(dev, "")->AllocAsync(nbytes, stream, alignment); } @@ -95,6 +103,7 @@ std::vector > Memory::AllocBatch(const Device& dev, const std::vector& nbytes, int64_t alignment) { MemoryPoolManager* mgr = MemoryPoolManager::Get(); + CheckAlignment(alignment); return mgr->GetPool(dev, "")->AllocBatch(nbytes, alignment); } diff --git a/src/pass/manifest_alloc.cc b/src/pass/manifest_alloc.cc index 7160ca2e..9c210816 100644 --- a/src/pass/manifest_alloc.cc +++ b/src/pass/manifest_alloc.cc @@ -217,8 +217,8 @@ class ManifestAllocMutator : public ExprMutator { private: Expr ComputeAlignment(DataType dtype) { int64_t align = dtype.bits() / 8 * dtype.lanes(); - if (align < 64) { - align = 64; + if (align < kDefaultMemoryAlignment) { + align = kDefaultMemoryAlignment; } return MakeConstant(ScalarValue::make(align)); } diff --git a/tests/cpp/test_memory_pool.cc b/tests/cpp/test_memory_pool.cc index cf2eda84..13be9bf7 100644 --- a/tests/cpp/test_memory_pool.cc +++ b/tests/cpp/test_memory_pool.cc @@ -23,7 +23,7 @@ TEST(NoPool, CPU) { ASSERT_EQ(result->data, nullptr); } for (int memory : {11, 19, 2019, 1024124}) { - for (int align : {16, (int)kDefaultMemoryAlignment, 512, 1024, 4096}) { + for (int align : {(int)kDefaultMemoryAlignment, 512, 1024, 4096}) { std::shared_ptr result = Memory::Alloc(dev, memory, align); ASSERT_EQ(result.use_count(), 1); int64_t address = (int64_t)result->data; @@ -42,7 +42,7 @@ TEST(PageUnitPool, CPU) { ASSERT_EQ(result->data, nullptr); } for (int memory : {11, 19, 2019, 1024124}) { - for (int align : {16, (int)kDefaultMemoryAlignment, 512, 1024, 4096}) { + for (int align : {(int)kDefaultMemoryAlignment, 512, 1024, 4096}) { std::shared_ptr result = Memory::Alloc(dev, memory, align); ASSERT_EQ(result.use_count(), 2); int64_t address = (int64_t)result->data; @@ -52,7 +52,7 @@ TEST(PageUnitPool, CPU) { auto pool_size = Memory::GetPoolSize(dev); ASSERT_EQ(pool_size.first, 0); // No chunk is used. - std::shared_ptr result = Memory::Alloc(dev, 4096, 64); + std::shared_ptr result = Memory::Alloc(dev, 4096, kDefaultMemoryAlignment); pool_size = Memory::GetPoolSize(dev); auto used_size = pool_size.first * 1048576.0; auto abs_diff = (used_size > 4096) ? used_size - 4096 : 4096 - used_size; diff --git a/tests/python/pass/test_pass_manifest_alloc.py b/tests/python/pass/test_pass_manifest_alloc.py index 6a4611ea..39bf5e31 100644 --- a/tests/python/pass/test_pass_manifest_alloc.py +++ b/tests/python/pass/test_pass_manifest_alloc.py @@ -69,9 +69,9 @@ def forward(self, x): "\n".join(text.splitlines()[:-4]) == """#[version = "0.0.5"] fn (%x: Tensor[(2, 2), float32]) -> Tensor[(meta[tir.Div][0], 2), int32] { - let %x_0 = raf.op.vm.alloc_storage(int64(32), int64(64), int32(1), int32(0), str"int32"); + let %x_0 = raf.op.vm.alloc_storage(int64(32), int64(128), int32(1), int32(0), str"int32"); let %x_1 = raf.op.vm.alloc_tensor(%x_0, [4, 2], str"int32", [4, 2]); - let %x_2 = raf.op.vm.alloc_storage(int64(16), int64(64), int32(1), int32(0), str"int64"); + let %x_2 = raf.op.vm.alloc_storage(int64(16), int64(128), int32(1), int32(0), str"int64"); let %x_3 = raf.op.vm.alloc_tensor(%x_2, [2], str"int64", [2]); let %x_4 = raf.op.upper_bound.argwhere; let %x_5 = (%x,); @@ -84,12 +84,12 @@ def forward(self, x): let %x_11 = %x_10.1; let %x_12 = %x_11.0; let %x_13 = %x_11.1; - let %x_14 = raf.op.vm.alloc_storage(%x_13, int64(64), int32(1), int32(0), str"int32"); + let %x_14 = raf.op.vm.alloc_storage(%x_13, int64(128), int32(1), int32(0), str"int32"); let %x_15 = raf.op.vm.alloc_tensor(%x_14, %x_12, str"int32", %x_12); let %x_16 = %x_10.2; let %x_17 = %x_16.0; let %x_18 = %x_16.1; - let %x_19 = raf.op.vm.alloc_storage(%x_18, int64(64), int32(1), int32(0), str"int64"); + let %x_19 = raf.op.vm.alloc_storage(%x_18, int64(128), int32(1), int32(0), str"int64"); let %x_20 = raf.op.vm.alloc_tensor(%x_19, %x_17, str"int64", %x_17); let %x_21 = (%x_15, %x_20); let %x_22 = raf.op.vm.invoke_op(%x_9, %x_8, %x_21); @@ -108,7 +108,7 @@ def forward(self, x): let %x_28 = %x_27.1; let %x_29 = %x_28.0; let %x_30 = %x_28.1; - let %x_31 = raf.op.vm.alloc_storage(%x_30, int64(64), int32(1), int32(0), str"int32"); + let %x_31 = raf.op.vm.alloc_storage(%x_30, int64(128), int32(1), int32(0), str"int32"); let %x_32 = raf.op.vm.alloc_tensor(%x_31, %x_29, str"int32", %x_29); let %x_33 = %x_27.0; let %x_34 = (%x_32,); @@ -171,7 +171,7 @@ def forward(self): # ManifestAlloc should allocate the tensor on the specific device # for init ops and memory related ops. In this case, we expect int32(2) and int32(1), # which mean cuda(1). - assert 'raf.op.vm.alloc_storage(int64(100), int64(64), int32(2), int32(1), str"int32")' in text + assert 'raf.op.vm.alloc_storage(int64(100), int64(128), int32(2), int32(1), str"int32")' in text if __name__ == "__main__":