diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index aa341949b3d6..8077bbff14c0 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -142,8 +142,20 @@ class Tuple : public Expr { TVM_DLL explicit Tuple(tvm::Array fields, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode); }; +/*! + * \brief Returns the tuple with given properties. A null property denotes 'no change'. + * Returns this if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param tuple The tuple to copy + * \param opt_fields The (optional) fields for the copied tuple. If none, ret_tuple->fields = + * tuple->fields. + * \param opt_span The (optional) span for the copied tuple. If none, ret_tuple->span = tuple->span. + */ +Tuple WithFields(Tuple tuple, Optional> opt_fields = Optional>(), + Optional opt_span = Optional(nullptr)); + /*! * \brief Local variables used in the let expression. * diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index c7a81f9f0f03..59e8c9ee9d0c 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -76,6 +76,27 @@ TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array fields, Span span) { return Tuple(fields, span); }); +Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_span) { + Array fields = opt_fields.value_or(tuple->fields); + Span span = opt_span.value_or(tuple->span); + + bool all_fields_unchanged = true; + if (fields.size() == tuple->fields.size()) { + for (size_t i = 0; i < fields.size(); i++) { + all_fields_unchanged &= fields[i].same_as(tuple->fields[i]); + } + } else { + all_fields_unchanged = false; + } + + all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span); + if (!all_fields_unchanged) { + TupleNode* cow_tuple_node = tuple.CopyOnWrite(); + cow_tuple_node->fields = fields; + cow_tuple_node->span = span; + } + return std::move(tuple); +} TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index e9441f1b3e58..08c9b9643caf 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -177,20 +177,15 @@ Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef(op); Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef(op); } -Expr ExprMutator::VisitExpr_(const TupleNode* op) { +Expr ExprMutator::VisitExpr_(const TupleNode* tuple_node) { tvm::Array fields; - bool all_fields_unchanged = true; - for (auto field : op->fields) { + fields.reserve(tuple_node->fields.size()); + + for (auto field : tuple_node->fields) { auto new_field = this->Mutate(field); fields.push_back(new_field); - all_fields_unchanged &= new_field.same_as(field); - } - - if (all_fields_unchanged) { - return GetRef(op); - } else { - return Tuple(fields, op->span); } + return WithFields(GetRef(tuple_node), std::move(fields)); } Expr ExprMutator::VisitExpr_(const FunctionNode* op) { diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index b12e25a425b6..df1a858f8d0b 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -266,11 +266,11 @@ class AnnotateTargetRewriter : public ExprRewriter { virtual std::unique_ptr RewriteVarCall(const Call& post_call) { return nullptr; } - Expr Rewrite_(const TupleNode* op, const Expr& post) override { - auto expr = Downcast(post); + Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override { + auto tuple = Downcast(post); - auto target_n_args = AnnotateArgs(expr->fields); - auto new_expr = Tuple(std::get<1>(target_n_args)); + auto target_n_args = AnnotateArgs(tuple->fields); + auto new_expr = WithFields(std::move(tuple), std::move(std::get<1>(target_n_args))); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); return std::move(new_expr); } @@ -370,13 +370,15 @@ class CallOpsTargetRewriter : public AnnotateTargetRewriter { return new_call; } - Expr Rewrite_(const TupleNode* op, const Expr& post) override { - auto expr = Downcast(post); + Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override { + auto tuple = Downcast(post); Array new_fields; - for (auto f : expr->fields) { + new_fields.reserve(tuple->fields.size()); + + for (auto f : tuple->fields) { new_fields.push_back(InsertCompilerEndAndPropogateTarget(f)); } - return std::move(Tuple(new_fields)); + return WithFields(std::move(tuple), std::move(new_fields)); } Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override { diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index d6ab566a336e..afa598b9a782 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -786,8 +786,7 @@ class DeviceCapturer : public ExprMutator { for (const auto& field : tuple_node->fields) { fields.push_back(VisitChild(tuple, field)); } - // TODO(mbs): Avoid copy - return Tuple(std::move(fields), tuple_node->span); + return WithFields(std::move(tuple), std::move(fields)); } Expr VisitExpr_(const FunctionNode* function_node) final { diff --git a/src/relay/transforms/first_order_gradient.cc b/src/relay/transforms/first_order_gradient.cc index 3419cb670a28..9408d16d87e9 100644 --- a/src/relay/transforms/first_order_gradient.cc +++ b/src/relay/transforms/first_order_gradient.cc @@ -195,11 +195,13 @@ struct FirstOrderReverseAD : ExprFunctor { return ret; } - ADValue VisitExpr_(const TupleNode* op) final { - auto tt = Downcast(op->checked_type()); + ADValue VisitExpr_(const TupleNode* tuple_node) final { + auto tt = Downcast(tuple_node->checked_type()); std::vector ad_fields; - std::vector field_bindings; - for (const auto& f : op->fields) { + Array field_bindings; + field_bindings.reserve(tuple_node->fields.size()); + + for (const auto& f : tuple_node->fields) { ADValue f_ad = VisitExpr(f); if (!dynamic_cast(f_ad.get())) { diag_ctx.EmitFatal(Diagnostic::Error(f->span) @@ -209,7 +211,7 @@ struct FirstOrderReverseAD : ExprFunctor { field_bindings.push_back(f_ad->get().forward); } // reconstruct tuple using let-bound variables to avoid duplication - auto orig = Tuple(field_bindings); + auto orig = WithFields(GetRef(tuple_node), std::move(field_bindings)); orig->checked_type_ = tt; auto ret = std::make_shared(ll, orig, diag_ctx); // for orig = tuple(x1, ..., xn), tuple_grad(x1, ..., xn, G) = [pi(G, 1), ..., pi(G, n)] diff --git a/src/relay/transforms/forward_rewrite.cc b/src/relay/transforms/forward_rewrite.cc index 1212ad7f19be..23c45a90a5e3 100644 --- a/src/relay/transforms/forward_rewrite.cc +++ b/src/relay/transforms/forward_rewrite.cc @@ -113,21 +113,16 @@ class ForwardRewriter : private MixedModeMutator { } } - Expr Rewrite_(const TupleNode* op, const Expr& post) final { + Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) final { tvm::Array fields; - bool all_fields_unchanged = true; - const auto* post_node = post.as(); - for (size_t i = 0; i < op->fields.size(); ++i) { - auto new_field = this->GetTempExpr(op->fields[i], post_node->fields[i]); - fields.push_back(new_field); - all_fields_unchanged &= new_field.same_as(op->fields[i]); - } + fields.reserve(tuple_node->fields.size()); - if (all_fields_unchanged) { - return GetRef(op); - } else { - return Tuple(fields); + const auto* post_tuple_node = post.as(); + for (size_t i = 0; i < tuple_node->fields.size(); ++i) { + fields.push_back(this->GetTempExpr(tuple_node->fields[i], post_tuple_node->fields[i])); } + + return WithFields(GetRef(tuple_node), std::move(fields)); } Expr Rewrite_(const CallNode* call_node, const Expr& post) final { diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 960f56957ebb..247ae33c178a 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -898,14 +898,14 @@ class FuseMutator : private MixedModeMutator { } } - Expr Rewrite_(const TupleNode* tuple, const Expr& post) { - auto* ret_group = gmap_.at(tuple)->FindRoot(); - if (ret_group->root_ref == tuple) { - return ExprMutator::VisitExpr_(tuple); + Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) { + auto* ret_group = gmap_.at(tuple_node)->FindRoot(); + if (ret_group->root_ref == tuple_node) { + return ExprMutator::VisitExpr_(tuple_node); } // This tuple is an intermediate node in the group - Array new_fields = GetNewArguments(tuple->fields, ret_group); - return Tuple(new_fields); + Array new_fields = GetNewArguments(tuple_node->fields, ret_group); + return WithFields(GetRef(tuple_node), std::move(new_fields)); } Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) { diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index acea12fb8560..ddbc6069ebe6 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -84,10 +84,12 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { Function Rewrite(const Function& expr) { return Downcast(Mutate(expr)); } private: - Expr VisitExpr_(const TupleNode* tn) final { + Expr VisitExpr_(const TupleNode* tuple_node) final { LetList& scope = scopes_.back(); Array new_fields; - for (auto field : tn->fields) { + new_fields.reserve(tuple_node->fields.size()); + + for (auto field : tuple_node->fields) { auto new_field = Mutate(field); if (new_field->IsInstance()) { Var const_var("const", Type(nullptr)); @@ -95,7 +97,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { } new_fields.push_back(new_field); } - return Tuple(new_fields); + return WithFields(GetRef(tuple_node), std::move(new_fields)); } void PreVisitLetBlock_(const LetNode* let_node) final { scopes_.emplace_back(); } diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index ccdd9c92cc27..7388d9f7eb32 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -615,6 +615,8 @@ class PartialEvaluator : public ExprFunctor value.push_back(ps); expr.push_back(ps->dynamic); } + // Note(@electriclilies): The partial evaluator seems to do some weird stuff with sharing. + // Changing Tuple(expr) to WithFields(op, expr) causes some strange failures. return HasStatic(MkSTuple(value), ll->Push(Tuple(expr))); } diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 99799fdeb5f0..4a21bc87411b 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -455,18 +455,18 @@ IRModule FlattenTupleOutputs(IRModule module) { // Arguments of annotation ops should be 1 ICHECK_EQ(call->args.size(), 1U); auto annotated_op = Downcast(post)->args[0]; - if (const auto* tn = annotated_op.as()) { + if (const auto* tuple_node = annotated_op.as()) { Array new_fields; + new_fields.reserve(tuple_node->fields.size()); // Here each input of the tuple will be annotated with compiler_ends - for (auto& tn_arg : tn->fields) { + for (auto& tn_arg : tuple_node->fields) { new_fields.push_back((*make_end_op)(tn_arg, target)); } // Return a tuple of compiler_ends in the place of the tuple that was // annotated with a compiler_end. - auto out = Tuple(new_fields); - return std::move(out); + return WithFields(GetRef(tuple_node), std::move(new_fields)); } } return post; diff --git a/src/relay/transforms/split_args.cc b/src/relay/transforms/split_args.cc index eb647ce5e2a5..fbb2d73d1db0 100644 --- a/src/relay/transforms/split_args.cc +++ b/src/relay/transforms/split_args.cc @@ -37,14 +37,14 @@ class ArgumentSplitter : public ExprRewriter { Expr Rewrite_(const CallNode* call, const Expr& post) final { if (max_function_args_ < 0) return post; if (call->op == concat_op_) { - auto op = call->args[0].as(); + auto tuple_node = call->args[0].as(); const auto param = call->attrs.as(); int outputsNum = 1; if (const auto* tuple_type = call->checked_type().as()) { outputsNum = tuple_type->fields.size(); } const int limit = max_function_args_ - outputsNum; - int argsNum = op->fields.size(); + int argsNum = tuple_node->fields.size(); if (argsNum < limit) return post; int splitNum = argsNum / limit; splitNum = (argsNum % limit) ? splitNum + 1 : splitNum; @@ -54,16 +54,18 @@ class ArgumentSplitter : public ExprRewriter { int startIdx = i * limit; int argsCount = std::min(limit, argsNum - startIdx); tvm::Array args; + args.reserve(argsCount); + for (int j = 0; j < argsCount; ++j) { - args.push_back(op->fields[j + startIdx]); + args.push_back(tuple_node->fields[j + startIdx]); } - Tuple tuple(args); - Expr body = MakeConcatenate(tuple, param->axis); + Tuple new_tuple = WithFields(GetRef(tuple_node), std::move(args)); + Expr body = MakeConcatenate(new_tuple, param->axis); splitted[i] = StopFusion(body); } - tvm::Array tupleArgs(splitted); - Tuple tuple(tupleArgs); - return MakeConcatenate(tuple, param->axis); + tvm::Array tuple_args(splitted); + Tuple new_tuple = WithFields(GetRef(tuple_node), std::move(tuple_args)); + return MakeConcatenate(new_tuple, param->axis); } return post; } diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 0814e73ab73d..f958a600551e 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -248,13 +248,14 @@ class Fill : ExprFunctor, private transform::Lexi return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v); } - Expr VisitExpr_(const TupleNode* t, const Var& v) final { - Expr e = GetRef(t); - std::vector fields; - for (const auto& a : t->fields) { + Expr VisitExpr_(const TupleNode* tuple_node, const Var& v) final { + Expr e = GetRef(tuple_node); + Array fields; + fields.reserve(tuple_node->fields.size()); + for (const auto& a : tuple_node->fields) { fields.push_back(VisitExpr(a)); } - return Compound(e, Tuple(fields), v); + return Compound(e, WithFields(GetRef(tuple_node), std::move(fields)), v); } Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final { diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index b7f9cafbc7dc..0f889cd6ff7f 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -210,13 +210,14 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, }); } - Expr VisitExpr_(const TupleNode* op, const MCont& k) final { + Expr VisitExpr_(const TupleNode* tuple_node, const MCont& k) final { tvm::Array fields; + fields.reserve(tuple_node->fields.size()); std::function next; next = [&]() { - return (fields.size() == op->fields.size()) - ? k(Tuple(fields)) - : VisitExpr(op->fields[fields.size()], [&](const Expr& v) { + return (fields.size() == tuple_node->fields.size()) + ? k(WithFields(GetRef(tuple_node), std::move(fields))) + : VisitExpr(tuple_node->fields[fields.size()], [&](const Expr& v) { fields.push_back(v); return next(); }); diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h index 7bfb31a299ad..56affb581fd1 100644 --- a/src/relay/transforms/transform_layout.h +++ b/src/relay/transforms/transform_layout.h @@ -32,6 +32,7 @@ #include #include #include +#include #include #include "infer_layout_utils.h" @@ -293,12 +294,13 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj // NOTE: do not support nested tuple if (new_arg->IsInstance()) { Tuple tuple_new_arg = Downcast(new_arg); - std::vector fields; + Array fields; + fields.reserve(tuple_new_arg->fields.size()); for (auto x : tuple_new_arg->fields) { Expr tmp = push_back_one_arg(x); fields.push_back(tmp); } - normal_new_args.push_back(Tuple(fields)); + normal_new_args.push_back(WithFields(tuple_new_arg, std::move(fields))); } else { Expr tmp = push_back_one_arg(new_arg); normal_new_args.push_back(tmp); @@ -375,12 +377,13 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj for (auto arg : new_call->args) { if (arg->IsInstance()) { // unflatten tuple Tuple tuple_arg = Downcast(arg); - std::vector transformed_tuple_arg; + Array transformed_tuple_arg; + transformed_tuple_arg.reserve(tuple_arg->fields.size()); for (auto arg_item : tuple_arg->fields) { transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt])); pt++; } - transformed_args.push_back(Tuple(transformed_tuple_arg)); + transformed_args.push_back(WithFields(tuple_arg, std::move(transformed_tuple_arg))); } else { transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt])); pt++;