Skip to content

Commit

Permalink
Fix Lunar Lander Test (#177)
Browse files Browse the repository at this point in the history
The lunar lander trainer for SAC required a `policy` parameter for
`create_train_callbacks()`, otherwise it wouldn't run. This PR fixes
that.
  • Loading branch information
jaxs-ribs authored Oct 11, 2023
1 parent 5b50d8d commit e99a1ec
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions experiments/gym/train_lunar_lander.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def create_actor_critic_agents(
policy = policy.to(device)
policy_proxy = FeatureAgentProxy(policy, device=device)
ln_alpha = torch.tensor(np.log(init_alpha), requires_grad=True, device=device)
return q1, q2, policy_proxy, ln_alpha
return q1, q2, policy_proxy, ln_alpha, policy


def create_train_callbacks(
Expand Down Expand Up @@ -354,7 +354,7 @@ def create_complementary_callbacks(
)

"""Creating the actor (policy) and critics (the two Q-functions) agents """
qnet1, qnet2, agent_proxy, ln_alpha = create_actor_critic_agents(
qnet1, qnet2, agent_proxy, ln_alpha, policy = create_actor_critic_agents(
args=input_args, num_actions=number_of_actions, num_obs=number_of_obs
)

Expand All @@ -363,6 +363,7 @@ def create_complementary_callbacks(
args=input_args,
q1=qnet1,
q2=qnet2,
policy=policy,
policy_proxy=agent_proxy,
ln_alpha=ln_alpha,
env=gym_wrapper,
Expand Down

0 comments on commit e99a1ec

Please sign in to comment.