Skip to content

Commit

Permalink
black formatted test_upsample.py
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Apr 30, 2024
1 parent f135a5d commit 5b8a24d
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions tests/py/dynamo/conversion/test_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ class TestUpsampleConverter(DispatchTestCase):
((2,), None, False, (1.5,)),
]
)
def test_upsample_linear1d(self, input_shape, output_size, align_corners, scale_factors):
def test_upsample_linear1d(
self, input_shape, output_size, align_corners, scale_factors
):
class Upsample(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.upsample_linear1d.vec(input, output_size, align_corners, scale_factors)
return torch.ops.aten.upsample_linear1d.vec(
input, output_size, align_corners, scale_factors
)

input = [torch.randn([1, 1] + list(input_shape))]
self.run_test(Upsample(), input)
Expand All @@ -41,13 +45,17 @@ def forward(self, input):
((2, 2), None, False, (1.5, 1.5)),
]
)
def test_upsample_bilinear2d(self, input_shape, output_size, align_corners, scale_factors):
def test_upsample_bilinear2d(
self, input_shape, output_size, align_corners, scale_factors
):
class Upsample(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.upsample_bilinear2d.vec(input, output_size, align_corners, scale_factors)
return torch.ops.aten.upsample_bilinear2d.vec(
input, output_size, align_corners, scale_factors
)

input = [torch.randn([1, 1] + list(input_shape))]
self.run_test(Upsample(), input)
Expand All @@ -64,13 +72,17 @@ def forward(self, input):
((2, 2, 2), None, False, (1.5, 1.5, 1.5)),
]
)
def test_upsample_trilinear3d(self, input_shape, output_size, align_corners, scale_factors):
def test_upsample_trilinear3d(
self, input_shape, output_size, align_corners, scale_factors
):
class Upsample(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.upsample_trilinear3d.vec(input, output_size, align_corners, scale_factors)
return torch.ops.aten.upsample_trilinear3d.vec(
input, output_size, align_corners, scale_factors
)

input = [torch.randn([1, 1] + list(input_shape))]
self.run_test(Upsample(), input)
Expand All @@ -87,13 +99,17 @@ def forward(self, input):
((2, 2), None, False, (1.5, 1.5)),
]
)
def test_upsample_bicubic2d(self, input_shape, output_size, align_corners, scale_factors):
def test_upsample_bicubic2d(
self, input_shape, output_size, align_corners, scale_factors
):
class Upsample(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.upsample_bicubic2d.vec(input, output_size, align_corners, scale_factors)
return torch.ops.aten.upsample_bicubic2d.vec(
input, output_size, align_corners, scale_factors
)

input = [torch.randn([1, 1] + list(input_shape))]
self.run_test(Upsample(), input)
Expand All @@ -112,7 +128,9 @@ def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.upsample_nearest1d.vec(input, output_size, scale_factors)
return torch.ops.aten.upsample_nearest1d.vec(
input, output_size, scale_factors
)

input = [torch.randn([1, 1] + list(input_shape))]
self.run_test(Upsample(), input)
Expand All @@ -131,7 +149,9 @@ def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.upsample_nearest2d.vec(input, output_size, scale_factors)
return torch.ops.aten.upsample_nearest2d.vec(
input, output_size, scale_factors
)

input = [torch.randn([1, 1] + list(input_shape))]
self.run_test(Upsample(), input)
Expand All @@ -150,7 +170,9 @@ def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.aten.upsample_nearest3d.vec(input, output_size, scale_factors)
return torch.ops.aten.upsample_nearest3d.vec(
input, output_size, scale_factors
)

input = [torch.randn([1, 1] + list(input_shape))]
self.run_test(Upsample(), input)
Expand Down

0 comments on commit 5b8a24d

Please sign in to comment.