diff --git a/pfrl/agents/dqn.py b/pfrl/agents/dqn.py index 3fa766577..7d92d1440 100644 --- a/pfrl/agents/dqn.py +++ b/pfrl/agents/dqn.py @@ -483,7 +483,7 @@ def _evaluate_model_and_update_recurrent_states( def batch_act(self, batch_obs: Sequence[Any]) -> Sequence[Any]: with torch.no_grad(), evaluating(self.model): batch_av = self._evaluate_model_and_update_recurrent_states(batch_obs) - batch_argmax = batch_av.greedy_actions.cpu().numpy() + batch_argmax = batch_av.greedy_actions.detach().cpu().numpy() if self.training: batch_action = [ self.explorer.select_action( diff --git a/pfrl/agents/state_q_function_actor.py b/pfrl/agents/state_q_function_actor.py index 7a4c32531..6f2047243 100644 --- a/pfrl/agents/state_q_function_actor.py +++ b/pfrl/agents/state_q_function_actor.py @@ -76,7 +76,7 @@ def _send_to_learner(self, transition, stop_episode=False): def act(self, obs): with torch.no_grad(), evaluating(self.model): action_value = self._evaluate_model_and_update_recurrent_states([obs]) - greedy_action = action_value.greedy_actions.cpu().numpy()[0] + greedy_action = action_value.greedy_actions.detach().cpu().numpy()[0] if self.training: action = self.explorer.select_action( self.t, lambda: greedy_action, action_value=action_value