Skip to content

Commit

Permalink
[Feature] Python-based RNN Modules (#1720)
Browse files Browse the repository at this point in the history
Co-authored-by: vmoens <vincentmoens@gmail.com>
  • Loading branch information
albertbou92 and vmoens authored Dec 4, 2023
1 parent 3f2ecfc commit d432a9c
Show file tree
Hide file tree
Showing 6 changed files with 860 additions and 31 deletions.
4 changes: 4 additions & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,11 @@ algorithms, such as DQN, DDPG or Dreamer.
DistributionalDQNnet
DreamerActor
DuelingCnnDQNet
GRUCell
GRU
GRUModule
LSTMCell
LSTM
LSTMModule
ObsDecoder
ObsEncoder
Expand Down
207 changes: 207 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from torchrl.modules import (
CEMPlanner,
DTActor,
GRU,
GRUCell,
LSTM,
LSTMCell,
LSTMNet,
MultiAgentConvNet,
MultiAgentMLP,
Expand Down Expand Up @@ -1186,6 +1190,209 @@ def test_onlinedtactor(self, batch_dims, T=5):
assert (dtactor.log_std_max > sig.log()).all()


@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
def test_python_lstm_cell(device, bias):

lstm_cell1 = LSTMCell(10, 20, device=device, bias=bias)
lstm_cell2 = nn.LSTMCell(10, 20, device=device, bias=bias)

lstm_cell1.load_state_dict(lstm_cell2.state_dict())

# Make sure parameters match
for (k1, v1), (k2, v2) in zip(
lstm_cell1.named_parameters(), lstm_cell2.named_parameters()
):
assert k1 == k2, f"Parameter names do not match: {k1} != {k2}"
assert (
v1.shape == v2.shape
), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}"

# Run loop
input = torch.randn(2, 3, 10, device=device)
h0 = torch.randn(3, 20, device=device)
c0 = torch.randn(3, 20, device=device)
with torch.no_grad():
for i in range(input.size()[0]):
h1, c1 = lstm_cell1(input[i], (h0, c0))
h2, c2 = lstm_cell2(input[i], (h0, c0))

# Make sure the final hidden states have the same shape
assert h1.shape == h2.shape
assert c1.shape == c2.shape
torch.testing.assert_close(h1, h2)
torch.testing.assert_close(c1, c2)
h0 = h1
c0 = c1


@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
def test_python_gru_cell(device, bias):

gru_cell1 = GRUCell(10, 20, device=device, bias=bias)
gru_cell2 = nn.GRUCell(10, 20, device=device, bias=bias)

gru_cell2.load_state_dict(gru_cell1.state_dict())

# Make sure parameters match
for (k1, v1), (k2, v2) in zip(
gru_cell1.named_parameters(), gru_cell2.named_parameters()
):
assert k1 == k2, f"Parameter names do not match: {k1} != {k2}"
assert (v1 == v2).all()
assert (
v1.shape == v2.shape
), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}"

# Run loop
input = torch.randn(2, 3, 10, device=device)
h0 = torch.zeros(3, 20, device=device)
with torch.no_grad():
for i in range(input.size()[0]):
print(i)
h1 = gru_cell1(input[i], h0)
h2 = gru_cell2(input[i], h0)

# Make sure the final hidden states have the same shape
assert h1.shape == h2.shape
torch.testing.assert_close(h1, h2)
h0 = h1


@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("batch_first", [True, False])
@pytest.mark.parametrize("dropout", [0.0, 0.5])
@pytest.mark.parametrize("num_layers", [1, 2])
def test_python_lstm(device, bias, dropout, batch_first, num_layers):
B = 5
T = 3
lstm1 = LSTM(
input_size=10,
hidden_size=20,
num_layers=num_layers,
device=device,
bias=bias,
batch_first=batch_first,
)
lstm2 = nn.LSTM(
input_size=10,
hidden_size=20,
num_layers=num_layers,
device=device,
bias=bias,
batch_first=batch_first,
)

lstm2.load_state_dict(lstm1.state_dict())

# Make sure parameters match
for (k1, v1), (k2, v2) in zip(lstm1.named_parameters(), lstm2.named_parameters()):
assert k1 == k2, f"Parameter names do not match: {k1} != {k2}"
assert (
v1.shape == v2.shape
), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}"

