From beadb143a45d9db9cf7d5657e772ba38b0f84b0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Fri, 28 Jun 2019 08:11:51 -0700 Subject: [PATCH] [Relay] Feature Detection (#3238) * init init lint rename ci fix add add some doc save add some test add some test lint lint lint * fix build --- include/tvm/expr.h | 2 +- include/tvm/relay/feature.h | 170 ++++++++++++++++++ include/tvm/relay/type.h | 4 +- python/tvm/relay/feature.py | 41 +++++ python/tvm/relay/ir_pass.py | 27 ++- python/tvm/relay/prelude.py | 6 +- src/relay/ir/type.cc | 4 +- src/relay/pass/alter_op_layout.cc | 6 +- src/relay/pass/feature.cc | 104 +++++++++++ src/relay/pass/let_list.h | 3 +- src/relay/pass/partial_eval.cc | 4 - src/relay/pass/pass_util.h | 15 +- src/relay/pass/util.cc | 4 +- src/relay/pass/well_formed.cc | 4 +- tests/python/relay/test_feature.py | 67 +++++++ .../relay/test_pass_to_a_normal_form.py | 10 +- .../relay/test_pass_to_graph_normal_form.py | 7 +- 17 files changed, 448 insertions(+), 30 deletions(-) create mode 100644 include/tvm/relay/feature.h create mode 100644 python/tvm/relay/feature.py create mode 100644 src/relay/pass/feature.cc create mode 100644 tests/python/relay/test_feature.py diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 57cd4fdadd75..60a6d971ad44 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -140,7 +140,7 @@ class Integer : public Expr { */ operator int64_t() const { CHECK(node_ != nullptr) - << " Trying get reference a null Integer"; + << " Trying to reference a null Integer"; return (*this)->value; } /*! \brief type indicate the container type */ diff --git a/include/tvm/relay/feature.h b/include/tvm/relay/feature.h new file mode 100644 index 000000000000..a8b60e7806fe --- /dev/null +++ b/include/tvm/relay/feature.h @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relay/feature.h + * \brief Detect features used in Expr/Module. + */ +#ifndef TVM_RELAY_FEATURE_H_ +#define TVM_RELAY_FEATURE_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Different kinds of relay feature a program might use. */ +enum Feature : int { + fVar = 0, + fGlobalVar = 1, + fConstant = 2, + fTuple = 3, + fTupleGetItem = 4, + fFunction = 5, + fOp = 6, + fCall = 7, + fLet = 8, + fIf = 9, + fRefCreate = 10, + fRefRead = 11, + fRefWrite = 12, + fConstructor = 13, + fMatch = 14, + /*! \brief Whether any non-atom fragment of the program is shared, making the program a graph. */ + fGraph = 15, + /*! \brief Whether there is local fixpoint in the program. */ + fLetRec = 16 +}; + +constexpr size_t feature_count = 17; + +/*! + * \brief A finite set of Feature. + */ +class FeatureSet { + public: + FeatureSet(const FeatureSet&) = default; + /*! \brief A singleton set containing a single Feature. */ + explicit FeatureSet(Feature ft) { + bs_.set(static_cast(ft)); + } + explicit FeatureSet(const tvm::Array& ft) { + for (Integer i : ft) { + (*this) += Feature(static_cast(i)); + } + } + explicit operator Array() const { + Array ret; + for (size_t i = 0; i < feature_count; ++i) { + if (bs_[i]) { + ret.push_back(Integer(i)); + } + } + return ret; + } + /*! \brief A set that contain all the Feature. */ + static FeatureSet AllFeature() { + FeatureSet fs; + fs.bs_.flip(); + return fs; + } + /*! \brief The empty set. Contain no Feature. */ + static FeatureSet NoFeature() { + FeatureSet fs; + return fs; + } + template + FeatureSet& operator+=(const T& rhs) { + bs_ |= FeatureSet(rhs).bs_; + return *this; + } + /*! \brief Set union. */ + template + FeatureSet operator+(const T& rhs) const { + FeatureSet fs(*this); + fs += rhs; + return fs; + } + template + FeatureSet& operator-=(const T& rhs) { + bs_ &= ~(FeatureSet(rhs)).bs_; + return *this; + } + /*! \brief Set difference. */ + template + FeatureSet operator-(const T& rhs) const { + FeatureSet fs(*this); + fs -= rhs; + return fs; + } + /*! + * \brief Is this a subset of rhs? + * + * \param rhs another FeatureSet. + * + * \return true only if this is a subset of rhs. + */ + bool is_subset_of(const FeatureSet& rhs) const { + return ((*this) - rhs).bs_.none(); + } + + private: + std::bitset bs_; + FeatureSet() = default; + explicit FeatureSet(const std::bitset& bs) : bs_(bs) { } +}; + +class Expr; +/*! + * \brief Calculate the feature of the program. + * + * \param expr The expression. + * + * \return The FeatureSet. + */ +FeatureSet DetectFeature(const Expr& expr); + +struct Module; +/*! + * \brief Calculate the feature of the program. + * + * \param mod The module. + * + * \return The FeatureSet. + */ +FeatureSet DetectFeature(const Module& mod); + +/*! + * \brief Calculate the feature of the program. + * + * \param expr The expression. + * \param mod The module. + * + * \return The FeatureSet. + */ +inline FeatureSet DetectFeature(const Expr& expr, const Module& mod) { + return DetectFeature(expr) + DetectFeature(mod); +} + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_FEATURE_H_ diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 452e3b6eb864..e42ef1f65ba2 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -116,10 +116,10 @@ class TensorTypeNode : public BaseTensorTypeNode { RELAY_DEFINE_NODE_REF(TensorType, TensorTypeNode, Type); -/*! \brief possible kinds of Type */ +/*! \brief Possible kinds of Type. */ enum Kind : int { - /*! \brief template variable in shape expression */ kType = 0, + /*! \brief Template variable in shape expression. */ kShapeVar = 1, kBaseType = 2, kShape = 3, diff --git a/python/tvm/relay/feature.py b/python/tvm/relay/feature.py new file mode 100644 index 000000000000..68502672682d --- /dev/null +++ b/python/tvm/relay/feature.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +"""The type nodes of the Relay language.""" +from enum import IntEnum + +class Feature(IntEnum): + """ The features a program might contain. """ + fVar = 0 + fGlobalVar = 1 + fConstant = 2 + fTuple = 3 + fTupleGetItem = 4 + fFunction = 5 + fOp = 6 + fCall = 7 + fLet = 8 + fIf = 9 + fRefCreate = 10 + fRefRead = 11 + fRefWrite = 12 + fConstructor = 13 + fMatch = 14 + """ Whether any non-atom fragment of the program is shared, making the program a graph. """ + fGraph = 15 + """ Whether there is local fixpoint in the program. """ + fLetRec = 16 diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index dd0f54c664ca..1748571cb316 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -25,6 +25,7 @@ from .expr import Expr from .ty import Type from .module import Module +from .feature import Feature def post_order_visit(expr, fvisit): @@ -604,7 +605,6 @@ def gradient(expr, mod=None, mode='higher_order'): raise Exception('unknown mode') - def get_total_mac_number(expr): """ Count the number of MACs (multiply-accumulate) of a model @@ -641,6 +641,7 @@ def eliminate_common_subexpr(expr, fskip=None): """ return _ir_pass.eliminate_common_subexpr(expr, fskip) + def partial_evaluate(expr, mod=None): """ Evaluate the static fragment of the code. @@ -660,6 +661,7 @@ def partial_evaluate(expr, mod=None): """ return _ir_pass.partial_evaluate(expr, mod) + def unmatched_cases(match, mod=None): """ Finds cases that the match expression does not catch, if any. @@ -677,3 +679,26 @@ def unmatched_cases(match, mod=None): Patterns that the match expression does not catch. """ return _ir_pass.unmatched_cases(match, mod) + + +def detect_feature(a, b=None): + """ + Detect the feature used in a relay program. + + Parameters + ---------- + a : Union[tvm.relay.Expr, tvm.relay.Module] + The input expression or module. + + b : Optional[Union[tvm.relay.Expr, tvm.relay.Module]] + The input expression or module. + The two arguments cannot both be expression or module. + + Returns + ------- + features : Set[Feature] + Features used in the program. + """ + if isinstance(a, Module): + a, b = b, a + return set([Feature(int(x)) for x in _ir_pass.detect_feature(a, b)]) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index af0497e3801a..fcb2d67a3314 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -23,8 +23,8 @@ from .adt import Constructor, TypeData, Clause, Match from .adt import PatternConstructor, PatternVar, PatternWildcard from .parser import fromtext - __PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__)) +from .module import Module class Prelude: """Contains standard definitions.""" @@ -486,7 +486,9 @@ def load_prelude(self): self.compose = self.mod.get_global_var("compose") - def __init__(self, mod): + def __init__(self, mod=None): + if mod is None: + mod = Module() self.mod = mod self.load_prelude() self.define_list_adt() diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 8f0bdcba2b1b..35a12052949e 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index d623393049a6..cc71968fba58 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * \file alter_op_layout.cc * \brief Alternate the layouts of operators or replace primitive operators with other expressions. This pass can be used for computing convolution in diff --git a/src/relay/pass/feature.cc b/src/relay/pass/feature.cc new file mode 100644 index 000000000000..e86ca0621112 --- /dev/null +++ b/src/relay/pass/feature.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file feature.cc + * \brief Detect features used in Expr/Module + */ +#include +#include +#include +#include +#include +#include "pass_util.h" + +namespace tvm { +namespace relay { + +FeatureSet DetectFeature(const Expr& expr) { + if (!expr.defined()) { + return FeatureSet::NoFeature(); + } + struct FeatureDetector : ExprVisitor { + std::unordered_set visited_; + FeatureSet fs = FeatureSet::NoFeature(); + void VisitExpr(const Expr& expr) final { + if (visited_.count(expr) == 0) { + ExprVisitor::VisitExpr(expr); + } else { + if (!IsAtomic(expr)) { + fs += fGraph; + } + } + } +#define DETECT_CONSTRUCT(CONSTRUCT_NAME, STMT) \ + void VisitExpr_(const CONSTRUCT_NAME##Node* op) final { \ + STMT \ + fs += f##CONSTRUCT_NAME; \ + ExprVisitor::VisitExpr_(op); \ + } +#define DETECT_DEFAULT_CONSTRUCT(CONSTRUCT_NAME) DETECT_CONSTRUCT(CONSTRUCT_NAME, {}) + DETECT_DEFAULT_CONSTRUCT(Var) + DETECT_DEFAULT_CONSTRUCT(GlobalVar) + DETECT_DEFAULT_CONSTRUCT(Constant) + DETECT_DEFAULT_CONSTRUCT(Tuple) + DETECT_DEFAULT_CONSTRUCT(TupleGetItem) + DETECT_DEFAULT_CONSTRUCT(Function) + DETECT_DEFAULT_CONSTRUCT(Op) + DETECT_DEFAULT_CONSTRUCT(Call) + DETECT_CONSTRUCT(Let, { + for (const Var& v : FreeVars(op->value)) { + if (op->var == v) { + fs += fLetRec; + } + } + }) + DETECT_DEFAULT_CONSTRUCT(If) + DETECT_DEFAULT_CONSTRUCT(RefCreate) + DETECT_DEFAULT_CONSTRUCT(RefRead) + DETECT_DEFAULT_CONSTRUCT(RefWrite) + DETECT_DEFAULT_CONSTRUCT(Constructor) + DETECT_DEFAULT_CONSTRUCT(Match) +#undef DETECT_DEFAULT_CONSTRUCT + } fd; + fd(expr); + return fd.fs; +} + +FeatureSet DetectFeature(const Module& mod) { + FeatureSet fs = FeatureSet::NoFeature(); + if (mod.defined()) { + for (const auto& f : mod->functions) { + fs += DetectFeature(f.second); + } + } + return fs; +} + +Array PyDetectFeature(const Expr& expr, const Module& mod) { + FeatureSet fs = DetectFeature(expr) + DetectFeature(mod); + return static_cast>(fs); +} + +TVM_REGISTER_API("relay._ir_pass.detect_feature") +.set_body_typed(PyDetectFeature); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index bd36a15c843c..9f56b22fc13e 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -39,7 +39,8 @@ namespace tvm { namespace relay { -/*! \brief LetList allow you to transform expression into variables, so you can copy them around. +/*! + * \brief LetList allow you to transform expression into variables, so you can copy them around. * one can insert into the LetList by calling Push, and wrap an expression with bindings with Get. * additionally, there is the 'With' function, which automatically call Get. */ diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index f1ca573d3e0e..b95c5844f8a4 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -389,10 +389,6 @@ FInterpreter CPUInterpreter() { return CreateInterpreter(Module(nullptr), CPUContext(), target); } -bool IsAtomic(const Expr& e) { - return e.as() || e.as() || e.as() || e.as(); -} - using FuncId = int; /*! diff --git a/src/relay/pass/pass_util.h b/src/relay/pass/pass_util.h index 38d8b0bd9040..386d1d889ea8 100644 --- a/src/relay/pass/pass_util.h +++ b/src/relay/pass/pass_util.h @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -97,6 +97,17 @@ inline Expr TransformF(const std::function& func, const Expr& } } +/*! + * \brief Decide whether the expression atomic or not? + * \param e the expression + * \return + * is it atomic? + * if so, the compute cost of the expression is bounded so it can be copy without graph mode. + */ +inline bool IsAtomic(const Expr& e) { + return e.as() || e.as() || e.as() || e.as(); +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_PASS_UTIL_H_ diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 8e02cf127bfd..3ec4f75cd1ad 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/src/relay/pass/well_formed.cc b/src/relay/pass/well_formed.cc index 4eaaa934e78b..dea937481289 100644 --- a/src/relay/pass/well_formed.cc +++ b/src/relay/pass/well_formed.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY diff --git a/tests/python/relay/test_feature.py b/tests/python/relay/test_feature.py new file mode 100644 index 000000000000..637e184704f2 --- /dev/null +++ b/tests/python/relay/test_feature.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from tvm import relay +from tvm.relay.ir_pass import detect_feature, gradient +from tvm.relay.feature import Feature +from tvm.relay.prelude import Prelude + +def test_prelude(): + p = Prelude() + feats = detect_feature(p.mod) + assert feats == set([ + Feature.fVar, + Feature.fGlobalVar, + Feature.fConstant, + Feature.fTuple, + Feature.fTupleGetItem, + Feature.fFunction, + Feature.fOp, + Feature.fCall, + Feature.fLet, + Feature.fIf, + Feature.fConstructor, + Feature.fMatch + ]) + + +def test_ad(): + shape = (10, 10) + dtype = 'float32' + t = relay.TensorType(shape, dtype) + x = relay.var("x", t) + func = relay.Function([x], x + x) + back_func = relay.ir_pass.infer_type(gradient(func)) + feats = detect_feature(back_func) + assert feats == set([ + Feature.fVar, + Feature.fTuple, + Feature.fTupleGetItem, + Feature.fFunction, + Feature.fOp, + Feature.fCall, + Feature.fLet, + Feature.fRefCreate, + Feature.fRefRead, + Feature.fRefWrite + ]) + + +if __name__ == '__main__': + test_prelude() + test_ad() diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index db40c86d4b28..9a2570eabb11 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -17,11 +17,12 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type +from tvm.relay.ir_pass import to_a_normal_form, alpha_equal, infer_type, detect_feature from tvm.relay import op, create_executor from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue from tvm.relay.prelude import Prelude from tvm.relay.testing import add_nat_definitions, count +from tvm.relay.feature import Feature def check_eval(expr, expected_result, mod=None, rtol=1e-07): @@ -37,9 +38,9 @@ def test_explicit_bound(): y = op.add(x, x) z = op.add(y, y) f = relay.Function([], op.add(z, z)) - assert not "let" in f.astext() # assert the values are implicitly bounded + assert not Feature.fLet in detect_feature(f) anf = to_a_normal_form(f) - assert "let" in anf.astext() # assert the values are explicitly bounded + assert Feature.fLet in detect_feature(anf) check_eval(f(), 8.0) check_eval(anf(), 8.0) @@ -144,7 +145,7 @@ def test_nat_add(): assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 - assert "let" in mod[add].astext() + assert Feature.fLet in detect_feature(mod[add]) def test_let(): @@ -173,7 +174,6 @@ def test_function(): test_if() test_recursion() test_ref() - test_add() test_let() test_nat_add() test_function() diff --git a/tests/python/relay/test_pass_to_graph_normal_form.py b/tests/python/relay/test_pass_to_graph_normal_form.py index 75975663a20c..6d9bd6ac254e 100644 --- a/tests/python/relay/test_pass_to_graph_normal_form.py +++ b/tests/python/relay/test_pass_to_graph_normal_form.py @@ -17,8 +17,9 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal +from tvm.relay.ir_pass import to_graph_normal_form, to_a_normal_form, alpha_equal, detect_feature from tvm.relay import op, create_executor +from tvm.relay.feature import Feature from tvm.relay.backend.interpreter import Value, TupleValue @@ -56,8 +57,8 @@ def test_round_trip(): f = relay.Function([], relay.Let(x, relay.const(1), body)) g = to_graph_normal_form(f) h = to_a_normal_form(g) - assert "let" in f.astext() - assert not "let" in g.astext() + assert Feature.fLet in detect_feature(f) + assert not Feature.fLet in detect_feature(g) check_eval(f, [], 8.0) check_eval(g, [], 8.0) check_eval(h, [], 8.0)