Skip to content

Commit

Permalink
fix guard
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jul 7, 2024
1 parent 24dc649 commit b0fe51d
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,12 @@ def test_optim_4bit_correctness(self, optim_name):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@pytest.mark.skipif(torch.cuda.get_device_capability() < (8, 9), reason="FP8 requires compute capability >= 8.9")
@parametrize("optim_name", ["AdamFp8", "AdamWFp8"])
@parametrize("device", _DEVICES)
def test_optim_fp8_smoke(self, optim_name, device):
if device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 requires compute capability >= 8.9")

model = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
optim = getattr(low_bit_optim, optim_name)(model.parameters())

Expand Down

0 comments on commit b0fe51d

Please sign in to comment.