Skip to content

Commit

Permalink
more test for aten.linear
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochengji committed Feb 3, 2025
1 parent 1327602 commit 10c124d
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions torchax/test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4490,15 +4490,25 @@ def test_aten_einsum(self):
rtol=1e-2, check_dtype=True)

def test_aten_linear(self):
# with bias
args = (
torch.randn((2, 2), dtype=torch.float16),
torch.randn((2, 2), dtype=torch.float16),
torch.randn((2, 4), dtype=torch.float16),
torch.randn((2, 4), dtype=torch.float16),
torch.randn((2, ), dtype=torch.float16),
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.linear, args, kwargs, atol=1e-2,
rtol=1e-2, check_dtype=True)

# without bias
args = (
torch.randn((2, 4), dtype=torch.float16),
torch.randn((2, 4), dtype=torch.float16),
)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.linear, args, kwargs, atol=1e-2,
rtol=1e-2, check_dtype=True)


if __name__ == "__main__":
test_base.main()

0 comments on commit 10c124d

Please sign in to comment.