if batch_first:
input = torch.randn(B, T, 10, device=device)
else:
input = torch.randn(T, B, 10, device=device)

h0 = torch.randn(num_layers, 5, 20, device=device)
c0 = torch.randn(num_layers, 5, 20, device=device)

# Test without hidden states
with torch.no_grad():
output1, (h1, c1) = lstm1(input)
output2, (h2, c2) = lstm2(input)

assert h1.shape == h2.shape
assert c1.shape == c2.shape
assert output1.shape == output2.shape
if dropout == 0.0:
torch.testing.assert_close(output1, output2)
torch.testing.assert_close(h1, h2)
torch.testing.assert_close(c1, c2)

# Test with hidden states
with torch.no_grad():
output1, (h1, c1) = lstm1(input, (h0, c0))
output2, (h2, c2) = lstm1(input, (h0, c0))

assert h1.shape == h2.shape
assert c1.shape == c2.shape
assert output1.shape == output2.shape
if dropout == 0.0:
torch.testing.assert_close(output1, output2)
torch.testing.assert_close(h1, h2)
torch.testing.assert_close(c1, c2)


@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("batch_first", [True, False])
@pytest.mark.parametrize("dropout", [0.0, 0.5])
@pytest.mark.parametrize("num_layers", [1, 2])
def test_python_gru(device, bias, dropout, batch_first, num_layers):
B = 5
T = 3
gru1 = GRU(
input_size=10,
hidden_size=20,
num_layers=num_layers,
device=device,
bias=bias,
batch_first=batch_first,
)
gru2 = nn.GRU(
input_size=10,
hidden_size=20,
num_layers=num_layers,
device=device,
bias=bias,
batch_first=batch_first,
)
gru2.load_state_dict(gru1.state_dict())

# Make sure parameters match
for (k1, v1), (k2, v2) in zip(gru1.named_parameters(), gru2.named_parameters()):
assert k1 == k2, f"Parameter names do not match: {k1} != {k2}"
torch.testing.assert_close(v1, v2)
assert (
v1.shape == v2.shape
), f"Parameter shapes do not match: {k1} shape {v1.shape} != {k2} shape {v2.shape}"

if batch_first:
input = torch.randn(B, T, 10, device=device)
else:
input = torch.randn(T, B, 10, device=device)

h0 = torch.randn(num_layers, 5, 20, device=device)

# Test without hidden states
with torch.no_grad():
output1, h1 = gru1(input)
output2, h2 = gru2(input)

assert h1.shape == h2.shape
assert output1.shape == output2.shape
if dropout == 0.0:
torch.testing.assert_close(output1, output2)
torch.testing.assert_close(h1, h2)

# Test with hidden states
with torch.no_grad():
output1, h1 = gru1(input, h0)
output2, h2 = gru2(input, h0)

