diff --git a/test/test_cost.py b/test/test_cost.py index 12c2a6b6d6f..8bae683c5d5 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -426,11 +426,13 @@ def _create_seq_mock_data_dqn( ) return td - @pytest.mark.parametrize("delay_value", (False, True)) + @pytest.mark.parametrize( + "delay_value,double_dqn", ([False, False], [True, False], [True, True]) + ) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("action_spec_type", ("one_hot", "categorical")) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - def test_dqn(self, delay_value, device, action_spec_type, td_est): + def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est): torch.manual_seed(self.seed) actor = self._create_mock_actor( action_spec_type=action_spec_type, device=device @@ -438,7 +440,9 @@ def test_dqn(self, delay_value, device, action_spec_type, td_est): td = self._create_mock_data_dqn( action_spec_type=action_spec_type, device=device ) - loss_fn = DQNLoss(actor, loss_function="l2", delay_value=delay_value) + loss_fn = DQNLoss( + actor, loss_function="l2", delay_value=delay_value, double_dqn=double_dqn + ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index a3ee1dfa893..7ce0938facb 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -68,7 +68,7 @@ class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild): Args: env_name (str): the environment name to build. - read_info (bool, optional): whether the the info should be parsed. + read_info (bool, optional): whether the info should be parsed. Defaults to ``True``. device (torch.device, optional): the device on which the input/output are expected. Defaults to torch default device. diff --git a/torchrl/modules/distributions/discrete.py b/torchrl/modules/distributions/discrete.py index bb98b1412a8..d73457b2261 100644 --- a/torchrl/modules/distributions/discrete.py +++ b/torchrl/modules/distributions/discrete.py @@ -163,7 +163,7 @@ class MaskedCategorical(D.Categorical): must be taken into account. Exclusive with ``mask``. neg_inf (float, optional): The log-probability value allocated to invalid (out-of-mask) indices. Defaults to -inf. - padding_value: The padding value in the then mask tensor when + padding_value: The padding value in the mask tensor. When sparse_mask == True, the padding_value will be ignored. >>> torch.manual_seed(0) @@ -314,7 +314,7 @@ class MaskedOneHotCategorical(MaskedCategorical): must be taken into account. Exclusive with ``mask``. neg_inf (float, optional): The log-probability value allocated to invalid (out-of-mask) indices. Defaults to -inf. - padding_value: The padding value in the then mask tensor when + padding_value: The padding value in then mask tensor when sparse_mask == True, the padding_value will be ignored. grad_method (ReparamGradientStrategy, optional): strategy to gather reparameterized samples. diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 74b95d86daf..59c3f32697f 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -184,6 +184,8 @@ def __init__( delay_value = False super().__init__() self._in_keys = None + if double_dqn and not delay_value: + raise ValueError("double_dqn=True requires delay_value=True.") self.double_dqn = double_dqn self._set_deprecated_ctor_keys(priority=priority_key) self.delay_value = delay_value @@ -310,7 +312,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: pred_val = td_copy.get(self.tensor_keys.action_value) if self.action_space == "categorical": - if action.shape != pred_val.shape: + if action.ndim != pred_val.ndim: # unsqueeze the action if it lacks on trailing singleton dim action = action.unsqueeze(-1) pred_val_index = torch.gather(pred_val, -1, index=action).squeeze(-1) @@ -321,17 +323,20 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: if self.double_dqn: step_td = step_mdp(td_copy, keep_other=False) step_td_copy = step_td.clone(False) - - with self.target_value_network_params.to_module(self.value_network): + # Use online network to compute the action + with self.value_network_params.data.to_module(self.value_network): self.value_network(step_td) + next_action = step_td.get(self.tensor_keys.action) + # Use target network to compute the values with self.target_value_network_params.to_module(self.value_network): self.value_network(step_td_copy) - next_action = step_td.get(self.tensor_keys.action).to(torch.float) - next_pred_val = step_td_copy.get(self.tensor_keys.action_value) + next_pred_val = step_td_copy.get(self.tensor_keys.action_value) + if self.action_space == "categorical": - if next_action.shape != next_pred_val.shape: - next_action = action.unsqueeze(-1) + if next_action.ndim != next_pred_val.ndim: + # unsqueeze the action if it lacks on trailing singleton dim + next_action = next_action.unsqueeze(-1) next_value = torch.gather(next_pred_val, -1, index=next_action) else: next_value = (next_pred_val * next_action).sum(-1, keepdim=True) @@ -404,9 +409,9 @@ class _AcceptedKeys: Defaults to ``"td_error"``. reward (NestedKey): The input tensordict key where the reward is expected. Defaults to ``"reward"``. - done (NestedKey): The input tensordict key where the the flag if a trajectory is done is expected. + done (NestedKey): The input tensordict key where the flag if a trajectory is done is expected. Defaults to ``"done"``. - terminated (NestedKey): The input tensordict key where the the flag if a trajectory is done is expected. + terminated (NestedKey): The input tensordict key where the flag if a trajectory is done is expected. Defaults to ``"terminated"``. steps_to_next_obs (NestedKey): The input tensordict key where the steps_to_next_obs is exptected. Defaults to ``"steps_to_next_obs"``.