Skip to content

Commit

Permalink
chore: revert layer_norm test
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 committed Apr 16, 2024
1 parent b0e92d8 commit d78a846
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions tests/py/dynamo/conversion/test_layer_norm_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,31 @@ def forward(self, x):
inputs,
)

def test_layernorm_with_dynamic_shape(self):
class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.layer_norm.default(
x,
torch.tensor([3, 224, 224]),
torch.ones((3, 224, 224)),
torch.zeros((3, 224, 224)),
1e-05,
True,
)

input_specs = [
Input(
shape=(-1, 3, 224, 224),
dtype=torch.float32,
shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))],
),
]

self.run_test_with_dynamic_shape(
LayerNorm(),
input_specs,
)


class TestNativeLayerNormConverter(DispatchTestCase):
def test_layer_norm(self):
Expand All @@ -43,6 +68,30 @@ def forward(self, x):
inputs,
)

def test_layernorm_with_dynamic_shape(self):
class LayerNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_layer_norm.default(
x,
torch.tensor([3, 224, 224]),
torch.ones((3, 224, 224)),
torch.zeros((3, 224, 224)),
1e-05,
)[0]

input_specs = [
Input(
shape=(-1, 3, 224, 224),
dtype=torch.float32,
shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))],
),
]

self.run_test_with_dynamic_shape(
LayerNorm(),
input_specs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit d78a846

Please sign in to comment.