From a75674d1fe586e0306e5d8ff8b9d052967374a6d Mon Sep 17 00:00:00 2001 From: An Wang Date: Wed, 1 Sep 2021 10:42:28 -0700 Subject: [PATCH 01/11] add extractor --- python/tvm/relay/analysis/analysis.py | 15 +++ src/relay/analysis/extract_operators.cc | 70 +++++++++++ .../relay/test_analysis_extract_operators.py | 116 ++++++++++++++++++ 3 files changed, 201 insertions(+) create mode 100644 src/relay/analysis/extract_operators.cc create mode 100644 tests/python/relay/test_analysis_extract_operators.py diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index c7b6c60849a1..2c888d95cd1e 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -384,6 +384,21 @@ def extract_fused_functions(mod): return ret +def extract_operators(mod): + """Pass to extract operator frequencies from an IRModule. + + Parameters + ---------- + mod : tvm.IRModule + + Returns + ------- + ret : Dict[str, int] + Dict of operator name to the number of times it appears in mod + """ + return _ffi_api.ExtractOperators(mod) + + def search_fc_transpose(expr): """Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0])) diff --git a/src/relay/analysis/extract_operators.cc b/src/relay/analysis/extract_operators.cc new file mode 100644 index 000000000000..6cb2f5919c7f --- /dev/null +++ b/src/relay/analysis/extract_operators.cc @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file extract_operators.cc + * \brief Extract operator frequencies from an IRModule + */ +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +using PackedOperatorFrequencyMap = Map; + +class OperatorExtractorWrapper : private ExprVisitor { + public: + explicit OperatorExtractorWrapper(const IRModule& mod) : mod_(mod) {} + + Map Extract() { + VisitExpr(this->mod_->Lookup("main")); + + return this->operator_freqs; + } + + private: + const IRModule mod_; + // Map of operator name to the number of times they appear in the module. + Map operator_freqs; + + void VisitExpr_(const OpNode* n) final { + + auto it = this->operator_freqs.find(n->name); + if (it == this->operator_freqs.end()) { + this->operator_freqs.Set(n->name, 0U); + } + std::cout << n->name << std::endl; + + this->operator_freqs.Set(n->name, 1 + this->operator_freqs.at(n->name)); + + ExprVisitor::VisitExpr_(n); + } +}; + +PackedOperatorFrequencyMap ExtractOperatorsPacked(const IRModule& mod) { + return OperatorExtractorWrapper(mod).Extract(); +} + +TVM_REGISTER_GLOBAL("relay.analysis.ExtractOperators").set_body_typed(ExtractOperatorsPacked); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_analysis_extract_operators.py b/tests/python/relay/test_analysis_extract_operators.py new file mode 100644 index 000000000000..e4985dfd366c --- /dev/null +++ b/tests/python/relay/test_analysis_extract_operators.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test function extraction""" +import tvm +from tvm import relay +from tvm.relay.testing.resnet import get_workload + + +def get_conv_net(): + """This gets the net for a case described in fuse_ops.cc: + conv2d + / | \ + / | \ + op op op + \ | / + \ | / + elemwise add + | + """ + dshape = (1, 1, 5, 1) + x = relay.var("x", shape=dshape) + y = relay.nn.conv2d(x, relay.var("w1"), + kernel_size=(3, 3), + padding=(1, 1), + channels=1) + + x1 = relay.nn.conv2d(y, relay.var("w2"), + kernel_size=(3, 3), + padding=(1, 1), + channels=1) + x2 = relay.nn.conv2d(y, relay.var("w3"), + kernel_size=(3, 3), + padding=(1, 1), + channels=1) + x3 = relay.nn.conv2d(y, relay.var("w4"), + kernel_size=(3, 3), + padding=(1, 1), + channels=1) + + z = relay.add(x1, x2) + z = relay.add(x3, z) + + return tvm.IRModule.from_expr(z) + + +def get_conv2d(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 32)) + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + return tvm.IRModule.from_expr(y) + + +def test_extract_identity(): + mod = get_conv2d() + items = relay.analysis.extract_operators(mod) + assert len(items) == 1 + breakpoint() + print("extract_identity") + + assert items["nn.conv2d"] == 1 + + +def test_extract_conv_net(): + mod = get_conv_net() + items = relay.analysis.extract_operators(mod) + breakpoint() + functions = list(items.values()) + assert len(functions) == 2 + x = functions[0] + y = functions[1] + + def is_conv(func): + conv2d = relay.op.op.get("nn.conv2d") + call_node = func.body + return call_node.op == conv2d + + def is_conv_add(func): + add = relay.op.op.get("add") + call_node = func.body + maybe_conv_module = tvm.IRModule.from_expr(call_node.args[0]) + return call_node.op == add and is_conv(maybe_conv_module["main"]) + + # Function traversal order isn't obvious, so checking both orders is more consistent + assert (is_conv(x) and is_conv_add(y)) or (is_conv_add(x) and is_conv(y)) + + +def test_extract_resnet(): + mod, _params = get_workload() + items = relay.analysis.extract_operators(mod) + breakpoint() + assert len(items) == 34 + + +if __name__ == '__main__': + test_extract_identity() + test_extract_conv_net() + test_extract_resnet() From 1ffb2d63d803caaa865e48db55002a63a172e74a Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 13 Sep 2021 13:31:44 -0700 Subject: [PATCH 02/11] extract to array --- python/tvm/relay/analysis/analysis.py | 6 +- src/relay/analysis/extract_operators.cc | 23 ++-- .../relay/test_analysis_extract_operators.py | 114 ++++++++---------- 3 files changed, 57 insertions(+), 86 deletions(-) diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 2c888d95cd1e..d07109f861fb 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -385,7 +385,7 @@ def extract_fused_functions(mod): def extract_operators(mod): - """Pass to extract operator frequencies from an IRModule. + """Pass to extract operator names from an IRModule. Parameters ---------- @@ -393,8 +393,8 @@ def extract_operators(mod): Returns ------- - ret : Dict[str, int] - Dict of operator name to the number of times it appears in mod + ret : List[str] + List of unique operator names """ return _ffi_api.ExtractOperators(mod) diff --git a/src/relay/analysis/extract_operators.cc b/src/relay/analysis/extract_operators.cc index 6cb2f5919c7f..022f8b130f98 100644 --- a/src/relay/analysis/extract_operators.cc +++ b/src/relay/analysis/extract_operators.cc @@ -19,7 +19,7 @@ /*! * \file extract_operators.cc - * \brief Extract operator frequencies from an IRModule + * \brief Extract unique operators from an IRModule */ #include #include @@ -29,38 +29,29 @@ namespace tvm { namespace relay { -using PackedOperatorFrequencyMap = Map; - class OperatorExtractorWrapper : private ExprVisitor { public: explicit OperatorExtractorWrapper(const IRModule& mod) : mod_(mod) {} - Map Extract() { + Array Extract() { VisitExpr(this->mod_->Lookup("main")); - return this->operator_freqs; + return this->operators; } private: const IRModule mod_; - // Map of operator name to the number of times they appear in the module. - Map operator_freqs; + // Array of unique operator names + Array operators; void VisitExpr_(const OpNode* n) final { - - auto it = this->operator_freqs.find(n->name); - if (it == this->operator_freqs.end()) { - this->operator_freqs.Set(n->name, 0U); - } - std::cout << n->name << std::endl; - - this->operator_freqs.Set(n->name, 1 + this->operator_freqs.at(n->name)); + this->operators.push_back(n->name); ExprVisitor::VisitExpr_(n); } }; -PackedOperatorFrequencyMap ExtractOperatorsPacked(const IRModule& mod) { +Array ExtractOperatorsPacked(const IRModule& mod) { return OperatorExtractorWrapper(mod).Extract(); } diff --git a/tests/python/relay/test_analysis_extract_operators.py b/tests/python/relay/test_analysis_extract_operators.py index e4985dfd366c..f178c90e1480 100644 --- a/tests/python/relay/test_analysis_extract_operators.py +++ b/tests/python/relay/test_analysis_extract_operators.py @@ -21,96 +21,76 @@ def get_conv_net(): - """This gets the net for a case described in fuse_ops.cc: - conv2d - / | \ - / | \ - op op op - \ | / - \ | / - elemwise add - | + """This gets the net for: + conv2d + / | + / | + conv2d | + \ | + \ | + elemwise add + | """ dshape = (1, 1, 5, 1) x = relay.var("x", shape=dshape) - y = relay.nn.conv2d(x, relay.var("w1"), - kernel_size=(3, 3), - padding=(1, 1), - channels=1) + y = relay.nn.conv2d(x, relay.var("w1"), kernel_size=(3, 3), padding=(1, 1), channels=1) + x1 = relay.nn.conv2d(y, relay.var("w2"), kernel_size=(3, 3), padding=(1, 1), channels=1) - x1 = relay.nn.conv2d(y, relay.var("w2"), - kernel_size=(3, 3), - padding=(1, 1), - channels=1) - x2 = relay.nn.conv2d(y, relay.var("w3"), - kernel_size=(3, 3), - padding=(1, 1), - channels=1) - x3 = relay.nn.conv2d(y, relay.var("w4"), - kernel_size=(3, 3), - padding=(1, 1), - channels=1) - - z = relay.add(x1, x2) - z = relay.add(x3, z) + z = relay.add(y, x1) return tvm.IRModule.from_expr(z) def get_conv2d(): x = relay.var("x", shape=(1, 56, 56, 64)) - weight1 = relay.var('weight1', shape=(3, 3, 64, 32)) - y = relay.nn.conv2d(x, weight1, - channels=32, - kernel_size=(3, 3), - padding=(1, 1), - data_layout='NHWC', - kernel_layout='HWIO') + weight1 = relay.var("weight1", shape=(3, 3, 64, 32)) + y = relay.nn.conv2d( + x, + weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) return tvm.IRModule.from_expr(y) def test_extract_identity(): mod = get_conv2d() - items = relay.analysis.extract_operators(mod) - assert len(items) == 1 - breakpoint() - print("extract_identity") - - assert items["nn.conv2d"] == 1 + ops = relay.analysis.extract_operators(mod) + assert len(ops) == 1 + assert ops[0] == "nn.conv2d" def test_extract_conv_net(): mod = get_conv_net() - items = relay.analysis.extract_operators(mod) - breakpoint() - functions = list(items.values()) - assert len(functions) == 2 - x = functions[0] - y = functions[1] - - def is_conv(func): - conv2d = relay.op.op.get("nn.conv2d") - call_node = func.body - return call_node.op == conv2d - - def is_conv_add(func): - add = relay.op.op.get("add") - call_node = func.body - maybe_conv_module = tvm.IRModule.from_expr(call_node.args[0]) - return call_node.op == add and is_conv(maybe_conv_module["main"]) - - # Function traversal order isn't obvious, so checking both orders is more consistent - assert (is_conv(x) and is_conv_add(y)) or (is_conv_add(x) and is_conv(y)) + ops = relay.analysis.extract_operators(mod) + assert len(ops) == 2 + assert "add" in ops + assert "nn.conv2d" in ops def test_extract_resnet(): mod, _params = get_workload() - items = relay.analysis.extract_operators(mod) - breakpoint() - assert len(items) == 34 - - -if __name__ == '__main__': + expected_ops = [ + "nn.batch_norm", + "nn.conv2d", + "nn.relu", + "nn.max_pool2d", + "add", + "nn.global_avg_pool2d", + "nn.batch_flatten", + "nn.dense", + "nn.bias_add", + "nn.softmax", + ] + ops = relay.analysis.extract_operators(mod) + assert len(ops) == len(expected_ops) + assert all([op in ops for op in expected_ops]) + + +if __name__ == "__main__": test_extract_identity() test_extract_conv_net() test_extract_resnet() From 70f49d690cb74b4cf535a6a6dd75b0b5037e056d Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 13 Sep 2021 13:35:53 -0700 Subject: [PATCH 03/11] add comments --- python/tvm/relay/analysis/analysis.py | 2 +- src/relay/analysis/extract_operators.cc | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index d07109f861fb..0ca59e803865 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -385,7 +385,7 @@ def extract_fused_functions(mod): def extract_operators(mod): - """Pass to extract operator names from an IRModule. + """Pass to extract unique operator names from an IRModule. Parameters ---------- diff --git a/src/relay/analysis/extract_operators.cc b/src/relay/analysis/extract_operators.cc index 022f8b130f98..d1ffc1befb50 100644 --- a/src/relay/analysis/extract_operators.cc +++ b/src/relay/analysis/extract_operators.cc @@ -41,10 +41,12 @@ class OperatorExtractorWrapper : private ExprVisitor { private: const IRModule mod_; - // Array of unique operator names + // Array of unique operator names. Array operators; void VisitExpr_(const OpNode* n) final { + // NOTE: OpNode is visited only once for every operator kind + // regardless of how many times that op appears in the graph. this->operators.push_back(n->name); ExprVisitor::VisitExpr_(n); From 2f11222a451d9d962f00022f205386ea82183cbd Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 13 Sep 2021 13:41:40 -0700 Subject: [PATCH 04/11] lint --- src/relay/analysis/extract_operators.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/analysis/extract_operators.cc b/src/relay/analysis/extract_operators.cc index d1ffc1befb50..51e44da2f225 100644 --- a/src/relay/analysis/extract_operators.cc +++ b/src/relay/analysis/extract_operators.cc @@ -54,7 +54,7 @@ class OperatorExtractorWrapper : private ExprVisitor { }; Array ExtractOperatorsPacked(const IRModule& mod) { - return OperatorExtractorWrapper(mod).Extract(); + return OperatorExtractorWrapper(mod).Extract(); } TVM_REGISTER_GLOBAL("relay.analysis.ExtractOperators").set_body_typed(ExtractOperatorsPacked); From 7323302e01823996f431290329cd1c8d867c57fe Mon Sep 17 00:00:00 2001 From: anwang2009 Date: Mon, 13 Sep 2021 13:46:24 -0700 Subject: [PATCH 05/11] Update tests/python/relay/test_analysis_extract_operators.py Co-authored-by: Cody Yu --- tests/python/relay/test_analysis_extract_operators.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/relay/test_analysis_extract_operators.py b/tests/python/relay/test_analysis_extract_operators.py index f178c90e1480..acf5e3d0401e 100644 --- a/tests/python/relay/test_analysis_extract_operators.py +++ b/tests/python/relay/test_analysis_extract_operators.py @@ -91,6 +91,4 @@ def test_extract_resnet(): if __name__ == "__main__": - test_extract_identity() - test_extract_conv_net() - test_extract_resnet() + pytest.main([__file__]) From 4e7dc0e204dcad7b6808717887e4faddf2ada757 Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 13 Sep 2021 15:37:11 -0700 Subject: [PATCH 06/11] op freqs --- python/tvm/relay/analysis/analysis.py | 9 +-- src/relay/analysis/extract_operators.cc | 34 ++++++++--- .../relay/test_analysis_extract_operators.py | 57 ++++++++++++------- 3 files changed, 66 insertions(+), 34 deletions(-) diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 0ca59e803865..2537c081b05a 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -384,8 +384,9 @@ def extract_fused_functions(mod): return ret -def extract_operators(mod): - """Pass to extract unique operator names from an IRModule. +def list_op_frequencies(mod): + """Pass to extract unique operator names and how frequently they appear + in an IRModule. Parameters ---------- @@ -393,8 +394,8 @@ def extract_operators(mod): Returns ------- - ret : List[str] - List of unique operator names + ret : Dict[str, int] + Dict of unique operator names to frequency """ return _ffi_api.ExtractOperators(mod) diff --git a/src/relay/analysis/extract_operators.cc b/src/relay/analysis/extract_operators.cc index 51e44da2f225..3cbf6a55fee1 100644 --- a/src/relay/analysis/extract_operators.cc +++ b/src/relay/analysis/extract_operators.cc @@ -33,27 +33,45 @@ class OperatorExtractorWrapper : private ExprVisitor { public: explicit OperatorExtractorWrapper(const IRModule& mod) : mod_(mod) {} - Array Extract() { + Map Extract() { VisitExpr(this->mod_->Lookup("main")); - return this->operators; + // Map opname_freqs; + // for (const auto& kv : operator_freqs_) { + // opname_freqs.Set(kv.first->name, kv.second); + // } + // return opname_freqs; + return operator_freqs_; } private: const IRModule mod_; - // Array of unique operator names. - Array operators; + /*! \brief Map of operator to frequency. */ + Map operator_freqs_; + + void VisitExpr_(const CallNode* n) final { + VisitExpr(n->op); + + auto op = n->op.as(); + if (op) { + auto it = operator_freqs_.find(op->name); + ICHECK(it != operator_freqs_.end()) + << "Call's OpNode must be visited and registered before access"; + operator_freqs_.Set(op->name, 1 + operator_freqs_.at(op->name)); + } + + ExprVisitor::VisitExpr_(n); + } void VisitExpr_(const OpNode* n) final { + std::cout << "here " << n->name << std::endl; // NOTE: OpNode is visited only once for every operator kind // regardless of how many times that op appears in the graph. - this->operators.push_back(n->name); - - ExprVisitor::VisitExpr_(n); + operator_freqs_.Set(n->name, 0U); } }; -Array ExtractOperatorsPacked(const IRModule& mod) { +Map ExtractOperatorsPacked(const IRModule& mod) { return OperatorExtractorWrapper(mod).Extract(); } diff --git a/tests/python/relay/test_analysis_extract_operators.py b/tests/python/relay/test_analysis_extract_operators.py index acf5e3d0401e..a62b62a66f64 100644 --- a/tests/python/relay/test_analysis_extract_operators.py +++ b/tests/python/relay/test_analysis_extract_operators.py @@ -15,9 +15,11 @@ # specific language governing permissions and limitations # under the License. """Test function extraction""" +import pytest import tvm from tvm import relay from tvm.relay.testing.resnet import get_workload +from tvm.relay.testing import run_opt_pass def get_conv_net(): @@ -58,36 +60,47 @@ def get_conv2d(): def test_extract_identity(): mod = get_conv2d() - ops = relay.analysis.extract_operators(mod) - assert len(ops) == 1 - assert ops[0] == "nn.conv2d" + op_freqs = relay.analysis.list_op_frequencies(mod) + assert len(op_freqs) == 1 + assert op_freqs["nn.conv2d"] == 1 def test_extract_conv_net(): mod = get_conv_net() - ops = relay.analysis.extract_operators(mod) - assert len(ops) == 2 - assert "add" in ops - assert "nn.conv2d" in ops + op_freqs = relay.analysis.list_op_frequencies(mod) + assert len(op_freqs) == 2 + assert op_freqs["add"] == 1 + assert op_freqs["nn.conv2d"] == 2 + + +def test_extract_fused(): + mod = get_conv_net() + mod = relay.transform.InferType()(mod) + mod = relay.transform.FuseOps(3)(mod) + + op_freqs = relay.analysis.list_op_frequencies(mod) + assert len(op_freqs) == 2 + assert op_freqs["add"] == 1 + assert op_freqs["nn.conv2d"] == 2 def test_extract_resnet(): mod, _params = get_workload() - expected_ops = [ - "nn.batch_norm", - "nn.conv2d", - "nn.relu", - "nn.max_pool2d", - "add", - "nn.global_avg_pool2d", - "nn.batch_flatten", - "nn.dense", - "nn.bias_add", - "nn.softmax", - ] - ops = relay.analysis.extract_operators(mod) - assert len(ops) == len(expected_ops) - assert all([op in ops for op in expected_ops]) + expected_op_freqs = { + "nn.batch_norm": 19, + "nn.conv2d": 21, + "nn.relu": 18, + "nn.max_pool2d": 1, + "add": 8, + "nn.global_avg_pool2d": 1, + "nn.batch_flatten": 1, + "nn.dense": 1, + "nn.bias_add": 1, + "nn.softmax": 1, + } + op_freqs = relay.analysis.list_op_frequencies(mod) + assert len(op_freqs) == len(expected_op_freqs) + assert all([op_freqs[op] == expected_op_freqs[op] for op in expected_op_freqs]) if __name__ == "__main__": From f71b55e6b19b730c6840e25f0b37f50e49dcbf95 Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 13 Sep 2021 15:45:59 -0700 Subject: [PATCH 07/11] add comment --- python/tvm/relay/analysis/analysis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 2537c081b05a..9d93e86629f8 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -386,7 +386,8 @@ def extract_fused_functions(mod): def list_op_frequencies(mod): """Pass to extract unique operator names and how frequently they appear - in an IRModule. + in an IRModule. Fused functions are traversed to count the operators + that compose them. Parameters ---------- From ea4bc24da86ec1a2390258e1eab307f2500b8f0b Mon Sep 17 00:00:00 2001 From: anwang2009 Date: Mon, 13 Sep 2021 15:48:39 -0700 Subject: [PATCH 08/11] Update python/tvm/relay/analysis/analysis.py Co-authored-by: Cody Yu --- python/tvm/relay/analysis/analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 9d93e86629f8..524f69bcdd13 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -384,7 +384,7 @@ def extract_fused_functions(mod): return ret -def list_op_frequencies(mod): +def list_op_freqs(mod): """Pass to extract unique operator names and how frequently they appear in an IRModule. Fused functions are traversed to count the operators that compose them. From ee17c5bb6cb6da78e4fbba9275dcffe115c3f9fb Mon Sep 17 00:00:00 2001 From: An Wang Date: Mon, 13 Sep 2021 15:48:48 -0700 Subject: [PATCH 09/11] oops --- src/relay/analysis/extract_operators.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/relay/analysis/extract_operators.cc b/src/relay/analysis/extract_operators.cc index 3cbf6a55fee1..3558d4c374f1 100644 --- a/src/relay/analysis/extract_operators.cc +++ b/src/relay/analysis/extract_operators.cc @@ -36,11 +36,6 @@ class OperatorExtractorWrapper : private ExprVisitor { Map Extract() { VisitExpr(this->mod_->Lookup("main")); - // Map opname_freqs; - // for (const auto& kv : operator_freqs_) { - // opname_freqs.Set(kv.first->name, kv.second); - // } - // return opname_freqs; return operator_freqs_; } From 0d44b5085cb2cedbdb5cf8706039702f2d916ccb Mon Sep 17 00:00:00 2001 From: An Wang Date: Tue, 14 Sep 2021 10:53:53 -0700 Subject: [PATCH 10/11] mixedmode visitor --- src/relay/analysis/extract_operators.cc | 4 ++-- tests/python/relay/test_analysis_extract_operators.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/relay/analysis/extract_operators.cc b/src/relay/analysis/extract_operators.cc index 3558d4c374f1..36b68479ff87 100644 --- a/src/relay/analysis/extract_operators.cc +++ b/src/relay/analysis/extract_operators.cc @@ -29,7 +29,7 @@ namespace tvm { namespace relay { -class OperatorExtractorWrapper : private ExprVisitor { +class OperatorExtractorWrapper : private MixedModeVisitor { public: explicit OperatorExtractorWrapper(const IRModule& mod) : mod_(mod) {} @@ -55,7 +55,7 @@ class OperatorExtractorWrapper : private ExprVisitor { operator_freqs_.Set(op->name, 1 + operator_freqs_.at(op->name)); } - ExprVisitor::VisitExpr_(n); + MixedModeVisitor::VisitExpr_(n); } void VisitExpr_(const OpNode* n) final { diff --git a/tests/python/relay/test_analysis_extract_operators.py b/tests/python/relay/test_analysis_extract_operators.py index a62b62a66f64..5878b2a6e497 100644 --- a/tests/python/relay/test_analysis_extract_operators.py +++ b/tests/python/relay/test_analysis_extract_operators.py @@ -60,14 +60,14 @@ def get_conv2d(): def test_extract_identity(): mod = get_conv2d() - op_freqs = relay.analysis.list_op_frequencies(mod) + op_freqs = relay.analysis.list_op_freqs(mod) assert len(op_freqs) == 1 assert op_freqs["nn.conv2d"] == 1 def test_extract_conv_net(): mod = get_conv_net() - op_freqs = relay.analysis.list_op_frequencies(mod) + op_freqs = relay.analysis.list_op_freqs(mod) assert len(op_freqs) == 2 assert op_freqs["add"] == 1 assert op_freqs["nn.conv2d"] == 2 @@ -78,7 +78,7 @@ def test_extract_fused(): mod = relay.transform.InferType()(mod) mod = relay.transform.FuseOps(3)(mod) - op_freqs = relay.analysis.list_op_frequencies(mod) + op_freqs = relay.analysis.list_op_freqs(mod) assert len(op_freqs) == 2 assert op_freqs["add"] == 1 assert op_freqs["nn.conv2d"] == 2 @@ -98,7 +98,7 @@ def test_extract_resnet(): "nn.bias_add": 1, "nn.softmax": 1, } - op_freqs = relay.analysis.list_op_frequencies(mod) + op_freqs = relay.analysis.list_op_freqs(mod) assert len(op_freqs) == len(expected_op_freqs) assert all([op_freqs[op] == expected_op_freqs[op] for op in expected_op_freqs]) From 02f2f55000dc91f35d55e632fa6120558f5383a2 Mon Sep 17 00:00:00 2001 From: An Wang Date: Wed, 15 Sep 2021 09:55:07 -0700 Subject: [PATCH 11/11] oops --- src/relay/analysis/extract_operators.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/relay/analysis/extract_operators.cc b/src/relay/analysis/extract_operators.cc index 36b68479ff87..8fd0f87239ff 100644 --- a/src/relay/analysis/extract_operators.cc +++ b/src/relay/analysis/extract_operators.cc @@ -59,7 +59,6 @@ class OperatorExtractorWrapper : private MixedModeVisitor { } void VisitExpr_(const OpNode* n) final { - std::cout << "here " << n->name << std::endl; // NOTE: OpNode is visited only once for every operator kind // regardless of how many times that op appears in the graph. operator_freqs_.Set(n->name, 0U);