diff --git a/cpp-package/example/inception_bn.cpp b/cpp-package/example/inception_bn.cpp index 6c0754e60d00..5db4f81b0e07 100644 --- a/cpp-package/example/inception_bn.cpp +++ b/cpp-package/example/inception_bn.cpp @@ -11,9 +11,6 @@ using namespace mxnet::cpp; -static const Symbol BN_BETA; -static const Symbol BN_GAMMA; - Symbol ConvFactoryBN(Symbol data, int num_filter, Shape kernel, Shape stride, Shape pad, const std::string & name, @@ -23,7 +20,12 @@ Symbol ConvFactoryBN(Symbol data, int num_filter, Symbol conv = Convolution("conv_" + name + suffix, data, conv_w, conv_b, kernel, num_filter, stride, Shape(1, 1), pad); - Symbol bn = BatchNorm("bn_" + name + suffix, conv, Symbol(), Symbol(), Symbol(), Symbol()); + std::string name_suffix = name + suffix; + Symbol gamma(name_suffix + "_gamma"); + Symbol beta(name_suffix + "_beta"); + Symbol mmean(name_suffix + "_mmean"); + Symbol mvar(name_suffix + "_mvar"); + Symbol bn = BatchNorm("bn_" + name + suffix, conv, gamma, beta, mmean, mvar); return Activation("relu_" + name + suffix, bn, "relu"); } diff --git a/cpp-package/example/resnet.cpp b/cpp-package/example/resnet.cpp index ace3459d4bd7..5521567e119d 100644 --- a/cpp-package/example/resnet.cpp +++ b/cpp-package/example/resnet.cpp @@ -35,9 +35,6 @@ Symbol ConvolutionNoBias(const std::string& symbol_name, .CreateSymbol(symbol_name); } -static const Symbol BN_BETA; -static const Symbol BN_GAMMA; - Symbol getConv(const std::string & name, Symbol data, int num_filter, Shape kernel, Shape stride, Shape pad, @@ -48,8 +45,13 @@ Symbol getConv(const std::string & name, Symbol data, kernel, num_filter, stride, Shape(1, 1), pad, 1, 512); - Symbol bn = BatchNorm(name + "_bn", conv, Symbol(), Symbol(), Symbol(), - Symbol(), 2e-5, bn_momentum, false); + Symbol gamma(name + "_gamma"); + Symbol beta(name + "_beta"); + Symbol mmean(name + "_mmean"); + Symbol mvar(name + "_mvar"); + + Symbol bn = BatchNorm(name + "_bn", conv, gamma, + beta, mmean, mvar, 2e-5, bn_momentum, false); if (with_relu) { return Activation(name + "_relu", bn, "relu"); @@ -109,8 +111,13 @@ Symbol ResNetSymbol(int num_class, int num_level = 3, int num_block = 9, Symbol data = Symbol::Variable("data"); Symbol data_label = Symbol::Variable("data_label"); - Symbol zscore = BatchNorm("zscore", data, Symbol(), Symbol(), Symbol(), - Symbol(), 0.001, bn_momentum); + Symbol gamma("gamma"); + Symbol beta("beta"); + Symbol mmean("mmean"); + Symbol mvar("mvar"); + + Symbol zscore = BatchNorm("zscore", data, gamma, + beta, mmean, mvar, 0.001, bn_momentum); Symbol conv = getConv("conv0", zscore, num_filter, Shape(3, 3), Shape(1, 1), Shape(1, 1),