Skip to content

Commit

Permalink
Pass that removes reshapes post LowerTE (#12215)
Browse files Browse the repository at this point in the history
Introduces a Pass for removing intermediate reshapes post
LowerTE() in AOT compiler. This commit adds pass specific
tests and updates usmp generated workspace pools due to
reduction in number of allocations post reshape removals.

Note: this pass at present does not support first reshape
appearing in the graph. If seen as a useful case, it can be
added in the future.
  • Loading branch information
ashutosh-arm authored Aug 8, 2022
1 parent c4aab62 commit fc411dc
Show file tree
Hide file tree
Showing 10 changed files with 597 additions and 27 deletions.
11 changes: 11 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,17 @@ TVM_DLL Pass CapturePostDfsIndexInSpans();
* expr->memory_scope and annotates expressions by VirtualDevice with required memory_scope
*/
TVM_DLL Pass AnnotateMemoryScope(CompilationConfig config);

/*!
* \brief Removes non-fused reshapes after lowering the graph.
* InferType() cannot be invoked after calling this pass as it removes reshapes from the call
* graph. Many targets only need buffer addresses irrespective of the shapes of them. This makes
* reshapes symbolic once the graph has been lowered. Reshape removal results into smaller code
* size and reduced buffer allocations. It opens up opportunities of operator fusion in the target
* backend. Thus, consequently, it improves the performance of the inference.
*/
TVM_DLL Pass RemoveStandaloneReshapes();

} // namespace transform

/*!
Expand Down
7 changes: 6 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment);
})(mod);

transform::PassContext pass_ctx = transform::PassContext::Current();
bool enable_remove_reshapes =
pass_ctx->GetConfig<Bool>("relay.remove_standalone_reshapes.enable", Bool(true)).value();
if (enable_remove_reshapes) {
lowered_mod = transform::RemoveStandaloneReshapes()(lowered_mod);
}
auto lowered_main = lowered_mod->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

Expand Down Expand Up @@ -1203,7 +1209,6 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// Parallel for loops are not supported in AoT codegen.
lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod);

transform::PassContext pass_ctx = transform::PassContext::Current();
bool enable_usmp = pass_ctx->GetConfig<Bool>(kUSMPEnableOption, Bool(false)).value();
if (enable_usmp) {
lowered_mod = PlanMemoryWithUSMP(lowered_mod);
Expand Down
120 changes: 120 additions & 0 deletions src/relay/transforms/remove_standalone_reshapes.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/relay/transforms/remove_standalone_reshapes.cc
* \brief This file contains the Relay pass for removing unfused reshapes from lowered graph.
*/

#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>

#include "../op/call/call.h"
#include "../op/memory/on_device.h"

