Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 24, 2024
1 parent 515ae33 commit e798882
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 8 deletions.
1 change: 1 addition & 0 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@

try:
multiprocessing.set_start_method("spawn" if is_sphinx else "fork")
mp_context = "fork"
except RuntimeError:
# If we can't set the method globally we can still run the parallel env with "fork"
# This will fail on windows! Use "spawn" and put the script within `if __name__ == "__main__"`
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/dqn_with_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@
# either by passing a string or an action-spec. This allows us to use
# Categorical (sometimes called "sparse") encoding or the one-hot version of it.
#
qval = QValueModule(action_space=env.action_spec)
qval = QValueModule(action_space=None, spec=env.action_spec)

######################################################################
# .. note::
Expand Down
4 changes: 1 addition & 3 deletions tutorials/sphinx-tutorials/getting-started-1.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,7 @@

policy = TensorDictSequential(
value_net, # writes action values in our tensordict
QValueModule(
action_space=env.action_spec
), # Reads the "action_value" entry by default
QValueModule(spec=env.action_spec), # Reads the "action_value" entry by default
)

###################################
Expand Down
2 changes: 1 addition & 1 deletion tutorials/sphinx-tutorials/getting-started-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

value_mlp = MLP(out_features=env.action_spec.shape[-1], num_cells=[64, 64])
value_net = Mod(value_mlp, in_keys=["observation"], out_keys=["action_value"])
policy = Seq(value_net, QValueModule(env.action_spec))
policy = Seq(value_net, QValueModule(spec=env.action_spec))
exploration_module = EGreedyModule(
env.action_spec, annealing_num_steps=100_000, eps_init=0.5
)
Expand Down
4 changes: 1 addition & 3 deletions tutorials/sphinx-tutorials/torchrl_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,7 @@
# we can just generate a random action:


policy = TensorDictModule(
functools.partial(env.action_spec.rand, env=env), in_keys=[], out_keys=["action"]
)
policy = TensorDictModule(env.action_spec.rand, in_keys=[], out_keys=["action"])


policy(reset_data)
Expand Down

0 comments on commit e798882

Please sign in to comment.