Skip to content

Commit

Permalink
[Pytorch][Relay] aten::_weight_norm implementation (#13661)
Browse files Browse the repository at this point in the history
Add implementation for pytorch weight normalization
  • Loading branch information
valmat07 authored Dec 27, 2022
1 parent e268014 commit 7a38477
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
15 changes: 15 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3514,6 +3514,20 @@ def multinomial(self, inputs, input_types):
_, indices = _expr.TupleWrapper(output, 2)
return indices

def weight_norm(self, inputs, input_types):
weight_v, weight_g = inputs[0], inputs[1]
dim = inputs[2]
dtype = input_types[0]
order = 2.0
reci_order = _expr.const(1.0 / order, dtype=dtype)
order = _expr.const(order)

norm_v = _op.power(
_op.reduce.sum(_op.power(_op.abs(weight_v), order), axis=dim, exclude=2, keepdims=True),
reci_order,
)
return weight_g * (weight_v / norm_v)

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -3781,6 +3795,7 @@ def create_convert_map(self):
"aten::__lshift__": self.make_elemwise("left_shift"),
"aten::__rshift__": self.make_elemwise("right_shift"),
"aten::multinomial": self.multinomial,
"aten::_weight_norm": self.weight_norm,
}

def update_convert_map(self, custom_map):
Expand Down
24 changes: 24 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5038,6 +5038,30 @@ def _test_multinomial(num_samples):
)


def test_weight_norm():
"""Test for atten::_weight_norm"""
in_channels = 32
out_channels = 64
input_data_conv = torch.rand((1, in_channels, 32, 32)).float()

conv_wn = torch.nn.utils.weight_norm(torch.nn.Conv2d(in_channels, out_channels, kernel_size=3))
verify_model(conv_wn.eval().float(), input_data_conv)

conv_wn_groups = torch.nn.utils.weight_norm(
torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, groups=2)
)
verify_model(conv_wn_groups.eval().float(), input_data_conv)

conv_wn = torch.nn.utils.weight_norm(
torch.nn.Conv2d(in_channels, out_channels, kernel_size=3), dim=1
)
verify_model(conv_wn.eval().float(), input_data_conv)

linear_wn = torch.nn.utils.weight_norm(torch.nn.Linear(in_channels, out_channels))
input_data_linear = torch.rand((128, in_channels)).float()
verify_model(linear_wn.eval().float(), input_data_linear)


@tvm.testing.uses_gpu
def test_baddbmm():
def test_fn(alpha, beta):
Expand Down

0 comments on commit 7a38477

Please sign in to comment.