namespace tvm {
namespace relay {

TVM_REGISTER_PASS_CONFIG_OPTION("relay.remove_standalone_reshapes.enable", Bool);
/*! Removes reshapes right after LowerTE. Removes preceding on_device calls
* while removing reshapes.
*/
class RemoveStandaloneReshapesMutator : public MixedModeMutator {
public:
explicit RemoveStandaloneReshapesMutator(IRModule& mod) : ir_module_(mod) {}

using MixedModeMutator::VisitExpr_;

/*! * \brief Generated map of let variables to preceding CallLowered */
Expr VisitExpr_(const LetNode* let) final {
Let ret_let;
Var var = Downcast<Var>(this->Mutate(let->var));
auto value = this->Mutate(let->value);
if (auto* on_device_call = value.as<CallNode>()) {
OnDeviceProps on_device_props = GetOnDeviceProps(on_device_call);
if (on_device_props.body.defined() && on_device_props.body->IsInstance<CallNode>()) {
const Call call_lowered = Downcast<Call>(on_device_props.body);
if (call_lowered.defined() && call_lowered->op.same_as(CallLoweredOp())) {
let_var_to_call_lowered_.Set(var, call_lowered);
}
}
}
auto body = this->Mutate(let->body);
return WithFields(GetRef<Let>(let), var, value, body);
}

/*! * \brief Returns preceding CallLowered when call is a CallLowered(Reshape) */
Expr Rewrite_(const CallNode* call, const Expr& post) final {
/*
%1 = call_lowered(@tvmgen_default_non_reshape_function, %input, ...);
let %x: = on_device(%1, ...);
%2 = (%x,);
%3 = call_lowered(@tvmgen_default_fused_reshape, %2, ...,
"relay_attrs"=__dict__="relay.reshape_only"=1, ...);
*/
const CallNode* post_call = post.as<CallNode>();
CallLoweredProps call_lowered_props = GetCallLoweredProps(post_call);
if (call_lowered_props.lowered_func.defined() && IsReshapeOnly(call_lowered_props)) {
if (!call_lowered_props.arguments.empty() &&
call_lowered_props.arguments[0]->IsInstance<VarNode>()) {
Var var = Downcast<Var>(call_lowered_props.arguments[0]);
if (var.defined() && let_var_to_call_lowered_.find(var) != let_var_to_call_lowered_.end()) {
return let_var_to_call_lowered_[var];
}
}
}

return post;
}

private:
/*! \brief Map of LetNode's var to previous call_lowered. */
Map<Var, Call> let_var_to_call_lowered_;
/*! \brief Module that contains global reshape functions. */
IRModule& ir_module_;
};

namespace transform {

Pass RemoveStandaloneReshapes() {
auto pass_func = [=](IRModule mod, const PassContext& pass_ctx) {
VLOG(1) << "RemoveStandaloneReshapes before:" << std::endl << PrettyPrint(mod);
RemoveStandaloneReshapesMutator remove_reshapes_mutator(mod);
Function main_func = Downcast<Function>(mod->Lookup("main"));
Expr new_main_body = remove_reshapes_mutator.VisitExpr(main_func->body);
if (!new_main_body.same_as(main_func->body)) {
auto main_var = mod->GetGlobalVar("main");
auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type,
main_func->type_params, main_func->attrs);
mod->Update(main_var, new_main_func);
}
Array<runtime::String> entry_functions{"main"};
mod = RemoveUnusedFunctions(entry_functions)(mod);

VLOG(1) << "RemoveStandaloneReshapes after:" << std::endl << PrettyPrint(mod);
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0, "RemoveStandaloneReshapes", {});
}

TVM_REGISTER_GLOBAL("relay._transform.RemoveStandaloneReshapes")
.set_body_typed(RemoveStandaloneReshapes);

} // namespace transform
} // namespace relay
} // namespace tvm
1 change: 1 addition & 0 deletions tests/python/contrib/test_cmsisnn/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ def test_relay_conv2d_cmsisnn_depthwise_int8(

cmsisnn_func = cmsisnn_tir_mod["tvmgen_default_cmsis_nn_main_0"]
call_extern = None
# This happens when context buffer is init in case depthM != 1
if isinstance(cmsisnn_func.body, tvm.tir.stmt.Evaluate):
call_extern = cmsisnn_func.body.value
else:
Expand Down
8 changes: 6 additions & 2 deletions tests/python/contrib/test_cmsisnn/test_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
# specific language governing permissions and limitations
# under the License.

"""CMSIS-NN integration tests: Conv2D"""
"""CMSIS-NN integration tests: Pooling"""
import numpy as np
import pytest
import tvm
from tvm import relay
from tvm.relay.op.contrib import cmsisnn

from tvm.testing.aot import AOTTestModel, compile_and_run, generate_ref_data
from tvm.testing.aot import (
generate_ref_data,
AOTTestModel,
compile_and_run,
)
from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER
from .utils import (
make_module,
Expand Down
169 changes: 169 additions & 0 deletions tests/python/contrib/test_cmsisnn/test_remove_reshapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""CMSIS-NN integration tests: Reshape removal"""
import numpy as np
import pytest
import tvm
from tvm import relay
from tvm.relay.op.contrib import cmsisnn

from tvm.testing.aot import (
generate_ref_data,
AOTTestModel,
compile_models,
run_and_check,
)
from tvm.micro.testing.aot_test_utils import AOT_USMP_CORSTONE300_RUNNER
from .utils import (
make_module,
get_range_for_dtype_str,
get_same_padding,
make_qnn_relu,
assert_partitioned_function,
)


def make_model(
pool_op,
shape=(1, 28, 28, 12),
pool_size=(3, 3),
strides=(2, 2),
padding="VALID",
dtype="int8",
scale=1,
zero_point=-33,
relu_type="RELU",
layout="NHWC",
input_op=None,
):
"""Return a model and any parameters it may have,
all parameters are defaulted to known good values
"""
if input_op:
op = input_op
else:
op = relay.var("input", shape=shape, dtype=dtype)
pad_ = (0, 0, 0, 0)
if padding == "SAME":
dilation = (1, 1)
pad_ = get_same_padding((shape[1], shape[2]), pool_size, dilation, strides)
op = relay.nn.pad(
op,
pad_width=[(0, 0), (pad_[0], pad_[2]), (pad_[1], pad_[3]), (0, 0)],
pad_value=zero_point,
pad_mode="constant",
)
if pool_op.__name__ == relay.nn.avg_pool2d.__name__:
op = relay.cast(op, "int32")
op = pool_op(
op, pool_size=pool_size, strides=strides, padding=pad_, ceil_mode=True, layout=layout
)
if pool_op.__name__ == relay.nn.avg_pool2d.__name__:
op = relay.cast(op, dtype)
op = make_qnn_relu(op, relu_type, scale, zero_point, dtype)
return op


@tvm.testing.requires_cmsisnn
@pytest.mark.parametrize("padding", ["SAME", "VALID"])
def test_reshape_removal(padding):
"""Tests reshape is removed from the network"""
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_USMP_CORSTONE300_RUNNER

in_shape = (1, 28, 28, 12)
pool_size = (3, 3)
strides = (2, 2)
relu_type = "NONE"
zero_point, scale = (-34, 0.0256)

max_pool = make_model(
pool_op=relay.nn.max_pool2d,
shape=in_shape,
pool_size=pool_size,
strides=strides,
padding=padding,
scale=scale,
zero_point=zero_point,
relu_type=relu_type,
)
new_shape = (1, 28, 28, 3) if padding == "VALID" else (1, 30, 30, 3)
reshape = relay.reshape(max_pool, newshape=new_shape)

model = make_model(
pool_op=relay.nn.avg_pool2d,
shape=new_shape,
pool_size=pool_size,
strides=strides,
padding=padding,
scale=scale,
zero_point=zero_point,
relu_type=relu_type,
input_op=reshape,
)
orig_mod = make_module(model)

cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

# validate pattern matching
assert_partitioned_function(orig_mod, cmsisnn_mod)

# generate reference output
rng = np.random.default_rng(12345)
in_min, in_max = get_range_for_dtype_str("int8")
inputs = {"input": rng.integers(in_min, high=in_max, size=in_shape, dtype="int8")}
output_list = generate_ref_data(orig_mod["main"], inputs, params=None)

# validate presence of depthwise convolution
compiled_models = compile_models(
AOTTestModel(
module=cmsisnn_mod,
inputs=inputs,
outputs=output_list,
params=None,
output_tolerance=1,
),
interface_api,
use_unpacked_api,
pass_config=test_runner.pass_config,
)

main_mod = None
for target, mod in compiled_models[0].executor_factory.lowered_ir_mods.items():
if target.kind.name == "c":
main_mod = mod

# when padding="SAME", extra padding is introduced which causes Reshape to be fused with the
# Pad. RemoveReshapes pass cannot remove a fused Reshape. Whereas padding="VALID" doesn't need
# an extra Pad layer. In this case, the pass removes the Reshape from the graph.
reshapes_present = any(["reshape" in gv.name_hint for gv in main_mod.get_global_vars()])
check_reshapes = reshapes_present if padding == "SAME" else not reshapes_present
expected_reshapes = "a" if padding == "SAME" else "No"
assert check_reshapes, "Expeting {} reshape layer(s).".format(expected_reshapes)

# validate the output
run_and_check(
models=compiled_models,
runner=test_runner,
interface_api=interface_api,
)


if __name__ == "__main__":
tvm.testing.main()
10 changes: 5 additions & 5 deletions tests/python/contrib/test_ethosu/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@
"accel_type, model_url, workspace_size",
[
("ethos-u65-256", MOBILENET_V1_URL, 1793376),
("ethos-u65-256", MOBILENET_V2_URL, 2218160),
("ethos-u65-256", MOBILENET_V2_URL, 2217152),
("ethos-u55-256", MOBILENET_V1_URL, 1793376),
("ethos-u55-256", MOBILENET_V2_URL, 2218160),
("ethos-u55-128", MOBILENET_V2_URL, 2218160),
("ethos-u55-64", MOBILENET_V2_URL, 2218160),
("ethos-u55-32", MOBILENET_V2_URL, 2218160),
("ethos-u55-256", MOBILENET_V2_URL, 2217152),
("ethos-u55-128", MOBILENET_V2_URL, 2217152),
("ethos-u55-64", MOBILENET_V2_URL, 2217152),
("ethos-u55-32", MOBILENET_V2_URL, 2217152),
],
)
def test_networks_without_usmp(accel_type, model_url, workspace_size):
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/aot/test_crt_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,7 +998,7 @@ def test_workspace_calculation_cmsis_nn():
):
lib = tvm.relay.build(mod, target, executor=executor, runtime=runtime, params=params)
mlf_memory_map = mlf._build_function_memory_map(lib.function_metadata)
assert mlf_memory_map["main"][0]["workspace_size_bytes"] == 14384
assert mlf_memory_map["main"][0]["workspace_size_bytes"] == 14256


def test_aot_codegen_checks_returns():
Expand Down
Loading

0 comments on commit fc411dc

Please sign in to comment.