Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.5.x] Handle fix_gamma in tensorrt subgraph conversion correctly #15874

Merged
merged 1 commit into from
Aug 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

#include <onnx/onnx_pb.h>

#include <unordered_map>
#include <vector>
#include <string>

namespace mxnet {
Expand Down Expand Up @@ -72,15 +74,12 @@ typedef void (*ConverterFunction)(NodeProto *node_proto,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);


// Forward declarations
void ConvertConvolution(
NodeProto *node_proto,
void ConvertConvolution(NodeProto *node_proto,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);


void ConvertPooling(NodeProto *node_proto,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
Expand Down Expand Up @@ -142,7 +141,7 @@ void ConvertPad(NodeProto* node_proto,
const array_view<IndexedGraph::NodeEntry> &inputs);

std::string ConvertNnvmGraphToOnnx(const nnvm::Graph &g,
const std::unordered_map<std::string, NDArray>* const params_map);
std::unordered_map<std::string, NDArray>* params_map);

static const std::unordered_map<std::string, ConverterFunction> converter_map = {
{"Activation", ConvertActivation},
Expand All @@ -160,6 +159,18 @@ static const std::unordered_map<std::string, ConverterFunction> converter_map =
{"SoftmaxOutput", ConvertSoftmaxOutput}
};

typedef void (*PreprocessFunction)(const NodeAttrs &attrs,
const std::vector<nnvm::NodeEntry> &inputs,
std::unordered_map<std::string, NDArray> *params_map);

void PreprocessBatchNorm(const NodeAttrs &attrs,
const std::vector<nnvm::NodeEntry> &inputs,
std::unordered_map<std::string, NDArray> *params_map);

static const std::unordered_map<std::string, PreprocessFunction> preprocess_map = {
{"BatchNorm", PreprocessBatchNorm}
};

} // namespace nnvm_to_onnx
} // namespace op
} // namespace mxnet
Expand Down
29 changes: 27 additions & 2 deletions src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ namespace nnvm_to_onnx {

std::string ConvertNnvmGraphToOnnx(
const nnvm::Graph& g,
const std::unordered_map<std::string, NDArray>* const params_map) {
std::unordered_map<std::string, NDArray>* params_map) {

static std::atomic_ulong subgraph_count = { 0 };

Expand Down Expand Up @@ -88,8 +88,21 @@ std::string ConvertNnvmGraphToOnnx(
auto placeholder_shapes = GetPlaceholderShapes(shape_inputs, ig);
auto placeholder_dtypes = GetPlaceholderDTypes(dtype_inputs, ig);
auto output_lookup = GetOutputLookup(ig);
uint32_t current_input = 0;

for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) {
const IndexedGraph::Node& node = ig[node_idx];
const nnvm::Node* source = node.source;
// If this is a op
if (!source->is_variable()) {
auto mightNeedPreprocessNode = preprocess_map.find(source->op()->name);
// if this op is defined in preprocess_map
if (mightNeedPreprocessNode != preprocess_map.end()) {
mightNeedPreprocessNode->second(source->attrs, source->inputs, params_map);
}
}
}

uint32_t current_input = 0;
// Can't do a foreach over IndexedGraph since it doesn't implement begin(), etc.
for (uint32_t node_idx = 0; node_idx < ig.num_nodes(); ++node_idx) {
const IndexedGraph::Node& node = ig[node_idx];
Expand Down Expand Up @@ -630,6 +643,18 @@ void ConvertDropout(NodeProto* node_proto, const NodeAttrs& attrs,
node_proto->set_op_type("Dropout");
}

void PreprocessBatchNorm(const NodeAttrs &attrs,
const std::vector<nnvm::NodeEntry> &inputs,
std::unordered_map<std::string, NDArray> *params_map) {
const auto& param = nnvm::get<op::BatchNormParam>(attrs.parsed);
if (param.fix_gamma) {
// if mxnet is specify fix_gamma, we will need to preprocess the params map
// to convert the gamma associate with this batch norm layer to 1.
std::string gammaNodeName = inputs[batchnorm::kGamma].node->attrs.name;
(*params_map)[gammaNodeName] = 1.0f;
}
}

} // namespace nnvm_to_onnx
} // namespace op
} // namespace mxnet
Expand Down
2 changes: 1 addition & 1 deletion src/operator/subgraph/tensorrt/tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ OpStatePtr TRTCreateState(const nnvm::NodeAttrs& attrs, Context ctx,
<< " instead of: " << max_batch_size;
max_batch_size = in_shape[0][0];
}
const auto& params_map = node_param.params_map;
std::unordered_map<std::string, NDArray> params_map = node_param.params_map;
const auto& inputs_to_idx = node_param.inputs_to_idx;
const auto& outputs_to_idx = node_param.outputs_to_idx;
const auto& idx_g = graph.indexed_graph();
Expand Down
65 changes: 65 additions & 0 deletions tests/python/tensorrt/test_tensorrt_batchnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
from mxnet.test_utils import assert_almost_equal

def get_params():
arg_params = {}
aux_params = {}
arg_params["trt_bn_test_conv_weight"] = mx.nd.ones((1, 1, 3, 3))
arg_params["trt_bn_test_bn_gamma"] = mx.nd.zeros((1,))
arg_params["trt_bn_test_bn_beta"] = mx.nd.zeros((1,))
aux_params["trt_bn_test_bn_moving_mean"] = mx.nd.ones(1)
aux_params["trt_bn_test_bn_moving_var"] = mx.nd.ones(1)
return arg_params, aux_params

def get_symbol():
data = mx.sym.Variable("data")
conv = mx.sym.Convolution(data=data, kernel=(3,3), no_bias=True, num_filter=1, num_group=1,
name="trt_bn_test_conv")
bn = mx.sym.BatchNorm(data=conv, fix_gamma=True, use_global_stats=False, name="trt_bn_test_bn")
return bn

def test_batch_norm_runs_correctly_with_fix_gamma():
arg_params, aux_params = get_params()
arg_params_trt, aux_params_trt = get_params()

sym = get_symbol()
sym_trt = get_symbol().get_backend_symbol("TensorRT")

mx.contrib.tensorrt.init_tensorrt_params(sym_trt, arg_params_trt, aux_params_trt)

executor = sym.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3), grad_req='null', force_rebind=True)
executor.copy_params_from(arg_params, aux_params)

executor_trt = sym_trt.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3), grad_req='null',
force_rebind=True)
executor_trt.copy_params_from(arg_params_trt, aux_params_trt)

input_data = mx.nd.random.uniform(low=0, high=1, shape=(1, 1, 3, 3))

y = executor.forward(is_train=False, data=input_data)
y_trt = executor_trt.forward(is_train=False, data=input_data)

print(y[0].asnumpy())
print(y_trt[0].asnumpy())
assert_almost_equal(y[0].asnumpy(), y_trt[0].asnumpy(), 1e-4, 1e-4)

if __name__ == '__main__':
import nose
nose.runmodule()