Skip to content

Commit

Permalink
BUG apache#8013: Remove register_alter_op_layout example from dev/use…
Browse files Browse the repository at this point in the history
…_pass_infra.py

This tutorial registers a global layout transformation for conv2d for all
targets which is not well-formed. Later uses of conv2d in the tutorials
pick that layout up then assert fail in the conv2d type-relation.

Better would be to register a transform for an entirely fake target, but
that is beyond my current level of expertise.

In general our use of sphinx/sphinx_gallery for running and rendering the
tutorials is highly suspect since there is no inter-example isolation:
 - Examples using tensorflow will gobble up GPU memory and not give it back.
 - Any examples which use any of the (many!) global registration mechanisms
   need to ensure the registrant is safe across all tutorials.
I recall seeing a thread with the sphinx_gallery where they said they'd prefer
not to work on process-level isolation, but it's probably worth pinging again.

While digging into this I noticed we had a slicing cast in AlterOpLayout due
to a derived class of ObjectRef introducing virtuals. I moved the virtuals to
the corresponding Node classes. In this case we got away with it since the
ObjectRef happened to not get copied but we were on very thin ice.
  • Loading branch information
mbs-octoml committed Sep 22, 2021
1 parent 4c8531d commit 0c85654
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 83 deletions.
41 changes: 23 additions & 18 deletions src/relay/transforms/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,6 @@ namespace alter_op_layout {
class AlterTransformMemorizerNode : public TransformMemorizerNode {
public:
static constexpr const char* _type_key = "relay.alter_op_layout.AlterTransformMemorizerNode";
};

/*!
* \brief Container that provides the transformation function for alter layout..
*/
class AlterTransformMemorizer : public TransformMemorizer {
public:
AlterTransformMemorizer() {}
explicit AlterTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {}

AlterTransformMemorizerNode* operator->() {
return static_cast<AlterTransformMemorizerNode*>(get_mutable());
}

/*!
* \brief Defines the call transformation for AlterOpLayout pass. The new layouts are defined by
Expand Down Expand Up @@ -102,7 +89,23 @@ class AlterTransformMemorizer : public TransformMemorizer {
return GetRef<Call>(new_call);
}

using TransformMemorizer::CallWithNewLayouts;
Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override {
return CallWithNewLayouts(ref_call, ref_call->attrs, new_args);
}
};

/*!
* \brief Container that provides the transformation function for alter layout..
*/
class AlterTransformMemorizer : public TransformMemorizer {
public:
AlterTransformMemorizer() = default;
explicit AlterTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {}

AlterTransformMemorizerNode* operator->() {
return static_cast<AlterTransformMemorizerNode*>(get_mutable());
}

using ContainerType = AlterTransformMemorizerNode;
};

Expand All @@ -113,10 +116,12 @@ class AlterTransformMemorizer : public TransformMemorizer {
*/
Expr AlterOpLayout(const Expr& expr) {
// TODO(@icemelon9): need to rerun type inference after applying an alter op.
AlterTransformMemorizer alterMemorizer(make_object<AlterTransformMemorizerNode>());
auto fcontext = [&](const Call& call) -> ObjectRef { return alterMemorizer; };

return ForwardRewrite(expr, LayoutRewriter<AlterTransformMemorizer>, fcontext);
AlterTransformMemorizer alter_memorizer(make_object<AlterTransformMemorizerNode>());
std::function<ObjectRef(const Call&)> fcontext = [=](const Call& call) -> ObjectRef {
return alter_memorizer;
};
FForwardRewrite rewrite_func = LayoutRewriter<AlterTransformMemorizer>;
return ForwardRewrite(expr, rewrite_func, fcontext);
}

} // namespace alter_op_layout
Expand Down
39 changes: 21 additions & 18 deletions src/relay/transforms/convert_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,6 @@ class ConvertTransformMemorizerNode : public TransformMemorizerNode {
explicit ConvertTransformMemorizerNode(Map<String, Array<String>> desired_layouts)
: desired_layouts_(std::move(desired_layouts)) {}

/*! \brief A mapping of op_name to array of desired layouts for each input. */
Map<String, Array<String>> desired_layouts_;
};

/*!
* \brief Container that provides the transformation function for convert layout.
*/
class ConvertTransformMemorizer : public TransformMemorizer {
public:
ConvertTransformMemorizer() {}
explicit ConvertTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {}

ConvertTransformMemorizerNode* operator->() {
return static_cast<ConvertTransformMemorizerNode*>(get_mutable());
}

/*!
* \brief Defines the call transformation for ConvertLayout pass. The new layouts should be the
* desired layout as specified by the user.
Expand All @@ -89,7 +73,7 @@ class ConvertTransformMemorizer : public TransformMemorizer {
Expr new_e;
bool modified = false;
if (fconvert_layout.count(op)) {
auto desired_layouts = operator->()->desired_layouts_;
auto desired_layouts = desired_layouts_;
if (desired_layouts.find(op->name) != desired_layouts.end()) {
tvm::Array<tvm::te::Tensor> tinfos;
for (auto& expr : ref_call->args) {
Expand Down Expand Up @@ -124,7 +108,26 @@ class ConvertTransformMemorizer : public TransformMemorizer {
return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args, ref_call->span);
}

using TransformMemorizer::CallWithNewLayouts;
Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override {
return CallWithNewLayouts(ref_call, ref_call->attrs, new_args);
}

/*! \brief A mapping of op_name to array of desired layouts for each input. */
Map<String, Array<String>> desired_layouts_;
};

/*!
* \brief Container that provides the transformation function for convert layout.
*/
class ConvertTransformMemorizer : public TransformMemorizer {
public:
ConvertTransformMemorizer() = default;
explicit ConvertTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {}

ConvertTransformMemorizerNode* operator->() {
return static_cast<ConvertTransformMemorizerNode*>(get_mutable());
}

using ContainerType = ConvertTransformMemorizerNode;
};

Expand Down
36 changes: 18 additions & 18 deletions src/relay/transforms/transform_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,21 @@ class TransformMemorizerNode : public Object {
}
};

/*!
* \brief Defines the call transformation for derived passes. The new layouts are defined by
* used for different targets using a packed func.
* \param ref_call The original call.
* \param new_attrs Updated attributes consistent with new layouts.
* \param new_args The traversed/recursed args to the call.
* \return The new Call after calling the packed func.
*/
virtual Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs,
const std::vector<Expr>& new_args) = 0;

