Skip to content

Commit

Permalink
[Bugfix] Fix batch_norm
Browse files Browse the repository at this point in the history
  • Loading branch information
liquanfeng committed May 15, 2023
1 parent 5566c3e commit 74c82ce
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,18 +1353,16 @@ def batch_norm(self, inputs, input_types):

channels = self.infer_shape(data)

if isinstance(inputs[1], _expr.Expr) and isinstance(inputs[2], _expr.Expr):
scale = center = True
weight = inputs[1]
beta = inputs[2]
gamma = weight
scale = isinstance(inputs[1], _expr.Expr)
if scale:
gamma = inputs[1]
else:
scale = center = False

if not scale:
gamma = _create_typed_const(np.ones([int(channels[1])]), data_type)

if not center:
center = isinstance(inputs[2], _expr.Expr)
if center:
beta = inputs[2]
else:
beta = _create_typed_const(np.zeros([int(channels[1])]), data_type)

moving_mean = inputs[3]
Expand Down

0 comments on commit 74c82ce

Please sign in to comment.