Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix batchNorm cpp example (#6454)
Browse files Browse the repository at this point in the history
  • Loading branch information
vsooda authored and piiswrong committed May 26, 2017
1 parent 775073b commit 57693ba
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
10 changes: 6 additions & 4 deletions cpp-package/example/inception_bn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

using namespace mxnet::cpp;

static const Symbol BN_BETA;
static const Symbol BN_GAMMA;

Symbol ConvFactoryBN(Symbol data, int num_filter,
Shape kernel, Shape stride, Shape pad,
const std::string & name,
Expand All @@ -23,7 +20,12 @@ Symbol ConvFactoryBN(Symbol data, int num_filter,
Symbol conv = Convolution("conv_" + name + suffix, data,
conv_w, conv_b, kernel,
num_filter, stride, Shape(1, 1), pad);
Symbol bn = BatchNorm("bn_" + name + suffix, conv, Symbol(), Symbol(), Symbol(), Symbol());
std::string name_suffix = name + suffix;
Symbol gamma(name_suffix + "_gamma");
Symbol beta(name_suffix + "_beta");
Symbol mmean(name_suffix + "_mmean");
Symbol mvar(name_suffix + "_mvar");
Symbol bn = BatchNorm("bn_" + name + suffix, conv, gamma, beta, mmean, mvar);
return Activation("relu_" + name + suffix, bn, "relu");
}

Expand Down
21 changes: 14 additions & 7 deletions cpp-package/example/resnet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ Symbol ConvolutionNoBias(const std::string& symbol_name,
.CreateSymbol(symbol_name);
}

static const Symbol BN_BETA;
static const Symbol BN_GAMMA;

Symbol getConv(const std::string & name, Symbol data,
int num_filter,
Shape kernel, Shape stride, Shape pad,
Expand All @@ -48,8 +45,13 @@ Symbol getConv(const std::string & name, Symbol data,
kernel, num_filter, stride, Shape(1, 1),
pad, 1, 512);

Symbol bn = BatchNorm(name + "_bn", conv, Symbol(), Symbol(), Symbol(),
Symbol(), 2e-5, bn_momentum, false);
Symbol gamma(name + "_gamma");
Symbol beta(name + "_beta");
Symbol mmean(name + "_mmean");
Symbol mvar(name + "_mvar");

Symbol bn = BatchNorm(name + "_bn", conv, gamma,
beta, mmean, mvar, 2e-5, bn_momentum, false);

if (with_relu) {
return Activation(name + "_relu", bn, "relu");
Expand Down Expand Up @@ -109,8 +111,13 @@ Symbol ResNetSymbol(int num_class, int num_level = 3, int num_block = 9,
Symbol data = Symbol::Variable("data");
Symbol data_label = Symbol::Variable("data_label");

Symbol zscore = BatchNorm("zscore", data, Symbol(), Symbol(), Symbol(),
Symbol(), 0.001, bn_momentum);
Symbol gamma("gamma");
Symbol beta("beta");
Symbol mmean("mmean");
Symbol mvar("mvar");

Symbol zscore = BatchNorm("zscore", data, gamma,
beta, mmean, mvar, 0.001, bn_momentum);

Symbol conv = getConv("conv0", zscore, num_filter,
Shape(3, 3), Shape(1, 1), Shape(1, 1),
Expand Down

0 comments on commit 57693ba

Please sign in to comment.