Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Dec 6, 2023
1 parent d0eaef3 commit ffd9518
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 15 deletions.
10 changes: 7 additions & 3 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,19 +426,23 @@ def _create_seq_mock_data_dqn(
)
return td

@pytest.mark.parametrize("delay_value", (False, True))
@pytest.mark.parametrize(
"delay_value,double_dqn", ([False, False], [True, False], [True, True])
)
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical"))
@pytest.mark.parametrize("td_est", list(ValueEstimators) + [None])
def test_dqn(self, delay_value, device, action_spec_type, td_est):
def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est):
torch.manual_seed(self.seed)
actor = self._create_mock_actor(
action_spec_type=action_spec_type, device=device
)
td = self._create_mock_data_dqn(
action_spec_type=action_spec_type, device=device
)
loss_fn = DQNLoss(actor, loss_function="l2", delay_value=delay_value)
loss_fn = DQNLoss(
actor, loss_function="l2", delay_value=delay_value, double_dqn=double_dqn
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/libs/robohive.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild):
Args:
env_name (str): the environment name to build.
read_info (bool, optional): whether the the info should be parsed.
read_info (bool, optional): whether the info should be parsed.
Defaults to ``True``.
device (torch.device, optional): the device on which the input/output
are expected. Defaults to torch default device.
Expand Down
4 changes: 2 additions & 2 deletions torchrl/modules/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class MaskedCategorical(D.Categorical):
must be taken into account. Exclusive with ``mask``.
neg_inf (float, optional): The log-probability value allocated to
invalid (out-of-mask) indices. Defaults to -inf.
padding_value: The padding value in the then mask tensor when
padding_value: The padding value in the mask tensor. When
sparse_mask == True, the padding_value will be ignored.
>>> torch.manual_seed(0)
Expand Down Expand Up @@ -314,7 +314,7 @@ class MaskedOneHotCategorical(MaskedCategorical):
must be taken into account. Exclusive with ``mask``.
neg_inf (float, optional): The log-probability value allocated to
invalid (out-of-mask) indices. Defaults to -inf.
padding_value: The padding value in the then mask tensor when
padding_value: The padding value in then mask tensor when
sparse_mask == True, the padding_value will be ignored.
grad_method (ReparamGradientStrategy, optional): strategy to gather
reparameterized samples.
Expand Down
23 changes: 14 additions & 9 deletions torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def __init__(
delay_value = False
super().__init__()
self._in_keys = None
if double_dqn and not delay_value:
raise ValueError("double_dqn=True requires delay_value=True.")
self.double_dqn = double_dqn
self._set_deprecated_ctor_keys(priority=priority_key)
self.delay_value = delay_value
Expand Down Expand Up @@ -310,7 +312,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
pred_val = td_copy.get(self.tensor_keys.action_value)

if self.action_space == "categorical":
if action.shape != pred_val.shape:
if action.ndim != pred_val.ndim:
# unsqueeze the action if it lacks on trailing singleton dim
action = action.unsqueeze(-1)
pred_val_index = torch.gather(pred_val, -1, index=action).squeeze(-1)
Expand All @@ -321,17 +323,20 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
if self.double_dqn:
step_td = step_mdp(td_copy, keep_other=False)
step_td_copy = step_td.clone(False)

with self.target_value_network_params.to_module(self.value_network):
# Use online network to compute the action
with self.value_network_params.data.to_module(self.value_network):
self.value_network(step_td)
next_action = step_td.get(self.tensor_keys.action)

# Use target network to compute the values
with self.target_value_network_params.to_module(self.value_network):
self.value_network(step_td_copy)
next_action = step_td.get(self.tensor_keys.action).to(torch.float)
next_pred_val = step_td_copy.get(self.tensor_keys.action_value)
next_pred_val = step_td_copy.get(self.tensor_keys.action_value)

if self.action_space == "categorical":
if next_action.shape != next_pred_val.shape:
next_action = action.unsqueeze(-1)
if next_action.ndim != next_pred_val.ndim:
# unsqueeze the action if it lacks on trailing singleton dim
next_action = next_action.unsqueeze(-1)
next_value = torch.gather(next_pred_val, -1, index=next_action)
else:
next_value = (next_pred_val * next_action).sum(-1, keepdim=True)
Expand Down Expand Up @@ -404,9 +409,9 @@ class _AcceptedKeys:
Defaults to ``"td_error"``.
reward (NestedKey): The input tensordict key where the reward is expected.
Defaults to ``"reward"``.
done (NestedKey): The input tensordict key where the the flag if a trajectory is done is expected.
done (NestedKey): The input tensordict key where the flag if a trajectory is done is expected.
Defaults to ``"done"``.
terminated (NestedKey): The input tensordict key where the the flag if a trajectory is done is expected.
terminated (NestedKey): The input tensordict key where the flag if a trajectory is done is expected.
Defaults to ``"terminated"``.
steps_to_next_obs (NestedKey): The input tensordict key where the steps_to_next_obs is exptected.
Defaults to ``"steps_to_next_obs"``.
Expand Down

0 comments on commit ffd9518

Please sign in to comment.