Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] RB MultiStep transform #2008

Merged
merged 10 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -823,3 +823,11 @@ Utils
consolidate_spec
check_no_exclusive_keys
contains_lazy_spec

.. currentmodule:: torchrl.envs.transforms.rb_transforms

.. autosummary::
:toctree: generated/
:template: rl_template.rst

MultiStepTransform
44 changes: 22 additions & 22 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def test_dqn_state_dict(self, delay_value, device, action_spec_type):
loss_fn2 = DQNLoss(actor, loss_function="l2", delay_value=delay_value)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("n", range(4))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
Expand Down Expand Up @@ -579,7 +579,7 @@ def test_dqn_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9):

with torch.no_grad():
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*td.keys(True, True)))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss")]
Expand Down Expand Up @@ -1125,7 +1125,7 @@ def test_qmixer_state_dict(self, delay_value, device, action_spec_type):
loss_fn2 = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("n", range(4))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
Expand Down Expand Up @@ -1158,7 +1158,7 @@ def test_qmix_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9)

with torch.no_grad():
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*td.keys(True, True)))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss")]
Expand Down Expand Up @@ -1801,7 +1801,7 @@ def test_ddpg_separate_losses(
raise NotImplementedError(k)
loss_fn.zero_grad()

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("delay_actor,delay_value", [(False, False), (True, True)])
def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9):
Expand Down Expand Up @@ -1832,7 +1832,7 @@ def test_ddpg_batcher(self, n, delay_actor, delay_value, device, gamma=0.9):

