diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index b9d167ad2d86..491c140c5cb4 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3514,6 +3514,20 @@ def multinomial(self, inputs, input_types): _, indices = _expr.TupleWrapper(output, 2) return indices + def weight_norm(self, inputs, input_types): + weight_v, weight_g = inputs[0], inputs[1] + dim = inputs[2] + dtype = input_types[0] + order = 2.0 + reci_order = _expr.const(1.0 / order, dtype=dtype) + order = _expr.const(order) + + norm_v = _op.power( + _op.reduce.sum(_op.power(_op.abs(weight_v), order), axis=dim, exclude=2, keepdims=True), + reci_order, + ) + return weight_g * (weight_v / norm_v) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -3781,6 +3795,7 @@ def create_convert_map(self): "aten::__lshift__": self.make_elemwise("left_shift"), "aten::__rshift__": self.make_elemwise("right_shift"), "aten::multinomial": self.multinomial, + "aten::_weight_norm": self.weight_norm, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 35242fbf7dde..0035d202ded2 100755 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5038,6 +5038,30 @@ def _test_multinomial(num_samples): ) +def test_weight_norm(): + """Test for atten::_weight_norm""" + in_channels = 32 + out_channels = 64 + input_data_conv = torch.rand((1, in_channels, 32, 32)).float() + + conv_wn = torch.nn.utils.weight_norm(torch.nn.Conv2d(in_channels, out_channels, kernel_size=3)) + verify_model(conv_wn.eval().float(), input_data_conv) + + conv_wn_groups = torch.nn.utils.weight_norm( + torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, groups=2) + ) + verify_model(conv_wn_groups.eval().float(), input_data_conv) + + conv_wn = torch.nn.utils.weight_norm( + torch.nn.Conv2d(in_channels, out_channels, kernel_size=3), dim=1 + ) + verify_model(conv_wn.eval().float(), input_data_conv) + + linear_wn = torch.nn.utils.weight_norm(torch.nn.Linear(in_channels, out_channels)) + input_data_linear = torch.rand((128, in_channels)).float() + verify_model(linear_wn.eval().float(), input_data_linear) + + @tvm.testing.uses_gpu def test_baddbmm(): def test_fn(alpha, beta):