Skip to content

Commit

Permalink
[Quality] Fix low/high in SOTA implementations (#2266)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 3, 2024
1 parent 79fa8bf commit ba6897d
Show file tree
Hide file tree
Showing 17 changed files with 34 additions and 34 deletions.
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def make_ppo_modules_pixels(proof_environment):
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"min": proof_environment.action_spec.space.low,
"max": proof_environment.action_spec.space.high,
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
}

# Define input keys
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/a2c/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"min": proof_environment.action_spec.space.low,
"max": proof_environment.action_spec.space.high,
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"tanh_loc": False,
}

Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/cql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
spec=action_spec,
distribution_class=TanhNormal,
distribution_kwargs={
"min": action_spec.space.low[len(train_env.batch_size) :],
"max": action_spec.space.high[
"low": action_spec.space.low[len(train_env.batch_size) :],
"high": action_spec.space.high[
len(train_env.batch_size) :
], # remove batch-size
"tanh_loc": False,
Expand Down
6 changes: 3 additions & 3 deletions sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def make_odt_model(cfg):
],
)
dist_class = TanhNormal
dist_kwargs = {"min": -1.0, "max": 1.0, "tanh_loc": False, "upscale": 5.0}
dist_kwargs = {"low": -1.0, "high": 1.0, "tanh_loc": False, "upscale": 5.0}

actor = ProbabilisticActor(
spec=action_spec,
Expand Down Expand Up @@ -409,8 +409,8 @@ def make_dt_model(cfg):
)
dist_class = TanhDelta
dist_kwargs = {
"min": action_spec.space.low,
"max": action_spec.space.high,
"low": action_spec.space.low,
"high": action_spec.space.high,
}

actor = ProbabilisticActor(
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/iql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ def make_iql_model(cfg, train_env, eval_env, device="cpu"):
spec=action_spec,
distribution_class=TanhNormal,
distribution_kwargs={
"min": action_spec.space.low,
"max": action_spec.space.high,
"low": action_spec.space.low,
"high": action_spec.space.high,
"tanh_loc": False,
},
default_interaction_type=ExplorationType.RANDOM,
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/multiagent/maddpg_iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def train(cfg: "DictConfig"): # noqa: F821
out_keys=[env.action_key],
distribution_class=TanhDelta,
distribution_kwargs={
"min": env.unbatched_action_spec[("agents", "action")].space.low,
"max": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
},
return_log_prob=False,
)
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/multiagent/mappo_ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def train(cfg: "DictConfig"): # noqa: F821
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"min": env.unbatched_action_spec[("agents", "action")].space.low,
"max": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/multiagent/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def train(cfg: "DictConfig"): # noqa: F821
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"min": env.unbatched_action_spec[("agents", "action")].space.low,
"max": env.unbatched_action_spec[("agents", "action")].space.high,
"low": env.unbatched_action_spec[("agents", "action")].space.low,
"high": env.unbatched_action_spec[("agents", "action")].space.high,
},
return_log_prob=True,
)
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/ppo/utils_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def make_ppo_modules_pixels(proof_environment):
num_outputs = proof_environment.action_spec.shape
distribution_class = TanhNormal
distribution_kwargs = {
"min": proof_environment.action_spec.space.low,
"max": proof_environment.action_spec.space.high,
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
}

# Define input keys
Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/ppo/utils_mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def make_ppo_models_state(proof_environment):
num_outputs = proof_environment.action_spec.shape[-1]
distribution_class = TanhNormal
distribution_kwargs = {
"min": proof_environment.action_spec.space.low,
"max": proof_environment.action_spec.space.high,
"low": proof_environment.action_spec.space.low,
"high": proof_environment.action_spec.space.high,
"tanh_loc": False,
}

Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/redq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,8 @@ def make_redq_model(

dist_class = TanhNormal
dist_kwargs = {
"min": action_spec.space.low,
"max": action_spec.space.high,
"low": action_spec.space.low,
"high": action_spec.space.high,
"tanh_loc": tanh_loc,
}

Expand Down
4 changes: 2 additions & 2 deletions sota-implementations/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def make_sac_agent(cfg, train_env, eval_env, device):

dist_class = TanhNormal
dist_kwargs = {
"min": action_spec.space.low,
"max": action_spec.space.high,
"low": action_spec.space.low,
"high": action_spec.space.high,
"tanh_loc": False,
}

Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ def _warn_minmax(self):
warnings.warn(
f"the min / high keyword arguments are deprecated in favor of low / high in {type(self).__name__} "
f"and will be removed entirely in v0.6. ",
DeprecationWarning,
category=DeprecationWarning,
)

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1810,8 +1810,8 @@ class DecisionTransformerInferenceWrapper(TensorDictModuleWrapper):
... out_keys=["param"])
>>> dist_class = TanhDelta
>>> dist_kwargs = {
... "min": -1.0,
... "max": 1.0,
... "low": -1.0,
... "high": 1.0,
... }
>>> actor = ProbabilisticActor(
... in_keys=["param"],
Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx-tutorials/coding_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,8 @@
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"min": env.action_spec.space.low,
"max": env.action_spec.space.high,
"low": env.action_spec.space.low,
"high": env.action_spec.space.high,
},
return_log_prob=True,
# we'll need the log-prob for the numerator of the importance weights
Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx-tutorials/multiagent_competitive_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,8 @@
out_keys=[(group, "action")],
distribution_class=TanhDelta,
distribution_kwargs={
"min": env.full_action_spec[group, "action"].space.low,
"max": env.full_action_spec[group, "action"].space.high,
"low": env.full_action_spec[group, "action"].space.low,
"high": env.full_action_spec[group, "action"].space.high,
},
return_log_prob=False,
)
Expand Down
4 changes: 2 additions & 2 deletions tutorials/sphinx-tutorials/multiagent_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,8 @@
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"min": env.unbatched_action_spec[env.action_key].space.low,
"max": env.unbatched_action_spec[env.action_key].space.high,
"low": env.unbatched_action_spec[env.action_key].space.low,
"high": env.unbatched_action_spec[env.action_key].space.high,
},
return_log_prob=True,
log_prob_key=("agents", "sample_log_prob"),
Expand Down

0 comments on commit ba6897d

Please sign in to comment.