From b58c83c4ee40d8cf30a9d063addc83497aca80e9 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Thu, 18 Nov 2021 17:26:50 -0800 Subject: [PATCH 1/2] Implement WithFields for Relay exprs --- include/tvm/relay/adt.h | 39 ++++++ include/tvm/relay/expr.h | 165 ++++++++++++++++++++++++ include/tvm/relay/function.h | 29 +++++ src/relay/ir/adt.cc | 46 +++++++ src/relay/ir/expr.cc | 166 +++++++++++++++++++++++++ src/relay/ir/expr_functor.cc | 154 ++++++++--------------- src/relay/ir/function.cc | 51 ++++++++ src/relay/transforms/device_planner.cc | 49 +++----- src/relay/transforms/partial_eval.cc | 4 +- 9 files changed, 570 insertions(+), 133 deletions(-) diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index b5dcab5e0bfc..b1d4d5975cb8 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -260,8 +260,25 @@ class Clause : public ObjectRef { TVM_DLL explicit Clause(Pattern lhs, Expr rhs); TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ClauseNode); }; +/*! + * \brief Returns the clause with given properties. A null property denotes 'no change'. + * Returns clause if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param clause The clause to copy. + * \param opt_lhs The (optional) lhs for the copied clause. If none, ret_clause->lhs = clause->lhs. + * \param opt_rhs The (optional) rhs for the copied clause. If none, + * ret_clause->rhs = clause->rhs. + * \return If all + * properties are null or the same as the property in the input clause (i.e., opt_lhs is null or + * opt_lhs.value() == clause->lhs, etc.), then we return clause. Otherwise, we return a copy of + * clause with the different fields overwritten. (i.e., if opt_lhs.value() != clause->lhs, then + * ret_clause->lhs = opt_lhs.value()). + */ +Clause WithFields(Clause clause, Optional opt_lhs = Optional(), + Optional opt_rhs = Optional()); + /*! \brief ADT pattern matching exression. */ class Match; /*! \brief Match container node. */ @@ -315,8 +332,30 @@ class Match : public Expr { TVM_DLL Match(Expr data, tvm::Array clauses, bool complete = true, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(MatchNode); }; +/*! + * \brief Returns the match with given properties. A null property denotes 'no change'. + * Returns match if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param match The match to copy. + * \param opt_data The (optional) data for the copied match. If none, ret_match->data = match->data. + * \param opt_clauses The (optional) clauses for the copied match. If none, ret_match->clauses = + * match->clauses. + * \param opt_complete The (optional) complete for the copied match. If none, ret_match->complete = + * match->complete. + * \param opt_span The (optional) span for the copied match. If none, ret_match->span = match->span. + * \return If all properties are null or the same as the + * property in the input match (i.e., opt_clauses is null or opt_clauses.value() == match->clauses, + * etc.), then we return match. Otherwise, we return a copy of match with the different fields + * overwritten. (i.e., if opt_clauses.value() != match->clauses, then ret_match->clauses = + * opt_clauses.value()). + */ +Match WithFields(Match match, Optional opt_data = Optional(), + Optional> opt_clauses = Optional>(), + Optional opt_complete = Optional(), + Optional opt_span = Optional()); + } // namespace relay } // namespace tvm diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 8077bbff14c0..f57b2d1a1952 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -230,8 +230,26 @@ class Var : public Expr { TVM_DLL Var(Id vid, Type type_annotation, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); }; +/*! + * \brief Returns the var with given properties. A null property denotes 'no change'. + * Returns var if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param var The var to copy. + * \param opt_vid The (optional) vid for the copied var. If none, ret_var->vid = var->vid. + * \param opt_type_annotation The (optional) type_annotation for the copied var. If none, + * ret_var->type_annotation = var->type_annotation. + * \param opt_span The (optional) span for the copied var. If none, ret_var->span = var->span. + * \return If all properties are null or the same as the property in the input var + * (i.e., opt_vid is null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise, + * we return a copy of call with the different fields overwritten. (i.e., if + * opt_vid.value() != var->vid, then ret_var->vid = opt_.value()). + */ +Var WithFields(Var var, Optional opt_vid = Optional(), + Optional opt_type_annotation = Optional(), + Optional opt_span = Optional()); + /*! * \brief Call corresponds to operator invocation. * Corresponds to the operator in computational graph terminology. @@ -331,8 +349,31 @@ class Call : public Expr { Array type_args = Array(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode); }; +/*! + * \brief Returns the call with given properties. A null property denotes 'no change'. + * Returns call if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param call The call to copy. + * \param opt_op The (optional) op for the copied call. If none, ret_call->op = call->op. + * \param opt_args The (optional) args for the copied call. If none, ret_call->args = call->args. + * \param opt_attrs The (optional) attrs for the copied call. If none, ret_call->attrs = + * call->attrs. + * \param opt_type_args The (optional) type args for the copied call. If none, + * ret_call->type_args = call->type_args. + * \param opt_span The (optional) span for the copied call. If none, ret_call->span = call->span. + * \return If all properties are null or the same as the property in the input call + * (i.e., opt_op is null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we + * return a copy of call with the different fields overwritten. (i.e., if opt_op.value() != + * call->op, then ret_call->op = opt_op.value()). + */ +Call WithFields(Call call, Optional opt_op = Optional(), + Optional> opt_args = Optional>(), + Optional opt_attrs = Optional(), + Optional> opt_type_args = Optional>(), + Optional opt_span = Optional()); + /*! * \brief Let binding that binds a local var and optionally a type annotation. * @@ -405,8 +446,27 @@ class Let : public Expr { TVM_DLL Let(Var var, Expr value, Expr body, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(LetNode); }; +/*! + * \brief Returns the let with given properties. A null property denotes 'no change'. + * Returns let if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param let The let to copy. + * \param opt_var The (optional) var for the copied let. If none, ret_let->op = let->op. + * \param opt_value The (optional) value for the copied let. If none, ret_let->args = let->args. + * \param opt_body The (optional) body for the copied let. If none, ret_let->attrs = let->attrs. + * \param opt_span The (optional) span for the copied let. If none, ret_let->span = let->span. + * \return If all properties are null or the same as the property in the input let (i.e., opt_var is + * null or opt_var.value() == let->var, etc.), then we return let. Otherwise, we return a copy of + * let with the different fields overwritten. (i.e., if opt_var.value() != let->var, then + * ret_let->var = opt_var.value()). + */ +Let WithFields(Let let, Optional opt_var = Optional(), + Optional opt_value = Optional(), + Optional opt_body = Optional(), + Optional opt_span = Optional()); + /*! * \brief Condition expression * @@ -466,8 +526,32 @@ class If : public Expr { TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(IfNode); }; +/*! + * \brief Returns the if_expr with given properties. A null property denotes 'no change'. + * Returns if_expr if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param if_expr The if expression to copy. + * \param opt_cond The (optional) cond for the copied if_expr. If none, ret_if->cond = + * if_expr->cond. + * \param opt_true_branch The (optional) true_branch for the copied if_expr. If none, + * ret_if->true_branch = ret_if->false_branch. + * \param opt_false_branch The (optional) false_branch + * for the copied if_expr. If none, ret_if->false_branch = if_expr->false_branch. + * \param opt_span + * The (optional) span for the copied if_expr. If none, ret_if->span = if_expr->span. + * \return If all + * properties are null or the same as the property in the input if_expr (i.e., opt_cond is null or + * opt_cond.value() == if_expr->cond, etc.), then we return if_expr. Otherwise, we return a copy of + * if_expr with the different fields overwritten. (i.e., if opt_cond.value() != if_expr->cond, then + * ret_if->cond = opt_cond.value()). + */ +If WithFields(If if_expr, Optional opt_cond = Optional(), + Optional opt_true_branch = Optional(), + Optional opt_false_branch = Optional(), + Optional opt_span = Optional()); + /*! \brief Get index-th field out of a tuple. */ class TupleGetItem; class TupleGetItemNode : public ExprNode { @@ -508,8 +592,30 @@ class TupleGetItem : public Expr { TVM_DLL TupleGetItem(Expr tuple, int index, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleGetItemNode); }; +/*! + * \brief Returns the tuple_get_item with given properties. A null property denotes 'no change'. + * Returns if_expr if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param tuple_get_item The tuple_get_item to copy. + * \param opt_tuple The (optional) tuple for the copied tuple_get_item. If none, + * ret_tuple_get_item->tuple = tuple_get_item->tuple. + * \param opt_index The (optional) index for the copied tuple_get_item. If none, + * ret_tuple_get_item->index = tuple_get_item->index. + * \param + * opt_span The (optional) span for the copied tuple_get_item. If none, + * ret_tuple_get_item->span = tuple_get_item->span. + * \return If all properties are null or the same as the property in the input tuple_get_item + * (i.e., opt_tuple is null or opt_tuple.value() == tuple_get_item->tuple, etc.), then we return + * tuple_get_item. Otherwise, we return a copy of tuple_get_item with the different fields + * overwritten. (i.e., if opt_tuple.value() != tuple_get_item->tuple, then + * ret_tuple_get_item->tuple = opt_tuple.value()). + */ +TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), + Optional opt_index = Optional(), + Optional opt_span = Optional()); + /*! \brief Create a new Reference out of initial value. */ class RefCreate; class RefCreateNode : public ExprNode { @@ -547,8 +653,27 @@ class RefCreate : public Expr { TVM_DLL explicit RefCreate(Expr value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(RefCreateNode); }; +/*! + * \brief Returns the ref create with given properties. A null property denotes 'no change'. + * Returns ref_create if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + * \param ref_create The ref_create to copy. + * \param opt_value The (optional) value for the copied ref_create. If none, + * ret_ref_create->value = ref_create->value. + * \param opt_span The (optional) span for the copied ref_create. If none, + * ret_ref_create->span = ref_create->span. + * \return If all properties are null or the same as the property in the input ref_create + * (i.e., opt_value is null or opt_value.value() == ref_create->value, etc.), then we return + * ref_create. Otherwise, we return a copy of ref_create with the different fields overwritten. + * (i.e., if opt_value.value() != ref_create->value, then + * ret_ref_create->value = opt_value.value()). + */ +RefCreate WithFields(RefCreate ref_create, Optional opt_value = Optional(), + Optional opt_span = Optional()); + /*! \brief Get value out of Reference. */ class RefRead; class RefReadNode : public ExprNode { @@ -586,7 +711,26 @@ class RefRead : public Expr { TVM_DLL explicit RefRead(Expr ref, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(RefReadNode); }; + +/*! + * \brief Returns the ref read with given properties. A null property denotes 'no change'. + * Returns ref_read if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param ref_read The ref_read to copy. + * \param opt_ref The (optional) ref for the copied ref_read. If none, ret_ref_read->ref = + * ref_read->ref. + * \param opt_span + * The (optional) span for the copied ref_read. If none, ret_ref_read->span = ref_read->span. + * \return If all properties are null or the same as the property in the input ref_read + * (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return ref_read. + * Otherwise, we return a copy of ref_read with the different fields overwritten. + * (i.e., if opt_ref.value() != ref_read->ref, then + * ret_ref_read->ref = opt_ref.value()). + */ +RefRead WithFields(RefRead ref_read, Optional opt_ref = Optional(), + Optional opt_span = Optional()); + /*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ class RefWrite; class RefWriteNode : public ExprNode { @@ -629,8 +773,29 @@ class RefWrite : public Expr { TVM_DLL RefWrite(Expr ref, Expr value, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(RefWriteNode); }; +/*! + * \brief Returns the ref write with given properties. A null property denotes 'no change'. + * Returns ref_write if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param ref_write The ref_write to copy. + * \param opt_ref The (optional) ref for the copied ref_write. If none, + * ret_ref_write->ref = ref_write->ref. + * \param opt_value The (optional) value for the copied ref_write. If none, + * ret_ref_write->value = ref_write->value. + * \param opt_span + * The (optional) span for the copied ref_write. If none, ret_ref_write->span = ref_write->span. + * \return If all properties are null or the same as the property in the input ref_write + * (i.e., opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write. + * Otherwise, we return a copy of ref_write with the different fields overwritten. + * (i.e., if ref_write.value() != ref_write->ref, then + * ret_ref_write->ref = opt_ref.value()). + */ +RefWrite WithFields(RefWrite ref_write, Optional opt_ref = Optional(), + Optional opt_value = Optional(), + Optional opt_span = Optional()); + /*! * \brief Base class of the temporary expression. * diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 9170bc53ea02..16351c94f821 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -119,6 +119,35 @@ class Function : public BaseFunc { TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); }; +/*! + * \brief Returns the function with given properties. A null property denotes 'no change'. + * Returns function if all properties are unchanged. Otherwise, returns a copy with the new fields. + * \param function The function to copy. + * \param opt_params The (optional) params for the copied function. If none, + * ret_function->params = function->params. + * \param opt_body The (optional) body for the copied function. If none, + * ret_function->body = function->body. + * \param opt_ret_type The (optional) return type for the copied function. If none, + * ret_function->ret_type = function->ret_type. + * \param opt_ty_params The (optional) type params for the copied function. If none, + * ret_function->type_params = function->type_params. + * \param opt_attrs + * The (optional) attributes for the copied function. If none, + * ret_function->attrs = function->attrs. + * \param opt_span The (optional) span for the copied function. If none, + * ret_function->span = function->span. + * \return If all properties are null or the same as the property in the input function + * (i.e., opt_params is null or opt_params.value() == function->params, etc.), then we return + * function. Otherwise, we return a copy of function with the different fields overwritten. (i.e., + * if opt_params.value() != function->params, then ret_function->params = opt_params.value()). + */ +Function WithFields(Function function, Optional> opt_params = Optional>(), + Optional opt_body = Optional(), + Optional opt_ret_type = Optional(), + Optional> opt_ty_params = Optional>(), + Optional opt_attrs = Optional(), + Optional opt_span = Optional()); + /*! * \brief namespace of the attributes that can be attached to a relay::Function. */ diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index ba9743cc35bf..c2b8fd641d03 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -104,6 +104,20 @@ Clause::Clause(Pattern lhs, Expr rhs) { data_ = std::move(n); } +Clause WithFields(Clause clause, Optional opt_lhs, Optional opt_rhs) { + Pattern lhs = opt_lhs.value_or(clause->lhs); + Expr rhs = opt_rhs.value_or(clause->rhs); + + bool unchanged = lhs.same_as(clause->lhs) && rhs.same_as(clause->rhs); + + if (!unchanged) { + ClauseNode* cow_clause_node = clause.CopyOnWrite(); + cow_clause_node->lhs = lhs; + cow_clause_node->rhs = rhs; + } + return std::move(clause); +} + TVM_REGISTER_NODE_TYPE(ClauseNode); TVM_REGISTER_GLOBAL("relay.ir.Clause").set_body_typed([](Pattern lhs, Expr rhs) { @@ -125,6 +139,38 @@ Match::Match(Expr data, tvm::Array clauses, bool complete, Span span) { data_ = std::move(n); } +Match WithFields(Match match, Optional opt_data, Optional> opt_clauses, + Optional opt_complete, Optional opt_span) { + Expr data = opt_data.value_or(match->data); + Array clauses = opt_clauses.value_or(match->clauses); + Bool complete = opt_complete.value_or(Bool(match->complete)); + Span span = opt_span.value_or(match->span); + + bool unchanged = + data.same_as(match->data) && (complete == match->complete) && span.same_as(match->span); + + // Check that all clauses are unchanged + if (unchanged) { + bool all_clauses_unchanged = true; + if (clauses.size() == match->clauses.size()) { + for (size_t i = 0; i < clauses.size(); i++) { + all_clauses_unchanged &= clauses[i].same_as(match->clauses[i]); + } + } else { + all_clauses_unchanged = false; + } + unchanged &= all_clauses_unchanged; + } + if (!unchanged) { + MatchNode* cow_match_node = match.CopyOnWrite(); + cow_match_node->data = data; + cow_match_node->clauses = clauses; + cow_match_node->complete = complete; + cow_match_node->span = span; + } + return std::move(match); +} + TVM_REGISTER_NODE_TYPE(MatchNode); TVM_REGISTER_GLOBAL("relay.ir.Match") diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 59e8c9ee9d0c..8998f4e1573d 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -112,6 +112,24 @@ Var::Var(Id vid, Type type_annotation, Span span) { data_ = std::move(n); } +Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation, + Optional opt_span) { + Id vid = opt_vid.value_or(var->vid); + Type type_annotation = opt_type_annotation.value_or(var->type_annotation); + Span span = opt_span.value_or(var->span); + + bool unchanged = vid.same_as(var->vid) && type_annotation.same_as(var->type_annotation) && + span.same_as(var->span); + + if (!unchanged) { + VarNode* cow_var_node = var.CopyOnWrite(); + cow_var_node->vid = vid; + cow_var_node->type_annotation = type_annotation; + cow_var_node->span = span; + } + return std::move(var); +} + TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type type_annotation) { @@ -139,6 +157,55 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span s data_ = std::move(n); } +Call WithFields(Call call, Optional opt_op, Optional> opt_args, + Optional opt_attrs, Optional> opt_type_args, + Optional opt_span) { + Expr op = opt_op.value_or(call->op); + Array args = opt_args.value_or(call->args); + Attrs attrs = opt_attrs.value_or(call->attrs); + Array type_args = opt_type_args.value_or(call->type_args); + Span span = opt_span.value_or(call->span); + + bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && span.same_as(call->span); + + // Check that the args are unchanged + if (unchanged) { + bool all_args_unchanged = true; + if (args.size() == call->args.size()) { + for (size_t i = 0; i < args.size(); i++) { + all_args_unchanged &= args[i].same_as(call->args[i]); + } + } else { + all_args_unchanged = false; + } + unchanged &= all_args_unchanged; + } + + // Check that the type_args are unchanged + if (unchanged) { + bool all_type_args_unchanged = true; + if (type_args.size() == call->type_args.size()) { + for (size_t i = 0; i < type_args.size(); i++) { + all_type_args_unchanged &= type_args[i].same_as(call->type_args[i]); + } + } else { + all_type_args_unchanged = false; + } + + unchanged &= all_type_args_unchanged; + } + + if (!unchanged) { + CallNode* cow_call_node = call.CopyOnWrite(); + cow_call_node->op = op; + cow_call_node->args = args; + cow_call_node->attrs = attrs; + cow_call_node->type_args = type_args; + cow_call_node->span = span; + } + return std::move(call); +} + TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_GLOBAL("relay.ir.Call") @@ -162,6 +229,26 @@ Let::Let(Var var, Expr value, Expr body, Span span) { data_ = std::move(n); } +Let WithFields(Let let, Optional opt_var, Optional opt_value, Optional opt_body, + Optional opt_span) { + Var var = opt_var.value_or(let->var); + Expr value = opt_value.value_or(let->value); + Expr body = opt_body.value_or(let->body); + Span span = opt_span.value_or(let->span); + + bool unchanged = var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body) && + span.same_as(let->span); + + if (!unchanged) { + LetNode* cow_let_node = let.CopyOnWrite(); + cow_let_node->var = var; + cow_let_node->value = value; + cow_let_node->body = body; + cow_let_node->span = span; + } + return std::move(let); +} + TVM_REGISTER_NODE_TYPE(LetNode); TVM_REGISTER_GLOBAL("relay.ir.Let").set_body_typed([](Var var, Expr value, Expr body) { @@ -183,6 +270,25 @@ If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { data_ = std::move(n); } +If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branch, + Optional opt_false_branch, Optional opt_span) { + Expr cond = opt_cond.value_or(if_expr->cond); + Expr true_branch = opt_true_branch.value_or(if_expr->true_branch); + Expr false_branch = opt_false_branch.value_or(if_expr->false_branch); + Span span = opt_span.value_or(if_expr->span); + + bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) && + false_branch.same_as(if_expr->false_branch) && span.same_as(if_expr->span); + + if (!unchanged) { + IfNode* cow_if_node = if_expr.CopyOnWrite(); + cow_if_node->cond = cond; + cow_if_node->true_branch = true_branch; + cow_if_node->false_branch = false_branch; + } + return std::move(if_expr); +} + TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_GLOBAL("relay.ir.If") @@ -205,6 +311,23 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { data_ = std::move(n); } +TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, + Optional opt_index, Optional opt_span) { + Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); + Integer index = opt_index.value_or(tuple_get_item->index); + Span span = opt_span.value_or(tuple_get_item->span); + + bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && + span.same_as(tuple_get_item->span); + if (!unchanged) { + TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite(); + cow_tuple_get_item_node->tuple = tuple; + cow_tuple_get_item_node->index = index; + cow_tuple_get_item_node->span = span; + } + return std::move(tuple_get_item); +} + TVM_REGISTER_NODE_TYPE(TupleGetItemNode); TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) { @@ -224,6 +347,19 @@ RefCreate::RefCreate(Expr value, Span span) { data_ = std::move(n); } +RefCreate WithFields(RefCreate ref_create, Optional opt_value, Optional opt_span) { + Expr value = opt_value.value_or(ref_create->value); + Span span = opt_span.value_or(ref_create->span); + + bool unchanged = value.same_as(ref_create->value) && span.same_as(ref_create->span); + if (!unchanged) { + RefCreateNode* cow_ref_create_node = ref_create.CopyOnWrite(); + cow_ref_create_node->value = value; + cow_ref_create_node->span = span; + } + return std::move(ref_create); +} + TVM_REGISTER_NODE_TYPE(RefCreateNode); TVM_REGISTER_GLOBAL("relay.ir.RefCreate").set_body_typed([](Expr value) { @@ -243,6 +379,19 @@ RefRead::RefRead(Expr ref, Span span) { data_ = std::move(n); } +RefRead WithFields(RefRead ref_read, Optional opt_ref, Optional opt_span) { + Expr ref = opt_ref.value_or(ref_read->ref); + Span span = opt_span.value_or(ref_read->span); + + bool unchanged = ref.same_as(ref_read->ref) && span.same_as(ref_read->span); + if (!unchanged) { + RefReadNode* cow_ref_read_node = ref_read.CopyOnWrite(); + cow_ref_read_node->ref = ref; + cow_ref_read_node->span = span; + } + return std::move(ref_read); +} + TVM_REGISTER_NODE_TYPE(RefReadNode); TVM_REGISTER_GLOBAL("relay.ir.RefRead").set_body_typed([](Expr ref) { return RefRead(ref); }); @@ -261,6 +410,23 @@ RefWrite::RefWrite(Expr ref, Expr value, Span span) { data_ = std::move(n); } +RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional opt_value, + Optional opt_span) { + Expr ref = opt_ref.value_or(ref_write->ref); + Expr value = opt_value.value_or(ref_write->value); + Span span = opt_span.value_or(ref_write->span); + + bool unchanged = ref.same_as(ref_write->ref) && value.same_as(ref_write->value) && + span.same_as(ref_write->span); + if (!unchanged) { + RefWriteNode* cow_ref_write_node = ref_write.CopyOnWrite(); + cow_ref_write_node->ref = ref; + cow_ref_write_node->value = value; + cow_ref_write_node->span = span; + } + return std::move(ref_write); +} + TVM_REGISTER_NODE_TYPE(RefWriteNode); TVM_REGISTER_GLOBAL("relay.ir.RefWrite").set_body_typed([](Expr ref, Expr value) { diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index 08c9b9643caf..a4b37adc3f83 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -25,6 +25,7 @@ * the cost of using functional updates. */ #include +#include #include #include #include @@ -160,15 +161,12 @@ Expr ExprMutator::VisitExpr(const Expr& expr) { } } -Expr ExprMutator::VisitExpr_(const VarNode* op) { - if (op->type_annotation.defined()) { - auto type = this->VisitType(op->type_annotation); - if (!op->type_annotation.same_as(type)) { - return Var(op->vid, type, op->span); - } +Expr ExprMutator::VisitExpr_(const VarNode* var_node) { + Type type_annotation = var_node->type_annotation; + if (var_node->type_annotation.defined()) { + type_annotation = this->VisitType(var_node->type_annotation); } - // default case return self. - return GetRef(op); + return WithFields(GetRef(var_node), std::move(var_node->vid), std::move(type_annotation)); } Expr ExprMutator::VisitExpr_(const ConstantNode* op) { return GetRef(op); } @@ -188,147 +186,101 @@ Expr ExprMutator::VisitExpr_(const TupleNode* tuple_node) { return WithFields(GetRef(tuple_node), std::move(fields)); } -Expr ExprMutator::VisitExpr_(const FunctionNode* op) { +Expr ExprMutator::VisitExpr_(const FunctionNode* func_node) { tvm::Array ty_params; - bool all_ty_params_unchanged = true; - for (auto ty_param : op->type_params) { + for (auto ty_param : func_node->type_params) { TypeVar new_ty_param = Downcast(VisitType(ty_param)); ty_params.push_back(new_ty_param); - all_ty_params_unchanged &= new_ty_param.same_as(ty_param); } tvm::Array params; - bool all_params_unchanged = true; - for (auto param : op->params) { + for (auto param : func_node->params) { Var new_param = Downcast(this->Mutate(param)); params.push_back(new_param); - all_params_unchanged &= param.same_as(new_param); } - auto ret_type = this->VisitType(op->ret_type); - auto body = this->Mutate(op->body); + auto ret_type = this->VisitType(func_node->ret_type); + auto body = this->Mutate(func_node->body); - if (all_ty_params_unchanged && all_params_unchanged && ret_type.same_as(op->ret_type) && - body.same_as(op->body)) { - return GetRef(op); - } else { - return Function(params, body, ret_type, ty_params, op->attrs, op->span); - } + return WithFields(GetRef(func_node), std::move(params), std::move(body), std::move(ret_type), std::move(ty_params)); } Expr ExprMutator::VisitExpr_(const CallNode* call_node) { auto new_op = this->Mutate(call_node->op); - bool unchanged = call_node->op.same_as(new_op); tvm::Array ty_args; + ty_args.reserve(call_node->type_args.size()); + for (auto ty_arg : call_node->type_args) { auto new_ty_arg = this->VisitType(ty_arg); ty_args.push_back(new_ty_arg); - unchanged &= new_ty_arg.same_as(ty_arg); } tvm::Array call_args; + call_args.reserve(call_node->args.size()); for (auto arg : call_node->args) { auto new_arg = this->Mutate(arg); call_args.push_back(new_arg); - unchanged &= new_arg.same_as(arg); } - if (unchanged) { - return GetRef(call_node); - } else { - return Call(new_op, call_args, call_node->attrs, ty_args, call_node->span); - } + return WithFields(GetRef(call_node), std::move(new_op), std::move(call_args), {}, + std::move(ty_args)); } -Expr ExprMutator::VisitExpr_(const LetNode* op) { - Var var = Downcast(this->Mutate(op->var)); - auto value = this->Mutate(op->value); - auto body = this->Mutate(op->body); +Expr ExprMutator::VisitExpr_(const LetNode* let_node) { + Var var = Downcast(this->Mutate(let_node->var)); + auto value = this->Mutate(let_node->value); + auto body = this->Mutate(let_node->body); - if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { - return GetRef(op); - } else { - return Let(var, value, body, op->span); - } + return WithFields(GetRef(let_node), std::move(var), std::move(value), std::move(body)); } -Expr ExprMutator::VisitExpr_(const IfNode* op) { - auto guard = this->Mutate(op->cond); - auto true_b = this->Mutate(op->true_branch); - auto false_b = this->Mutate(op->false_branch); - if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && - op->false_branch.same_as(false_b)) { - return GetRef(op); - } else { - return If(guard, true_b, false_b, op->span); - } +Expr ExprMutator::VisitExpr_(const IfNode* if_node) { + auto cond = this->Mutate(if_node->cond); + auto true_b = this->Mutate(if_node->true_branch); + auto false_b = this->Mutate(if_node->false_branch); + + return WithFields(GetRef(if_node), std::move(cond), std::move(true_b), std::move(false_b)); } Expr ExprMutator::VisitExpr_(const TupleGetItemNode* get_item) { - auto t = this->Mutate(get_item->tuple); - if (get_item->tuple == t) { - return GetRef(get_item); - } else { - return TupleGetItem(t, get_item->index, get_item->span); - } + Expr tuple = this->Mutate(get_item->tuple); + return WithFields(GetRef(get_item), std::move(tuple)); } -Expr ExprMutator::VisitExpr_(const RefCreateNode* op) { - Expr value = this->Mutate(op->value); - if (value.same_as(op->value)) { - return GetRef(op); - } else { - return RefCreate(value, op->span); - } +Expr ExprMutator::VisitExpr_(const RefCreateNode* ref_create) { + Expr value = this->Mutate(ref_create->value); + return WithFields(GetRef(ref_create), std::move(value)); } -Expr ExprMutator::VisitExpr_(const RefReadNode* op) { - Expr ref = this->Mutate(op->ref); - if (ref.same_as(op->ref)) { - return GetRef(op); - } else { - return RefRead(ref, op->span); - } +Expr ExprMutator::VisitExpr_(const RefReadNode* ref_read) { + Expr ref = this->Mutate(ref_read->ref); + return WithFields(GetRef(ref_read), std::move(ref)); } -Expr ExprMutator::VisitExpr_(const RefWriteNode* op) { - Expr ref = this->Mutate(op->ref); - Expr value = this->Mutate(op->value); - if (ref.same_as(op->ref) && value.same_as(op->value)) { - return GetRef(op); - } else { - return RefWrite(ref, value, op->span); - } +Expr ExprMutator::VisitExpr_(const RefWriteNode* ref_write) { + Expr ref = this->Mutate(ref_write->ref); + Expr value = this->Mutate(ref_write->value); + return WithFields(GetRef(ref_write), std::move(ref), std::move(value)); } Expr ExprMutator::VisitExpr_(const ConstructorNode* c) { return GetRef(c); } -Expr ExprMutator::VisitExpr_(const MatchNode* m) { - bool unchanged = true; - std::vector clauses; - for (const Clause& p : m->clauses) { - Clause c = VisitClause(p); - clauses.push_back(c); - unchanged &= c.same_as(p); +Expr ExprMutator::VisitExpr_(const MatchNode* match_node) { + Array clauses; + for (const Clause& p : match_node->clauses) { + clauses.push_back(VisitClause(p)); } - Expr data = Mutate(m->data); - unchanged &= data.same_as(m->data); + Expr data = Mutate(match_node->data); - if (unchanged) { - return GetRef(m); - } - return Match(data, clauses, m->complete, m->span); + return WithFields(GetRef(match_node), std::move(data), std::move(clauses)); } -Clause ExprMutator::VisitClause(const Clause& c) { - Pattern p = VisitPattern(c->lhs); - Expr rhs = Mutate(c->rhs); - if (p.same_as(c->lhs) && rhs.same_as(c->rhs)) { - return c; - } - return Clause(p, rhs); +Clause ExprMutator::VisitClause(const Clause& clause) { + Pattern lhs = VisitPattern(clause->lhs); + Expr rhs = Mutate(clause->rhs); + return WithFields(std::move(clause), std::move(lhs), std::move(rhs)); } Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } @@ -507,9 +459,9 @@ class ExprBinder : public MixedModeMutator, PatternMutator { Pattern VisitPattern(const Pattern& p) final { return PatternMutator::VisitPattern(p); } - Clause VisitClause(const Clause& c) final { - Pattern pat = VisitPattern(c->lhs); - return Clause(pat, VisitExpr(c->rhs)); + Clause VisitClause(const Clause& clause) final { + Pattern lhs = VisitPattern(clause->lhs); + return WithFields(std::move(clause), std::move(lhs), VisitExpr(clause->rhs)); } Var VisitVar(const Var& v) final { diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 83ac55fce085..92c2e7f70c8e 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -40,6 +40,57 @@ Function::Function(tvm::Array params, Expr body, Type ret_type, data_ = std::move(n); } +Function WithFields(Function function, Optional> opt_params, Optional opt_body, + Optional opt_ret_type, Optional> opt_ty_params, + Optional opt_attrs, Optional opt_span) { + Array params = opt_params.value_or(function->params); + Expr body = opt_body.value_or(function->body); + Type ret_type = opt_ret_type.value_or(function->ret_type); + Array ty_params = opt_ty_params.value_or(function->type_params); + DictAttrs attrs = opt_attrs.value_or(function->attrs); + Span span = opt_span.value_or(function->span); + + bool unchanged = body.same_as(function->body) && ret_type.same_as(function->ret_type) && + attrs.same_as(function->attrs) && span.same_as(function->span); + + // Check that all the type params are unchanged + if (unchanged) { + bool all_ty_params_unchanged = true; + if (ty_params.size() == function->type_params.size()) { + for (size_t i = 0; i < ty_params.size(); i++) { + all_ty_params_unchanged &= ty_params[i].same_as(function->type_params[i]); + } + } else { + all_ty_params_unchanged = false; + } + unchanged &= all_ty_params_unchanged; + } + + // Check that all the params are unchanged + if (unchanged) { + bool all_params_unchanged = true; + if (params.size() == function->params.size()) { + for (size_t i = 0; i < params.size(); i++) { + all_params_unchanged &= params[i].same_as(function->params[i]); + } + } else { + all_params_unchanged = false; + } + unchanged &= all_params_unchanged; + } + + if (!unchanged) { + FunctionNode* cow_function_node = function.CopyOnWrite(); + cow_function_node->params = params; + cow_function_node->body = body; + cow_function_node->ret_type = ret_type; + cow_function_node->type_params = ty_params; + cow_function_node->attrs = attrs; + cow_function_node->span = span; + } + return std::move(function); +} + FuncType FunctionNode::func_type_annotation() const { Array param_types; for (auto param : this->params) { diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index afa598b9a782..9296c958fe8c 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -291,10 +291,9 @@ class RewriteOnDevices : public ExprMutator { private: Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { Expr tuple = VisitExpr(tuple_get_item_node->tuple); - // TODO(mbs): Avoid copy. - Expr tuple_get_item = - TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); OnDeviceProps props = GetOnDeviceProps(tuple); + + Expr tuple_get_item = WithFields(GetRef(tuple_get_item_node), std::move(tuple)); if (props.body.defined() && !props.is_fixed) { VLOG(2) << "wrapping tuple get item:" << std::endl << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl @@ -307,9 +306,9 @@ class RewriteOnDevices : public ExprMutator { Expr VisitExpr_(const LetNode* let_node) final { auto expr = GetRef(let_node); - std::vector> bindings; + std::vector> bindings; while (const auto* inner_let_node = expr.as()) { - Expr inner_let = GetRef(inner_let_node); + Let inner_let = GetRef(inner_let_node); Expr value = VisitExpr(inner_let_node->value); OnDeviceProps props = GetOnDeviceProps(value); if (props.body.defined() && !props.is_fixed) { @@ -318,14 +317,13 @@ class RewriteOnDevices : public ExprMutator { << "to be fixed to SEScope " << props.se_scope; value = OnDevice(props.body, props.se_scope, /*is_fixed=*/true); } - bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); + bindings.emplace_back(inner_let, value); expr = inner_let_node->body; } expr = VisitExpr(expr); - // TODO(mbs): Avoid copy. for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { - expr = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), expr, - /*span=*/std::get<2>(*itr)); + expr = WithFields(/*let=*/std::move(std::get<0>(*itr)), /*var = unchanged*/ {}, + /*value=*/std::move(std::get<1>(*itr)), /*body=*/std::move(expr)); } return expr; } @@ -339,9 +337,8 @@ class RewriteOnDevices : public ExprMutator { << "to be fixed to SEScope " << props.se_scope; body = OnDevice(props.body, props.se_scope, /*is_fixed=*/true); } - // TODO(mbs): Avoid copy - return Function(function_node->params, body, function_node->ret_type, - function_node->type_params, function_node->attrs, function_node->span); + return WithFields(GetRef(function_node), std::move(function_node->params), + std::move(body)); } }; @@ -820,9 +817,8 @@ class DeviceCapturer : public ExprMutator { /*expected_se_scope=*/result_se_scope, /*child_se_scope=*/GetSEScope(function_node->body), function_node->body); - // TODO(mbs): Avoid copy - Function func = Function(function_node->params, body, function_node->ret_type, - function_node->type_params, function_node->attrs, function_node->span); + Function func = WithFields(GetRef(function_node), std::move(function_node->params), + std::move(body)); return FunctionOnDevice(func, std::move(param_se_scopes), std::move(result_se_scope)); } @@ -884,9 +880,7 @@ class DeviceCapturer : public ExprMutator { /*child_se_scope=*/GetSEScope(call_node->args[i]), call_node->args[i])); } - // TODO(mbs): Avoid copy - return Call(std::move(op), std::move(args), call_node->attrs, call_node->type_args, - call_node->span); + return WithFields(GetRef(call_node), std::move(op), std::move(args)); } Expr VisitExpr_(const LetNode* let_node) final { @@ -925,37 +919,33 @@ class DeviceCapturer : public ExprMutator { Expr cond = VisitChild(ife, if_node->cond); Expr true_branch = VisitChild(ife, if_node->true_branch); Expr false_branch = VisitChild(ife, if_node->false_branch); - // TODO(mbs): Avoid copy - return If(cond, true_branch, false_branch, if_node->span); + return WithFields(std::move(ife), std::move(cond), std::move(true_branch), + std::move(false_branch)); } Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { auto tuple_get_item = GetRef(tuple_get_item_node); Expr tuple = VisitChild(tuple_get_item, tuple_get_item_node->tuple); - // TODO(mbs): Avoid copy - return TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); + return WithFields(std::move(tuple_get_item), std::move(tuple)); } Expr VisitExpr_(const RefCreateNode* ref_create_node) final { auto ref_create = GetRef(ref_create_node); Expr value = VisitChild(ref_create, ref_create_node->value); - // TODO(mbs): Avoid copy - return RefCreate(value, ref_create_node->span); + return WithFields(std::move(ref_create), std::move(value)); } Expr VisitExpr_(const RefReadNode* ref_read_node) final { auto ref_read = GetRef(ref_read_node); Expr ref = VisitChild(ref_read, ref_read_node->ref); - // TODO(mbs): Avoid copy - return RefRead(ref, ref_read_node->span); + return WithFields(std::move(ref_read), std::move(ref)); } Expr VisitExpr_(const RefWriteNode* ref_write_node) final { auto ref_write = GetRef(ref_write_node); Expr ref = VisitChild(ref_write, ref_write_node->ref); Expr value = VisitChild(ref_write, ref_write_node->value); - // TODO(mbs): Avoid copy - return RefWrite(ref, value, ref_write_node->span); + return WithFields(std::move(ref_write), std::move(ref), std::move(value)); } Expr VisitExpr_(const MatchNode* match_node) final { @@ -968,8 +958,7 @@ class DeviceCapturer : public ExprMutator { Expr rhs = VisitChild(match, clause->rhs); clauses.push_back(Clause(lhs, rhs)); } - // TODO(mbs): Avoid copy - return Match(data, std::move(clauses), match_node->complete, match_node->span); + return WithFields(std::move(match), std::move(data), std::move(clauses)); } SEScope GetSEScope(const Expr& expr) { diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 7388d9f7eb32..8f5e9e146d54 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -615,8 +615,8 @@ class PartialEvaluator : public ExprFunctor value.push_back(ps); expr.push_back(ps->dynamic); } - // Note(@electriclilies): The partial evaluator seems to do some weird stuff with sharing. - // Changing Tuple(expr) to WithFields(op, expr) causes some strange failures. + // Note: The partial evaluator seems to do some weird stuff with sharing. Changing Tuple(expr) + // to WithFields(op, expr) causes failures in the partial evaluator tests. return HasStatic(MkSTuple(value), ll->Push(Tuple(expr))); } From f7d9bf6a00a97d195eaa877a6145e0f96292a614 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 24 Nov 2021 15:08:42 -0800 Subject: [PATCH 2/2] lint --- src/relay/ir/expr_functor.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index a4b37adc3f83..a08de39d0abb 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -203,7 +203,8 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* func_node) { auto ret_type = this->VisitType(func_node->ret_type); auto body = this->Mutate(func_node->body); - return WithFields(GetRef(func_node), std::move(params), std::move(body), std::move(ret_type), std::move(ty_params)); + return WithFields(GetRef(func_node), std::move(params), std::move(body), + std::move(ret_type), std::move(ty_params)); } Expr ExprMutator::VisitExpr_(const CallNode* call_node) {