From 6f22bb76907e9aca494e31045a9ddb36e0d0fe2e Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 14 Oct 2020 11:14:37 +0900 Subject: [PATCH 1/2] Detach greedy_actions before calling cpu() --- pfrl/agents/dqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pfrl/agents/dqn.py b/pfrl/agents/dqn.py index 3fa766577..7d92d1440 100644 --- a/pfrl/agents/dqn.py +++ b/pfrl/agents/dqn.py @@ -483,7 +483,7 @@ def _evaluate_model_and_update_recurrent_states( def batch_act(self, batch_obs: Sequence[Any]) -> Sequence[Any]: with torch.no_grad(), evaluating(self.model): batch_av = self._evaluate_model_and_update_recurrent_states(batch_obs) - batch_argmax = batch_av.greedy_actions.cpu().numpy() + batch_argmax = batch_av.greedy_actions.detach().cpu().numpy() if self.training: batch_action = [ self.explorer.select_action( From 70caf994c7e096656863e47d586c8073080df7d0 Mon Sep 17 00:00:00 2001 From: muupan Date: Wed, 14 Oct 2020 11:47:52 +0900 Subject: [PATCH 2/2] Fix StateQFunctionActor as well --- pfrl/agents/state_q_function_actor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pfrl/agents/state_q_function_actor.py b/pfrl/agents/state_q_function_actor.py index 7a4c32531..6f2047243 100644 --- a/pfrl/agents/state_q_function_actor.py +++ b/pfrl/agents/state_q_function_actor.py @@ -76,7 +76,7 @@ def _send_to_learner(self, transition, stop_episode=False): def act(self, obs): with torch.no_grad(), evaluating(self.model): action_value = self._evaluate_model_and_update_recurrent_states([obs]) - greedy_action = action_value.greedy_actions.cpu().numpy()[0] + greedy_action = action_value.greedy_actions.detach().cpu().numpy()[0] if self.training: action = self.explorer.select_action( self.t, lambda: greedy_action, action_value=action_value