Skip to content

Commit

Permalink
Fixed conversion of Batch Norm when training=ture (#1249)
Browse files Browse the repository at this point in the history
Signed-off-by: Tom Wildenhain <tomwi@microsoft.com>
  • Loading branch information
TomWildenhain-Microsoft authored Dec 30, 2020
1 parent 1c9c02d commit cf8c953
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
21 changes: 21 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2221,6 +2221,27 @@ def func(x):
return tf.identity(y, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04)

@check_opset_min_version(7, "batchnorm")
def test_fused_batchnorm_training(self):
x_shape = [1, 28, 28, 2]
x_dtype = np.float32
scale_dtype = np.float32
scale_shape = [2]
# only nhwc is support on cpu for tensorflow
data_format = "NHWC"
x_val = np.random.random_sample(x_shape).astype(x_dtype)
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
def func(x):
scale = tf.constant(scale_val, name='scale')
offset = tf.constant(offset_val, name='offset')
epsilon = 0.001
y, _, _ = fused_batch_norm(
x, scale, offset, mean=None, variance=None,
epsilon=epsilon, data_format=data_format, is_training=True)
return tf.identity(y, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04)

@check_opset_min_version(7, "batchnorm")
@check_tf_min_version("1.13")
def test_batchnorm(self):
Expand Down
33 changes: 29 additions & 4 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,24 +784,49 @@ def version_6(cls, ctx, node, **kwargs):

conv_convert_inputs(ctx, node, with_kernel=False)

inp_shape = ctx.get_shape(node.input[0])
inp_rank = len(inp_shape) if inp_shape is not None else None
scale_shape = ctx.get_shape(node.input[1])
mean_shape = ctx.get_shape(node.input[3])
var_shape = ctx.get_shape(node.input[4])
val_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[1]))

if node.get_attr_value('is_training', 1) == 1:
is_training = node.get_attr_value('is_training', True)

if is_training and node.get_attr_value('exponential_avg_factor', 1.0) == 1.0:
# Sometimes TF uses a BatchNorm op with training = True and exponential_avg_factor = 1.0
# to perform layer mean/variance normalization. In such cases, the mean/var are computed from the input.
# TF allows mean/variance to be excluded only if is_training and exponential_avg_factor == 1.0
utils.make_sure(inp_rank is not None, "Cannot convert node %s of type %s with input of unknown rank.",
node.name, tf_type)
dims = [0] + list(range(2, inp_rank))
avg = ctx.make_node("ReduceMean", [node.input[0]], attr={'axes': dims, 'keepdims': True}).output[0]
avg_squeezed = GraphBuilder(ctx).make_squeeze({"data": avg, "axes": dims})
sub = ctx.make_node("Sub", [node.input[0], avg]).output[0]
var_squeezed = ctx.make_node("ReduceSumSquare", [sub], attr={'axes': dims, 'keepdims': False}).output[0]

inp_shape = ctx.make_node("Shape", [node.input[0]]).output[0]
dims_const = ctx.make_const(utils.make_name("axes_const"), np.array(dims, dtype=np.int64)).output[0]
reduce_dims = ctx.make_node("Gather", [inp_shape, dims_const]).output[0]
dims_product = ctx.make_node("ReduceProd", [reduce_dims], attr={'axes': [0], 'keepdims': False})
cnt_float = ctx.make_node("Cast", [dims_product.output[0]], attr={'to': ctx.get_dtype(node.input[0])})

pop_var_squeezed = ctx.make_node("Div", [var_squeezed, cnt_float.output[0]]).output[0]
ctx.replace_inputs(node, node.input[:3] + [avg_squeezed, pop_var_squeezed])
else:
logger.warning("Node %s of type %s has is_training set to true, which is not supperted. "
"Please re-save the model with training set to false.",
node.name, tf_type)
# As long as the mean/variance estimates are provided, we should be OK
is_training = False

if mean_shape != scale_shape and all(d >= 0 for d in scale_shape):
if not is_training and mean_shape != scale_shape and all(d >= 0 for d in scale_shape):
new_mean_value = np.array(np.resize(node.inputs[3].get_tensor_value(as_list=False), scale_shape),
dtype=val_type)
new_mean_node_name = utils.make_name(node.name)
ctx.make_const(new_mean_node_name, new_mean_value)
ctx.replace_input(node, node.input[3], new_mean_node_name, 3)

if var_shape != scale_shape and all(d >= 0 for d in scale_shape):
if not is_training and var_shape != scale_shape and all(d >= 0 for d in scale_shape):
new_var_value = np.array(np.resize(node.inputs[4].get_tensor_value(as_list=False), scale_shape),
dtype=val_type)
new_val_node_name = utils.make_name(node.name)
Expand Down

0 comments on commit cf8c953

Please sign in to comment.