Skip to content

Commit

Permalink
WithFields for Tuples (#9533)
Browse files Browse the repository at this point in the history
  • Loading branch information
electriclilies authored Nov 24, 2021
1 parent 0195afc commit 3c48cad
Show file tree
Hide file tree
Showing 15 changed files with 108 additions and 71 deletions.
12 changes: 12 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,20 @@ class Tuple : public Expr {
TVM_DLL explicit Tuple(tvm::Array<relay::Expr> fields, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
};

/*!
* \brief Returns the tuple with given properties. A null property denotes 'no change'.
* Returns this if all properties are unchanged. Otherwise, returns a copy with the new fields.
* \param tuple The tuple to copy
* \param opt_fields The (optional) fields for the copied tuple. If none, ret_tuple->fields =
* tuple->fields.
* \param opt_span The (optional) span for the copied tuple. If none, ret_tuple->span = tuple->span.
*/
Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields = Optional<Array<Expr>>(),
Optional<Span> opt_span = Optional<Span>(nullptr));

/*!
* \brief Local variables used in the let expression.
*
Expand Down
21 changes: 21 additions & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,27 @@ TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array<relay::Expr> fields, Span span) {
return Tuple(fields, span);
});
Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields, Optional<Span> opt_span) {
Array<Expr> fields = opt_fields.value_or(tuple->fields);
Span span = opt_span.value_or(tuple->span);

bool all_fields_unchanged = true;
if (fields.size() == tuple->fields.size()) {
for (size_t i = 0; i < fields.size(); i++) {
all_fields_unchanged &= fields[i].same_as(tuple->fields[i]);
}
} else {
all_fields_unchanged = false;
}

all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span);
if (!all_fields_unchanged) {
TupleNode* cow_tuple_node = tuple.CopyOnWrite();
cow_tuple_node->fields = fields;
cow_tuple_node->span = span;
}
return std::move(tuple);
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
Expand Down
15 changes: 5 additions & 10 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,15 @@ Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef<Expr>(op);

Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef<Expr>(op); }

Expr ExprMutator::VisitExpr_(const TupleNode* op) {
Expr ExprMutator::VisitExpr_(const TupleNode* tuple_node) {
tvm::Array<Expr> fields;
bool all_fields_unchanged = true;
for (auto field : op->fields) {
fields.reserve(tuple_node->fields.size());

for (auto field : tuple_node->fields) {
auto new_field = this->Mutate(field);
fields.push_back(new_field);
all_fields_unchanged &= new_field.same_as(field);
}

if (all_fields_unchanged) {
return GetRef<Expr>(op);
} else {
return Tuple(fields, op->span);
}
return WithFields(GetRef<Tuple>(tuple_node), std::move(fields));
}

Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
Expand Down
18 changes: 10 additions & 8 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,11 @@ class AnnotateTargetRewriter : public ExprRewriter {

virtual std::unique_ptr<Call> RewriteVarCall(const Call& post_call) { return nullptr; }

Expr Rewrite_(const TupleNode* op, const Expr& post) override {
auto expr = Downcast<Tuple>(post);
Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override {
auto tuple = Downcast<Tuple>(post);

auto target_n_args = AnnotateArgs(expr->fields);
auto new_expr = Tuple(std::get<1>(target_n_args));
auto target_n_args = AnnotateArgs(tuple->fields);
auto new_expr = WithFields(std::move(tuple), std::move(std::get<1>(target_n_args)));
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
}
Expand Down Expand Up @@ -370,13 +370,15 @@ class CallOpsTargetRewriter : public AnnotateTargetRewriter {
return new_call;
}

Expr Rewrite_(const TupleNode* op, const Expr& post) override {
auto expr = Downcast<Tuple>(post);
Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override {
auto tuple = Downcast<Tuple>(post);
Array<Expr> new_fields;
for (auto f : expr->fields) {
new_fields.reserve(tuple->fields.size());

for (auto f : tuple->fields) {
new_fields.push_back(InsertCompilerEndAndPropogateTarget(f));
}
return std::move(Tuple(new_fields));
return WithFields(std::move(tuple), std::move(new_fields));
}

Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override {
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/device_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,7 @@ class DeviceCapturer : public ExprMutator {
for (const auto& field : tuple_node->fields) {
fields.push_back(VisitChild(tuple, field));
}
// TODO(mbs): Avoid copy
return Tuple(std::move(fields), tuple_node->span);
return WithFields(std::move(tuple), std::move(fields));
}

Expr VisitExpr_(const FunctionNode* function_node) final {
Expand Down
12 changes: 7 additions & 5 deletions src/relay/transforms/first_order_gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,13 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
return ret;
}

ADValue VisitExpr_(const TupleNode* op) final {
auto tt = Downcast<TupleType>(op->checked_type());
ADValue VisitExpr_(const TupleNode* tuple_node) final {
auto tt = Downcast<TupleType>(tuple_node->checked_type());
std::vector<ADValue> ad_fields;
std::vector<Expr> field_bindings;
for (const auto& f : op->fields) {
Array<Expr> field_bindings;
field_bindings.reserve(tuple_node->fields.size());

for (const auto& f : tuple_node->fields) {
ADValue f_ad = VisitExpr(f);
if (!dynamic_cast<ADTensor*>(f_ad.get())) {
diag_ctx.EmitFatal(Diagnostic::Error(f->span)
Expand All @@ -209,7 +211,7 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
field_bindings.push_back(f_ad->get<ADTensor>().forward);
}
// reconstruct tuple using let-bound variables to avoid duplication
auto orig = Tuple(field_bindings);
auto orig = WithFields(GetRef<Tuple>(tuple_node), std::move(field_bindings));
orig->checked_type_ = tt;
auto ret = std::make_shared<ADTensor>(ll, orig, diag_ctx);
// for orig = tuple(x1, ..., xn), tuple_grad(x1, ..., xn, G) = [pi(G, 1), ..., pi(G, n)]
Expand Down
19 changes: 7 additions & 12 deletions src/relay/transforms/forward_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,16 @@ class ForwardRewriter : private MixedModeMutator {
}
}

Expr Rewrite_(const TupleNode* op, const Expr& post) final {
Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) final {
tvm::Array<Expr> fields;
bool all_fields_unchanged = true;
const auto* post_node = post.as<TupleNode>();
for (size_t i = 0; i < op->fields.size(); ++i) {
auto new_field = this->GetTempExpr(op->fields[i], post_node->fields[i]);
fields.push_back(new_field);
all_fields_unchanged &= new_field.same_as(op->fields[i]);
}
fields.reserve(tuple_node->fields.size());

if (all_fields_unchanged) {
return GetRef<Expr>(op);
} else {
return Tuple(fields);
const auto* post_tuple_node = post.as<TupleNode>();
for (size_t i = 0; i < tuple_node->fields.size(); ++i) {
fields.push_back(this->GetTempExpr(tuple_node->fields[i], post_tuple_node->fields[i]));
}

return WithFields(GetRef<Tuple>(tuple_node), std::move(fields));
}

Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
Expand Down
12 changes: 6 additions & 6 deletions src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,14 +898,14 @@ class FuseMutator : private MixedModeMutator {
}
}

Expr Rewrite_(const TupleNode* tuple, const Expr& post) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
if (ret_group->root_ref == tuple) {
return ExprMutator::VisitExpr_(tuple);
Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) {
auto* ret_group = gmap_.at(tuple_node)->FindRoot();
if (ret_group->root_ref == tuple_node) {
return ExprMutator::VisitExpr_(tuple_node);
}
// This tuple is an intermediate node in the group
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
return Tuple(new_fields);
Array<Expr> new_fields = GetNewArguments(tuple_node->fields, ret_group);
return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
}

Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) {
Expand Down
8 changes: 5 additions & 3 deletions src/relay/transforms/memory_alloc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,20 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
Function Rewrite(const Function& expr) { return Downcast<Function>(Mutate(expr)); }

private:
Expr VisitExpr_(const TupleNode* tn) final {
Expr VisitExpr_(const TupleNode* tuple_node) final {
LetList& scope = scopes_.back();
Array<Expr> new_fields;
for (auto field : tn->fields) {
new_fields.reserve(tuple_node->fields.size());

for (auto field : tuple_node->fields) {
auto new_field = Mutate(field);
if (new_field->IsInstance<ConstantNode>()) {
Var const_var("const", Type(nullptr));
new_field = scope.Push(const_var, new_field);
}
new_fields.push_back(new_field);
}
return Tuple(new_fields);
return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
}

void PreVisitLetBlock_(const LetNode* let_node) final { scopes_.emplace_back(); }
Expand Down
2 changes: 2 additions & 0 deletions src/relay/transforms/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,8 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
value.push_back(ps);
expr.push_back(ps->dynamic);
}
// Note(@electriclilies): The partial evaluator seems to do some weird stuff with sharing.
// Changing Tuple(expr) to WithFields(op, expr) causes some strange failures.
return HasStatic(MkSTuple(value), ll->Push(Tuple(expr)));
}

Expand Down
8 changes: 4 additions & 4 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,18 +455,18 @@ IRModule FlattenTupleOutputs(IRModule module) {
// Arguments of annotation ops should be 1
ICHECK_EQ(call->args.size(), 1U);
auto annotated_op = Downcast<Call>(post)->args[0];
if (const auto* tn = annotated_op.as<TupleNode>()) {
if (const auto* tuple_node = annotated_op.as<TupleNode>()) {
Array<Expr> new_fields;
new_fields.reserve(tuple_node->fields.size());

// Here each input of the tuple will be annotated with compiler_ends
for (auto& tn_arg : tn->fields) {
for (auto& tn_arg : tuple_node->fields) {
new_fields.push_back((*make_end_op)(tn_arg, target));
}

// Return a tuple of compiler_ends in the place of the tuple that was
// annotated with a compiler_end.
auto out = Tuple(new_fields);
return std::move(out);
return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
}
}
return post;
Expand Down
18 changes: 10 additions & 8 deletions src/relay/transforms/split_args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ class ArgumentSplitter : public ExprRewriter {
Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (max_function_args_ < 0) return post;
if (call->op == concat_op_) {
auto op = call->args[0].as<TupleNode>();
auto tuple_node = call->args[0].as<TupleNode>();
const auto param = call->attrs.as<ConcatenateAttrs>();
int outputsNum = 1;
if (const auto* tuple_type = call->checked_type().as<TupleTypeNode>()) {
outputsNum = tuple_type->fields.size();
}
const int limit = max_function_args_ - outputsNum;
int argsNum = op->fields.size();
int argsNum = tuple_node->fields.size();
if (argsNum < limit) return post;
int splitNum = argsNum / limit;
splitNum = (argsNum % limit) ? splitNum + 1 : splitNum;
Expand All @@ -54,16 +54,18 @@ class ArgumentSplitter : public ExprRewriter {
int startIdx = i * limit;
int argsCount = std::min(limit, argsNum - startIdx);
tvm::Array<Expr> args;
args.reserve(argsCount);

for (int j = 0; j < argsCount; ++j) {
args.push_back(op->fields[j + startIdx]);
args.push_back(tuple_node->fields[j + startIdx]);
}
Tuple tuple(args);
Expr body = MakeConcatenate(tuple, param->axis);
Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), std::move(args));
Expr body = MakeConcatenate(new_tuple, param->axis);
splitted[i] = StopFusion(body);
}
tvm::Array<Expr> tupleArgs(splitted);
Tuple tuple(tupleArgs);
return MakeConcatenate(tuple, param->axis);
tvm::Array<Expr> tuple_args(splitted);
Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), std::move(tuple_args));
return MakeConcatenate(new_tuple, param->axis);
}
return post;
}
Expand Down
11 changes: 6 additions & 5 deletions src/relay/transforms/to_a_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,14 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)>, private transform::Lexi
return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v);
}

Expr VisitExpr_(const TupleNode* t, const Var& v) final {
Expr e = GetRef<Expr>(t);
std::vector<Expr> fields;
for (const auto& a : t->fields) {
Expr VisitExpr_(const TupleNode* tuple_node, const Var& v) final {
Expr e = GetRef<Expr>(tuple_node);
Array<Expr> fields;
fields.reserve(tuple_node->fields.size());
for (const auto& a : tuple_node->fields) {
fields.push_back(VisitExpr(a));
}
return Compound(e, Tuple(fields), v);
return Compound(e, WithFields(GetRef<Tuple>(tuple_node), std::move(fields)), v);
}

Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final {
Expand Down
9 changes: 5 additions & 4 deletions src/relay/transforms/to_cps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,14 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm,
});
}

Expr VisitExpr_(const TupleNode* op, const MCont& k) final {
Expr VisitExpr_(const TupleNode* tuple_node, const MCont& k) final {
tvm::Array<Expr> fields;
fields.reserve(tuple_node->fields.size());
std::function<Expr()> next;
next = [&]() {
return (fields.size() == op->fields.size())
? k(Tuple(fields))
: VisitExpr(op->fields[fields.size()], [&](const Expr& v) {
return (fields.size() == tuple_node->fields.size())
? k(WithFields(GetRef<Tuple>(tuple_node), std::move(fields)))
: VisitExpr(tuple_node->fields[fields.size()], [&](const Expr& v) {
fields.push_back(v);
return next();
});
Expand Down
11 changes: 7 additions & 4 deletions src/relay/transforms/transform_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>

#include "infer_layout_utils.h"
Expand Down Expand Up @@ -293,12 +294,13 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
// NOTE: do not support nested tuple
if (new_arg->IsInstance<TupleNode>()) {
Tuple tuple_new_arg = Downcast<Tuple>(new_arg);
std::vector<Expr> fields;
Array<Expr> fields;
fields.reserve(tuple_new_arg->fields.size());
for (auto x : tuple_new_arg->fields) {
Expr tmp = push_back_one_arg(x);
fields.push_back(tmp);
}
normal_new_args.push_back(Tuple(fields));
normal_new_args.push_back(WithFields(tuple_new_arg, std::move(fields)));
} else {
Expr tmp = push_back_one_arg(new_arg);
normal_new_args.push_back(tmp);
Expand Down Expand Up @@ -375,12 +377,13 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
for (auto arg : new_call->args) {
if (arg->IsInstance<TupleNode>()) { // unflatten tuple
Tuple tuple_arg = Downcast<Tuple>(arg);
std::vector<Expr> transformed_tuple_arg;
Array<Expr> transformed_tuple_arg;
transformed_tuple_arg.reserve(tuple_arg->fields.size());
for (auto arg_item : tuple_arg->fields) {
transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt]));
pt++;
}
transformed_args.push_back(Tuple(transformed_tuple_arg));
transformed_args.push_back(WithFields(tuple_arg, std::move(transformed_tuple_arg)));
} else {
transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt]));
pt++;
Expand Down

0 comments on commit 3c48cad

Please sign in to comment.