From fc411dc6fad14909ed17ce8c39d621d4587441bc Mon Sep 17 00:00:00 2001 From: Ashutosh Parkhi <86472128+ashutosh-arm@users.noreply.github.com> Date: Mon, 8 Aug 2022 10:16:33 +0100 Subject: [PATCH] Pass that removes reshapes post LowerTE (#12215) Introduces a Pass for removing intermediate reshapes post LowerTE() in AOT compiler. This commit adds pass specific tests and updates usmp generated workspace pools due to reduction in number of allocations post reshape removals. Note: this pass at present does not support first reshape appearing in the graph. If seen as a useful case, it can be added in the future. --- include/tvm/relay/transform.h | 11 + src/relay/backend/aot_executor_codegen.cc | 7 +- .../transforms/remove_standalone_reshapes.cc | 120 ++++++++ .../contrib/test_cmsisnn/test_conv2d.py | 1 + .../contrib/test_cmsisnn/test_pooling.py | 8 +- .../test_cmsisnn/test_remove_reshapes.py | 169 ++++++++++++ .../contrib/test_ethosu/test_networks.py | 10 +- tests/python/relay/aot/test_crt_aot.py | 2 +- tests/python/relay/aot/test_crt_aot_usmp.py | 36 +-- .../test_pass_remove_standalone_reshapes.py | 260 ++++++++++++++++++ 10 files changed, 597 insertions(+), 27 deletions(-) create mode 100644 src/relay/transforms/remove_standalone_reshapes.cc create mode 100644 tests/python/contrib/test_cmsisnn/test_remove_reshapes.py create mode 100644 tests/python/relay/backend/test_pass_remove_standalone_reshapes.py diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index f60912fb012e..b37d0f83adf3 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -585,6 +585,17 @@ TVM_DLL Pass CapturePostDfsIndexInSpans(); * expr->memory_scope and annotates expressions by VirtualDevice with required memory_scope */ TVM_DLL Pass AnnotateMemoryScope(CompilationConfig config); + +/*! + * \brief Removes non-fused reshapes after lowering the graph. + * InferType() cannot be invoked after calling this pass as it removes reshapes from the call + * graph. Many targets only need buffer addresses irrespective of the shapes of them. This makes + * reshapes symbolic once the graph has been lowered. Reshape removal results into smaller code + * size and reduced buffer allocations. It opens up opportunities of operator fusion in the target + * backend. Thus, consequently, it improves the performance of the inference. + */ +TVM_DLL Pass RemoveStandaloneReshapes(); + } // namespace transform /*! diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index b380f7b7c8b8..6a9cadb6f770 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1096,6 +1096,12 @@ class AOTExecutorCodegen : public MixedModeVisitor { tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment); })(mod); + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool enable_remove_reshapes = + pass_ctx->GetConfig("relay.remove_standalone_reshapes.enable", Bool(true)).value(); + if (enable_remove_reshapes) { + lowered_mod = transform::RemoveStandaloneReshapes()(lowered_mod); + } auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); @@ -1203,7 +1209,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { // Parallel for loops are not supported in AoT codegen. lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod); - transform::PassContext pass_ctx = transform::PassContext::Current(); bool enable_usmp = pass_ctx->GetConfig(kUSMPEnableOption, Bool(false)).value(); if (enable_usmp) { lowered_mod = PlanMemoryWithUSMP(lowered_mod); diff --git a/src/relay/transforms/remove_standalone_reshapes.cc b/src/relay/transforms/remove_standalone_reshapes.cc new file mode 100644 index 000000000000..28924e8bdfed --- /dev/null +++ b/src/relay/transforms/remove_standalone_reshapes.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relay/transforms/remove_standalone_reshapes.cc + * \brief This file contains the Relay pass for removing unfused reshapes from lowered graph. + */ + +#include +#include + +#include "../op/call/call.h" +#include "../op/memory/on_device.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_PASS_CONFIG_OPTION("relay.remove_standalone_reshapes.enable", Bool); +/*! Removes reshapes right after LowerTE. Removes preceding on_device calls + * while removing reshapes. + */ +class RemoveStandaloneReshapesMutator : public MixedModeMutator { + public: + explicit RemoveStandaloneReshapesMutator(IRModule& mod) : ir_module_(mod) {} + + using MixedModeMutator::VisitExpr_; + + /*! * \brief Generated map of let variables to preceding CallLowered */ + Expr VisitExpr_(const LetNode* let) final { + Let ret_let; + Var var = Downcast(this->Mutate(let->var)); + auto value = this->Mutate(let->value); + if (auto* on_device_call = value.as()) { + OnDeviceProps on_device_props = GetOnDeviceProps(on_device_call); + if (on_device_props.body.defined() && on_device_props.body->IsInstance()) { + const Call call_lowered = Downcast(on_device_props.body); + if (call_lowered.defined() && call_lowered->op.same_as(CallLoweredOp())) { + let_var_to_call_lowered_.Set(var, call_lowered); + } + } + } + auto body = this->Mutate(let->body); + return WithFields(GetRef(let), var, value, body); + } + + /*! * \brief Returns preceding CallLowered when call is a CallLowered(Reshape) */ + Expr Rewrite_(const CallNode* call, const Expr& post) final { + /* + %1 = call_lowered(@tvmgen_default_non_reshape_function, %input, ...); + let %x: = on_device(%1, ...); + %2 = (%x,); + %3 = call_lowered(@tvmgen_default_fused_reshape, %2, ..., + "relay_attrs"=__dict__="relay.reshape_only"=1, ...); + */ + const CallNode* post_call = post.as(); + CallLoweredProps call_lowered_props = GetCallLoweredProps(post_call); + if (call_lowered_props.lowered_func.defined() && IsReshapeOnly(call_lowered_props)) { + if (!call_lowered_props.arguments.empty() && + call_lowered_props.arguments[0]->IsInstance()) { + Var var = Downcast(call_lowered_props.arguments[0]); + if (var.defined() && let_var_to_call_lowered_.find(var) != let_var_to_call_lowered_.end()) { + return let_var_to_call_lowered_[var]; + } + } + } + + return post; + } + + private: + /*! \brief Map of LetNode's var to previous call_lowered. */ + Map let_var_to_call_lowered_; + /*! \brief Module that contains global reshape functions. */ + IRModule& ir_module_; +}; + +namespace transform { + +Pass RemoveStandaloneReshapes() { + auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) { + VLOG(1) << "RemoveStandaloneReshapes before:" << std::endl << PrettyPrint(mod); + RemoveStandaloneReshapesMutator remove_reshapes_mutator(mod); + Function main_func = Downcast(mod->Lookup("main")); + Expr new_main_body = remove_reshapes_mutator.VisitExpr(main_func->body); + if (!new_main_body.same_as(main_func->body)) { + auto main_var = mod->GetGlobalVar("main"); + auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type, + main_func->type_params, main_func->attrs); + mod->Update(main_var, new_main_func); + } + Array entry_functions{"main"}; + mod = RemoveUnusedFunctions(entry_functions)(mod); + + VLOG(1) << "RemoveStandaloneReshapes after:" << std::endl << PrettyPrint(mod); + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "RemoveStandaloneReshapes", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.RemoveStandaloneReshapes") + .set_body_typed(RemoveStandaloneReshapes); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py index 623f5c0fc0d7..502743387bfa 100644 --- a/tests/python/contrib/test_cmsisnn/test_conv2d.py +++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py @@ -669,6 +669,7 @@ def test_relay_conv2d_cmsisnn_depthwise_int8( cmsisnn_func = cmsisnn_tir_mod["tvmgen_default_cmsis_nn_main_0"] call_extern = None + # This happens when context buffer is init in case depthM != 1 if isinstance(cmsisnn_func.body, tvm.tir.stmt.Evaluate): call_extern = cmsisnn_func.body.value else: diff --git a/tests/python/contrib/test_cmsisnn/test_pooling.py b/tests/python/contrib/test_cmsisnn/test_pooling.py index e96f397c04da..29140ad2e656 100644 --- a/tests/python/contrib/test_cmsisnn/test_pooling.py +++ b/tests/python/contrib/test_cmsisnn/test_pooling.py @@ -15,14 +15,18 @@ # specific language governing permissions and limitations # under the License. -"""CMSIS-NN integration tests: Conv2D""" +"""CMSIS-NN integration tests: Pooling""" import numpy as np import pytest import tvm from tvm import relay from tvm.relay.op.contrib import cmsisnn -from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data +from tvm.testing.aot import ( + generate_ref_data, + AOTTestModel, + compile_and_run, +) from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER from .utils import ( make_module, diff --git a/tests/python/contrib/test_cmsisnn/test_remove_reshapes.py b/tests/python/contrib/test_cmsisnn/test_remove_reshapes.py new file mode 100644 index 000000000000..8b33a8a90b76 --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/test_remove_reshapes.py @@ -0,0 +1,169 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN integration tests: Reshape removal""" +import numpy as np +import pytest +import tvm +from tvm import relay +from tvm.relay.op.contrib import cmsisnn + +from tvm.testing.aot import ( + generate_ref_data, + AOTTestModel, + compile_models, + run_and_check, +) +from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER +from .utils import ( + make_module, + get_range_for_dtype_str, + get_same_padding, + make_qnn_relu, + assert_partitioned_function, +) + + +def make_model( + pool_op, + shape=(1, 28, 28, 12), + pool_size=(3, 3), + strides=(2, 2), + padding="VALID", + dtype="int8", + scale=1, + zero_point=-33, + relu_type="RELU", + layout="NHWC", + input_op=None, +): + """Return a model and any parameters it may have, + all parameters are defaulted to known good values + """ + if input_op: + op = input_op + else: + op = relay.var("input", shape=shape, dtype=dtype) + pad_ = (0, 0, 0, 0) + if padding == "SAME": + dilation = (1, 1) + pad_ = get_same_padding((shape[1], shape[2]), pool_size, dilation, strides) + op = relay.nn.pad( + op, + pad_width=[(0, 0), (pad_[0], pad_[2]), (pad_[1], pad_[3]), (0, 0)], + pad_value=zero_point, + pad_mode="constant", + ) + if pool_op.__name__ == relay.nn.avg_pool2d.__name__: + op = relay.cast(op, "int32") + op = pool_op( + op, pool_size=pool_size, strides=strides, padding=pad_, ceil_mode=True, layout=layout + ) + if pool_op.__name__ == relay.nn.avg_pool2d.__name__: + op = relay.cast(op, dtype) + op = make_qnn_relu(op, relu_type, scale, zero_point, dtype) + return op + + +@tvm.testing.requires_cmsisnn +@pytest.mark.parametrize("padding", ["SAME", "VALID"]) +def test_reshape_removal(padding): + """Tests reshape is removed from the network""" + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_USMP_CORSTONE300_RUNNER + + in_shape = (1, 28, 28, 12) + pool_size = (3, 3) + strides = (2, 2) + relu_type = "NONE" + zero_point, scale = (-34, 0.0256) + + max_pool = make_model( + pool_op=relay.nn.max_pool2d, + shape=in_shape, + pool_size=pool_size, + strides=strides, + padding=padding, + scale=scale, + zero_point=zero_point, + relu_type=relu_type, + ) + new_shape = (1, 28, 28, 3) if padding == "VALID" else (1, 30, 30, 3) + reshape = relay.reshape(max_pool, newshape=new_shape) + + model = make_model( + pool_op=relay.nn.avg_pool2d, + shape=new_shape, + pool_size=pool_size, + strides=strides, + padding=padding, + scale=scale, + zero_point=zero_point, + relu_type=relu_type, + input_op=reshape, + ) + orig_mod = make_module(model) + + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + # validate pattern matching + assert_partitioned_function(orig_mod, cmsisnn_mod) + + # generate reference output + rng = np.random.default_rng(12345) + in_min, in_max = get_range_for_dtype_str("int8") + inputs = {"input": rng.integers(in_min, high=in_max, size=in_shape, dtype="int8")} + output_list = generate_ref_data(orig_mod["main"], inputs, params=None) + + # validate presence of depthwise convolution + compiled_models = compile_models( + AOTTestModel( + module=cmsisnn_mod, + inputs=inputs, + outputs=output_list, + params=None, + output_tolerance=1, + ), + interface_api, + use_unpacked_api, + pass_config=test_runner.pass_config, + ) + + main_mod = None + for target, mod in compiled_models[0].executor_factory.lowered_ir_mods.items(): + if target.kind.name == "c": + main_mod = mod + + # when padding="SAME", extra padding is introduced which causes Reshape to be fused with the + # Pad. RemoveReshapes pass cannot remove a fused Reshape. Whereas padding="VALID" doesn't need + # an extra Pad layer. In this case, the pass removes the Reshape from the graph. + reshapes_present = any(["reshape" in gv.name_hint for gv in main_mod.get_global_vars()]) + check_reshapes = reshapes_present if padding == "SAME" else not reshapes_present + expected_reshapes = "a" if padding == "SAME" else "No" + assert check_reshapes, "Expeting {} reshape layer(s).".format(expected_reshapes) + + # validate the output + run_and_check( + models=compiled_models, + runner=test_runner, + interface_api=interface_api, + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/contrib/test_ethosu/test_networks.py b/tests/python/contrib/test_ethosu/test_networks.py index 02643f6c1ded..2b4ffd96caef 100644 --- a/tests/python/contrib/test_ethosu/test_networks.py +++ b/tests/python/contrib/test_ethosu/test_networks.py @@ -45,12 +45,12 @@ "accel_type, model_url, workspace_size", [ ("ethos-u65-256", MOBILENET_V1_URL, 1793376), - ("ethos-u65-256", MOBILENET_V2_URL, 2218160), + ("ethos-u65-256", MOBILENET_V2_URL, 2217152), ("ethos-u55-256", MOBILENET_V1_URL, 1793376), - ("ethos-u55-256", MOBILENET_V2_URL, 2218160), - ("ethos-u55-128", MOBILENET_V2_URL, 2218160), - ("ethos-u55-64", MOBILENET_V2_URL, 2218160), - ("ethos-u55-32", MOBILENET_V2_URL, 2218160), + ("ethos-u55-256", MOBILENET_V2_URL, 2217152), + ("ethos-u55-128", MOBILENET_V2_URL, 2217152), + ("ethos-u55-64", MOBILENET_V2_URL, 2217152), + ("ethos-u55-32", MOBILENET_V2_URL, 2217152), ], ) def test_networks_without_usmp(accel_type, model_url, workspace_size): diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 987d425aa63d..edf23ff22781 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -998,7 +998,7 @@ def test_workspace_calculation_cmsis_nn(): ): lib = tvm.relay.build(mod, target, executor=executor, runtime=runtime, params=params) mlf_memory_map = mlf._build_function_memory_map(lib.function_metadata) - assert mlf_memory_map["main"][0]["workspace_size_bytes"] == 14384 + assert mlf_memory_map["main"][0]["workspace_size_bytes"] == 14256 def test_aot_codegen_checks_returns(): diff --git a/tests/python/relay/aot/test_crt_aot_usmp.py b/tests/python/relay/aot/test_crt_aot_usmp.py index 724932183a54..b79350d172ac 100644 --- a/tests/python/relay/aot/test_crt_aot_usmp.py +++ b/tests/python/relay/aot/test_crt_aot_usmp.py @@ -105,24 +105,24 @@ def test_synthetic(interface_api, use_unpacked_api, test_runner): "workspace_byte_alignment,constant_byte_alignment," "main_workspace_size,main_constant_size,usmp_algo", [ - (8, 8, 17280, 948, "greedy_by_conflicts"), - (16, 8, 17280, 948, "greedy_by_conflicts"), - (256, 8, 17792, 948, "greedy_by_conflicts"), - (8, 16, 17280, 956, "greedy_by_conflicts"), - (16, 16, 17280, 956, "greedy_by_conflicts"), - (256, 16, 17792, 956, "greedy_by_conflicts"), - (8, 256, 17280, 1804, "greedy_by_conflicts"), - (16, 256, 17280, 1804, "greedy_by_conflicts"), - (256, 256, 17792, 1804, "greedy_by_conflicts"), - (8, 8, 22032, 948, "greedy_by_size"), - (16, 8, 22032, 948, "greedy_by_size"), - (256, 8, 22976, 948, "greedy_by_size"), - (8, 16, 22032, 956, "greedy_by_size"), - (16, 16, 22032, 956, "greedy_by_size"), - (256, 16, 22976, 956, "greedy_by_size"), - (8, 256, 22032, 1804, "greedy_by_size"), - (16, 256, 22032, 1804, "greedy_by_size"), - (256, 256, 22976, 1804, "greedy_by_size"), + (8, 8, 14208, 948, "greedy_by_conflicts"), + (16, 8, 14208, 948, "greedy_by_conflicts"), + (256, 8, 14720, 948, "greedy_by_conflicts"), + (8, 16, 14208, 956, "greedy_by_conflicts"), + (16, 16, 14208, 956, "greedy_by_conflicts"), + (256, 16, 14720, 956, "greedy_by_conflicts"), + (8, 256, 14208, 1804, "greedy_by_conflicts"), + (16, 256, 14208, 1804, "greedy_by_conflicts"), + (256, 256, 14720, 1804, "greedy_by_conflicts"), + (8, 8, 18576, 948, "greedy_by_size"), + (16, 8, 18576, 948, "greedy_by_size"), + (256, 8, 19392, 948, "greedy_by_size"), + (8, 16, 18576, 956, "greedy_by_size"), + (16, 16, 18576, 956, "greedy_by_size"), + (256, 16, 19392, 956, "greedy_by_size"), + (8, 256, 18576, 1804, "greedy_by_size"), + (16, 256, 18576, 1804, "greedy_by_size"), + (256, 256, 19392, 1804, "greedy_by_size"), (8, 8, 11424, 948, "hill_climb"), (16, 8, 11424, 948, "hill_climb"), (256, 8, 11920, 948, "hill_climb"), diff --git a/tests/python/relay/backend/test_pass_remove_standalone_reshapes.py b/tests/python/relay/backend/test_pass_remove_standalone_reshapes.py new file mode 100644 index 000000000000..2113ae7b5c72 --- /dev/null +++ b/tests/python/relay/backend/test_pass_remove_standalone_reshapes.py @@ -0,0 +1,260 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Exercises the RemoveStandaloneReshapes pass. + +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprMutator +import tvm.testing +from tvm.script import tir as T + + +HOST_DEVICE = tvm.device("cpu") +HOST_TARGET = tvm.target.Target("llvm") + +CPU_DEVICE = tvm.device("cpu") +CPU_TARGET = tvm.target.Target("llvm").with_host(HOST_TARGET) + +CPU = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET) # device_type=1 + + +RemoveStandaloneReshapes = tvm._ffi.get_global_func("relay._transform.RemoveStandaloneReshapes") + + +class MarkReshapeOnlyMutator(ExprMutator): + """A pass for marking call_lowered as ReshapeOnly where reshapes exist unfused""" + + def __init__(self): + ExprMutator.__init__(self) + + def visit_call(self, call): + if isinstance(call.args[0], tvm.ir.GlobalVar) and "reshape" in call.args[0].name_hint: + # attrs = {"relay_attrs" : {"relay.reshape_only" : 1}} + dict_attrs = tvm.ir.make_node("DictAttrs", **{"relay.reshape_only": 1}) + attrs = tvm.ir.make_node( + "relay.attrs.CallLoweredAttrs", **{"metadata": {"relay_attrs": dict_attrs}} + ) + return relay.Call(call.op, call.args, attrs) + return super().visit_call(call) + + +# Reshape should not be removed if its the first layer in the network +def test_first_reshape(): + mod = tvm.ir.IRModule() + + @T.prim_func + def reshape_primfunc(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + D = T.match_buffer(d, [128, 128]) + + for i, j in T.grid(128, 128): + D[i, j] = A[i, j] + + metatable = {"VirtualDevice": [CPU]} + reshape_ty = relay.FuncType( + [ + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), + ) + + reshape_gv = relay.GlobalVar("reshape", type_annot=reshape_ty) + mod[reshape_gv] = reshape_primfunc + mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + virtual_device=meta[VirtualDevice][0]) { + %1 = call_lowered(@reshape, (%x,) ); + let %x_14: Tensor[(128, 128), float32] = on_device(%1, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %x_14 + } + """, + "from_string", + mod, + metatable, + ) + + mod["main"] = MarkReshapeOnlyMutator().visit(mod["main"]) + mod = RemoveStandaloneReshapes()(mod) + reshapes_present = any(["reshape" in gv.name_hint for gv in mod.get_global_vars()]) + assert reshapes_present, "Reshape should have been removed." + return + + +# When reshape layer is the last one in the network +def test_last_reshape(): + mod = tvm.ir.IRModule() + + @T.prim_func + def mul_primfunc(a: T.handle, b: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + D = T.match_buffer(d, [128, 128]) + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + D[vi, vj] = A[vi, vk] * B[vj, vk] + + @T.prim_func + def reshape_primfunc(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + D = T.match_buffer(d, [128, 128]) + + for i, j in T.grid(128, 128): + D[i, j] = A[i, j] + + metatable = {"VirtualDevice": [CPU]} + mul_ty = relay.FuncType( + [ + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), + ) + + mul_gv = relay.GlobalVar("multiply", type_annot=mul_ty) + mod[mul_gv] = mul_primfunc + reshape_ty = relay.FuncType( + [ + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), + ) + + reshape_gv = relay.GlobalVar("reshape", type_annot=reshape_ty) + mod[reshape_gv] = reshape_primfunc + mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + %z {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + virtual_device=meta[VirtualDevice][0]) { + %0 = call_lowered(@multiply, (%x, %y, %z)); + let %x_12: Tensor[(128, 128), float32] = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %1 = call_lowered(@reshape, (%x_12,) ); + let %x_14: Tensor[(128, 128), float32] = on_device(%1, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %x_14 + } + """, + "from_string", + mod, + metatable, + ) + + # Expected main: + ##[version = "0.0.5"] + # def @main(%x /* ty=Tensor[(128, 128), float32] */) -> Tensor[(128, 128), float32] { + # %0 = (%x, %y, %z); + # %1 = call_lowered(@multiply, %0); + # let %x_12: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True); + # let %x_14: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True); + # %x_14 + # } + + mod["main"] = MarkReshapeOnlyMutator().visit(mod["main"]) + mod = RemoveStandaloneReshapes()(mod) + reshapes_present = any(["reshape" in gv.name_hint for gv in mod.get_global_vars()]) + assert not reshapes_present, "Reshape should have been removed." + return + + +# When reshape layer is not marked as reshape_only +def test_fused_reshape(): + mod = tvm.ir.IRModule() + + @T.prim_func + def mul_primfunc(a: T.handle, b: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + D = T.match_buffer(d, [128, 128]) + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + D[vi, vj] = A[vi, vk] * B[vj, vk] + + @T.prim_func + def fused_reshape_primfunc(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + D = T.match_buffer(d, [128, 128]) + + for i, j in T.grid(128, 128): + D[i, j] = A[i, j] + + metatable = {"VirtualDevice": [CPU]} + mul_ty = relay.FuncType( + [ + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), + ) + + mul_gv = relay.GlobalVar("multiply", type_annot=mul_ty) + mod[mul_gv] = mul_primfunc + reshape_ty = relay.FuncType( + [ + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), + ) + + reshape_gv = relay.GlobalVar("fused_reshape", type_annot=reshape_ty) + mod[reshape_gv] = fused_reshape_primfunc + mod = tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + %y {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + %z {virtual_device=meta[VirtualDevice][0]}: Tensor[(128, 128), float32], + virtual_device=meta[VirtualDevice][0]) { + %0 = call_lowered(@multiply, (%x, %y, %z)); + let %x_12: Tensor[(128, 128), float32] = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %1 = call_lowered(@fused_reshape, (%x_12,) ); + let %x_14: Tensor[(128, 128), float32] = on_device(%1, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %x_14 + } + """, + "from_string", + mod, + metatable, + ) + + # Expected main: + ##[version = "0.0.5"] + # def @main(%x /* ty=Tensor[(128, 128), float32] */) -> Tensor[(128, 128), float32] { + # %0 = (%x, %y, %z); + # %1 = call_lowered(@multiply, %0); + # let %x_12: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True); + # let %x_14: Tensor[(128, 128), float32] = on_device(%1, constrain_result=True); + # %x_14 + # } + + mod = RemoveStandaloneReshapes()(mod) + reshapes_present = any(["reshape" in gv.name_hint for gv in mod.get_global_vars()]) + assert reshapes_present, "Reshape should have been removed." + return + + +if __name__ == "__main__": + tvm.testing.main()