Skip to content

Commit

Permalink
Allow uers to specify the name of moving mean and variance in batch_n…
Browse files Browse the repository at this point in the history
…orm interface.
  • Loading branch information
qingqing01 committed Feb 2, 2018
1 parent 292c195 commit cbc9a59
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/paddle/v2/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,9 @@ def batch_norm(input,
param_attr=None,
bias_attr=None,
data_layout='NCHW',
name=None):
name=None,
moving_mean_name=None,
moving_variance_name=None):
"""
This function helps create an operator to implement
the BatchNorm layer using the configurations from the input parameters.
Expand Down Expand Up @@ -1508,13 +1510,15 @@ def batch_norm(input,
attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True)

mean = helper.create_global_variable(
name=moving_mean_name,
dtype=input.dtype,
shape=param_shape,
persistable=True,
stop_gradient=True)
helper.set_variable_initializer(var=mean, initializer=Constant(0.0))

variance = helper.create_global_variable(
name=moving_variance_name,
dtype=input.dtype,
shape=param_shape,
persistable=True,
Expand Down

0 comments on commit cbc9a59

Please sign in to comment.