From e8872866acd82e24fb20cd86bcc6fbdb409eec47 Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Thu, 23 Sep 2021 09:41:13 -0700 Subject: [PATCH] BUG #8013: Remove register_alter_op_layout example from dev/use_pass_infra.py (#9076) * BUG #8013: Remove register_alter_op_layout example from dev/use_pass_infra.py This tutorial registers a global layout transformation for conv2d for all targets which is not well-formed. Later uses of conv2d in the tutorials pick that layout up then assert fail in the conv2d type-relation. Better would be to register a transform for an entirely fake target, but that is beyond my current level of expertise. In general our use of sphinx/sphinx_gallery for running and rendering the tutorials is highly suspect since there is no inter-example isolation: - Examples using tensorflow will gobble up GPU memory and not give it back. - Any examples which use any of the (many!) global registration mechanisms need to ensure the registrant is safe across all tutorials. I recall seeing a thread with the sphinx_gallery where they said they'd prefer not to work on process-level isolation, but it's probably worth pinging again. While digging into this I noticed we had a slicing cast in AlterOpLayout due to a derived class of ObjectRef introducing virtuals. I moved the virtuals to the corresponding Node classes. In this case we got away with it since the ObjectRef happened to not get copied but we were on very thin ice. * [checkpoint] Woops, forgot there was an extra AlterOpLayout I should have run locally, there goes 6hrs of CI. --- src/relay/transforms/alter_op_layout.cc | 41 ++++++++++++++----------- src/relay/transforms/convert_layout.cc | 39 ++++++++++++----------- src/relay/transforms/transform_layout.h | 36 +++++++++++----------- tutorials/dev/use_pass_infra.py | 30 ------------------ 4 files changed, 62 insertions(+), 84 deletions(-) diff --git a/src/relay/transforms/alter_op_layout.cc b/src/relay/transforms/alter_op_layout.cc index 9afdb7210cba..f347eddae760 100644 --- a/src/relay/transforms/alter_op_layout.cc +++ b/src/relay/transforms/alter_op_layout.cc @@ -50,19 +50,6 @@ namespace alter_op_layout { class AlterTransformMemorizerNode : public TransformMemorizerNode { public: static constexpr const char* _type_key = "relay.alter_op_layout.AlterTransformMemorizerNode"; -}; - -/*! - * \brief Container that provides the transformation function for alter layout.. - */ -class AlterTransformMemorizer : public TransformMemorizer { - public: - AlterTransformMemorizer() {} - explicit AlterTransformMemorizer(ObjectPtr n) : TransformMemorizer(n) {} - - AlterTransformMemorizerNode* operator->() { - return static_cast(get_mutable()); - } /*! * \brief Defines the call transformation for AlterOpLayout pass. The new layouts are defined by @@ -102,7 +89,23 @@ class AlterTransformMemorizer : public TransformMemorizer { return GetRef(new_call); } - using TransformMemorizer::CallWithNewLayouts; + Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) override { + return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); + } +}; + +/*! + * \brief Container that provides the transformation function for alter layout.. + */ +class AlterTransformMemorizer : public TransformMemorizer { + public: + AlterTransformMemorizer() = default; + explicit AlterTransformMemorizer(ObjectPtr n) : TransformMemorizer(n) {} + + AlterTransformMemorizerNode* operator->() { + return static_cast(get_mutable()); + } + using ContainerType = AlterTransformMemorizerNode; }; @@ -113,10 +116,12 @@ class AlterTransformMemorizer : public TransformMemorizer { */ Expr AlterOpLayout(const Expr& expr) { // TODO(@icemelon9): need to rerun type inference after applying an alter op. - AlterTransformMemorizer alterMemorizer(make_object()); - auto fcontext = [&](const Call& call) -> ObjectRef { return alterMemorizer; }; - - return ForwardRewrite(expr, LayoutRewriter, fcontext); + AlterTransformMemorizer alter_memorizer(make_object()); + std::function fcontext = [=](const Call& call) -> ObjectRef { + return alter_memorizer; + }; + FForwardRewrite rewrite_func = LayoutRewriter; + return ForwardRewrite(expr, rewrite_func, fcontext); } } // namespace alter_op_layout diff --git a/src/relay/transforms/convert_layout.cc b/src/relay/transforms/convert_layout.cc index e74ea0115857..e10be508529e 100644 --- a/src/relay/transforms/convert_layout.cc +++ b/src/relay/transforms/convert_layout.cc @@ -58,22 +58,6 @@ class ConvertTransformMemorizerNode : public TransformMemorizerNode { explicit ConvertTransformMemorizerNode(Map> desired_layouts) : desired_layouts_(std::move(desired_layouts)) {} - /*! \brief A mapping of op_name to array of desired layouts for each input. */ - Map> desired_layouts_; -}; - -/*! - * \brief Container that provides the transformation function for convert layout. - */ -class ConvertTransformMemorizer : public TransformMemorizer { - public: - ConvertTransformMemorizer() {} - explicit ConvertTransformMemorizer(ObjectPtr n) : TransformMemorizer(n) {} - - ConvertTransformMemorizerNode* operator->() { - return static_cast(get_mutable()); - } - /*! * \brief Defines the call transformation for ConvertLayout pass. The new layouts should be the * desired layout as specified by the user. @@ -89,7 +73,7 @@ class ConvertTransformMemorizer : public TransformMemorizer { Expr new_e; bool modified = false; if (fconvert_layout.count(op)) { - auto desired_layouts = operator->()->desired_layouts_; + auto desired_layouts = desired_layouts_; if (desired_layouts.find(op->name) != desired_layouts.end()) { tvm::Array tinfos; for (auto& expr : ref_call->args) { @@ -124,7 +108,26 @@ class ConvertTransformMemorizer : public TransformMemorizer { return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args, ref_call->span); } - using TransformMemorizer::CallWithNewLayouts; + Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) override { + return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); + } + + /*! \brief A mapping of op_name to array of desired layouts for each input. */ + Map> desired_layouts_; +}; + +/*! + * \brief Container that provides the transformation function for convert layout. + */ +class ConvertTransformMemorizer : public TransformMemorizer { + public: + ConvertTransformMemorizer() = default; + explicit ConvertTransformMemorizer(ObjectPtr n) : TransformMemorizer(n) {} + + ConvertTransformMemorizerNode* operator->() { + return static_cast(get_mutable()); + } + using ContainerType = ConvertTransformMemorizerNode; }; diff --git a/src/relay/transforms/transform_layout.h b/src/relay/transforms/transform_layout.h index fbb7bc9cd985..7bfb31a299ad 100644 --- a/src/relay/transforms/transform_layout.h +++ b/src/relay/transforms/transform_layout.h @@ -57,6 +57,21 @@ class TransformMemorizerNode : public Object { } }; + /*! + * \brief Defines the call transformation for derived passes. The new layouts are defined by + * used for different targets using a packed func. + * \param ref_call The original call. + * \param new_attrs Updated attributes consistent with new layouts. + * \param new_args The traversed/recursed args to the call. + * \return The new Call after calling the packed func. + */ + virtual Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs, + const std::vector& new_args) = 0; + + virtual Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) { + return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); + } + /*! \brief The memorizer map. */ std::unordered_map memo; @@ -69,11 +84,9 @@ class TransformMemorizerNode : public Object { */ class TransformMemorizer : public ObjectRef { public: - TransformMemorizer() {} + TransformMemorizer() = default; explicit TransformMemorizer(ObjectPtr n) : ObjectRef(n) {} - virtual ~TransformMemorizer() {} - TransformMemorizerNode* operator->() { return static_cast(get_mutable()); } @@ -146,19 +159,6 @@ class TransformMemorizer : public ObjectRef { return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name()); } - /*! - * \brief Defines the call transformation for derived passes. The new layouts are defined by - * used for different targets using a packed func. - * \param ref_call The original call. - * \param new_attrs Updated attributes consistent with new layouts. - * \param new_args The traversed/recursed args to the call. - * \return The new Call after calling the packed func. - */ - virtual Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs, - const std::vector& new_args) = 0; - virtual Call CallWithNewLayouts(const Call& ref_call, const std::vector& new_args) { - return CallWithNewLayouts(ref_call, ref_call->attrs, new_args); - } using ContainerType = TransformMemorizerNode; }; @@ -312,7 +312,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj if (ref_call->op.as()) { Op op = Downcast(ref_call->op); if (falter_layout.count(op) && !finfer_layout.count(op)) { - return memorizer.CallWithNewLayouts(ref_call, normal_new_args); + return memorizer->CallWithNewLayouts(ref_call, normal_new_args); } } } @@ -349,7 +349,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array& new_args, const Obj } // new_op = alter(op) - Call new_call = memorizer.CallWithNewLayouts(ref_call, infer_out->new_attrs, normal_new_args); + Call new_call = memorizer->CallWithNewLayouts(ref_call, infer_out->new_attrs, normal_new_args); // new_in2, new_out = op.infer(new_in) if (new_call->op->IsInstance()) { diff --git a/tutorials/dev/use_pass_infra.py b/tutorials/dev/use_pass_infra.py index 468c4d40b942..67cdfdedce0e 100644 --- a/tutorials/dev/use_pass_infra.py +++ b/tutorials/dev/use_pass_infra.py @@ -69,20 +69,6 @@ def example(): return relay.Function([x, weight], z2) -############################################################################### -# Let us register layout alteration for a conv2d op so that we can apply the -# layout alteration pass on the example. How alter layout pass works is out -# the scope of this tutorial. - - -@relay.op.register_alter_op_layout("nn.conv2d", level=101) -def alter_conv2d(attrs, inputs, tinfos, out_type): - data, weight = inputs - new_attrs = dict(attrs) - new_attrs["data_layout"] = "NCHW16c" - return relay.nn.conv2d(data, weight, **new_attrs) - - ############################################################################### # Optimize the Program # -------------------- @@ -188,21 +174,6 @@ def alter_conv2d(attrs, inputs, tinfos, out_type): mod3 = seq(mod) print(mod3) -############################################################################### -# The passes applied so far are target independent. The pass infra also -# provides a means to make pass target-aware. For example, the layout -# alteration pass falls in such category. - -with tvm.transform.PassContext(opt_level=3): - mod4 = seq(mod) -print(mod4) - -seq1 = tvm.transform.Sequential([relay.transform.AlterOpLayout()]) -with tvm.transform.PassContext(opt_level=3): - with tvm.target.Target("llvm"): - mod5 = seq1(mod) -print(mod5) - ############################################################################## # Implement a Pass Using Python Decorator # ------------------------------------------ @@ -257,7 +228,6 @@ def visit_constant(self, c): tvm.transform.PrintIR(), relay.transform.EliminateCommonSubexpr(), relay.transform.FuseOps(), - relay.transform.AlterOpLayout(), ] )