-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pass that removes reshapes post LowerTE (#12215)
Introduces a Pass for removing intermediate reshapes post LowerTE() in AOT compiler. This commit adds pass specific tests and updates usmp generated workspace pools due to reduction in number of allocations post reshape removals. Note: this pass at present does not support first reshape appearing in the graph. If seen as a useful case, it can be added in the future.
- Loading branch information
1 parent
c4aab62
commit fc411dc
Showing
10 changed files
with
597 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
/*! | ||
* \file src/relay/transforms/remove_standalone_reshapes.cc | ||
* \brief This file contains the Relay pass for removing unfused reshapes from lowered graph. | ||
*/ | ||
|
||
#include <tvm/relay/expr_functor.h> | ||
#include <tvm/relay/transform.h> | ||
|
||
#include "../op/call/call.h" | ||
#include "../op/memory/on_device.h" | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
TVM_REGISTER_PASS_CONFIG_OPTION("relay.remove_standalone_reshapes.enable", Bool); | ||
/*! Removes reshapes right after LowerTE. Removes preceding on_device calls | ||
* while removing reshapes. | ||
*/ | ||
class RemoveStandaloneReshapesMutator : public MixedModeMutator { | ||
public: | ||
explicit RemoveStandaloneReshapesMutator(IRModule& mod) : ir_module_(mod) {} | ||
|
||
using MixedModeMutator::VisitExpr_; | ||
|
||
/*! * \brief Generated map of let variables to preceding CallLowered */ | ||
Expr VisitExpr_(const LetNode* let) final { | ||
Let ret_let; | ||
Var var = Downcast<Var>(this->Mutate(let->var)); | ||
auto value = this->Mutate(let->value); | ||
if (auto* on_device_call = value.as<CallNode>()) { | ||
OnDeviceProps on_device_props = GetOnDeviceProps(on_device_call); | ||
if (on_device_props.body.defined() && on_device_props.body->IsInstance<CallNode>()) { | ||
const Call call_lowered = Downcast<Call>(on_device_props.body); | ||
if (call_lowered.defined() && call_lowered->op.same_as(CallLoweredOp())) { | ||
let_var_to_call_lowered_.Set(var, call_lowered); | ||
} | ||
} | ||
} | ||
auto body = this->Mutate(let->body); | ||
return WithFields(GetRef<Let>(let), var, value, body); | ||
} | ||
|
||
/*! * \brief Returns preceding CallLowered when call is a CallLowered(Reshape) */ | ||
Expr Rewrite_(const CallNode* call, const Expr& post) final { | ||
/* | ||
%1 = call_lowered(@tvmgen_default_non_reshape_function, %input, ...); | ||
let %x: = on_device(%1, ...); | ||
%2 = (%x,); | ||
%3 = call_lowered(@tvmgen_default_fused_reshape, %2, ..., | ||
"relay_attrs"=__dict__="relay.reshape_only"=1, ...); | ||
*/ | ||
const CallNode* post_call = post.as<CallNode>(); | ||
CallLoweredProps call_lowered_props = GetCallLoweredProps(post_call); | ||
if (call_lowered_props.lowered_func.defined() && IsReshapeOnly(call_lowered_props)) { | ||
if (!call_lowered_props.arguments.empty() && | ||
call_lowered_props.arguments[0]->IsInstance<VarNode>()) { | ||
Var var = Downcast<Var>(call_lowered_props.arguments[0]); | ||
if (var.defined() && let_var_to_call_lowered_.find(var) != let_var_to_call_lowered_.end()) { | ||
return let_var_to_call_lowered_[var]; | ||
} | ||
} | ||
} | ||
|
||
return post; | ||
} | ||
|
||
private: | ||
/*! \brief Map of LetNode's var to previous call_lowered. */ | ||
Map<Var, Call> let_var_to_call_lowered_; | ||
/*! \brief Module that contains global reshape functions. */ | ||
IRModule& ir_module_; | ||
}; | ||
|
||
namespace transform { | ||
|
||
Pass RemoveStandaloneReshapes() { | ||
auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) { | ||
VLOG(1) << "RemoveStandaloneReshapes before:" << std::endl << PrettyPrint(mod); | ||
RemoveStandaloneReshapesMutator remove_reshapes_mutator(mod); | ||
Function main_func = Downcast<Function>(mod->Lookup("main")); | ||
Expr new_main_body = remove_reshapes_mutator.VisitExpr(main_func->body); | ||
if (!new_main_body.same_as(main_func->body)) { | ||
auto main_var = mod->GetGlobalVar("main"); | ||
auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type, | ||
main_func->type_params, main_func->attrs); | ||
mod->Update(main_var, new_main_func); | ||
} | ||
Array<runtime::String> entry_functions{"main"}; | ||
mod = RemoveUnusedFunctions(entry_functions)(mod); | ||
|
||
VLOG(1) << "RemoveStandaloneReshapes after:" << std::endl << PrettyPrint(mod); | ||
return mod; | ||
}; | ||
return tvm::transform::CreateModulePass(pass_func, 0, "RemoveStandaloneReshapes", {}); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("relay._transform.RemoveStandaloneReshapes") | ||
.set_body_typed(RemoveStandaloneReshapes); | ||
|
||
} // namespace transform | ||
} // namespace relay | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
169 changes: 169 additions & 0 deletions
169
tests/python/contrib/test_cmsisnn/test_remove_reshapes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
"""CMSIS-NN integration tests: Reshape removal""" | ||
import numpy as np | ||
import pytest | ||
import tvm | ||
from tvm import relay | ||
from tvm.relay.op.contrib import cmsisnn | ||
|
||
from tvm.testing.aot import ( | ||
generate_ref_data, | ||
AOTTestModel, | ||
compile_models, | ||
run_and_check, | ||
) | ||
from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER | ||
from .utils import ( | ||
make_module, | ||
get_range_for_dtype_str, | ||
get_same_padding, | ||
make_qnn_relu, | ||
assert_partitioned_function, | ||
) | ||
|
||
|
||
def make_model( | ||
pool_op, | ||
shape=(1, 28, 28, 12), | ||
pool_size=(3, 3), | ||
strides=(2, 2), | ||
padding="VALID", | ||
dtype="int8", | ||
scale=1, | ||
zero_point=-33, | ||
relu_type="RELU", | ||
layout="NHWC", | ||
input_op=None, | ||
): | ||
"""Return a model and any parameters it may have, | ||
all parameters are defaulted to known good values | ||
""" | ||
if input_op: | ||
op = input_op | ||
else: | ||
op = relay.var("input", shape=shape, dtype=dtype) | ||
pad_ = (0, 0, 0, 0) | ||
if padding == "SAME": | ||
dilation = (1, 1) | ||
pad_ = get_same_padding((shape[1], shape[2]), pool_size, dilation, strides) | ||
op = relay.nn.pad( | ||
op, | ||
pad_width=[(0, 0), (pad_[0], pad_[2]), (pad_[1], pad_[3]), (0, 0)], | ||
pad_value=zero_point, | ||
pad_mode="constant", | ||
) | ||
if pool_op.__name__ == relay.nn.avg_pool2d.__name__: | ||
op = relay.cast(op, "int32") | ||
op = pool_op( | ||
op, pool_size=pool_size, strides=strides, padding=pad_, ceil_mode=True, layout=layout | ||
) | ||
if pool_op.__name__ == relay.nn.avg_pool2d.__name__: | ||
op = relay.cast(op, dtype) | ||
op = make_qnn_relu(op, relu_type, scale, zero_point, dtype) | ||
return op | ||
|
||
|
||
@tvm.testing.requires_cmsisnn | ||
@pytest.mark.parametrize("padding", ["SAME", "VALID"]) | ||
def test_reshape_removal(padding): | ||
"""Tests reshape is removed from the network""" | ||
interface_api = "c" | ||
use_unpacked_api = True | ||
test_runner = AOT_USMP_CORSTONE300_RUNNER | ||
|
||
in_shape = (1, 28, 28, 12) | ||
pool_size = (3, 3) | ||
strides = (2, 2) | ||
relu_type = "NONE" | ||
zero_point, scale = (-34, 0.0256) | ||
|
||
max_pool = make_model( | ||
pool_op=relay.nn.max_pool2d, | ||
shape=in_shape, | ||
pool_size=pool_size, | ||
strides=strides, | ||
padding=padding, | ||
scale=scale, | ||
zero_point=zero_point, | ||
relu_type=relu_type, | ||
) | ||
new_shape = (1, 28, 28, 3) if padding == "VALID" else (1, 30, 30, 3) | ||
reshape = relay.reshape(max_pool, newshape=new_shape) | ||
|
||
model = make_model( | ||
pool_op=relay.nn.avg_pool2d, | ||
shape=new_shape, | ||
pool_size=pool_size, | ||
strides=strides, | ||
padding=padding, | ||
scale=scale, | ||
zero_point=zero_point, | ||
relu_type=relu_type, | ||
input_op=reshape, | ||
) | ||
orig_mod = make_module(model) | ||
|
||
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) | ||
|
||
# validate pattern matching | ||
assert_partitioned_function(orig_mod, cmsisnn_mod) | ||
|
||
# generate reference output | ||
rng = np.random.default_rng(12345) | ||
in_min, in_max = get_range_for_dtype_str("int8") | ||
inputs = {"input": rng.integers(in_min, high=in_max, size=in_shape, dtype="int8")} | ||
output_list = generate_ref_data(orig_mod["main"], inputs, params=None) | ||
|
||
# validate presence of depthwise convolution | ||
compiled_models = compile_models( | ||
AOTTestModel( | ||
module=cmsisnn_mod, | ||
inputs=inputs, | ||
outputs=output_list, | ||
params=None, | ||
output_tolerance=1, | ||
), | ||
interface_api, | ||
use_unpacked_api, | ||
pass_config=test_runner.pass_config, | ||
) | ||
|
||
main_mod = None | ||
for target, mod in compiled_models[0].executor_factory.lowered_ir_mods.items(): | ||
if target.kind.name == "c": | ||
main_mod = mod | ||
|
||
# when padding="SAME", extra padding is introduced which causes Reshape to be fused with the | ||
# Pad. RemoveReshapes pass cannot remove a fused Reshape. Whereas padding="VALID" doesn't need | ||
# an extra Pad layer. In this case, the pass removes the Reshape from the graph. | ||
reshapes_present = any(["reshape" in gv.name_hint for gv in main_mod.get_global_vars()]) | ||
check_reshapes = reshapes_present if padding == "SAME" else not reshapes_present | ||
expected_reshapes = "a" if padding == "SAME" else "No" | ||
assert check_reshapes, "Expeting {} reshape layer(s).".format(expected_reshapes) | ||
|
||
# validate the output | ||
run_and_check( | ||
models=compiled_models, | ||
runner=test_runner, | ||
interface_api=interface_api, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
tvm.testing.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.