Skip to content

Commit

Permalink
Update test/test_modules.py
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
albertbou92 and vmoens authored Dec 1, 2023
1 parent 22b0999 commit 18c900d
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,12 +1218,12 @@ def test_python_lstm_cell(device, bias):
h2, c2 = lstm_cell2(input[i], (h0, c0))

# Make sure the final hidden states have the same shape
assert h1.shape == h2.shape
assert c1.shape == c2.shape
torch.testing.assert_close(h1, h2)
torch.testing.assert_close(c1, c2)
h0 = h1
c0 = c1
assert h1.shape == h2.shape
assert c1.shape == c2.shape
torch.testing.assert_close(h1, h2)
torch.testing.assert_close(c1, c2)
h0 = h1
c0 = c1


@pytest.mark.parametrize("device", get_default_devices())
Expand Down

0 comments on commit 18c900d

Please sign in to comment.