virtual Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) {
return CallWithNewLayouts(ref_call, ref_call->attrs, new_args);
}

/*! \brief The memorizer map. */
std::unordered_map<TransformKey, Expr, key_hash> memo;

Expand All @@ -69,11 +84,9 @@ class TransformMemorizerNode : public Object {
*/
class TransformMemorizer : public ObjectRef {
public:
TransformMemorizer() {}
TransformMemorizer() = default;
explicit TransformMemorizer(ObjectPtr<Object> n) : ObjectRef(n) {}

virtual ~TransformMemorizer() {}

TransformMemorizerNode* operator->() {
return static_cast<TransformMemorizerNode*>(get_mutable());
}
Expand Down Expand Up @@ -146,19 +159,6 @@ class TransformMemorizer : public ObjectRef {
return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name());
}

/*!
* \brief Defines the call transformation for derived passes. The new layouts are defined by
* used for different targets using a packed func.
* \param ref_call The original call.
* \param new_attrs Updated attributes consistent with new layouts.
* \param new_args The traversed/recursed args to the call.
* \return The new Call after calling the packed func.
*/
virtual Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs,
const std::vector<Expr>& new_args) = 0;
virtual Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) {
return CallWithNewLayouts(ref_call, ref_call->attrs, new_args);
}
using ContainerType = TransformMemorizerNode;
};

Expand Down Expand Up @@ -312,7 +312,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
if (ref_call->op.as<OpNode>()) {
Op op = Downcast<Op>(ref_call->op);
if (falter_layout.count(op) && !finfer_layout.count(op)) {
return memorizer.CallWithNewLayouts(ref_call, normal_new_args);
return memorizer->CallWithNewLayouts(ref_call, normal_new_args);
}
}
}
Expand Down Expand Up @@ -349,7 +349,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
}

// new_op = alter(op)
Call new_call = memorizer.CallWithNewLayouts(ref_call, infer_out->new_attrs, normal_new_args);
Call new_call = memorizer->CallWithNewLayouts(ref_call, infer_out->new_attrs, normal_new_args);

// new_in2, new_out = op.infer(new_in)
if (new_call->op->IsInstance<OpNode>()) {
Expand Down
29 changes: 0 additions & 29 deletions tutorials/dev/use_pass_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,6 @@ def example():
return relay.Function([x, weight], z2)


###############################################################################
# Let us register layout alteration for a conv2d op so that we can apply the
# layout alteration pass on the example. How alter layout pass works is out
# the scope of this tutorial.


@relay.op.register_alter_op_layout("nn.conv2d", level=101)
def alter_conv2d(attrs, inputs, tinfos, out_type):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs["data_layout"] = "NCHW16c"
return relay.nn.conv2d(data, weight, **new_attrs)


###############################################################################
# Optimize the Program
# --------------------
Expand Down Expand Up @@ -188,21 +174,6 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
mod3 = seq(mod)
print(mod3)

###############################################################################
# The passes applied so far are target independent. The pass infra also
# provides a means to make pass target-aware. For example, the layout
# alteration pass falls in such category.

with tvm.transform.PassContext(opt_level=3):
mod4 = seq(mod)
print(mod4)

seq1 = tvm.transform.Sequential([relay.transform.AlterOpLayout()])
with tvm.transform.PassContext(opt_level=3):
with tvm.target.Target("llvm"):
mod5 = seq1(mod)
print(mod5)

##############################################################################
# Implement a Pass Using Python Decorator
# ------------------------------------------
Expand Down

0 comments on commit 0c85654

Please sign in to comment.