assert h1.shape == h2.shape
assert output1.shape == output2.shape
if dropout == 0.0:
torch.testing.assert_close(output1, output2)
torch.testing.assert_close(h1, h2)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
33 changes: 25 additions & 8 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,7 +1672,8 @@ def test_noncontiguous(self):
lstm_module(padded)

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
def test_singel_step(self, shape):
@pytest.mark.parametrize("python_based", [True, False])
def test_single_step(self, shape, python_based):
td = TensorDict(
{
"observation": torch.zeros(*shape, 3),
Expand All @@ -1686,6 +1687,7 @@ def test_singel_step(self, shape):
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
python_based=python_based,
)
td = lstm_module(td)
td_next = step_mdp(td, keep_other=True)
Expand All @@ -1697,7 +1699,8 @@ def test_singel_step(self, shape):

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
@pytest.mark.parametrize("t", [1, 10])
def test_single_step_vs_multi(self, shape, t):
@pytest.mark.parametrize("python_based", [True, False])
def test_single_step_vs_multi(self, shape, t, python_based):
td = TensorDict(
{
"observation": torch.arange(t, dtype=torch.float32)
Expand All @@ -1713,6 +1716,7 @@ def test_single_step_vs_multi(self, shape, t):
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
python_based=python_based,
)
lstm_module_ms = lstm_module_ss.set_recurrent_mode()
lstm_module_ms(td)
Expand All @@ -1732,7 +1736,8 @@ def test_single_step_vs_multi(self, shape, t):
)

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
def test_multi_consecutive(self, shape):
@pytest.mark.parametrize("python_based", [False, True])
def test_multi_consecutive(self, shape, python_based):
t = 20
td = TensorDict(
{
Expand All @@ -1754,6 +1759,7 @@ def test_multi_consecutive(self, shape):
batch_first=True,
in_keys=["observation", "hidden0", "hidden1"],
out_keys=["intermediate", ("next", "hidden0"), ("next", "hidden1")],
python_based=python_based,
)
lstm_module_ms = lstm_module_ss.set_recurrent_mode()
lstm_module_ms(td)
Expand All @@ -1769,11 +1775,13 @@ def test_multi_consecutive(self, shape):
lstm_module_ss(td_ss)
td_ss = step_mdp(td_ss, keep_other=True)
td_ss["observation"][:] = _t + 1
# import ipdb; ipdb.set_trace() # assert fails when python_based is True, why?
torch.testing.assert_close(
td_ss["intermediate"], td["intermediate"][..., -1, :]
)

def test_lstm_parallel_env(self):
@pytest.mark.parametrize("python_based", [True, False])
def test_lstm_parallel_env(self, python_based):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

device = "cuda" if torch.cuda.device_count() else "cpu"
Expand All @@ -1785,6 +1793,7 @@ def test_lstm_parallel_env(self):
in_key="observation",
out_key="features",
device=device,
python_based=python_based,
)

def create_transformed_env():
Expand Down Expand Up @@ -1938,7 +1947,8 @@ def test_noncontiguous(self):
gru_module(padded)

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
def test_singel_step(self, shape):
@pytest.mark.parametrize("python_based", [True, False])
def test_single_step(self, shape, python_based):
td = TensorDict(
{
"observation": torch.zeros(*shape, 3),
Expand All @@ -1952,6 +1962,7 @@ def test_singel_step(self, shape):
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
python_based=python_based,
)
td = gru_module(td)
td_next = step_mdp(td, keep_other=True)
Expand All @@ -1961,7 +1972,8 @@ def test_singel_step(self, shape):

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
@pytest.mark.parametrize("t", [1, 10])
def test_single_step_vs_multi(self, shape, t):
@pytest.mark.parametrize("python_based", [True, False])
def test_single_step_vs_multi(self, shape, t, python_based):
td = TensorDict(
{
"observation": torch.arange(t, dtype=torch.float32)
Expand All @@ -1977,6 +1989,7 @@ def test_single_step_vs_multi(self, shape, t):
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
python_based=python_based,
)
gru_module_ms = gru_module_ss.set_recurrent_mode()
gru_module_ms(td)
Expand All @@ -1994,7 +2007,8 @@ def test_single_step_vs_multi(self, shape, t):
torch.testing.assert_close(td_ss["hidden"], td["next", "hidden"][..., -1, :, :])

@pytest.mark.parametrize("shape", [[], [2], [2, 3], [2, 3, 4]])
def test_multi_consecutive(self, shape):
@pytest.mark.parametrize("python_based", [True, False])
def test_multi_consecutive(self, shape, python_based):
t = 20
td = TensorDict(
{
Expand All @@ -2016,6 +2030,7 @@ def test_multi_consecutive(self, shape):
batch_first=True,
in_keys=["observation", "hidden"],
out_keys=["intermediate", ("next", "hidden")],
python_based=python_based,
)
gru_module_ms = gru_module_ss.set_recurrent_mode()
gru_module_ms(td)
Expand All @@ -2035,7 +2050,8 @@ def test_multi_consecutive(self, shape):
td_ss["intermediate"], td["intermediate"][..., -1, :]
)

def test_gru_parallel_env(self):
@pytest.mark.parametrize("python_based", [True, False])
def test_gru_parallel_env(self, python_based):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

device = "cuda" if torch.cuda.device_count() else "cpu"
Expand All @@ -2047,6 +2063,7 @@ def test_gru_parallel_env(self):
in_key="observation",
out_key="features",
device=device,
python_based=python_based,
)

def create_transformed_env():
Expand Down
4 changes: 4 additions & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,12 @@
DistributionalQValueModule,
EGreedyModule,
EGreedyWrapper,
GRU,
GRUCell,
GRUModule,
LMHeadActorValueOperator,
LSTM,
LSTMCell,
LSTMModule,
OrnsteinUhlenbeckProcessWrapper,
ProbabilisticActor,
Expand Down
Loading

0 comments on commit d432a9c

Please sign in to comment.