Skip to content

Commit

Permalink
add cell tests
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 committed Nov 29, 2023
1 parent d144f14 commit a876428
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 1 deletion.
63 changes: 63 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
MultiAgentConvNet,
MultiAgentMLP,
OnlineDTActor,
PythonGRUCell,
PythonLSTMCell,
QMixer,
SafeModule,
TanhModule,
Expand Down Expand Up @@ -1186,6 +1188,67 @@ def test_onlinedtactor(self, batch_dims, T=5):
assert (dtactor.log_std_max > sig.log()).all()


@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
def test_python_lstm_cell(device, bias):

lstm_cell1 = PythonLSTMCell(10, 20, device=device, bias=bias)
lstm_cell2 = nn.LSTMCell(10, 20, device=device, bias=bias)

# Make sure parameters match
for (k1, v1), (k2, v2) in zip(
lstm_cell1.named_parameters(), lstm_cell2.named_parameters()
):
assert k1 == k2, f"Parameter names do not match: {k1} != {k2}"
assert (
v1.shape == v2.shape
), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}"

# Run loop
input = torch.randn(2, 3, 10).to(device)
hx1 = torch.randn(3, 20).to(device)
cx1 = torch.randn(3, 20).to(device)
hx2 = torch.randn(3, 20).to(device)
cx2 = torch.randn(3, 20).to(device)
with torch.no_grad():
for i in range(input.size()[0]):
hx1, cx1 = lstm_cell1(input[i], (hx1, cx1))
hx2, cx2 = lstm_cell2(input[i], (hx2, cx2))

# Make sure the final hidden states have the same shape
assert hx1.shape == hx2.shape
assert cx1.shape == cx2.shape


@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
def test_python_gru_cell(device, bias):

gru_cell1 = PythonGRUCell(10, 20, device=device, bias=bias)
gru_cell2 = nn.GRUCell(10, 20, device=device, bias=bias)

# Make sure parameters match
for (k1, v1), (k2, v2) in zip(
gru_cell1.named_parameters(), gru_cell2.named_parameters()
):
assert k1 == k2, f"Parameter names do not match: {k1} != {k2}"
assert (
v1.shape == v2.shape
), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}"

# Run loop
input = torch.randn(2, 3, 10).to(device)
hx1 = torch.randn(3, 20).to(device)
hx2 = torch.randn(3, 20).to(device)
with torch.no_grad():
for i in range(input.size()[0]):
hx1 = gru_cell1(input[i], hx1)
hx2 = gru_cell2(input[i], hx2)

# Make sure the final hidden states have the same shape
assert hx1.shape == hx2.shape


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
2 changes: 2 additions & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
LSTMModule,
OrnsteinUhlenbeckProcessWrapper,
ProbabilisticActor,
PythonGRUCell,
PythonLSTMCell,
QValueActor,
QValueHook,
QValueModule,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@
SafeProbabilisticModule,
SafeProbabilisticTensorDictSequential,
)
from .rnn import GRUModule, LSTMModule
from .rnn import GRUModule, LSTMModule, PythonGRUCell, PythonLSTMCell
from .sequence import SafeSequential
from .world_models import WorldModelWrapper

0 comments on commit a876428

Please sign in to comment.