with torch.no_grad():
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -2433,7 +2433,7 @@ def test_td3_separate_losses(
loss_fn.zero_grad()

@pytest.mark.skipif(not _has_functorch, reason="functorch not installed")
@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("delay_actor,delay_qvalue", [(False, False), (True, True)])
@pytest.mark.parametrize("policy_noise", [0.1, 1.0])
Expand Down Expand Up @@ -2479,7 +2479,7 @@ def test_td3_batcher(
np.random.seed(0)
loss = loss_fn(td)

if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -3228,7 +3228,7 @@ def test_sac_separate_losses(
raise NotImplementedError(k)
loss_fn.zero_grad()

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_value", (True, False))
@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
Expand Down Expand Up @@ -3292,7 +3292,7 @@ def test_sac_batcher(
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -3927,7 +3927,7 @@ def test_discrete_sac_state_dict(
)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [2])
@pytest.mark.parametrize("device", get_default_devices())
Expand Down Expand Up @@ -3983,7 +3983,7 @@ def test_discrete_sac_batcher(
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -4871,7 +4871,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est):
# TODO: find a way to compare the losses: problem is that we sample actions either sequentially or in batch,
# so setting seed has little impact

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("device", get_default_devices())
Expand Down Expand Up @@ -4914,7 +4914,7 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9):
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -5482,7 +5482,7 @@ def test_cql_state_dict(
)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_actor", (True, False))
@pytest.mark.parametrize("delay_qvalue", (True, False))
@pytest.mark.parametrize("max_q_backup", [True, False])
Expand Down Expand Up @@ -5537,7 +5537,7 @@ def test_cql_batcher(
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -5843,7 +5843,7 @@ def test_dcql_state_dict(self, delay_value, device, action_spec_type):
loss_fn2 = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value)
loss_fn2.load_state_dict(sd)

@pytest.mark.parametrize("n", range(4))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
Expand Down Expand Up @@ -5874,7 +5874,7 @@ def test_dcql_batcher(self, n, delay_value, device, action_spec_type, gamma=0.9)

with torch.no_grad():
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*td.keys(True, True)))
_loss = sum([item for key, item in loss.items() if key.startswith("loss_")])
_loss_ms = sum(
Expand Down Expand Up @@ -9356,7 +9356,7 @@ def test_iql_separate_losses(self, separate_losses):
raise NotImplementedError(k)
loss_fn.zero_grad()

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0])
@pytest.mark.parametrize("expectile", [0.1, 0.5, 1.0])
Expand Down Expand Up @@ -9407,7 +9407,7 @@ def test_iql_batcher(
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down Expand Up @@ -10168,7 +10168,7 @@ def test_discrete_iql_separate_losses(self, separate_losses):
raise NotImplementedError(k)
loss_fn.zero_grad()

@pytest.mark.parametrize("n", list(range(4)))
@pytest.mark.parametrize("n", range(1, 4))
@pytest.mark.parametrize("num_qvalue", [1, 2, 4, 8])
@pytest.mark.parametrize("temperature", [0.0, 0.1, 1.0, 10.0])
@pytest.mark.parametrize("expectile", [0.1, 0.5])
Expand Down Expand Up @@ -10219,7 +10219,7 @@ def test_discrete_iql_batcher(
torch.manual_seed(0) # log-prob is computed with a random action
np.random.seed(0)
loss = loss_fn(td)
if n == 0:
if n == 1:
assert_allclose_td(td, ms_td.select(*list(td.keys(True, True))))
_loss = sum(
[item for name, item in loss.items() if name.startswith("loss_")]
Expand Down
47 changes: 13 additions & 34 deletions test/test_postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchrl.data.postprocs.postprocs import MultiStep


@pytest.mark.parametrize("n", range(13))
@pytest.mark.parametrize("n", range(1, 14))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("key", ["observation", "pixels", "observation_whatever"])
def test_multistep(n, key, device, T=11):
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_multistep(n, key, device, T=11):

assert ms_tensordict.get("done").max() == 1

if n == 0:
if n == 1:
assert_allclose_td(
tensordict, ms_tensordict.select(*list(tensordict.keys(True, True)))
)
Expand All @@ -76,20 +76,18 @@ def test_multistep(n, key, device, T=11):
)

# check that next obs is properly replaced, or that it is terminated
next_obs = ms_tensordict.get(key)[:, (1 + ms.n_steps) :]
true_next_obs = ms_tensordict.get(("next", key))[:, : -(1 + ms.n_steps)]
next_obs = ms_tensordict.get(key)[:, (ms.n_steps) :]
true_next_obs = ms_tensordict.get(("next", key))[:, : -(ms.n_steps)]
terminated = ~ms_tensordict.get("nonterminal")
assert (
(next_obs == true_next_obs).all(-1) | terminated[:, (1 + ms.n_steps) :]
).all()
assert ((next_obs == true_next_obs).all(-1) | terminated[:, (ms.n_steps) :]).all()

# test gamma computation
torch.testing.assert_close(
ms_tensordict.get("gamma"), ms.gamma ** ms_tensordict.get("steps_to_next_obs")
)

# test reward
if n > 0:
if n > 1:
assert (
ms_tensordict.get(("next", "reward"))
!= ms_tensordict.get(("next", "original_reward"))
Expand All @@ -105,36 +103,17 @@ def test_multistep(n, key, device, T=11):
@pytest.mark.parametrize(
"batch_size",
[
[
4,
],
[4],
[],
[
1,
],
[1],
[2, 3],
],
)
@pytest.mark.parametrize(
"T",
[
10,
1,
2,
],
)
@pytest.mark.parametrize(
"obs_dim",
[
[
1,
],
[],
],
)
@pytest.mark.parametrize("T", [10, 1, 2])
@pytest.mark.parametrize("obs_dim", [[1], []])
@pytest.mark.parametrize("unsq_reward", [True, False])
@pytest.mark.parametrize("last_done", [True, False])
@pytest.mark.parametrize("n_steps", [3, 1, 0])
@pytest.mark.parametrize("n_steps", [4, 2, 1])
def test_mutistep_cattrajs(
batch_size, T, obs_dim, unsq_reward, last_done, device, n_steps
):
Expand Down Expand Up @@ -166,7 +145,7 @@ def test_mutistep_cattrajs(
)
ms = MultiStep(0.98, n_steps)
tdm = ms(td)
if n_steps == 0:
if n_steps == 1:
# n_steps = 0 has no effect
for k in td["next"].keys():
assert (tdm["next", k] == td["next", k]).all()
Expand All @@ -179,7 +158,7 @@ def test_mutistep_cattrajs(
if unsq_reward:
done = done.squeeze(-1)
for t in range(T):
idx = t + n_steps
idx = t + n_steps - 1
while (done[..., t:idx].any() and idx > t) or idx > done.shape[-1] - 1:
idx = idx - 1
next_obs.append(obs[..., idx])
Expand Down
Loading
Loading