Skip to content

Commit

Permalink
Update test/test_modules.py
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
  • Loading branch information
albertbou92 and vmoens authored Dec 1, 2023
1 parent 9e84a04 commit 5331647
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,9 +1209,9 @@ def test_python_lstm_cell(device, bias):
), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}"

# Run loop
input = torch.randn(2, 3, 10).to(device)
h0 = torch.randn(3, 20).to(device)
c0 = torch.randn(3, 20).to(device)
input = torch.randn(2, 3, 10, device=device)
h0 = torch.randn(3, 20, device=device)
c0 = torch.randn(3, 20, device=device)
with torch.no_grad():
for i in range(input.size()[0]):
h1, c1 = lstm_cell1(input[i], (h0, c0))
Expand Down

0 comments on commit 5331647

Please sign in to comment.