diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index f3aff0da1d2..1c43d536fe8 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -160,7 +160,7 @@ def _call_actor_net( log_prob_key: NestedKey, ): # TODO: extend to handle time dimension (and vmap?) - log_pi = actor_net(data.select(actor_net.in_keys)).get(log_prob_key) + log_pi = actor_net(data.select(*actor_net.in_keys)).get(log_prob_key) return log_pi