Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WithFields for Tuples #9533

Merged
merged 20 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,20 @@ class Tuple : public Expr {
TVM_DLL explicit Tuple(tvm::Array<relay::Expr> fields, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
};

/*!
* \brief Returns the tuple with given properties. A null property denotes 'no change'.
* Returns this if all properties are unchanged. Otherwise, returns a copy with the new fields.
* \param tuple The tuple to copy
* \param opt_fields The (optional) fields for the copied tuple. If none, ret_tuple->fields =
* tuple->fields.
* \param opt_span The (optional) span for the copied tuple. If none, ret_tuple->span = tuple->span.
*/
Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields = Optional<Array<Expr>>(),
Optional<Span> opt_span = Optional<Span>(nullptr));

/*!
* \brief Local variables used in the let expression.
*
Expand Down
23 changes: 23 additions & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,29 @@ TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array<relay::Expr> fields, Span span) {
return Tuple(fields, span);
});
Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields, Optional<Span> opt_span) {
Array<Expr> fields = opt_fields.value_or(tuple->fields);
Span span = opt_span.value_or(tuple->span);

bool all_fields_unchanged = true;
if (fields.size() == tuple->fields.size()) {
for (uint i = 0; i < fields.size(); i++) {
all_fields_unchanged &= fields[i].same_as(tuple->fields[i]);
}
} else {
all_fields_unchanged = false;
}

all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span);
if (all_fields_unchanged) {
return std::move(tuple);
} else {
TupleNode* cow_tuple_node = tuple.CopyOnWrite();
cow_tuple_node->fields = fields;
cow_tuple_node->span = span;
return GetRef<Tuple>(cow_tuple_node);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, you can directly return tuple here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all your help with the subtleties!

}
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
Expand Down
15 changes: 5 additions & 10 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,15 @@ Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { return GetRef<Expr>(op);

Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef<Expr>(op); }

Expr ExprMutator::VisitExpr_(const TupleNode* op) {
Expr ExprMutator::VisitExpr_(const TupleNode* tuple_node) {
tvm::Array<Expr> fields;
bool all_fields_unchanged = true;
for (auto field : op->fields) {
fields.reserve(tuple_node->fields.size());

for (auto field : tuple_node->fields) {
auto new_field = this->Mutate(field);
fields.push_back(new_field);
all_fields_unchanged &= new_field.same_as(field);
}

if (all_fields_unchanged) {
return GetRef<Expr>(op);
} else {
return Tuple(fields, op->span);
}
return WithFields(GetRef<Tuple>(tuple_node), std::move(fields));
}

Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
Expand Down
18 changes: 10 additions & 8 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,11 @@ class AnnotateTargetRewriter : public ExprRewriter {

virtual std::unique_ptr<Call> RewriteVarCall(const Call& post_call) { return nullptr; }

Expr Rewrite_(const TupleNode* op, const Expr& post) override {
auto expr = Downcast<Tuple>(post);
Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override {
auto tuple = Downcast<Tuple>(post);

auto target_n_args = AnnotateArgs(expr->fields);
auto new_expr = Tuple(std::get<1>(target_n_args));
auto target_n_args = AnnotateArgs(tuple->fields);
auto new_expr = WithFields(std::move(tuple), std::move(std::get<1>(target_n_args)));
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
}
Expand Down Expand Up @@ -370,13 +370,15 @@ class CallOpsTargetRewriter : public AnnotateTargetRewriter {
return new_call;
}

Expr Rewrite_(const TupleNode* op, const Expr& post) override {
auto expr = Downcast<Tuple>(post);
Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override {
auto tuple = Downcast<Tuple>(post);
Array<Expr> new_fields;
for (auto f : expr->fields) {
new_fields.reserve(tuple->fields.size());

for (auto f : tuple->fields) {
new_fields.push_back(InsertCompilerEndAndPropogateTarget(f));
}
return std::move(Tuple(new_fields));
return WithFields(std::move(tuple), std::move(new_fields));
}

Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override {
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/device_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,7 @@ class DeviceCapturer : public ExprMutator {
for (const auto& field : tuple_node->fields) {
fields.push_back(VisitChild(tuple, field));
}
// TODO(mbs): Avoid copy
return Tuple(std::move(fields), tuple_node->span);
return WithFields(std::move(tuple), std::move(fields));
}

Expr VisitExpr_(const FunctionNode* function_node) final {
Expand Down
12 changes: 7 additions & 5 deletions src/relay/transforms/first_order_gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,13 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
return ret;
}

ADValue VisitExpr_(const TupleNode* op) final {
auto tt = Downcast<TupleType>(op->checked_type());
ADValue VisitExpr_(const TupleNode* tuple_node) final {
auto tt = Downcast<TupleType>(tuple_node->checked_type());
std::vector<ADValue> ad_fields;
std::vector<Expr> field_bindings;
for (const auto& f : op->fields) {
Array<Expr> field_bindings;
field_bindings.reserve(tuple_node->fields.size());

for (const auto& f : tuple_node->fields) {
ADValue f_ad = VisitExpr(f);
if (!dynamic_cast<ADTensor*>(f_ad.get())) {
diag_ctx.EmitFatal(Diagnostic::Error(f->span)
Expand All @@ -209,7 +211,7 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr&)> {
field_bindings.push_back(f_ad->get<ADTensor>().forward);
}
// reconstruct tuple using let-bound variables to avoid duplication
auto orig = Tuple(field_bindings);
auto orig = WithFields(GetRef<Tuple>(tuple_node), std::move(field_bindings));
orig->checked_type_ = tt;
auto ret = std::make_shared<ADTensor>(ll, orig, diag_ctx);
// for orig = tuple(x1, ..., xn), tuple_grad(x1, ..., xn, G) = [pi(G, 1), ..., pi(G, n)]
Expand Down
19 changes: 7 additions & 12 deletions src/relay/transforms/forward_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,21 +113,16 @@ class ForwardRewriter : private MixedModeMutator {
}
}

Expr Rewrite_(const TupleNode* op, const Expr& post) final {
Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) final {
tvm::Array<Expr> fields;
bool all_fields_unchanged = true;
const auto* post_node = post.as<TupleNode>();
for (size_t i = 0; i < op->fields.size(); ++i) {
auto new_field = this->GetTempExpr(op->fields[i], post_node->fields[i]);
fields.push_back(new_field);
all_fields_unchanged &= new_field.same_as(op->fields[i]);
}
fields.reserve(tuple_node->fields.size());

if (all_fields_unchanged) {
return GetRef<Expr>(op);
} else {
return Tuple(fields);
const auto* post_tuple_node = post.as<TupleNode>();
for (size_t i = 0; i < tuple_node->fields.size(); ++i) {
fields.push_back(this->GetTempExpr(tuple_node->fields[i], post_tuple_node->fields[i]));
}

return WithFields(GetRef<Tuple>(tuple_node), std::move(fields));
}

Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
Expand Down
12 changes: 6 additions & 6 deletions src/relay/transforms/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,14 +898,14 @@ class FuseMutator : private MixedModeMutator {
}
}

Expr Rewrite_(const TupleNode* tuple, const Expr& post) {
auto* ret_group = gmap_.at(tuple)->FindRoot();
if (ret_group->root_ref == tuple) {
return ExprMutator::VisitExpr_(tuple);
Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) {
auto* ret_group = gmap_.at(tuple_node)->FindRoot();
if (ret_group->root_ref == tuple_node) {
return ExprMutator::VisitExpr_(tuple_node);
}
// This tuple is an intermediate node in the group
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
return Tuple(new_fields);
Array<Expr> new_fields = GetNewArguments(tuple_node->fields, ret_group);
return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
}

Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) {
Expand Down
8 changes: 5 additions & 3 deletions src/relay/transforms/memory_alloc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,20 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
Function Rewrite(const Function& expr) { return Downcast<Function>(Mutate(expr)); }

private:
Expr VisitExpr_(const TupleNode* tn) final {
Expr VisitExpr_(const TupleNode* tuple_node) final {
LetList& scope = scopes_.back();
Array<Expr> new_fields;
for (auto field : tn->fields) {
new_fields.reserve(tuple_node->fields.size());

for (auto field : tuple_node->fields) {
auto new_field = Mutate(field);
if (new_field->IsInstance<ConstantNode>()) {
Var const_var("const", Type(nullptr));
new_field = scope.Push(const_var, new_field);
}
new_fields.push_back(new_field);
}
return Tuple(new_fields);
return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
}

void PreVisitLetBlock_(const LetNode* let_node) final { scopes_.emplace_back(); }
Expand Down
12 changes: 7 additions & 5 deletions src/relay/transforms/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,15 +607,17 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return HasStatic(MkSTensor(op->data.CopyTo(device_)), ll->Push(GetRef<Expr>(op)));
}

PStatic VisitExpr_(const TupleNode* op, LetList* ll) final {
PStatic VisitExpr_(const TupleNode* tuple_node, LetList* ll) final {
std::vector<PStatic> value;
tvm::Array<Expr> expr;
for (const Expr& e : op->fields) {
tvm::Array<Expr> new_fields;
new_fields.reserve(tuple_node->fields.size());
for (const Expr& e : tuple_node->fields) {
PStatic ps = VisitExpr(e, ll);
value.push_back(ps);
expr.push_back(ps->dynamic);
new_fields.push_back(ps->dynamic);
}
return HasStatic(MkSTuple(value), ll->Push(Tuple(expr)));
return HasStatic(MkSTuple(value),
ll->Push(WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields))));
}

PStatic VisitExpr_(const TupleGetItemNode* op, LetList* ll) final {
Expand Down
8 changes: 4 additions & 4 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,18 +455,18 @@ IRModule FlattenTupleOutputs(IRModule module) {
// Arguments of annotation ops should be 1
ICHECK_EQ(call->args.size(), 1U);
auto annotated_op = Downcast<Call>(post)->args[0];
if (const auto* tn = annotated_op.as<TupleNode>()) {
if (const auto* tuple_node = annotated_op.as<TupleNode>()) {
Array<Expr> new_fields;
new_fields.reserve(tuple_node->fields.size());

// Here each input of the tuple will be annotated with compiler_ends
for (auto& tn_arg : tn->fields) {
for (auto& tn_arg : tuple_node->fields) {
new_fields.push_back((*make_end_op)(tn_arg, target));
}

// Return a tuple of compiler_ends in the place of the tuple that was
// annotated with a compiler_end.
auto out = Tuple(new_fields);
return std::move(out);
return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
}
}
return post;
Expand Down
18 changes: 10 additions & 8 deletions src/relay/transforms/split_args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ class ArgumentSplitter : public ExprRewriter {
Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (max_function_args_ < 0) return post;
if (call->op == concat_op_) {
auto op = call->args[0].as<TupleNode>();
auto tuple_node = call->args[0].as<TupleNode>();
const auto param = call->attrs.as<ConcatenateAttrs>();
int outputsNum = 1;
if (const auto* tuple_type = call->checked_type().as<TupleTypeNode>()) {
outputsNum = tuple_type->fields.size();
}
const int limit = max_function_args_ - outputsNum;
int argsNum = op->fields.size();
int argsNum = tuple_node->fields.size();
if (argsNum < limit) return post;
int splitNum = argsNum / limit;
splitNum = (argsNum % limit) ? splitNum + 1 : splitNum;
Expand All @@ -54,16 +54,18 @@ class ArgumentSplitter : public ExprRewriter {
int startIdx = i * limit;
int argsCount = std::min(limit, argsNum - startIdx);
tvm::Array<Expr> args;
args.reserve(argsCount);

for (int j = 0; j < argsCount; ++j) {
args.push_back(op->fields[j + startIdx]);
args.push_back(tuple_node->fields[j + startIdx]);
}
Tuple tuple(args);
Expr body = MakeConcatenate(tuple, param->axis);
Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), std::move(args));
Expr body = MakeConcatenate(new_tuple, param->axis);
splitted[i] = StopFusion(body);
}
tvm::Array<Expr> tupleArgs(splitted);
Tuple tuple(tupleArgs);
return MakeConcatenate(tuple, param->axis);
tvm::Array<Expr> tuple_args(splitted);
Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), std::move(tuple_args));
return MakeConcatenate(new_tuple, param->axis);
}
return post;
}
Expand Down
11 changes: 6 additions & 5 deletions src/relay/transforms/to_a_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,14 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)>, private transform::Lexi
return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v);
}

Expr VisitExpr_(const TupleNode* t, const Var& v) final {
Expr e = GetRef<Expr>(t);
std::vector<Expr> fields;
for (const auto& a : t->fields) {
Expr VisitExpr_(const TupleNode* tuple_node, const Var& v) final {
Expr e = GetRef<Expr>(tuple_node);
Array<Expr> fields;
fields.reserve(tuple_node->fields.size());
for (const auto& a : tuple_node->fields) {
fields.push_back(VisitExpr(a));
}
return Compound(e, Tuple(fields), v);
return Compound(e, WithFields(GetRef<Tuple>(tuple_node), std::move(fields)), v);
}

Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final {
Expand Down
9 changes: 5 additions & 4 deletions src/relay/transforms/to_cps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,14 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm,
});
}

Expr VisitExpr_(const TupleNode* op, const MCont& k) final {
Expr VisitExpr_(const TupleNode* tuple_node, const MCont& k) final {
tvm::Array<Expr> fields;
fields.reserve(tuple_node->fields.size());
std::function<Expr()> next;
next = [&]() {
return (fields.size() == op->fields.size())
? k(Tuple(fields))
: VisitExpr(op->fields[fields.size()], [&](const Expr& v) {
return (fields.size() == tuple_node->fields.size())
? k(WithFields(GetRef<Tuple>(tuple_node), std::move(fields)))
: VisitExpr(tuple_node->fields[fields.size()], [&](const Expr& v) {
fields.push_back(v);
return next();
});
Expand Down
11 changes: 7 additions & 4 deletions src/relay/transforms/transform_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>

#include "infer_layout_utils.h"
Expand Down Expand Up @@ -293,12 +294,13 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
// NOTE: do not support nested tuple
if (new_arg->IsInstance<TupleNode>()) {
Tuple tuple_new_arg = Downcast<Tuple>(new_arg);
std::vector<Expr> fields;
Array<Expr> fields;
fields.reserve(tuple_new_arg->fields.size());
for (auto x : tuple_new_arg->fields) {
Expr tmp = push_back_one_arg(x);
fields.push_back(tmp);
}
normal_new_args.push_back(Tuple(fields));
normal_new_args.push_back(WithFields(tuple_new_arg, std::move(fields)));
} else {
Expr tmp = push_back_one_arg(new_arg);
normal_new_args.push_back(tmp);
Expand Down Expand Up @@ -375,12 +377,13 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
for (auto arg : new_call->args) {
if (arg->IsInstance<TupleNode>()) { // unflatten tuple
Tuple tuple_arg = Downcast<Tuple>(arg);
std::vector<Expr> transformed_tuple_arg;
Array<Expr> transformed_tuple_arg;
transformed_tuple_arg.reserve(tuple_arg->fields.size());
for (auto arg_item : tuple_arg->fields) {
transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt]));
pt++;
}
transformed_args.push_back(Tuple(transformed_tuple_arg));
transformed_args.push_back(WithFields(tuple_arg, std::move(transformed_tuple_arg)));
} else {
transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt]));
pt++;
Expand Down