From b9da8d22c2502f796fa9af31d77cfc193b52816c Mon Sep 17 00:00:00 2001 From: Daniel Plop <1534513+dantp-ai@users.noreply.github.com> Date: Wed, 3 Apr 2024 18:07:51 +0200 Subject: [PATCH] Fix mypy issues in tests and examples (#1077) Closes #952 - `SamplingConfig` supports `batch_size=None`. #1077 - tests and examples are covered by `mypy`. #1077 - `NetBase` is more used, stricter typing by making it generic. #1077 - `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077 --------- Co-authored-by: Michael Panchenko --- .pre-commit-config.yaml | 2 +- CHANGELOG.md | 5 +- docs/02_notebooks/L0_overview.ipynb | 15 +- docs/02_notebooks/L5_Collector.ipynb | 5 +- docs/02_notebooks/L6_Trainer.ipynb | 5 +- docs/02_notebooks/L7_Experiment.ipynb | 5 +- examples/atari/atari_c51.py | 6 +- examples/atari/atari_dqn.py | 9 +- examples/atari/atari_fqf.py | 6 +- examples/atari/atari_iqn.py | 6 +- examples/atari/atari_network.py | 90 +++-- examples/atari/atari_ppo.py | 18 +- examples/atari/atari_ppo_hl.py | 6 +- examples/atari/atari_qrdqn.py | 16 +- examples/atari/atari_rainbow.py | 6 +- examples/atari/atari_sac.py | 11 +- examples/atari/atari_sac_hl.py | 7 +- examples/atari/atari_wrapper.py | 179 ++++++---- examples/box2d/acrobot_dualdqn.py | 7 +- examples/box2d/bipedal_bdq.py | 27 +- examples/box2d/bipedal_hardcore_sac.py | 17 +- examples/box2d/lunarlander_dqn.py | 13 +- examples/box2d/mcc_sac.py | 10 +- examples/common.py | 3 - examples/discrete/discrete_dqn.py | 7 +- examples/inverse/irl_gail.py | 7 +- examples/mujoco/analysis.py | 3 +- examples/mujoco/fetch_her_ddpg.py | 58 ++-- examples/mujoco/gen_json.py | 3 +- examples/mujoco/mujoco_a2c.py | 3 +- examples/mujoco/mujoco_a2c_hl.py | 2 +- examples/mujoco/mujoco_ddpg.py | 10 +- examples/mujoco/mujoco_env.py | 16 +- examples/mujoco/mujoco_npg.py | 3 +- examples/mujoco/mujoco_npg_hl.py | 2 +- examples/mujoco/mujoco_ppo.py | 3 +- examples/mujoco/mujoco_redq.py | 11 +- examples/mujoco/mujoco_reinforce.py | 3 +- examples/mujoco/mujoco_sac.py | 13 +- examples/mujoco/mujoco_td3.py | 13 +- examples/mujoco/mujoco_trpo.py | 3 +- examples/mujoco/mujoco_trpo_hl.py | 2 +- examples/mujoco/plotter.py | 59 ++-- examples/mujoco/tools.py | 20 +- examples/offline/atari_bcq.py | 20 +- examples/offline/atari_cql.py | 28 +- examples/offline/atari_crr.py | 24 +- examples/offline/atari_il.py | 18 +- .../offline/convert_rl_unplugged_atari.py | 17 +- examples/offline/d4rl_bcq.py | 8 +- examples/offline/d4rl_cql.py | 15 +- examples/offline/d4rl_il.py | 4 +- examples/offline/d4rl_td3_bc.py | 8 +- examples/vizdoom/env.py | 53 ++- examples/vizdoom/replay.py | 6 +- examples/vizdoom/vizdoom_c51.py | 8 +- examples/vizdoom/vizdoom_ppo.py | 18 +- poetry.lock | 127 ++++--- pyproject.toml | 7 +- test/base/env.py | 2 +- test/base/test_batch.py | 27 +- test/base/test_buffer.py | 108 +++--- test/base/test_collector.py | 317 ++++++++++-------- test/base/test_env.py | 94 ++++-- test/base/test_env_finite.py | 124 ++++--- test/base/test_policy.py | 15 +- test/base/test_returns.py | 38 ++- test/base/test_stats.py | 11 +- test/continuous/test_ddpg.py | 6 +- test/continuous/test_ppo.py | 4 +- test/continuous/test_redq.py | 9 +- test/continuous/test_sac_with_il.py | 14 +- test/continuous/test_td3.py | 10 +- test/continuous/test_trpo.py | 2 +- test/discrete/test_a2c_with_il.py | 13 +- test/discrete/test_bdq.py | 15 +- test/discrete/test_c51.py | 7 +- test/discrete/test_dqn.py | 9 +- test/discrete/test_drqn.py | 3 +- test/discrete/test_fqf.py | 5 +- test/discrete/test_iqn.py | 3 +- test/discrete/test_pg.py | 6 +- test/discrete/test_ppo.py | 2 +- test/discrete/test_qrdqn.py | 7 +- test/discrete/test_rainbow.py | 7 +- test/discrete/test_sac.py | 7 +- test/modelbased/test_dqn_icm.py | 7 +- test/modelbased/test_ppo_icm.py | 2 +- test/modelbased/test_psrl.py | 5 +- test/offline/gather_cartpole_data.py | 7 +- test/offline/gather_pendulum_data.py | 6 +- test/offline/test_bcq.py | 4 +- test/offline/test_cql.py | 11 +- test/offline/test_discrete_bcq.py | 5 +- test/offline/test_discrete_cql.py | 9 +- test/offline/test_discrete_crr.py | 9 +- test/offline/test_gail.py | 15 +- test/offline/test_td3_bc.py | 8 +- test/pettingzoo/pistonball.py | 17 +- test/pettingzoo/pistonball_continuous.py | 27 +- test/pettingzoo/tic_tac_toe.py | 16 +- tianshou/data/batch.py | 4 +- tianshou/highlevel/config.py | 7 +- tianshou/highlevel/module/actor.py | 14 +- tianshou/highlevel/module/critic.py | 6 +- tianshou/utils/net/common.py | 48 +-- 106 files changed, 1265 insertions(+), 903 deletions(-) delete mode 100644 examples/common.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 54b9a9794..aa00e7474 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: pass_filenames: false - id: mypy name: mypy - entry: poetry run mypy tianshou + entry: poetry run mypy tianshou examples test # filenames should not be passed as they would collide with the config in pyproject.toml pass_filenames: false files: '^tianshou(/[^/]*)*/[^/]*\.py$' diff --git a/CHANGELOG.md b/CHANGELOG.md index 52a91b34a..126f81a89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - `Collector`s can now be closed, and their reset is more granular. #1063 - Trainers can control whether collectors should be reset prior to training. #1063 - Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 +- `SamplingConfig` supports `batch_size=None`. #1077 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 @@ -20,6 +21,8 @@ instead of just `nn.Module`. #1032 - Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032 - Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032 - Exception no longer raised on `len` of empty `Batch`. #1084 +- tests and examples are covered by `mypy`. #1077 +- `NetBase` is more used, stricter typing by making it generic. #1077 ### Breaking Changes @@ -30,10 +33,10 @@ expicitly or pass `reset_before_collect=True` . #1063 - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 - Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both continuous and discrete cases. #1032 +- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077 ### Tests - Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081 Started after v1.0.0 - diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index 37cba0be5..59d6fd207 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -17,14 +17,14 @@ }, { "cell_type": "code", - "outputs": [], - "source": [ - "# !pip install tianshou gym" - ], + "execution_count": null, "metadata": { "collapsed": false }, - "execution_count": 0 + "outputs": [], + "source": [ + "# !pip install tianshou gym" + ] }, { "cell_type": "markdown", @@ -71,7 +71,7 @@ "\n", "from tianshou.data import Collector, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import BasePolicy, PPOPolicy\n", + "from tianshou.policy import PPOPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import ActorCritic, Net\n", "from tianshou.utils.net.discrete import Actor, Critic\n", @@ -106,8 +106,7 @@ "\n", "# PPO policy\n", "dist = torch.distributions.Categorical\n", - "policy: BasePolicy\n", - "policy = PPOPolicy(\n", + "policy: PPOPolicy = PPOPolicy(\n", " actor=actor,\n", " critic=critic,\n", " optim=optim,\n", diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb index 3e91e0f43..7da98a5cf 100644 --- a/docs/02_notebooks/L5_Collector.ipynb +++ b/docs/02_notebooks/L5_Collector.ipynb @@ -60,7 +60,7 @@ "\n", "from tianshou.data import Collector, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import BasePolicy, PGPolicy\n", + "from tianshou.policy import PGPolicy\n", "from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.discrete import Actor" ] @@ -87,8 +87,7 @@ "actor = Actor(net, env.action_space.n)\n", "optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n", "\n", - "policy: BasePolicy\n", - "policy = PGPolicy(\n", + "policy: PGPolicy = PGPolicy(\n", " actor=actor,\n", " optim=optim,\n", " dist_fn=torch.distributions.Categorical,\n", diff --git a/docs/02_notebooks/L6_Trainer.ipynb b/docs/02_notebooks/L6_Trainer.ipynb index d0f1ebf4d..75aea471c 100644 --- a/docs/02_notebooks/L6_Trainer.ipynb +++ b/docs/02_notebooks/L6_Trainer.ipynb @@ -75,7 +75,7 @@ "\n", "from tianshou.data import Collector, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import BasePolicy, PGPolicy\n", + "from tianshou.policy import PGPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import Net\n", "from tianshou.utils.net.discrete import Actor" @@ -110,9 +110,8 @@ "actor = Actor(net, env.action_space.n)\n", "optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n", "\n", - "policy: BasePolicy\n", "# We choose to use REINFORCE algorithm, also known as Policy Gradient\n", - "policy = PGPolicy(\n", + "policy: PGPolicy = PGPolicy(\n", " actor=actor,\n", " optim=optim,\n", " dist_fn=torch.distributions.Categorical,\n", diff --git a/docs/02_notebooks/L7_Experiment.ipynb b/docs/02_notebooks/L7_Experiment.ipynb index 46c065c75..9a97b20cb 100644 --- a/docs/02_notebooks/L7_Experiment.ipynb +++ b/docs/02_notebooks/L7_Experiment.ipynb @@ -73,7 +73,7 @@ "\n", "from tianshou.data import Collector, VectorReplayBuffer\n", "from tianshou.env import DummyVectorEnv\n", - "from tianshou.policy import BasePolicy, PPOPolicy\n", + "from tianshou.policy import PPOPolicy\n", "from tianshou.trainer import OnpolicyTrainer\n", "from tianshou.utils.net.common import ActorCritic, Net\n", "from tianshou.utils.net.discrete import Actor, Critic\n", @@ -164,8 +164,7 @@ "outputs": [], "source": [ "dist = torch.distributions.Categorical\n", - "policy: BasePolicy\n", - "policy = PPOPolicy(\n", + "policy: PPOPolicy = PPOPolicy(\n", " actor=actor,\n", " critic=critic,\n", " optim=optim,\n", diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 1946cc790..c6fe6dd04 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -9,8 +9,8 @@ from atari_network import C51 from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51Policy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -122,6 +122,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -182,8 +183,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index c669fa714..765463cd3 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -9,8 +9,8 @@ from atari_network import DQN from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DQNPolicy from tianshou.policy.base import BasePolicy from tianshou.policy.modelbased.icm import ICMPolicy @@ -104,7 +104,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy - policy: DQNPolicy = DQNPolicy( + policy: DQNPolicy | ICMPolicy + policy = DQNPolicy( model=net, optim=optim, action_space=env.action_space, @@ -157,6 +158,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -223,8 +225,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 31adf9efd..f616a6838 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -9,8 +9,8 @@ from atari_network import DQN from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import FQFPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -135,6 +135,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -195,8 +196,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index d9832b9ea..911069400 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -9,8 +9,8 @@ from atari_network import DQN from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import IQNPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -132,6 +132,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -192,8 +193,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_network.py b/examples/atari/atari_network.py index 2b8288a7c..ea900e975 100644 --- a/examples/atari/atari_network.py +++ b/examples/atari/atari_network.py @@ -14,6 +14,7 @@ IntermediateModule, IntermediateModuleFactory, ) +from tianshou.utils.net.common import NetBase from tianshou.utils.net.discrete import Actor, NoisyLinear @@ -24,7 +25,7 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0. class ScaledObsInputModule(torch.nn.Module): - def __init__(self, module: torch.nn.Module, denom: float = 255.0) -> None: + def __init__(self, module: NetBase, denom: float = 255.0) -> None: super().__init__() self.module = module self.denom = denom @@ -42,11 +43,11 @@ def forward( return self.module.forward(obs / self.denom, state, info) -def scale_obs(module: nn.Module, denom: float = 255.0) -> nn.Module: +def scale_obs(module: NetBase, denom: float = 255.0) -> ScaledObsInputModule: return ScaledObsInputModule(module, denom=denom) -class DQN(nn.Module): +class DQN(NetBase[Any]): """Reference: Human-level control through deep reinforcement learning. For advanced usage (how to customize the network), please refer to @@ -58,12 +59,17 @@ def __init__( c: int, h: int, w: int, - action_shape: Sequence[int], + action_shape: Sequence[int] | int, device: str | int | torch.device = "cpu", features_only: bool = False, - output_dim: int | None = None, + output_dim_added_layer: int | None = None, layer_init: Callable[[nn.Module], nn.Module] = lambda x: x, ) -> None: + # TODO: Add docstring + if features_only and output_dim_added_layer is not None: + raise ValueError( + "Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.", + ) super().__init__() self.device = device self.net = nn.Sequential( @@ -76,32 +82,33 @@ def __init__( nn.Flatten(), ) with torch.no_grad(): - self.output_dim = int(np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])) + base_cnn_output_dim = int(np.prod(self.net(torch.zeros(1, c, h, w)).shape[1:])) if not features_only: + action_dim = int(np.prod(action_shape)) self.net = nn.Sequential( self.net, - layer_init(nn.Linear(self.output_dim, 512)), + layer_init(nn.Linear(base_cnn_output_dim, 512)), nn.ReLU(inplace=True), - layer_init(nn.Linear(512, int(np.prod(action_shape)))), + layer_init(nn.Linear(512, action_dim)), ) - self.output_dim = np.prod(action_shape) - elif output_dim is not None: + self.output_dim = action_dim + elif output_dim_added_layer is not None: self.net = nn.Sequential( self.net, - layer_init(nn.Linear(self.output_dim, output_dim)), + layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)), nn.ReLU(inplace=True), ) - self.output_dim = output_dim + else: + self.output_dim = base_cnn_output_dim def forward( self, obs: np.ndarray | torch.Tensor, state: Any | None = None, info: dict[str, Any] | None = None, + **kwargs: Any, ) -> tuple[torch.Tensor, Any]: r"""Mapping: s -> Q(s, \*).""" - if info is None: - info = {} obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32) return self.net(obs), state @@ -122,7 +129,7 @@ def __init__( num_atoms: int = 51, device: str | int | torch.device = "cpu", ) -> None: - self.action_num = np.prod(action_shape) + self.action_num = int(np.prod(action_shape)) super().__init__(c, h, w, [self.action_num * num_atoms], device) self.num_atoms = num_atoms @@ -131,10 +138,9 @@ def forward( obs: np.ndarray | torch.Tensor, state: Any | None = None, info: dict[str, Any] | None = None, + **kwargs: Any, ) -> tuple[torch.Tensor, Any]: r"""Mapping: x -> Z(x, \*).""" - if info is None: - info = {} obs, state = super().forward(obs) obs = obs.view(-1, self.num_atoms).softmax(dim=-1) obs = obs.view(-1, self.action_num, self.num_atoms) @@ -161,10 +167,10 @@ def __init__( is_noisy: bool = True, ) -> None: super().__init__(c, h, w, action_shape, device, features_only=True) - self.action_num = np.prod(action_shape) + self.action_num = int(np.prod(action_shape)) self.num_atoms = num_atoms - def linear(x, y): + def linear(x: int, y: int) -> NoisyLinear | nn.Linear: if is_noisy: return NoisyLinear(x, y, noisy_std) return nn.Linear(x, y) @@ -188,10 +194,9 @@ def forward( obs: np.ndarray | torch.Tensor, state: Any | None = None, info: dict[str, Any] | None = None, + **kwargs: Any, ) -> tuple[torch.Tensor, Any]: r"""Mapping: x -> Z(x, \*).""" - if info is None: - info = {} obs, state = super().forward(obs) q = self.Q(obs) q = q.view(-1, self.action_num, self.num_atoms) @@ -214,14 +219,15 @@ class QRDQN(DQN): def __init__( self, + *, c: int, h: int, w: int, - action_shape: Sequence[int], + action_shape: Sequence[int] | int, num_quantiles: int = 200, device: str | int | torch.device = "cpu", ) -> None: - self.action_num = np.prod(action_shape) + self.action_num = int(np.prod(action_shape)) super().__init__(c, h, w, [self.action_num * num_quantiles], device) self.num_quantiles = num_quantiles @@ -230,10 +236,9 @@ def forward( obs: np.ndarray | torch.Tensor, state: Any | None = None, info: dict[str, Any] | None = None, + **kwargs: Any, ) -> tuple[torch.Tensor, Any]: r"""Mapping: x -> Z(x, \*).""" - if info is None: - info = {} obs, state = super().forward(obs) obs = obs.view(-1, self.action_num, self.num_quantiles) return obs, state @@ -242,21 +247,28 @@ def forward( class ActorFactoryAtariDQN(ActorFactory): def __init__( self, - hidden_size: int | Sequence[int], - scale_obs: bool, - features_only: bool, + scale_obs: bool = True, + features_only: bool = False, + output_dim_added_layer: int | None = None, ) -> None: - self.hidden_size = hidden_size + self.output_dim_added_layer = output_dim_added_layer self.scale_obs = scale_obs self.features_only = features_only def create_module(self, envs: Environments, device: TDevice) -> Actor: + c, h, w = envs.get_observation_shape() # type: ignore # only right shape is a sequence of length 3 + action_shape = envs.get_action_shape() + if isinstance(action_shape, np.int64): + action_shape = int(action_shape) + net: DQN | ScaledObsInputModule net = DQN( - *envs.get_observation_shape(), - envs.get_action_shape(), + c=c, + h=h, + w=w, + action_shape=action_shape, device=device, features_only=self.features_only, - output_dim=self.hidden_size, + output_dim_added_layer=self.output_dim_added_layer, layer_init=layer_init, ) if self.scale_obs: @@ -270,9 +282,19 @@ def __init__(self, features_only: bool = False, net_only: bool = False) -> None: self.net_only = net_only def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: + obs_shape = envs.get_observation_shape() + if isinstance(obs_shape, int): + obs_shape = [obs_shape] + assert len(obs_shape) == 3 + c, h, w = obs_shape + action_shape = envs.get_action_shape() + if isinstance(action_shape, np.int64): + action_shape = int(action_shape) dqn = DQN( - *envs.get_observation_shape(), - envs.get_action_shape(), + c=c, + h=h, + w=w, + action_shape=action_shape, device=device, features_only=self.features_only, ).to(device) diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 86f54d4d7..612b54008 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -8,11 +8,11 @@ import torch from atari_network import DQN, layer_init, scale_obs from atari_wrapper import make_atari_env -from torch.distributions import Categorical, Distribution +from torch.distributions import Categorical from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -115,7 +115,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: args.action_shape, device=args.device, features_only=True, - output_dim=args.hidden_size, + output_dim_added_layer=args.hidden_size, layer_init=layer_init, ) if args.scale_obs: @@ -131,15 +131,11 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - # define policy - def dist(logits: torch.Tensor) -> Distribution: - return Categorical(logits=logits) - policy: PPOPolicy = PPOPolicy( actor=actor, critic=critic, optim=optim, - dist_fn=dist, + dist_fn=Categorical, discount_factor=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, @@ -167,7 +163,7 @@ def dist(logits: torch.Tensor) -> Distribution: device=args.device, ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy = ICMPolicy( + policy: ICMPolicy = ICMPolicy( # type: ignore[no-redef] policy=policy, model=icm_net, optim=icm_optim, @@ -200,6 +196,7 @@ def dist(logits: torch.Tensor) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -252,8 +249,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_ppo_hl.py b/examples/atari/atari_ppo_hl.py index b492b9c84..53393b05a 100644 --- a/examples/atari/atari_ppo_hl.py +++ b/examples/atari/atari_ppo_hl.py @@ -34,7 +34,7 @@ def main( step_per_collect: int = 1000, repeat_per_collect: int = 4, batch_size: int = 256, - hidden_sizes: int | Sequence[int] = 512, + hidden_sizes: Sequence[int] = (512,), training_num: int = 10, test_num: int = 10, rew_norm: bool = False, @@ -93,7 +93,7 @@ def main( else None, ), ) - .with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True)) + .with_actor_factory(ActorFactoryAtariDQN(scale_obs=scale_obs, features_only=True)) .with_critic_factory_use_actor() .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) @@ -101,7 +101,7 @@ def main( builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), - hidden_sizes=[hidden_sizes], + hidden_sizes=hidden_sizes, lr=lr, lr_scale=icm_lr_scale, reward_scale=icm_reward_scale, diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index cef1c4247..7d6330ee1 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -9,8 +9,8 @@ from atari_network import QRDQN from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import QRDQNPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -82,7 +82,15 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # define model - net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) + c, h, w = args.state_shape + net = QRDQN( + c=c, + h=h, + w=w, + action_shape=args.action_shape, + num_quantiles=args.num_quantiles, + device=args.device, + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy: QRDQNPolicy = QRDQNPolicy( @@ -118,6 +126,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -178,8 +187,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index dbea00688..86e7fe0e1 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -9,8 +9,8 @@ from atari_network import Rainbow from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51Policy, RainbowPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -152,6 +152,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -222,8 +223,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index 7dc60c0e8..d5edf1a9a 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -9,8 +9,8 @@ from atari_network import DQN from atari_wrapper import make_atari_env -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteSACPolicy, ICMPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -108,7 +108,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: args.action_shape, device=args.device, features_only=True, - output_dim=args.hidden_size, + output_dim_added_layer=args.hidden_size, ) actor = Actor(net, args.action_shape, device=args.device, softmax_output=False) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) @@ -124,7 +124,8 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: DiscreteSACPolicy = DiscreteSACPolicy( + policy: DiscreteSACPolicy | ICMPolicy + policy = DiscreteSACPolicy( actor=actor, actor_optim=actor_optim, critic=critic1, @@ -182,6 +183,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -234,8 +236,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - print(f"Mean reward (over {result['n/ep']} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/atari/atari_sac_hl.py b/examples/atari/atari_sac_hl.py index 127156777..dd49f49a7 100644 --- a/examples/atari/atari_sac_hl.py +++ b/examples/atari/atari_sac_hl.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import os +from collections.abc import Sequence from examples.atari.atari_network import ( ActorFactoryAtariDQN, @@ -39,7 +40,7 @@ def main( step_per_collect: int = 10, update_per_step: float = 0.1, batch_size: int = 64, - hidden_size: int = 512, + hidden_sizes: Sequence[int] = (512,), training_num: int = 10, test_num: int = 10, frames_stack: int = 4, @@ -80,7 +81,7 @@ def main( estimation_step=n_step, ), ) - .with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True)) + .with_actor_factory(ActorFactoryAtariDQN(scale_obs=False, features_only=True)) .with_common_critic_factory_use_actor() .with_epoch_stop_callback(AtariEpochStopCallback(task)) ) @@ -88,7 +89,7 @@ def main( builder.with_policy_wrapper_factory( PolicyWrapperFactoryIntrinsicCuriosity( feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(), - hidden_sizes=[hidden_size], + hidden_sizes=hidden_sizes, lr=actor_lr, lr_scale=icm_lr_scale, reward_scale=icm_reward_scale, diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index a2fdcca1f..db1b6b2c3 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -3,12 +3,14 @@ import logging import warnings from collections import deque +from typing import Any, SupportsFloat import cv2 import gymnasium as gym import numpy as np from gymnasium import Env +from tianshou.env import BaseVectorEnv from tianshou.highlevel.env import ( EnvFactoryRegistered, EnvMode, @@ -26,7 +28,7 @@ log = logging.getLogger(__name__) -def _parse_reset_result(reset_result): +def _parse_reset_result(reset_result: tuple) -> tuple[tuple, dict, bool]: contains_info = ( isinstance(reset_result, tuple) and len(reset_result) == 2 @@ -46,22 +48,20 @@ class NoopResetEnv(gym.Wrapper): :param int noop_max: the maximum value of no-ops to run. """ - def __init__(self, env, noop_max=30) -> None: + def __init__(self, env: gym.Env, noop_max: int = 30) -> None: super().__init__(env) self.noop_max = noop_max self.noop_action = 0 + assert hasattr(env.unwrapped, "get_action_meanings") assert env.unwrapped.get_action_meanings()[0] == "NOOP" - def reset(self, **kwargs): + def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: _, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) - if hasattr(self.unwrapped.np_random, "integers"): - noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) - else: - noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) + noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) for _ in range(noops): step_result = self.env.step(self.noop_action) if len(step_result) == 4: - obs, rew, done, info = step_result + obs, rew, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) else: obs, rew, term, trunc, info = step_result done = term or trunc @@ -69,7 +69,7 @@ def reset(self, **kwargs): obs, info, _ = _parse_reset_result(self.env.reset()) if return_info: return obs, info - return obs + return obs, {} class MaxAndSkipEnv(gym.Wrapper): @@ -79,34 +79,35 @@ class MaxAndSkipEnv(gym.Wrapper): :param int skip: number of `skip`-th frame. """ - def __init__(self, env, skip=4) -> None: + def __init__(self, env: gym.Env, skip: int = 4) -> None: super().__init__(env) self._skip = skip - def step(self, action): + def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: """Step the environment with the given action. Repeat action, sum reward, and max over last observations. """ - obs_list, total_reward = [], 0.0 + obs_list = [] + total_reward = 0.0 new_step_api = False for _ in range(self._skip): step_result = self.env.step(action) if len(step_result) == 4: - obs, reward, done, info = step_result + obs, reward, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) else: obs, reward, term, trunc, info = step_result done = term or trunc new_step_api = True obs_list.append(obs) - total_reward += reward + total_reward += float(reward) if done: break max_frame = np.max(obs_list[-2:], axis=0) if new_step_api: return max_frame, total_reward, term, trunc, info - return max_frame, total_reward, done, info + return max_frame, total_reward, done, info.get("TimeLimit.truncated", False), info class EpisodicLifeEnv(gym.Wrapper): @@ -117,25 +118,26 @@ class EpisodicLifeEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env) -> None: + def __init__(self, env: gym.Env) -> None: super().__init__(env) self.lives = 0 self.was_real_done = True self._return_info = False - def step(self, action): + def step(self, action: Any) -> tuple[Any, float, bool, bool, dict[str, Any]]: step_result = self.env.step(action) if len(step_result) == 4: - obs, reward, done, info = step_result + obs, reward, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) new_step_api = False else: obs, reward, term, trunc, info = step_result done = term or trunc new_step_api = True - + reward = float(reward) self.was_real_done = done # check current lives, make loss of life terminal, then update lives to # handle bonus lives + assert hasattr(self.env.unwrapped, "ale") lives = self.env.unwrapped.ale.lives() if 0 < lives < self.lives: # for Qbert sometimes we stay in lives == 0 condition for a few @@ -146,9 +148,9 @@ def step(self, action): self.lives = lives if new_step_api: return obs, reward, term, trunc, info - return obs, reward, done, info + return obs, reward, done, info.get("TimeLimit.truncated", False), info - def reset(self, **kwargs): + def reset(self, **kwargs: Any) -> tuple[Any, dict[str, Any]]: """Calls the Gym environment reset, only when lives are exhausted. This way all states are still reachable even though lives are episodic, and @@ -160,10 +162,11 @@ def reset(self, **kwargs): # no-op step to advance from terminal/lost life state step_result = self.env.step(0) obs, info = step_result[0], step_result[-1] + assert hasattr(self.env.unwrapped, "ale") self.lives = self.env.unwrapped.ale.lives() if self._return_info: return obs, info - return obs + return obs, {} class FireResetEnv(gym.Wrapper): @@ -174,15 +177,16 @@ class FireResetEnv(gym.Wrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env) -> None: + def __init__(self, env: gym.Env) -> None: super().__init__(env) + assert hasattr(env.unwrapped, "get_action_meanings") assert env.unwrapped.get_action_meanings()[1] == "FIRE" assert len(env.unwrapped.get_action_meanings()) >= 3 - def reset(self, **kwargs): + def reset(self, **kwargs: Any) -> tuple[Any, dict]: _, _, return_info = _parse_reset_result(self.env.reset(**kwargs)) obs = self.env.step(1)[0] - return (obs, {}) if return_info else obs + return obs, {} class WarpFrame(gym.ObservationWrapper): @@ -191,17 +195,24 @@ class WarpFrame(gym.ObservationWrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env) -> None: + def __init__(self, env: gym.Env) -> None: super().__init__(env) self.size = 84 + obs_space = env.observation_space + obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]] + if np.issubdtype(type(obs_space.dtype), np.integer): + obs_space_dtype = np.integer + elif np.issubdtype(type(obs_space.dtype), np.floating): + obs_space_dtype = np.floating + assert isinstance(obs_space, gym.spaces.Box) self.observation_space = gym.spaces.Box( - low=np.min(env.observation_space.low), - high=np.max(env.observation_space.high), + low=np.min(obs_space.low), + high=np.max(obs_space.high), shape=(self.size, self.size), - dtype=env.observation_space.dtype, + dtype=obs_space_dtype, ) - def observation(self, frame): + def observation(self, frame: np.ndarray) -> np.ndarray: """Returns the current observation from a frame.""" frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA) @@ -213,20 +224,22 @@ class ScaledFloatFrame(gym.ObservationWrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env) -> None: + def __init__(self, env: gym.Env) -> None: super().__init__(env) - low = np.min(env.observation_space.low) - high = np.max(env.observation_space.high) + obs_space = env.observation_space + assert isinstance(obs_space, gym.spaces.Box) + low = np.min(obs_space.low) + high = np.max(obs_space.high) self.bias = low self.scale = high - low self.observation_space = gym.spaces.Box( low=0.0, high=1.0, - shape=env.observation_space.shape, + shape=obs_space.shape, dtype=np.float32, ) - def observation(self, observation): + def observation(self, observation: np.ndarray) -> np.ndarray: return (observation - self.bias) / self.scale @@ -236,13 +249,13 @@ class ClipRewardEnv(gym.RewardWrapper): :param gym.Env env: the environment to wrap. """ - def __init__(self, env) -> None: + def __init__(self, env: gym.Env) -> None: super().__init__(env) self.reward_range = (-1, 1) - def reward(self, reward): + def reward(self, reward: SupportsFloat) -> int: """Bin reward to {+1, 0, -1} by its sign. Note: np.sign(0) == 0.""" - return np.sign(reward) + return np.sign(float(reward)) class FrameStack(gym.Wrapper): @@ -252,50 +265,69 @@ class FrameStack(gym.Wrapper): :param int n_frames: the number of frames to stack. """ - def __init__(self, env, n_frames) -> None: + def __init__(self, env: gym.Env, n_frames: int) -> None: super().__init__(env) - self.n_frames = n_frames - self.frames = deque([], maxlen=n_frames) - shape = (n_frames, *env.observation_space.shape) + self.n_frames: int = n_frames + self.frames: deque[tuple[Any, ...]] = deque([], maxlen=n_frames) + obs_space = env.observation_space + obs_space_shape = env.observation_space.shape + assert obs_space_shape is not None + shape = (n_frames, *obs_space_shape) + assert isinstance(env.observation_space, gym.spaces.Box) + obs_space_dtype: type[np.floating[Any]] | type[np.integer[Any]] + if np.issubdtype(type(obs_space.dtype), np.integer): + obs_space_dtype = np.integer + elif np.issubdtype(type(obs_space.dtype), np.floating): + obs_space_dtype = np.floating self.observation_space = gym.spaces.Box( low=np.min(env.observation_space.low), high=np.max(env.observation_space.high), shape=shape, - dtype=env.observation_space.dtype, + dtype=obs_space_dtype, ) - def reset(self, **kwargs): + def reset(self, **kwargs: Any) -> tuple[np.ndarray, dict]: obs, info, return_info = _parse_reset_result(self.env.reset(**kwargs)) for _ in range(self.n_frames): self.frames.append(obs) - return (self._get_ob(), info) if return_info else self._get_ob() + return (self._get_ob(), info) if return_info else (self._get_ob(), {}) - def step(self, action): + def step(self, action: Any) -> tuple[np.ndarray, float, bool, bool, dict[str, Any]]: step_result = self.env.step(action) + done: bool if len(step_result) == 4: - obs, reward, done, info = step_result + obs, reward, done, info = step_result # type: ignore[unreachable] # mypy doesn't know that Gym version <0.26 has only 4 items (no truncation) new_step_api = False else: obs, reward, term, trunc, info = step_result new_step_api = True self.frames.append(obs) + reward = float(reward) if new_step_api: return self._get_ob(), reward, term, trunc, info - return self._get_ob(), reward, done, info + return self._get_ob(), reward, done, info.get("TimeLimit.truncated", False), info - def _get_ob(self): + def _get_ob(self) -> np.ndarray: # the original wrapper use `LazyFrames` but since we use np buffer, # it has no effect return np.stack(self.frames, axis=0) def wrap_deepmind( - env: Env, - episode_life=True, - clip_rewards=True, - frame_stack=4, - scale=False, - warp_frame=True, + env: gym.Env, + episode_life: bool = True, + clip_rewards: bool = True, + frame_stack: int = 4, + scale: bool = False, + warp_frame: bool = True, +) -> ( + MaxAndSkipEnv + | EpisodicLifeEnv + | FireResetEnv + | WarpFrame + | ScaledFloatFrame + | ClipRewardEnv + | FrameStack ): """Configure environment for DeepMind-style Atari. @@ -311,29 +343,34 @@ def wrap_deepmind( """ env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) + assert hasattr(env.unwrapped, "get_action_meanings") # for mypy + + wrapped_env: MaxAndSkipEnv | EpisodicLifeEnv | FireResetEnv | WarpFrame | ScaledFloatFrame | ClipRewardEnv | FrameStack = ( + env + ) if episode_life: - env = EpisodicLifeEnv(env) + wrapped_env = EpisodicLifeEnv(wrapped_env) if "FIRE" in env.unwrapped.get_action_meanings(): - env = FireResetEnv(env) + wrapped_env = FireResetEnv(wrapped_env) if warp_frame: - env = WarpFrame(env) + wrapped_env = WarpFrame(wrapped_env) if scale: - env = ScaledFloatFrame(env) + wrapped_env = ScaledFloatFrame(wrapped_env) if clip_rewards: - env = ClipRewardEnv(env) + wrapped_env = ClipRewardEnv(wrapped_env) if frame_stack: - env = FrameStack(env, frame_stack) - return env + wrapped_env = FrameStack(wrapped_env, frame_stack) + return wrapped_env def make_atari_env( - task, - seed, - training_num, - test_num, + task: str, + seed: int, + training_num: int, + test_num: int, scale: int | bool = False, frame_stack: int = 4, -): +) -> tuple[Env, BaseVectorEnv, BaseVectorEnv]: """Wrapper function for Atari env. If EnvPool is installed, it will automatically switch to EnvPool's Atari env. @@ -360,7 +397,7 @@ def __init__( envpool_factory = None if use_envpool_if_available: if envpool_is_available: - envpool_factory = self.EnvPoolFactory(self) + envpool_factory = self.EnvPoolFactoryAtari(self) log.info("Using envpool, because it available") else: log.info("Not using envpool, because it is not available") @@ -371,7 +408,7 @@ def __init__( envpool_factory=envpool_factory, ) - def create_env(self, mode: EnvMode) -> Env: + def create_env(self, mode: EnvMode) -> gym.Env: env = super().create_env(mode) is_train = mode == EnvMode.TRAIN return wrap_deepmind( @@ -382,7 +419,7 @@ def create_env(self, mode: EnvMode) -> Env: scale=self.scale, ) - class EnvPoolFactory(EnvPoolFactory): + class EnvPoolFactoryAtari(EnvPoolFactory): """Atari-specific envpool creation. Since envpool internally handles the functions that are implemented through the wrappers in `wrap_deepmind`, it sets the creation keyword arguments accordingly. @@ -416,7 +453,7 @@ def __init__(self, task: str) -> None: def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool: env = context.envs.env - if env.spec.reward_threshold: + if env.spec and env.spec.reward_threshold: return mean_rewards >= env.spec.reward_threshold if "Pong" in self.task: return mean_rewards >= 20 diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index ae5223a01..ad53b16da 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -51,7 +50,7 @@ def get_args() -> argparse.Namespace: def test_dqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape @@ -69,8 +68,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, dueling_param=(Q_param, V_param), diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 16dfaf097..f52f6d5c1 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -35,7 +35,7 @@ def get_args() -> argparse.Namespace: parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--target-update-freq", type=int, default=1000) - parser.add_argument("--epoch", type=int, default=1000) + parser.add_argument("--epoch", type=int, default=25) parser.add_argument("--step-per-epoch", type=int, default=80000) parser.add_argument("--step-per-collect", type=int, default=16) parser.add_argument("--update-per-step", type=float, default=0.0625) @@ -57,11 +57,14 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.num_branches = ( - args.action_shape if isinstance(args.action_shape, int) else args.action_shape[0] - ) + assert isinstance(env.action_space, gym.spaces.MultiDiscrete) + assert isinstance( + env.observation_space, + gym.spaces.Box, + ) # BipedalWalker-v3 has `Box` observation space by design + args.state_shape = env.observation_space.shape + args.action_shape = env.action_space.shape + args.num_branches = args.action_shape[0] print("Observations shape:", args.state_shape) print("Num branches:", args.num_branches) @@ -98,11 +101,11 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: device=args.device, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy = BranchingDQNPolicy( + policy: BranchingDQNPolicy = BranchingDQNPolicy( model=net, optim=optim, discount_factor=args.gamma, - action_space=env.action_space, + action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? target_update_freq=args.target_update_freq, ) # collector @@ -125,7 +128,9 @@ def save_best_fn(policy: BasePolicy) -> None: torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) def stop_fn(mean_rewards: float) -> bool: - return mean_rewards >= getattr(env.spec.reward_threshold) + if env.spec and env.spec.reward_threshold: + return mean_rewards >= env.spec.reward_threshold + return False def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) @@ -145,14 +150,14 @@ def test_fn(epoch: int, env_step: int | None) -> None: episode_per_test=args.test_num, batch_size=args.batch_size, update_per_step=args.update_per_step, - # stop_fn=stop_fn, + stop_fn=stop_fn, train_fn=train_fn, test_fn=test_fn, save_best_fn=save_best_fn, logger=logger, ).run() - # assert stop_fn(result.best_reward) + assert stop_fn(result.best_reward) if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index e7186915b..2c071bc1c 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -108,13 +108,18 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb(net_a, args.action_shape, device=args.device, unbounded=True).to(args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + preprocess_net=net_a, + action_shape=args.action_shape, + device=args.device, + unbounded=True, + ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, @@ -123,8 +128,8 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 46e98cc62..47ba9d102 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -14,6 +14,7 @@ from tianshou.trainer import OffpolicyTrainer from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net +from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: @@ -50,9 +51,11 @@ def get_args() -> argparse.Namespace: def test_dqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + assert isinstance(env.action_space, gym.spaces.Discrete) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + args.max_action = space_info.action_info.max_action # train_envs = gym.make(args.task) # you can also use tianshou.env.SubprocVectorEnv train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) @@ -67,8 +70,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: Q_param = {"hidden_sizes": args.dueling_q_hidden_sizes} V_param = {"hidden_sizes": args.dueling_v_hidden_sizes} net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, dueling_param=(Q_param, V_param), diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 858c70834..5c093093e 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -66,12 +66,12 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb(net, args.action_shape, device=args.device, unbounded=True).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, @@ -79,8 +79,8 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/examples/common.py b/examples/common.py deleted file mode 100644 index a86115c2a..000000000 --- a/examples/common.py +++ /dev/null @@ -1,3 +0,0 @@ -from tianshou.highlevel.logger import LoggerFactoryDefault - -logger_factory = LoggerFactoryDefault() diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 2e9697adb..4f1a82b12 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -3,6 +3,7 @@ from torch.utils.tensorboard import SummaryWriter import tianshou as ts +from tianshou.utils.space_info import SpaceInfo def main() -> None: @@ -26,8 +27,10 @@ def main() -> None: # Note: You can easily define other networks. # See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network env = gym.make(task, render_mode="human") - state_shape = env.observation_space.shape or env.observation_space.n - action_shape = env.action_space.shape or env.action_space.n + assert isinstance(env.action_space, gym.spaces.Discrete) + space_info = SpaceInfo.from_env(env) + state_shape = space_info.observation_info.obs_shape + action_shape = space_info.action_info.action_shape net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128]) optim = torch.optim.Adam(net.parameters(), lr=lr) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 3ee3709bd..2d013a01b 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -120,7 +120,12 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: activation=nn.Tanh, device=args.device, ) - actor = ActorProb(net_a, args.action_shape, unbounded=True, device=args.device).to(args.device) + actor = ActorProb( + preprocess_net=net_a, + action_shape=args.action_shape, + unbounded=True, + device=args.device, + ).to(args.device) net_c = Net( args.state_shape, hidden_sizes=args.hidden_sizes, diff --git a/examples/mujoco/analysis.py b/examples/mujoco/analysis.py index 9c2983e04..b881cdd34 100755 --- a/examples/mujoco/analysis.py +++ b/examples/mujoco/analysis.py @@ -3,13 +3,14 @@ import argparse import re from collections import defaultdict +from os import PathLike import numpy as np from tabulate import tabulate from tools import csv2numpy, find_all_files, group_files -def numerical_analysis(root_dir, xlim, norm=False): +def numerical_analysis(root_dir: str | PathLike, xlim: float, norm: bool = False) -> None: file_pattern = re.compile(r".*/test_reward_\d+seeds.csv$") norm_group_pattern = re.compile(r"(/|^)\w+?\-v(\d|$)") output_group_pattern = re.compile(r".*?(?=(/|^)\w+?\-v\d)") diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 0804b1a34..be6594aa7 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -9,7 +9,6 @@ import gymnasium as gym import numpy as np import torch -from torch.utils.tensorboard import SummaryWriter from tianshou.data import ( @@ -19,14 +18,16 @@ ReplayBuffer, VectorReplayBuffer, ) +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated from tianshou.exploration import GaussianNoise from tianshou.policy import DDPGPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils.net.common import Net, get_dict_state_decorator from tianshou.utils.net.continuous import Actor, Critic +from tianshou.env.venvs import BaseVectorEnv +from tianshou.utils.space_info import ActionSpaceInfo def get_args() -> argparse.Namespace: @@ -77,7 +78,11 @@ def get_args() -> argparse.Namespace: return parser.parse_args() -def make_fetch_env(task, training_num, test_num): +def make_fetch_env( + task: str, + training_num: int, + test_num: int, +) -> tuple[gym.Env, BaseVectorEnv, BaseVectorEnv]: env = TruncatedAsTerminated(gym.make(task)) train_envs = ShmemVectorEnv( [lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(training_num)], @@ -96,33 +101,44 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=args.resume_id, - config=args, - project=args.wandb_project, - ) - writer = SummaryWriter(log_path) - writer.add_text("args", str(args)) - if args.logger == "tensorboard": - logger = TensorboardLogger(writer) - else: # wandb - logger.load(writer) + logger_factory.logger_type = "wandb" + logger_factory.wandb_project = args.wandb_project + else: + logger_factory.logger_type = "tensorboard" + + logger = logger_factory.create_logger( + log_dir=log_path, + experiment_name=log_name, + run_id=args.resume_id, + config_dict=vars(args), + ) env, train_envs, test_envs = make_fetch_env(args.task, args.training_num, args.test_num) + # The method HER works with goal-based environments + if not isinstance(env.observation_space, gym.spaces.Dict): + raise ValueError( + "`env.observation_space` must be of type `gym.spaces.Dict`. Make sure you're using a goal-based environment like `FetchReach-v2`.", + ) + if not hasattr(env, "compute_reward"): + raise ValueError( + "Atrribute `compute_reward` not found in `env`. " + "HER-based algorithms typically require this attribute. Make sure you're using a goal-based environment like `FetchReach-v2`.", + ) args.state_shape = { "observation": env.observation_space["observation"].shape, "achieved_goal": env.observation_space["achieved_goal"].shape, "desired_goal": env.observation_space["desired_goal"].shape, } - args.action_shape = env.action_space.shape or env.action_space.n - args.max_action = env.action_space.high[0] + action_info = ActionSpaceInfo.from_space(env.action_space) + args.action_shape = action_info.action_shape + args.max_action = action_info.max_action + args.exploration_noise = args.exploration_noise * args.max_action print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) - print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high)) + print("Action range:", action_info.min_action, action_info.max_action) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -170,7 +186,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: print("Loaded agent from: ", args.resume_path) # collector - def compute_reward_fn(ag: np.ndarray, g: np.ndarray): + def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: return env.compute_reward(ag, g, {}) buffer: VectorReplayBuffer | ReplayBuffer | HERReplayBuffer | HERVectorReplayBuffer @@ -225,7 +241,7 @@ def save_best_fn(policy: BasePolicy) -> None: test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) - print(collector_stats) + collector_stats.pprint_asdict() if __name__ == "__main__": diff --git a/examples/mujoco/gen_json.py b/examples/mujoco/gen_json.py index b41b06b15..54ee3bb90 100755 --- a/examples/mujoco/gen_json.py +++ b/examples/mujoco/gen_json.py @@ -4,9 +4,10 @@ import json import os import sys +from os import PathLike -def merge(rootdir): +def merge(rootdir: str | PathLike[str]) -> None: """format: $rootdir/$algo/*.csv.""" result = [] for path, _, filenames in os.walk(rootdir): diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 6caac9898..ea6ab8f24 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -12,8 +12,8 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import A2CPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -182,6 +182,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index fec2e264f..187208757 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -30,7 +30,7 @@ def main( step_per_epoch: int = 30000, step_per_collect: int = 80, repeat_per_collect: int = 1, - batch_size: int | None = None, + batch_size: int = 16, training_num: int = 16, test_num: int = 10, rew_norm: bool = True, diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 04fc0109b..ceac47604 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -9,9 +9,9 @@ import torch from mujoco_env import make_mujoco_env -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DDPGPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -83,14 +83,14 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net_a, args.action_shape, max_action=args.max_action, device=args.device).to( args.device, ) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, @@ -115,6 +115,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: print("Loaded agent from: ", args.resume_path) # collector + buffer: VectorReplayBuffer | ReplayBuffer if args.training_num > 1: buffer = VectorReplayBuffer(args.buffer_size, len(train_envs)) else: @@ -130,6 +131,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index dacf91548..fa972ac31 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -1,6 +1,8 @@ import logging import pickle +from gymnasium import Env + from tianshou.env import BaseVectorEnv, VectorEnvNormObs from tianshou.highlevel.env import ( ContinuousEnvironments, @@ -22,7 +24,13 @@ log = logging.getLogger(__name__) -def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool): +def make_mujoco_env( + task: str, + seed: int, + num_train_envs: int, + num_test_envs: int, + obs_norm: bool, +) -> tuple[Env, BaseVectorEnv, BaseVectorEnv]: """Wrapper function for Mujoco env. If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env. @@ -41,16 +49,16 @@ class MujocoEnvObsRmsPersistence(Persistence): def persist(self, event: PersistEvent, world: World) -> None: if event != PersistEvent.PERSIST_POLICY: - return + return # type: ignore[unreachable] # since PersistEvent has only one member, mypy infers that line is unreachable obs_rms = world.envs.train_envs.get_obs_rms() path = world.persist_path(self.FILENAME) log.info(f"Saving environment obs_rms value to {path}") with open(path, "wb") as f: pickle.dump(obs_rms, f) - def restore(self, event: RestoreEvent, world: World): + def restore(self, event: RestoreEvent, world: World) -> None: if event != RestoreEvent.RESTORE_POLICY: - return + return # type: ignore[unreachable] path = world.restore_path(self.FILENAME) log.info(f"Restoring environment obs_rms value from {path}") with open(path, "rb") as f: diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index e8ee97cae..8a379da92 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -12,8 +12,8 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import NPGPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -179,6 +179,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 6ab0eb891..18360f779 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -32,7 +32,7 @@ def main( step_per_epoch: int = 30000, step_per_collect: int = 1024, repeat_per_collect: int = 1, - batch_size: int | None = None, + batch_size: int = 16, training_num: int = 16, test_num: int = 10, rew_norm: bool = True, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 218b95d07..00042884f 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -12,8 +12,8 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import PPOPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -187,6 +187,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 66c9f7db6..b300e498a 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -9,8 +9,8 @@ import torch from mujoco_env import make_mujoco_env -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import REDQPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -86,7 +86,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( net_a, args.action_shape, @@ -96,12 +96,12 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - def linear(x, y): + def linear(x: int, y: int) -> EnsembleLinear: return EnsembleLinear(args.ensemble_size, x, y) net_c = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, @@ -159,6 +159,7 @@ def linear(x, y): log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 06e2bc173..109f1cc46 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -12,8 +12,8 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import PGPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -159,6 +159,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index a09118979..2058a71e9 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -9,8 +9,8 @@ import torch from mujoco_env import make_mujoco_env -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import SACPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -83,7 +83,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( net_a, args.action_shape, @@ -93,15 +93,15 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, @@ -153,6 +153,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 057905c64..30e7539c1 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -9,9 +9,9 @@ import torch from mujoco_env import make_mujoco_env -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer from tianshou.exploration import GaussianNoise +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TD3Policy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -88,21 +88,21 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net_a, args.action_shape, max_action=args.max_action, device=args.device).to( args.device, ) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, @@ -151,6 +151,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index c17ba6c14..219593343 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -12,8 +12,8 @@ from torch.distributions import Distribution, Independent, Normal from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import TRPOPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -184,6 +184,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 2f9a77748..f113645b1 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -32,7 +32,7 @@ def main( step_per_epoch: int = 30000, step_per_collect: int = 1024, repeat_per_collect: int = 1, - batch_size: int | None = None, + batch_size: int = 16, training_num: int = 16, test_num: int = 10, rew_norm: bool = True, diff --git a/examples/mujoco/plotter.py b/examples/mujoco/plotter.py index 60840decf..5e2f9e016 100755 --- a/examples/mujoco/plotter.py +++ b/examples/mujoco/plotter.py @@ -3,6 +3,7 @@ import argparse import os import re +from typing import Any, Literal import matplotlib.pyplot as plt import matplotlib.ticker as mticker @@ -10,7 +11,12 @@ from tools import csv2numpy, find_all_files, group_files -def smooth(y, radius, mode="two_sided", valid_only=False): +def smooth( + y: np.ndarray, + radius: int, + mode: Literal["two_sided", "causal"] = "two_sided", + valid_only: bool = False, +) -> np.ndarray: """Smooth signal y, where radius is determines the size of the window. mode='twosided': @@ -19,7 +25,6 @@ def smooth(y, radius, mode="two_sided", valid_only=False): average over the window [max(index - radius, 0), index] valid_only: put nan in entries where the full-sized window is not available """ - assert mode in ("two_sided", "causal") if len(y) < 2 * radius + 1: return np.ones_like(y) * y.mean() if mode == "two_sided": @@ -88,23 +93,25 @@ def smooth(y, radius, mode="two_sided", valid_only=False): def plot_ax( - ax, - file_lists, - legend_pattern=".*", - xlabel=None, - ylabel=None, - title=None, - xlim=None, - xkey="env_step", - ykey="reward", - smooth_radius=0, - shaded_std=True, - legend_outside=False, -): - def legend_fn(x): + ax: plt.Axes, + file_lists: list[str], + legend_pattern: str = ".*", + xlabel: str | None = None, + ylabel: str | None = None, + title: str = "", + xlim: float | None = None, + xkey: str = "env_step", + ykey: str = "reward", + smooth_radius: int = 0, + shaded_std: bool = True, + legend_outside: bool = False, +) -> None: + def legend_fn(x: str) -> str: # return os.path.split(os.path.join( # args.root_dir, x))[0].replace('/', '_') + " (10)" - return re.search(legend_pattern, x).group(0) + match = re.search(legend_pattern, x) + assert match is not None # for mypy + return match.group(0) legneds = map(legend_fn, file_lists) # sort filelist according to legends @@ -139,15 +146,15 @@ def legend_fn(x): def plot_figure( - file_lists, - group_pattern=None, - fig_length=6, - fig_width=6, - sharex=False, - sharey=False, - title=None, - **kwargs, -): + file_lists: list[str], + group_pattern: str | None = None, + fig_length: int = 6, + fig_width: int = 6, + sharex: bool = False, + sharey: bool = False, + title: str = "", + **kwargs: Any, +) -> None: if not group_pattern: fig, ax = plt.subplots(figsize=(fig_length, fig_width)) plot_ax(ax, file_lists, title=title, **kwargs) diff --git a/examples/mujoco/tools.py b/examples/mujoco/tools.py index be289e33a..e0db8162b 100755 --- a/examples/mujoco/tools.py +++ b/examples/mujoco/tools.py @@ -5,13 +5,16 @@ import os import re from collections import defaultdict +from os import PathLike +from re import Pattern +from typing import Any import numpy as np import tqdm from tensorboard.backend.event_processing import event_accumulator -def find_all_files(root_dir, pattern): +def find_all_files(root_dir: str | PathLike[str], pattern: str | Pattern[str]) -> list: """Find all files under root_dir according to relative pattern.""" file_list = [] for dirname, _, files in os.walk(root_dir): @@ -22,7 +25,7 @@ def find_all_files(root_dir, pattern): return file_list -def group_files(file_list, pattern): +def group_files(file_list: list[str], pattern: str | Pattern[str]) -> dict[str, list]: res = defaultdict(list) for f in file_list: match = re.search(pattern, f) @@ -31,7 +34,7 @@ def group_files(file_list, pattern): return res -def csv2numpy(csv_file): +def csv2numpy(csv_file: str) -> dict[Any, np.ndarray]: csv_dict = defaultdict(list) with open(csv_file) as f: for row in csv.DictReader(f): @@ -40,7 +43,10 @@ def csv2numpy(csv_file): return {k: np.array(v) for k, v in csv_dict.items()} -def convert_tfevents_to_csv(root_dir, refresh=False): +def convert_tfevents_to_csv( + root_dir: str | PathLike[str], + refresh: bool = False, +) -> dict[str, list]: """Recursively convert test/reward from all tfevent file under root_dir to csv. This function assumes that there is at most one tfevents file in each directory @@ -81,7 +87,11 @@ def convert_tfevents_to_csv(root_dir, refresh=False): return result -def merge_csv(csv_files, root_dir, remove_zero=False): +def merge_csv( + csv_files: dict[str, list], + root_dir: str | PathLike[str], + remove_zero: bool = False, +) -> None: """Merge result in csv_files into a single csv file.""" assert len(csv_files) > 0 if remove_zero: diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 766662a07..1fc0dc7e3 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -9,12 +9,13 @@ import numpy as np import torch +from gymnasium.spaces import Discrete from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env -from examples.common import logger_factory from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteBCQPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer @@ -82,8 +83,9 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + assert isinstance(env.action_space, Discrete) + args.state_shape = env.observation_space.shape + args.action_shape = int(env.action_space.n) # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -91,8 +93,13 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model + assert args.state_shape is not None + assert len(args.state_shape) == 3 + c, h, w = args.state_shape feature_net = DQN( - *args.state_shape, + c, + h, + w, args.action_shape, device=args.device, features_only=True, @@ -157,6 +164,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -185,9 +193,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - pprint.pprint(result) - rew = result.returns_stat.mean - print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 5f1afcdcd..40d91c1bb 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -6,18 +6,21 @@ import pickle import pprint import sys +from collections.abc import Sequence import numpy as np import torch +from gymnasium.spaces import Discrete from examples.atari.atari_network import QRDQN from examples.atari.atari_wrapper import make_atari_env -from examples.common import logger_factory from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteCQLPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer +from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: @@ -80,8 +83,13 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + assert isinstance(env.action_space, Discrete) + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + assert isinstance(args.state_shape, Sequence) + assert len(args.state_shape) == 3, "state shape must have only 3 dimensions." + c, h, w = args.state_shape # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -89,7 +97,14 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net = QRDQN(*args.state_shape, args.action_shape, args.num_quantiles, args.device) + net = QRDQN( + c=c, + h=h, + w=w, + action_shape=args.action_shape, + num_quantiles=args.num_quantiles, + device=args.device, + ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy: DiscreteCQLPolicy = DiscreteCQLPolicy( @@ -133,6 +148,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -161,9 +177,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - pprint.pprint(result) - rew = result.returns_stat.mean - print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 97622b6d5..a4b31c4fb 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -9,17 +9,19 @@ import numpy as np import torch +from gymnasium.spaces import Discrete from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env -from examples.common import logger_factory from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import DiscreteCRRPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.discrete import Actor, Critic +from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: @@ -82,8 +84,10 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + assert isinstance(env.action_space, Discrete) + space_info = SpaceInfo.from_env(env) + args.state_shape = env.observation_space.shape + args.action_shape = space_info.action_info.action_shape # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -91,8 +95,13 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model + assert args.state_shape is not None + assert len(args.state_shape) == 3 + c, h, w = args.state_shape feature_net = DQN( - *args.state_shape, + c, + h, + w, args.action_shape, device=args.device, features_only=True, @@ -107,7 +116,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: critic = Critic( feature_net, hidden_sizes=args.hidden_sizes, - last_size=np.prod(args.action_shape), + last_size=int(np.prod(args.action_shape)), device=args.device, ).to(args.device) actor_critic = ActorCritic(actor, critic) @@ -156,6 +165,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -183,9 +193,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - pprint.pprint(result) - rew = result.returns_stat.mean - print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 615d38ec0..bb7822ea9 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -12,12 +12,13 @@ from examples.atari.atari_network import DQN from examples.atari.atari_wrapper import make_atari_env -from examples.common import logger_factory from examples.offline.utils import load_buffer from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ImitationPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OfflineTrainer +from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: @@ -73,8 +74,12 @@ def test_il(args: argparse.Namespace = get_args()) -> None: scale=args.scale_obs, frame_stack=args.frames_stack, ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + space_info = SpaceInfo.from_env(env) + args.state_shape = space_info.observation_info.obs_shape + args.action_shape = space_info.action_info.action_shape + assert isinstance(args.state_shape, list[int] | tuple[int]) + assert len(args.state_shape) == 3 + c, h, w = args.state_shape # should be N_FRAMES x H x W print("Observations shape:", args.state_shape) print("Actions shape:", args.action_shape) @@ -82,7 +87,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net = DQN(*args.state_shape, args.action_shape, device=args.device).to(args.device) + net = DQN(c, h, w, args.action_shape, device=args.device).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) # define policy policy: ImitationPolicy = ImitationPolicy(actor=net, optim=optim, action_space=env.action_space) @@ -117,6 +122,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -144,9 +150,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - pprint.pprint(result) - rew = result.returns_stat.mean - print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/offline/convert_rl_unplugged_atari.py b/examples/offline/convert_rl_unplugged_atari.py index a28a35e5f..1afd721a5 100755 --- a/examples/offline/convert_rl_unplugged_atari.py +++ b/examples/offline/convert_rl_unplugged_atari.py @@ -28,10 +28,11 @@ clipping. """ import os -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace import h5py import numpy as np +import numpy.typing as npt import requests import tensorflow as tf from tqdm import tqdm @@ -172,7 +173,7 @@ def _tf_example_to_tianshou_batch(tf_example: tf.train.Example) -> Batch: # Adapted From https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 -def download(url: str, fname: str, chunk_size=1024): +def download(url: str, fname: str, chunk_size: int | None = 1024) -> None: resp = requests.get(url, stream=True) total = int(resp.headers.get("content-length", 0)) if os.path.exists(fname): @@ -192,11 +193,11 @@ def download(url: str, fname: str, chunk_size=1024): def process_shard(url: str, fname: str, ofname: str, maxsize: int = 500000) -> None: download(url, fname) - obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") - act = np.ndarray((maxsize,), dtype="int64") - rew = np.ndarray((maxsize,), dtype="float32") - done = np.ndarray((maxsize,), dtype="bool") - obs_next = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") + obs: npt.NDArray[np.uint8] = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") + act: npt.NDArray[np.int64] = np.ndarray((maxsize,), dtype="int64") + rew: npt.NDArray[np.float32] = np.ndarray((maxsize,), dtype="float32") + done: npt.NDArray[np.bool_] = np.ndarray((maxsize,), dtype="bool") + obs_next: npt.NDArray[np.uint8] = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") i = 0 file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP") for example in file_ds: @@ -238,7 +239,7 @@ def process_dataset( process_shard(url, filepath, ofname) -def main(args) -> None: +def main(args: Namespace) -> None: if args.task not in ALL_GAMES: raise KeyError(f"`{args.task}` is not in the list of games.") fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards) diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 7c275b555..80b233cb7 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -109,15 +109,15 @@ def test_bcq() -> None: actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 0e9fe62cd..7ca8ae2fb 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -4,7 +4,6 @@ import datetime import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -220,7 +219,7 @@ def get_args() -> argparse.Namespace: def test_cql() -> None: args = get_args() env = gym.make(args.task) - env.action_space = cast(gym.spaces.Box, env.action_space) + assert isinstance(env.action_space, gym.spaces.Box) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape @@ -245,8 +244,8 @@ def test_cql() -> None: # model # actor network net_a = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ) @@ -261,15 +260,15 @@ def test_cql() -> None: # critic network net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index b7153ed11..c2152a711 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -83,8 +83,8 @@ def test_il() -> None: # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index b2a0b24c8..4d6159ff5 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -117,15 +117,15 @@ def test_td3_bc() -> None: # critic network net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/examples/vizdoom/env.py b/examples/vizdoom/env.py index f5c974fa0..2869acd1a 100644 --- a/examples/vizdoom/env.py +++ b/examples/vizdoom/env.py @@ -1,9 +1,12 @@ import os +from collections.abc import Sequence +from typing import Any import cv2 import gymnasium as gym import numpy as np import vizdoom as vzd +from numpy.typing import NDArray from tianshou.env import ShmemVectorEnv @@ -13,7 +16,7 @@ envpool = None -def normal_button_comb(): +def normal_button_comb() -> list: actions = [] m_forward = [[0.0], [1.0]] t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] @@ -23,7 +26,7 @@ def normal_button_comb(): return actions -def battle_button_comb(): +def battle_button_comb() -> list: actions = [] m_forward_backward = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] m_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]] @@ -41,7 +44,13 @@ def battle_button_comb(): class Env(gym.Env): - def __init__(self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False) -> None: + def __init__( + self, + cfg_path: str, + frameskip: int = 4, + res: Sequence[int] = (4, 40, 60), + save_lmp: bool = False, + ) -> None: super().__init__() self.save_lmp = save_lmp self.health_setting = "battle" in cfg_path @@ -62,7 +71,7 @@ def __init__(self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False) -> No self.spec = gym.envs.registration.EnvSpec("vizdoom-v0") self.count = 0 - def get_obs(self): + def get_obs(self) -> None: state = self.game.get_state() if state is None: return @@ -70,7 +79,11 @@ def get_obs(self): self.obs_buffer[:-1] = self.obs_buffer[1:] self.obs_buffer[-1] = cv2.resize(obs, (self.res[-1], self.res[-2])) - def reset(self): + def reset( + self, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[NDArray[np.uint8], dict[str, Any]]: if self.save_lmp: self.game.new_episode(f"lmps/episode_{self.count}.lmp") else: @@ -81,9 +94,9 @@ def reset(self): self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH) self.killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT) self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2) - return self.obs_buffer + return self.obs_buffer, {"TimeLimit.truncated": False} - def step(self, action): + def step(self, action: int) -> tuple[NDArray[np.uint8], float, bool, bool, dict[str, Any]]: self.game.make_action(self.available_actions[action], self.skip) reward = 0.0 self.get_obs() @@ -105,17 +118,27 @@ def step(self, action): elif self.game.is_episode_finished(): done = True info["TimeLimit.truncated"] = True - return self.obs_buffer, reward, done, info + return self.obs_buffer, reward, done, info.get("TimeLimit.truncated", False), info - def render(self): + def render(self) -> None: pass - def close(self): + def close(self) -> None: self.game.close() -def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_num): - test_num = min(os.cpu_count() - 1, test_num) +def make_vizdoom_env( + task: str, + frame_skip: int, + res: tuple[int], + save_lmp: bool = False, + seed: int | None = None, + training_num: int = 10, + test_num: int = 10, +) -> tuple[Env, ShmemVectorEnv, ShmemVectorEnv]: + cpu_count = os.cpu_count() + if cpu_count is not None: + test_num = min(cpu_count - 1, test_num) if envpool is not None: task_id = "".join([i.capitalize() for i in task.split("_")]) + "-v1" lmp_save_dir = "lmps/" if save_lmp else "" @@ -167,9 +190,11 @@ def make_vizdoom_env(task, frame_skip, res, save_lmp, seed, training_num, test_n # env = Env("maps/D1_basic.cfg", 4, (4, 84, 84)) env = Env("maps/D3_battle.cfg", 4, (4, 84, 84)) print(env.available_actions) + assert isinstance(env.action_space, gym.spaces.Discrete) action_num = env.action_space.n - obs = env.reset() - print(env.spec.reward_threshold) + obs, _ = env.reset() + if env.spec: + print(env.spec.reward_threshold) print(obs.shape, action_num) for _ in range(4000): obs, rew, terminated, truncated, info = env.step(0) diff --git a/examples/vizdoom/replay.py b/examples/vizdoom/replay.py index 4437a08ba..45f9df671 100755 --- a/examples/vizdoom/replay.py +++ b/examples/vizdoom/replay.py @@ -1,4 +1,5 @@ # import cv2 +import os import sys import time @@ -6,7 +7,10 @@ import vizdoom as vzd -def main(cfg_path="maps/D3_battle.cfg", lmp_path="test.lmp") -> None: +def main( + cfg_path: str = os.path.join("maps", "D3_battle.cfg"), + lmp_path: str = os.path.join("test.lmp"), +) -> None: game = vzd.DoomGame() game.load_config(cfg_path) game.set_screen_format(vzd.ScreenFormat.CRCGCB) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 9ba82a10f..62daaf64f 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -9,8 +9,8 @@ from env import make_vizdoom_env from network import C51 -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import C51Policy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer @@ -130,6 +130,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -188,10 +189,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - lens = result.lens_stat.mean * args.skip_num - print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") - print(f"Mean length (over {result.n_collected_episodes} episodes): {lens}") + result.pprint_asdict() if args.watch: watch() diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 5c2a9e1f7..7476d4f26 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -8,11 +8,11 @@ import torch from env import make_vizdoom_env from network import DQN -from torch.distributions import Categorical, Distribution +from torch.distributions import Categorical from torch.optim.lr_scheduler import LambdaLR -from examples.common import logger_factory from tianshou.data import Collector, VectorReplayBuffer +from tianshou.highlevel.logger import LoggerFactoryDefault from tianshou.policy import ICMPolicy, PPOPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OnpolicyTrainer @@ -136,15 +136,11 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - # define policy - def dist(logits: torch.Tensor) -> Distribution: - return Categorical(logits=logits) - policy: PPOPolicy = PPOPolicy( actor=actor, critic=critic, optim=optim, - dist_fn=dist, + dist_fn=Categorical, discount_factor=args.gamma, gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm, @@ -177,7 +173,7 @@ def dist(logits: torch.Tensor) -> Distribution: device=args.device, ) icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr) - policy = ICMPolicy( + policy: ICMPolicy = ICMPolicy( # type: ignore[no-redef] policy=policy, model=icm_net, optim=icm_optim, @@ -210,6 +206,7 @@ def dist(logits: torch.Tensor) -> Distribution: log_path = os.path.join(args.logdir, log_name) # logger + logger_factory = LoggerFactoryDefault() if args.logger == "wandb": logger_factory.logger_type = "wandb" logger_factory.wandb_project = args.wandb_project @@ -254,10 +251,7 @@ def watch() -> None: print("Testing agent ...") test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) - rew = result.returns_stat.mean - lens = result.lens_stat.mean * args.skip_num - print(f"Mean reward (over {result.n_collected_episodes} episodes): {rew}") - print(f"Mean length (over {result.n_collected_episodes} episodes): {lens}") + result.pprint_asdict() if args.watch: watch() diff --git a/poetry.lock b/poetry.lock index 0aa79b7e7..a6fbf3229 100644 --- a/poetry.lock +++ b/poetry.lock @@ -980,6 +980,13 @@ files = [ {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fa42a605d099ee7d41ba2b5fb75e21423951fd26e5d50583a00471238fb3021d"}, {file = "dm_tree-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b7764de0d855338abefc6e3ee9fe40d301668310aa3baea3f778ff051f4393"}, {file = "dm_tree-0.1.8-cp311-cp311-win_amd64.whl", hash = "sha256:a5d819c38c03f0bb5b3b3703c60e4b170355a0fc6b5819325bf3d4ceb3ae7e80"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75c5d528bb992981c20793b6b453e91560784215dffb8a5440ba999753c14ceb"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e"}, + {file = "dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715"}, {file = "dm_tree-0.1.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8c60a7eadab64c2278861f56bca320b2720f163dca9d7558103c3b77f2416571"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:af4b3d372f2477dcd89a6e717e4a575ca35ccc20cc4454a8a4b6f8838a00672d"}, {file = "dm_tree-0.1.8-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de287fabc464b8734be251e46e06aa9aa1001f34198da2b6ce07bd197172b9cb"}, @@ -1083,43 +1090,6 @@ packaging = "*" types-protobuf = ">=3.17.3" typing-extensions = "*" -[[package]] -name = "etils" -version = "1.6.0" -description = "Collection of common python utils" -optional = true -python-versions = ">=3.10" -files = [ - {file = "etils-1.6.0-py3-none-any.whl", hash = "sha256:3da192b057929f2511f9ef713cee7d9c498e741740f8b2a9c0f6392d787201d4"}, - {file = "etils-1.6.0.tar.gz", hash = "sha256:c635fbd02a79fed4ad76825d31306b581d22b40671721daa8bc279cf6333e48a"}, -] - -[package.dependencies] -fsspec = {version = "*", optional = true, markers = "extra == \"epath\""} -importlib_resources = {version = "*", optional = true, markers = "extra == \"epath\""} -typing_extensions = {version = "*", optional = true, markers = "extra == \"epy\""} -zipp = {version = "*", optional = true, markers = "extra == \"epath\""} - -[package.extras] -all = ["etils[array-types]", "etils[eapp]", "etils[ecolab]", "etils[edc]", "etils[enp]", "etils[epath-gcs]", "etils[epath-s3]", "etils[epath]", "etils[epy]", "etils[etqdm]", "etils[etree-dm]", "etils[etree-jax]", "etils[etree-tf]", "etils[etree]"] -array-types = ["etils[enp]"] -dev = ["chex", "dataclass_array", "optree", "pyink", "pylint (>=2.6.0)", "pytest", "pytest-subtests", "pytest-xdist", "torch"] -docs = ["etils[all,dev]", "sphinx-apitree[ext]"] -eapp = ["absl-py", "etils[epy]", "simple_parsing"] -ecolab = ["etils[enp]", "etils[epy]", "etils[etree]", "jupyter", "mediapy", "numpy", "packaging", "protobuf"] -edc = ["etils[epy]"] -enp = ["etils[epy]", "numpy"] -epath = ["etils[epy]", "fsspec", "importlib_resources", "typing_extensions", "zipp"] -epath-gcs = ["etils[epath]", "gcsfs"] -epath-s3 = ["etils[epath]", "s3fs"] -epy = ["typing_extensions"] -etqdm = ["absl-py", "etils[epy]", "tqdm"] -etree = ["etils[array-types]", "etils[enp]", "etils[epy]", "etils[etqdm]"] -etree-dm = ["dm-tree", "etils[etree]"] -etree-jax = ["etils[etree]", "jax[cpu]"] -etree-tf = ["etils[etree]", "tensorflow"] -lazy-imports = ["etils[ecolab]"] - [[package]] name = "executing" version = "2.0.1" @@ -2688,42 +2658,40 @@ files = [ [[package]] name = "mujoco" -version = "3.1.1" +version = "2.3.7" description = "MuJoCo Physics Simulator" optional = true python-versions = ">=3.8" files = [ - {file = "mujoco-3.1.1-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:be7aa04f8c91bc77fea6574c80154e62973fda0a959a6add4c9bc426db0ea9de"}, - {file = "mujoco-3.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e35a60ade27b8e074ad7f08496e4a9101da9d358401bcbb08610dcf5066c3622"}, - {file = "mujoco-3.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f450b46802fca047e2d19ce8adefa9f4a1787273a27511d76ef717eafaf18d8b"}, - {file = "mujoco-3.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51ac0f9df06e612ee628c571bab0320dc7721b7732e8c025a2289fda17f98a47"}, - {file = "mujoco-3.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:d78a07fd18ae82a4cd4628e062fff1224220a7d86749c02170a0ea8e356c7442"}, - {file = "mujoco-3.1.1-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:34a61d8c1631aa6d85252b04b01fdc98bf7d6829e1aab08182069f29af02617e"}, - {file = "mujoco-3.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:34f2b63b9f7e76b10a9a82d085d2637ecccf6f2b2df177d7bc3d16b6857af861"}, - {file = "mujoco-3.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:537e6ca9b0896865a8c30da6060b158299450776cd8e5796fd23c1fc54d26aa5"}, - {file = "mujoco-3.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3aee8a9af27f5443a0c6fc09dd2384ebb3e2774928fda7213ca9809e552e0010"}, - {file = "mujoco-3.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:431fdb51194f5a6dc1b3c2d625410d7468c40ec1091ac4e4e23081ace47d9a15"}, - {file = "mujoco-3.1.1-cp312-cp312-macosx_10_16_x86_64.whl", hash = "sha256:53ca08b1af724104ceeb307b47131e5f244ebb35ff5b5b38cf4f5f3b6b662b9f"}, - {file = "mujoco-3.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f5e6502c1ba6902c276d384fe7dee8a99ca570ef187dc122c60692baf0f068cb"}, - {file = "mujoco-3.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:267458ff744cb1a2265ce2cf3f81ecb096883b2003a647de2d9177bb606514bb"}, - {file = "mujoco-3.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5731c8e6efb739312ece205fa6932d76e8d6ecd78a19c78da58e58b2abe5b591"}, - {file = "mujoco-3.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:0037ea34af70a5553cf516027e76d3f91b13389a4b01679d5d77d8ea0bc4aaf7"}, - {file = "mujoco-3.1.1-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:70a440463d7ec272085a16057115bd3e2c74c4e91773f4fc809a40edca2b4546"}, - {file = "mujoco-3.1.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b2f471410896a23a0325b240ab535ea6ba170af1a044ff82f6ac34fb5e17f7d6"}, - {file = "mujoco-3.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50930f8bddb81f23b7c01d2beee9b01bb52827f0413c53dd2ff0b0220688e4a3"}, - {file = "mujoco-3.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31aa58202baeafa9f95dac65dc19c7c04b6b5079eaed65113c66235d08a49a98"}, - {file = "mujoco-3.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:d867792d0ca21720337e17e9dda67ada16d03bdc2c300082140aca7d1a2d01f0"}, - {file = "mujoco-3.1.1-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:f9d2e6e3cd857662e1eac7b7ff68074b329ab99bda9c0a5020e2aeb242db00e1"}, - {file = "mujoco-3.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ec29474314726a71f60ed2fa519a9f8df332ae23b638368a7833c851ce0fe500"}, - {file = "mujoco-3.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3195aa1bfb96cfce4aaf116baf8b74aee7e479cb3c2427ede4d6f9ad91f7c107"}, - {file = "mujoco-3.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd0ebcfc7f4771aeedb5e66321c00e9c8c4393834722385b4a23401f1eee3e8f"}, - {file = "mujoco-3.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:0e76ebd3030aa32fd755e4ec0c1db069ad0a0fb86184b80c12fe5f2ef822bc56"}, - {file = "mujoco-3.1.1.tar.gz", hash = "sha256:1121273de2fbf4ed309e5944a3db39d01f385b220d20e78c460ec4efc06945b3"}, + {file = "mujoco-2.3.7-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:e8714a5ff6a1561b364b7b4648d4c0c8d13e751874cf7401c309b9d23fa9598b"}, + {file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a934315f858a4e0c4b90a682fde519471cfdd7baa64435179da8cd20d4ae3f99"}, + {file = "mujoco-2.3.7-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:36513024330f88b5f9a43558efef5692b33599bffd5141029b690a27918ffcbe"}, + {file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d4eede8ba8210fbd3d3cd1dbf69e24dd1541aa74c5af5b8adbbbf65504b6dba"}, + {file = "mujoco-2.3.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab85fafc9d5a091c712947573b7e694512d283876bf7f33ae3f8daad3a20c0db"}, + {file = "mujoco-2.3.7-cp310-cp310-win_amd64.whl", hash = "sha256:f8b7e13fef8c813d91b78f975ed0815157692777907ffa4b4be53a4edb75019b"}, + {file = "mujoco-2.3.7-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:779520216f72a8e370e3f0cdd71b45c3b7384c63331a3189194c930a3e7cff5c"}, + {file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9d4018053879016282d27ab7a91e292c72d44efb5a88553feacfe5b843dde103"}, + {file = "mujoco-2.3.7-cp311-cp311-macosx_11_0_x86_64.whl", hash = "sha256:3149b16b8122ee62642474bfd2871064e8edc40235471cf5d84be3569afc0312"}, + {file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c08660a8d52ef3efde76095f0991e807703a950c1e882d2bcd984b9a846626f7"}, + {file = "mujoco-2.3.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:426af8965f8636d94a0f75740c3024a62b3e585020ee817ef5208ec844a1ad94"}, + {file = "mujoco-2.3.7-cp311-cp311-win_amd64.whl", hash = "sha256:215415a8e98a4b50625beae859079d5e0810b2039e50420f0ba81763c34abb59"}, + {file = "mujoco-2.3.7-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:8b78d14f4c60cea3c58e046bd4de453fb5b9b33aca6a25fc91d39a53f3a5342a"}, + {file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5c6f5a51d6f537a4bf294cf73816f3a6384573f8f10a5452b044df2771412a96"}, + {file = "mujoco-2.3.7-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:ea8911e6047f92d7d775701f37e4c093971b6def3160f01d0b6926e29a7e962e"}, + {file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7473a3de4dd1a8762d569ffb139196b4c5e7eca27d256df97b6cd4c66d2a09b2"}, + {file = "mujoco-2.3.7-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40e7e2d8f93d2495ec74efec84e5118ecc6e1d85157a844789c73c9ac9a4e28e"}, + {file = "mujoco-2.3.7-cp38-cp38-win_amd64.whl", hash = "sha256:720bc228a2023b3b0ed6af78f5b0f8ea36867be321d473321555c57dbf6e4e5b"}, + {file = "mujoco-2.3.7-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:855e79686366442aa410246043b44f7d842d3900d68fe7e37feb42147db9d707"}, + {file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:98947f4a742d34d36f3c3f83e9167025bb0414bbaa4bd859b0673bdab9959963"}, + {file = "mujoco-2.3.7-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:d42818f2ee5d1632dbce31d136ed5ff868db54b04e4e9aca0c5a3ac329f8a90f"}, + {file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9237e1ba14bced9449c31199e6d5be49547f3a4c99bc83b196af7ca45fd73b83"}, + {file = "mujoco-2.3.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b728ea638245b150e2650c5433e6952e0ed3798c63e47e264574270caea2a3"}, + {file = "mujoco-2.3.7-cp39-cp39-win_amd64.whl", hash = "sha256:9c721a5042b99d948d5f0296a534bcce3f142c777c4d7642f503a539513f3912"}, + {file = "mujoco-2.3.7.tar.gz", hash = "sha256:422041f1ce37c6d151fbced1048df626837e94fe3cd9f813585907046336a7d0"}, ] [package.dependencies] absl-py = "*" -etils = {version = "*", extras = ["epath"]} glfw = "*" numpy = "*" pyopengl = "*" @@ -3271,6 +3239,7 @@ optional = false python-versions = ">=3" files = [ {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"}, + {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux2014_aarch64.whl", hash = "sha256:211a63e7b30a9d62f1a853e19928fbb1a750e3f17a13a3d1f98ff0ced19478dd"}, {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"}, ] @@ -4220,6 +4189,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -5935,6 +5905,31 @@ files = [ {file = "types_python_dateutil-2.8.19.14-py3-none-any.whl", hash = "sha256:f977b8de27787639986b4e28963263fd0e5158942b3ecef91b9335c130cb1ce9"}, ] +[[package]] +name = "types-requests" +version = "2.31.0.20240311" +description = "Typing stubs for requests" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-requests-2.31.0.20240311.tar.gz", hash = "sha256:b1c1b66abfb7fa79aae09097a811c4aa97130eb8831c60e47aee4ca344731ca5"}, + {file = "types_requests-2.31.0.20240311-py3-none-any.whl", hash = "sha256:47872893d65a38e282ee9f277a4ee50d1b28bd592040df7d1fdaffdf3779937d"}, +] + +[package.dependencies] +urllib3 = ">=2" + +[[package]] +name = "types-tabulate" +version = "0.9.0.20240106" +description = "Typing stubs for tabulate" +optional = false +python-versions = ">=3.8" +files = [ + {file = "types-tabulate-0.9.0.20240106.tar.gz", hash = "sha256:c9b6db10dd7fcf55bd1712dd3537f86ddce72a08fd62bb1af4338c7096ce947e"}, + {file = "types_tabulate-0.9.0.20240106-py3-none-any.whl", hash = "sha256:0378b7b6fe0ccb4986299496d027a6d4c218298ecad67199bbd0e2d7e9d335a1"}, +] + [[package]] name = "typing-extensions" version = "4.8.0" @@ -6228,4 +6223,4 @@ vizdoom = ["vizdoom"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "0d4ff98ed02fe3f34c0d05b5d175822a82ac08a5ed52e57b7f847a48c302add6" +content-hash = "06b9166b2e752fbab564cbc0dbce226844c26dd2b59f9f7e95104570e377c43b" diff --git a/pyproject.toml b/pyproject.toml index 7a795004a..813cbd335 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,8 @@ envpool = { version = "^0.8.2", optional = true, markers = "sys_platform != 'da gymnasium-robotics = { version = "*", optional = true } imageio = { version = ">=2.14.1", optional = true } jsonargparse = {version = "^4.24.1", optional = true} -mujoco = { version = ">=2.1.5", optional = true } +# we need <3 b/c of https://github.com/Farama-Foundation/Gymnasium/issues/749 +mujoco = { version = ">=2.1.5, <3", optional = true } mujoco-py = { version = ">=2.1,<2.2", optional = true } opencv_python = { version = "*", optional = true } pybullet = { version = "*", optional = true } @@ -111,6 +112,8 @@ sphinx-togglebutton = "^0.3.2" sphinx-toolbox = "^3.5.0" sphinxcontrib-bibtex = "*" sphinxcontrib-spelling = "^8.0.0" +types-requests = "^2.31.0.20240311" +types-tabulate = "^0.9.0.20240106" wandb = "^0.12.0" [tool.mypy] @@ -219,6 +222,6 @@ doc-clean = "rm -rf docs/_build" doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"] doc-spellcheck = "sphinx-build -W -b spelling docs docs/_build" doc-build = ["doc-generate-files", "doc-spellcheck", "_sphinx_build"] -_mypy = "mypy tianshou" +_mypy = "mypy tianshou test examples" _mypy_nb = "nbqa mypy docs" type-check = ["_mypy", "_mypy_nb"] diff --git a/test/base/env.py b/test/base/env.py index c05c98718..2a7b09278 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -137,7 +137,7 @@ def do_sleep(self) -> None: sleep_time *= self.sleep time.sleep(sleep_time) - def step(self, action: np.ndarray | int): + def step(self, action: np.ndarray | int): # type: ignore[no-untyped-def] # cf. issue #1080 self.steps += 1 if self._md_action and isinstance(action, np.ndarray): action = action[0] diff --git a/test/base/test_batch.py b/test/base/test_batch.py index aaacffdd4..f11a8d60e 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -2,6 +2,7 @@ import pickle import sys from itertools import starmap +from typing import cast import networkx as nx import numpy as np @@ -135,13 +136,13 @@ def test_batch() -> None: assert batch_slice.a.c == batch2.a.c assert batch_slice.a.d.e == batch2.a.d.e batch2.a.d.f = {} - batch2_sum = (batch2 + 1.0) * 2 + batch2_sum = (batch2 + 1.0) * 2 # type: ignore # __add__ supports Number as input type assert batch2_sum.a.b == (batch2.a.b + 1.0) * 2 assert batch2_sum.a.c == (batch2.a.c + 1.0) * 2 assert batch2_sum.a.d.e == (batch2.a.d.e + 1.0) * 2 assert batch2_sum.a.d.f.is_empty() with pytest.raises(TypeError): - batch2 += [1] + batch2 += [1] # type: ignore # error is raised explicitly batch3 = Batch(a={"c": np.zeros(1), "d": Batch(e=np.array([0.0]), f=np.array([3.0]))}) batch3.a.d[0] = {"e": 4.0} assert batch3.a.d.e[0] == 4.0 @@ -160,7 +161,11 @@ def test_batch() -> None: batch5 = Batch(a=np.array([{"index": 0}])) assert isinstance(batch5.a, Batch) assert np.allclose(batch5.a.index, [0]) + # We use setattr b/c the setattr of Batch will actually change the type of the field that is being set! + # However, mypy would not understand this, and rightly expect that batch.b = some_array would lead to + # batch.b being an array (which it is not, it's turned into a Batch instead) batch5.b = np.array([{"index": 1}]) + batch5.b = cast(Batch, batch5.b) assert isinstance(batch5.b, Batch) assert np.allclose(batch5.b.index, [1]) @@ -215,7 +220,7 @@ def test_batch_over_batch() -> None: batch5[:, 3] with pytest.raises(IndexError): batch5[:, :, -1] - batch5[:, -1] += 1 + batch5[:, -1] += np.int_(1) assert np.allclose(batch5.a, [1, 3]) assert np.allclose(batch5.b.c.squeeze(), [[0, 1]] * 3) with pytest.raises(ValueError): @@ -251,7 +256,7 @@ def test_batch_cat_and_stack() -> None: assert np.allclose(ans.b, np.concatenate([a.b, b.b, a.b])) assert ans.a.t.is_empty() - assert b1.stack_([b2]) is None + b1.stack_([b2]) assert isinstance(b1.a.d.e, np.ndarray) assert b1.a.d.e.ndim == 2 @@ -350,13 +355,15 @@ def test_batch_cat_and_stack() -> None: # test with illegal input format with pytest.raises(ValueError): - Batch.cat([[Batch(a=1)], [Batch(a=1)]]) + Batch.cat([[Batch(a=1)], [Batch(a=1)]]) # type: ignore # cat() tested with invalid inp with pytest.raises(ValueError): - Batch.stack([[Batch(a=1)], [Batch(a=1)]]) + Batch.stack([[Batch(a=1)], [Batch(a=1)]]) # type: ignore # stack() tested with invalid inp # exceptions - assert Batch.cat([]).is_empty() - assert Batch.stack([]).is_empty() + batch_cat: Batch = Batch.cat([]) + assert batch_cat.is_empty() + batch_stack: Batch = Batch.stack([]) + assert batch_stack.is_empty() b1 = Batch(e=[4, 5], d=6) b2 = Batch(e=[4, 6]) with pytest.raises(ValueError): @@ -548,8 +555,8 @@ def test_batch_empty() -> None: def test_batch_standard_compatibility() -> None: batch = Batch(a=np.array([[1.0, 2.0], [3.0, 4.0]]), b=Batch(), c=np.array([5.0, 6.0])) batch_mean = np.mean(batch) - assert isinstance(batch_mean, Batch) - assert sorted(batch_mean.keys()) == ["a", "b", "c"] + assert isinstance(batch_mean, Batch) # type: ignore # mypy doesn't know but it works, cf. `batch.rst` + assert sorted(batch_mean.keys()) == ["a", "b", "c"] # type: ignore with pytest.raises(TypeError): len(batch_mean) assert np.all(batch_mean.a == np.mean(batch.a, axis=0)) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 0806a750f..31265f664 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -5,6 +5,7 @@ import h5py import numpy as np +import numpy.typing as npt import pytest import torch @@ -27,7 +28,7 @@ from test.base.env import MoveToRightEnv, MyGoalEnv -def test_replaybuffer(size=10, bufsize=20) -> None: +def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None: env = MoveToRightEnv(size) buf = ReplayBuffer(bufsize) buf.update(buf) @@ -52,6 +53,7 @@ def test_replaybuffer(size=10, bufsize=20) -> None: assert buf.act.dtype == int assert buf.act.shape == (bufsize, 1) data, indices = buf.sample(bufsize * 2) + assert isinstance(data, Batch) assert (indices < len(buf)).all() assert (data.obs < size).all() assert (data.done >= 0).all() @@ -139,7 +141,7 @@ def test_replaybuffer(size=10, bufsize=20) -> None: assert np.all(b.next(np.array([0, 1, 2, 3])) == [0, 2, 2, 3]) -def test_ignore_obs_next(size=10) -> None: +def test_ignore_obs_next(size: int = 10) -> None: # Issue 82 buf = ReplayBuffer(size, ignore_obs_next=True) for i in range(size): @@ -164,11 +166,19 @@ def test_ignore_obs_next(size=10) -> None: assert isinstance(data, Batch) assert isinstance(data2, Batch) assert np.allclose(indices, orig) + assert hasattr(data.obs_next, "mask") and hasattr( + data2.obs_next, + "mask", + ), "Both `data.obs_next` and `data2.obs_next` must have attribute `mask`." assert np.allclose(data.obs_next.mask, data2.obs_next.mask) assert np.allclose(data.obs_next.mask, [0, 2, 3, 3, 5, 6, 6, 8, 9, 9]) buf.stack_num = 4 data = buf[indices] data2 = buf[indices] + assert hasattr(data.obs_next, "mask") and hasattr( + data2.obs_next, + "mask", + ), "Both `data.obs_next` and `data2.obs_next` must have attribute `mask`." assert np.allclose(data.obs_next.mask, data2.obs_next.mask) assert np.allclose( data.obs_next.mask, @@ -187,9 +197,9 @@ def test_ignore_obs_next(size=10) -> None: ], ), ) - assert np.allclose(data.info["if"], data2.info["if"]) + assert np.allclose(data["info"]["if"], data2["info"]["if"]) assert np.allclose( - data.info["if"], + data["info"]["if"], np.array( [ [0, 0, 0, 0], @@ -208,7 +218,7 @@ def test_ignore_obs_next(size=10) -> None: assert data.obs_next -def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None: +def test_stack(size: int = 5, bufsize: int = 9, stack_num: int = 4, cached_num: int = 3) -> None: env = MoveToRightEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) @@ -279,7 +289,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None: buf[bufsize * 2] -def test_priortized_replaybuffer(size=32, bufsize=15) -> None: +def test_priortized_replaybuffer(size: int = 32, bufsize: int = 15) -> None: env = MoveToRightEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) @@ -319,21 +329,24 @@ def test_priortized_replaybuffer(size=32, bufsize=15) -> None: assert np.allclose(buf.weight[indices], np.abs(-data.weight / 2) ** buf._alpha) # check multi buffer's data assert np.allclose(buf2[np.arange(buf2.maxsize)].weight, 1) - batch, indices = buf2.sample(10) - buf2.update_weight(indices, batch.weight * 0) + batch_sample, indices = buf2.sample(10) + buf2.update_weight(indices, batch_sample.weight * 0) weight = buf2[np.arange(buf2.maxsize)].weight + assert isinstance(weight, np.ndarray) mask = np.isin(np.arange(buf2.maxsize), indices) - assert np.all(weight[mask] == weight[mask][0]) - assert np.all(weight[~mask] == weight[~mask][0]) - assert weight[~mask][0] < weight[mask][0] - assert weight[mask][0] <= 1 + selected_weight = weight[mask] + unselected_weight = weight[~mask] + assert np.all(selected_weight == selected_weight[0]) + assert np.all(unselected_weight == unselected_weight[0]) + assert unselected_weight[0] < selected_weight[0] + assert selected_weight[0] <= 1 -def test_herreplaybuffer(size=10, bufsize=100, sample_sz=4) -> None: +def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4) -> None: env_size = size env = MyGoalEnv(env_size, array_state=True) - def compute_reward_fn(ag, g): + def compute_reward_fn(ag: np.ndarray, g: np.ndarray) -> np.ndarray: return env.compute_reward_fn(ag, g, {}) buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) @@ -368,7 +381,7 @@ def compute_reward_fn(ag, g): assert len(buf) == min(bufsize, i + 1) assert len(buf2) == min(bufsize, 3 * (i + 1)) - batch, indices = buf.sample(sample_sz) + batch_sample, indices = buf.sample(sample_sz) # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() @@ -398,7 +411,7 @@ def compute_reward_fn(ag, g): tmp_indices = buf.next(tmp_indices) # Test vector buffer - batch, indices = buf2.sample(sample_sz) + batch_sample, indices = buf2.sample(sample_sz) # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() @@ -431,9 +444,6 @@ def compute_reward_fn(ag, g): bufsize = 15 env = MyGoalEnv(env_size, array_state=False) - def compute_reward_fn(ag, g): - return env.compute_reward_fn(ag, g, {}) - buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) buf._index = 5 # shifted start index buf.future_p = 1 @@ -454,11 +464,11 @@ def compute_reward_fn(ag, g): ) buf.add(batch) obs = obs_next - batch, indices = buf.sample(0) - assert np.all(buf[:5].obs.desired_goal == buf[0].obs.desired_goal) - assert np.all(buf[5:10].obs.desired_goal == buf[5].obs.desired_goal) - assert np.all(buf[10:].obs.desired_goal == buf[0].obs.desired_goal) # (same ep) - assert np.all(buf[0].obs.desired_goal != buf[5].obs.desired_goal) # (diff ep) + batch_sample, indices = buf.sample(0) + assert np.all(buf.obs.desired_goal[:5] == buf.obs.desired_goal[0]) + assert np.all(buf.obs.desired_goal[5:10] == buf.obs.desired_goal[5]) + assert np.all(buf.obs.desired_goal[10:] == buf.obs.desired_goal[0]) # (same ep) + assert np.all(buf.obs.desired_goal[0] != buf.obs.desired_goal[5]) # (diff ep) # Another test case for cycled indices env_size = 99 @@ -508,8 +518,8 @@ def test_update() -> None: assert len(buf1) > len(buf2) buf2.update(buf1) assert len(buf1) == len(buf2) - assert (buf2[0].obs == buf1[1].obs).all() - assert (buf2[-1].obs == buf1[0].obs).all() + assert (buf2.obs[0] == buf1.obs[1]).all() + assert (buf2.obs[-1] == buf1.obs[0]).all() b = CachedReplayBuffer(ReplayBuffer(10), 4, 5) with pytest.raises(NotImplementedError): b.update(b) @@ -524,11 +534,11 @@ def test_segtree() -> None: assert np.all([tree[i] == 0.0 for i in range(actual_len)]) with pytest.raises(IndexError): tree[actual_len] - naive = np.zeros([actual_len]) + naive = np.zeros(actual_len) for _ in range(1000): # random choose a place to perform single update - index = np.random.randint(actual_len) - value = np.random.rand() + index: int | np.ndarray = np.random.randint(actual_len) + value: float | np.ndarray = np.random.rand() naive[index] = value tree[index] = value for i in range(actual_len): @@ -605,10 +615,10 @@ def test_segtree() -> None: tree = SegmentTree(size) tree[np.arange(size)] = naive - def sample_npbuf(): + def sample_npbuf() -> np.ndarray: return np.random.choice(size, bsz, p=naive / naive.sum()) - def sample_tree(): + def sample_tree() -> int | np.ndarray: scalar = np.random.rand(bsz) * tree.reduce() return tree.get_prefix_sum_idx(scalar) @@ -699,19 +709,19 @@ def test_hdf5() -> None: assert isinstance(buffers[k].get(0, "info"), Batch) assert isinstance(_buffers[k].get(0, "info"), Batch) for k in ["array"]: - assert np.all(buffers[k][:].info.number.n == _buffers[k][:].info.number.n) - assert np.all(buffers[k][:].info.extra == _buffers[k][:].info.extra) + assert np.all(buffers[k][:]["info"].number.n == _buffers[k][:]["info"].number.n) + assert np.all(buffers[k][:]["info"]["extra"] == _buffers[k][:]["info"]["extra"]) # raise exception when value cannot be pickled data = {"not_supported": lambda x: x * x} grp = h5py.Group with pytest.raises(NotImplementedError): - to_hdf5(data, grp) + to_hdf5(data, grp) # type: ignore # ndarray with data type not supported by HDF5 that cannot be pickled data = {"not_supported": np.array(lambda x: x * x)} grp = h5py.Group with pytest.raises(RuntimeError): - to_hdf5(data, grp) + to_hdf5(data, grp) # type: ignore def test_replaybuffermanager() -> None: @@ -860,7 +870,7 @@ def test_replaybuffermanager() -> None: assert np.all(ptr == [10]) assert np.all(ep_idx == [13]) assert np.allclose(buf.unfinished_index(), [4]) - indices = sorted(buf.sample_indices(0)) + indices = np.array(sorted(buf.sample_indices(0))) assert np.allclose(indices, np.arange(len(buf))) assert np.allclose( buf.prev(indices), @@ -913,8 +923,8 @@ def test_replaybuffermanager() -> None: ], ) # corner case: list, int and -1 - assert buf.prev(-1) == buf.prev([buf.maxsize - 1])[0] - assert buf.next(-1) == buf.next([buf.maxsize - 1])[0] + assert buf.prev(-1) == buf.prev(np.array([buf.maxsize - 1]))[0] + assert buf.next(-1) == buf.next(np.array([buf.maxsize - 1]))[0] batch = buf._meta batch.info = np.ones(buf.maxsize) buf.set_batch(batch) @@ -1131,10 +1141,12 @@ def test_multibuf_stack() -> None: ], ), buf4.done assert np.allclose(buf4.unfinished_index(), [10, 15, 20]) - indices = sorted(buf4.sample_indices(0)) + indices = np.array(sorted(buf4.sample_indices(0))) assert np.allclose(indices, [*list(range(bufsize)), 9, 10, 14, 15, 19, 20]) + cur_obs = buf4[indices].obs + assert isinstance(cur_obs, np.ndarray) assert np.allclose( - buf4[indices].obs[..., 0], + cur_obs[..., 0], [ [11, 11, 11, 12], [11, 11, 12, 13], @@ -1153,8 +1165,10 @@ def test_multibuf_stack() -> None: [11, 11, 11, 12], ], ) + next_obs = buf4[indices].obs_next + assert isinstance(next_obs, np.ndarray) assert np.allclose( - buf4[indices].obs_next[..., 0], + next_obs[..., 0], [ [11, 11, 12, 13], [11, 12, 13, 14], @@ -1182,7 +1196,7 @@ def test_multibuf_stack() -> None: buf.stack_num = 2 indices = buf5.sample_indices(0) assert np.allclose(sorted(indices), [0, 1, 2, 5, 6, 7, 10, 15, 20]) - batch, _ = buf5.sample(0) + batch_sample, _ = buf5.sample(0) # test Atari with CachedReplayBuffer, save_only_last_obs + ignore_obs_next buf6 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True, ignore_obs_next=True), @@ -1285,7 +1299,7 @@ def test_multibuf_hdf5() -> None: def test_from_data() -> None: - obs_data = np.ndarray((10, 3, 3), dtype="uint8") + obs_data: npt.NDArray[np.uint8] = np.ndarray((10, 3, 3), dtype="uint8") for i in range(10): obs_data[i] = i * np.ones((3, 3), dtype="uint8") obs_next_data = np.zeros_like(obs_data) @@ -1303,11 +1317,15 @@ def test_from_data() -> None: buf = ReplayBuffer.from_data(obs, act, rew, terminated, truncated, done, obs_next) assert len(buf) == 10 batch = buf[3] - assert np.array_equal(batch.obs, 3 * np.ones((3, 3), dtype="uint8")) + cur_obs = batch.obs + assert isinstance(cur_obs, np.ndarray) + assert np.array_equal(cur_obs, 3 * np.ones((3, 3), dtype="uint8")) assert batch.act == 3 assert batch.rew == 3.0 assert not batch.done - assert np.array_equal(batch.obs_next, 4 * np.ones((3, 3), dtype="uint8")) + next_obs = batch.obs_next + assert isinstance(next_obs, np.ndarray) + assert np.array_equal(next_obs, 4 * np.ones((3, 3), dtype="uint8")) os.remove(path) diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 6bc1703f6..6baa6abf3 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -1,3 +1,6 @@ +from collections.abc import Callable, Sequence +from typing import Any + import gymnasium as gym import numpy as np import pytest @@ -12,8 +15,10 @@ ReplayBuffer, VectorReplayBuffer, ) +from tianshou.data.batch import BatchProtocol +from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol from tianshou.env import DummyVectorEnv, SubprocVectorEnv -from tianshou.policy import BasePolicy +from tianshou.policy import BasePolicy, TrainingStats try: import envpool @@ -30,9 +35,9 @@ class MaxActionPolicy(BasePolicy): def __init__( self, action_space: gym.spaces.Space | None = None, - dict_state=False, - need_state=True, - action_shape=None, + dict_state: bool = False, + need_state: bool = True, + action_shape: Sequence[int] | int | None = None, ) -> None: """Mock policy for testing, will always return an array of ones of the shape of the action space. Note that this doesn't make much sense for discrete action space (the output is then intepreted as @@ -48,20 +53,32 @@ def __init__( self.need_state = need_state self.action_shape = action_shape - def forward(self, batch, state=None): + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> Batch: if self.need_state: if state is None: state = np.zeros((len(batch.obs), 2)) - else: - state += 1 + elif isinstance(state, np.ndarray | BatchProtocol): + state += np.int_(1) + elif isinstance(state, dict) and state.get("hidden") is not None: + state["hidden"] += np.int_(1) if self.dict_state: - action_shape = self.action_shape if self.action_shape else len(batch.obs["index"]) + if self.action_shape: + action_shape = self.action_shape + elif isinstance(batch.obs, BatchProtocol): + action_shape = len(batch.obs["index"]) + else: + action_shape = len(batch.obs) return Batch(act=np.ones(action_shape), state=state) action_shape = self.action_shape if self.action_shape else len(batch.obs) return Batch(act=np.ones(action_shape), state=state) - def learn(self): - pass + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TrainingStats: + raise NotImplementedError def test_collector() -> None: @@ -90,7 +107,9 @@ def test_collector() -> None: # Making one more step results in obs_next=1 # The final 0 in the buffer.obs is because the buffer is initialized with zeros and the direct attr access assert np.allclose(c_single_env.buffer.obs[:4, 0], [0, 1, 0, 0]) - assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1]) + obs_next = c_single_env.buffer[:].obs_next[..., 0] + assert isinstance(obs_next, np.ndarray) + assert np.allclose(obs_next, [1, 2, 1]) keys = np.zeros(100) keys[:3] = 1 assert np.allclose(c_single_env.buffer.info["key"], keys) @@ -110,7 +129,9 @@ def test_collector() -> None: c_single_env.collect(n_episode=3) assert len(c_single_env.buffer) == 8 assert np.allclose(c_single_env.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) - assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) + obs_next = c_single_env.buffer[:].obs_next[..., 0] + assert isinstance(obs_next, np.ndarray) + assert np.allclose(obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) assert np.allclose(c_single_env.buffer.info["key"][:8], 1) for e in c_single_env.buffer.info["env"][:8]: assert isinstance(e, MoveToRightEnv) @@ -131,7 +152,9 @@ def test_collector() -> None: valid_indices = [0, 1, 25, 26, 50, 51, 75, 76] obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1] assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs) - assert np.allclose(c_subproc_venv_4_envs.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) + obs_next = c_subproc_venv_4_envs.buffer[:].obs_next[..., 0] + assert isinstance(obs_next, np.ndarray) + assert np.allclose(obs_next, [1, 2, 1, 2, 1, 2, 1, 2]) keys = np.zeros(100) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys) @@ -153,8 +176,10 @@ def test_collector() -> None: valid_indices = [2, 3, 27, 52, 53, 77, 78, 79] obs[valid_indices] = [0, 1, 2, 2, 3, 2, 3, 4] assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs) + obs_next = c_subproc_venv_4_envs.buffer[:].obs_next[..., 0] + assert isinstance(obs_next, np.ndarray) assert np.allclose( - c_subproc_venv_4_envs.buffer[:].obs_next[..., 0], + obs_next, [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], ) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] @@ -204,9 +229,12 @@ def test_collector() -> None: with pytest.raises(TypeError): c_dummy_venv_4_envs.collect() + def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]: + return lambda: NXEnv(i, t) + # test NXEnv for obs_type in ["array", "object"]: - envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]]) + envs = SubprocVectorEnv([get_env_factory(i=i, t=obs_type) for i in [5, 10, 15, 20]]) c_suproc_new = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c_suproc_new.reset() c_suproc_new.collect(n_step=6) @@ -214,46 +242,55 @@ def test_collector() -> None: @pytest.fixture() -def get_AsyncCollector(): +def async_collector_and_env_lens() -> tuple[AsyncCollector, list[int]]: env_lens = [2, 3, 4, 5] env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens] venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) policy = MaxActionPolicy() bufsize = 60 - c1 = AsyncCollector( + async_collector = AsyncCollector( policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), ) - c1.reset() - return c1, env_lens + async_collector.reset() + return async_collector, env_lens class TestAsyncCollector: - def test_collect_without_argument_gives_error(self, get_AsyncCollector): - c1, env_lens = get_AsyncCollector + def test_collect_without_argument_gives_error( + self, + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens with pytest.raises(TypeError): c1.collect() - def test_collect_one_episode_async(self, get_AsyncCollector): - c1, env_lens = get_AsyncCollector + def test_collect_one_episode_async( + self, + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens result = c1.collect(n_episode=1) assert result.n_collected_episodes >= 1 def test_enough_episodes_two_collection_cycles_n_episode_without_reset( self, - get_AsyncCollector, - ): - c1, env_lens = get_AsyncCollector + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens n_episode = 2 result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=False) assert result_c1.n_collected_episodes >= n_episode result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=False) assert result_c2.n_collected_episodes >= n_episode - def test_enough_episodes_two_collection_cycles_n_episode_with_reset(self, get_AsyncCollector): - c1, env_lens = get_AsyncCollector + def test_enough_episodes_two_collection_cycles_n_episode_with_reset( + self, + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens n_episode = 2 result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=True) assert result_c1.n_collected_episodes >= n_episode @@ -262,9 +299,9 @@ def test_enough_episodes_two_collection_cycles_n_episode_with_reset(self, get_As def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_episode( self, - get_AsyncCollector, - ): - c1, env_lens = get_AsyncCollector + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens ptr = [0, 0, 0, 0] bufsize = 60 for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): @@ -284,9 +321,9 @@ def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collecti def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_step( self, - get_AsyncCollector, - ): - c1, env_lens = get_AsyncCollector + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens bufsize = 60 ptr = [0, 0, 0, 0] for n_step in tqdm.trange(1, 15, desc="test async n_step"): @@ -303,17 +340,15 @@ def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collecti assert np.all(buf.obs[indices].reshape(count, env_len) == seq) assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) - @pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_first_n_episode_then_n_step( self, - get_AsyncCollector, - gym_reset_kwargs, - ): - c1, env_lens = get_AsyncCollector + async_collector_and_env_lens: tuple[AsyncCollector, list[int]], + ) -> None: + c1, env_lens = async_collector_and_env_lens bufsize = 60 ptr = [0, 0, 0, 0] for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): - result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs) + result = c1.collect(n_episode=n_episode) assert result.n_collected_episodes >= n_episode # check buffer data, obs and obs_next, env_id for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): @@ -328,7 +363,7 @@ def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collecti assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) # test async n_step, for now the buffer should be full of data, thus no bincount stuff as above for n_step in tqdm.trange(1, 15, desc="test async n_step"): - result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs) + result = c1.collect(n_step=n_step) assert result.n_collected_steps >= n_step for i in range(4): env_len = i + 2 @@ -371,9 +406,11 @@ def test_collector_with_dict_state() -> None: batch, _ = c1.buffer.sample(10) c0.buffer.update(c1.buffer) assert len(c0.buffer) in [42, 43] + cur_obs = c0.buffer[:].obs + assert isinstance(cur_obs, Batch) if len(c0.buffer) == 42: assert np.all( - c0.buffer[:].obs.index[..., 0] + cur_obs.index[..., 0] == [ 0, 1, @@ -418,10 +455,10 @@ def test_collector_with_dict_state() -> None: 3, 4, ], - ), c0.buffer[:].obs.index[..., 0] + ), cur_obs.index[..., 0] else: assert np.all( - c0.buffer[:].obs.index[..., 0] + cur_obs.index[..., 0] == [ 0, 1, @@ -467,7 +504,7 @@ def test_collector_with_dict_state() -> None: 3, 4, ], - ), c0.buffer[:].obs.index[..., 0] + ), cur_obs.index[..., 0] c2 = Collector( policy, envs, @@ -512,96 +549,100 @@ def test_collector_with_multi_agent() -> None: c_single_env.buffer.update(c_multi_env_ma.buffer) assert len(c_single_env.buffer) in [42, 43] if len(c_single_env.buffer) == 42: - multi_env_returns = [ - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - ] + multi_env_returns = np.array( + [ + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + ], + ) else: - multi_env_returns = [ - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 1, - ] + multi_env_returns = np.array( + [ + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 1, + ], + ) assert np.all(c_single_env.buffer[:].rew == [[x] * 4 for x in multi_env_returns]) assert np.all(c_single_env.buffer[:].done == multi_env_returns) c2 = Collector( @@ -656,7 +697,9 @@ def test_collector_with_atari_setting() -> None: obs = np.zeros_like(c2.buffer.obs) obs[np.arange(8)] = reference_obs[[0, 1, 2, 3, 4, 0, 1, 2], -1] assert np.all(c2.buffer.obs == obs) - assert np.allclose(c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) + obs_next = c2.buffer[:].obs_next + assert isinstance(obs_next, np.ndarray) + assert np.allclose(obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) # atari multi buffer env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]] @@ -881,7 +924,7 @@ def test_collector_envpool_gym_reset_return_info() -> None: assert np.allclose(c0.buffer.info["env_id"], env_ids) -def test_collector_with_vector_env(): +def test_collector_with_vector_env() -> None: env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]] dum = DummyVectorEnv(env_fns) @@ -905,7 +948,7 @@ def test_collector_with_vector_env(): assert np.array_equal(np.array([1, 1, 1, 8, 1, 9, 1, 10]), c4r.lens) -def test_async_collector_with_vector_env(): +def test_async_collector_with_vector_env() -> None: env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]] dum = DummyVectorEnv(env_fns) diff --git a/test/base/test_env.py b/test/base/test_env.py index f1571ca8a..a476ec5a9 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -1,5 +1,7 @@ import sys import time +from collections.abc import Callable +from typing import Any, Literal import gymnasium as gym import numpy as np @@ -17,6 +19,7 @@ VectorEnvNormObs, ) from tianshou.env.gym_wrappers import TruncatedAsTerminated +from tianshou.env.venvs import BaseVectorEnv from tianshou.utils import RunningMeanStd if __name__ == "__main__": @@ -30,7 +33,7 @@ envpool = None -def has_ray(): +def has_ray() -> bool: try: import ray # noqa: F401 @@ -39,7 +42,7 @@ def has_ray(): return False -def recurse_comp(a, b): +def recurse_comp(a: np.ndarray | list | tuple | dict, b: Any) -> np.bool_ | bool | None: try: if isinstance(a, np.ndarray): if a.dtype == object: @@ -53,7 +56,7 @@ def recurse_comp(a, b): return False -def test_async_env(size=10000, num=8, sleep=0.1) -> None: +def test_async_env(size: int = 10000, num: int = 8, sleep: float = 0.1) -> None: # simplify the test case, just keep stepping env_fns = [ lambda i=i: MoveToRightEnv(size=i, sleep=sleep, random_sleep=True) @@ -106,7 +109,12 @@ def test_async_env(size=10000, num=8, sleep=0.1) -> None: assert spent_time < 6.0 * sleep * num / (num + 1) -def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None: +def test_async_check_id( + size: int = 100, + num: int = 4, + sleep: float = 0.2, + timeout: float = 0.7, +) -> None: env_fns = [ lambda: MoveToRightEnv(size=size, sleep=sleep * 2), lambda: MoveToRightEnv(size=size, sleep=sleep * 3), @@ -154,7 +162,7 @@ def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None: assert total_pass >= 2 -def test_vecenv(size=10, num=8, sleep=0.001) -> None: +def test_vecenv(size: int = 10, num: int = 8, sleep: float = 0.001) -> None: env_fns = [ lambda i=i: MoveToRightEnv(size=i, sleep=sleep, recurse_state=True) for i in range(size, size + num) @@ -169,11 +177,10 @@ def test_vecenv(size=10, num=8, sleep=0.001) -> None: for v in venv: v.seed(0) action_list = [1] * 5 + [0] * 10 + [1] * 20 - o = [v.reset()[0] for v in venv] for a in action_list: o = [] for v in venv: - A, B, C, D, E = v.step([a] * num) + A, B, C, D, E = v.step(np.array([a] * num)) if sum(C + D): A, _ = v.reset(np.where(C + D)[0]) o.append([A, B, C, D, E]) @@ -184,19 +191,19 @@ def test_vecenv(size=10, num=8, sleep=0.001) -> None: assert recurse_comp(infos[0], info) if __name__ == "__main__": - t = [0] * len(venv) + t = [0.0] * len(venv) for i, e in enumerate(venv): t[i] = time.time() e.reset() for a in action_list: - done = e.step([a] * num)[2] + done = e.step(np.array([a] * num))[2] if sum(done) > 0: e.reset(np.where(done)[0]) t[i] = time.time() - t[i] for i, v in enumerate(venv): print(f"{type(v)}: {t[i]:.6f}s") - def assert_get(v, expected): + def assert_get(v: BaseVectorEnv, expected: list) -> None: assert v.get_env_attr("size") == expected assert v.get_env_attr("size", id=0) == [expected[0]] assert v.get_env_attr("size", id=[0, 1, 2]) == expected[:3] @@ -223,20 +230,24 @@ def test_attr_unwrapped() -> None: train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")]) train_envs.set_env_attr("test_attribute", 1337) assert train_envs.get_env_attr("test_attribute") == [1337] - assert hasattr(train_envs.workers[0].env, "test_attribute") - assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute") + # mypy doesn't know but BaseVectorEnv takes the reserved keys in gym.Env (one of which is env) + assert hasattr(train_envs.workers[0].env, "test_attribute") # type: ignore + assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute") # type: ignore def test_env_obs_dtype() -> None: + def create_env(i: int, t: str) -> Callable[[], NXEnv]: + return lambda: NXEnv(i, t) + for obs_type in ["array", "object"]: - envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]]) + envs = SubprocVectorEnv([create_env(x, obs_type) for x in [5, 10, 15, 20]]) obs, info = envs.reset() assert obs.dtype == object - obs = envs.step([1, 1, 1, 1])[0] + obs = envs.step(np.array([1, 1, 1, 1]))[0] assert obs.dtype == object -def test_env_reset_optional_kwargs(size=10000, num=8) -> None: +def test_env_reset_optional_kwargs(size: int = 10000, num: int = 8) -> None: env_fns = [lambda i=i: MoveToRightEnv(size=i) for i in range(size, size + num)] test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv] if has_ray(): @@ -262,20 +273,25 @@ def test_venv_wrapper_gym(num_envs: int = 4) -> None: assert obs.shape[0] == len(info) == num_envs -def run_align_norm_obs(raw_env, train_env, test_env, action_list): - def reset_result_to_obs(reset_result): +def run_align_norm_obs( + raw_env: DummyVectorEnv, + train_env: VectorEnvNormObs, + test_env: VectorEnvNormObs, + action_list: list[np.ndarray], +) -> None: + def reset_result_to_obs(reset_result: tuple[np.ndarray, dict | list[dict]]) -> np.ndarray: """Extract observation from reset result (result is possibly a tuple containing info).""" if isinstance(reset_result, tuple) and len(reset_result) == 2: obs, _ = reset_result else: - obs = reset_result + obs = reset_result # type: ignore return obs eps = np.finfo(np.float32).eps.item() raw_reset_result = raw_env.reset() train_reset_result = train_env.reset() - initial_raw_obs = reset_result_to_obs(raw_reset_result) - initial_train_obs = reset_result_to_obs(train_reset_result) + initial_raw_obs = reset_result_to_obs(raw_reset_result) # type: ignore + initial_train_obs = reset_result_to_obs(train_reset_result) # type: ignore raw_obs, train_obs = [initial_raw_obs], [initial_train_obs] for action in action_list: step_result = raw_env.step(action) @@ -283,22 +299,22 @@ def reset_result_to_obs(reset_result): obs, rew, terminated, truncated, info = step_result done = np.logical_or(terminated, truncated) else: - obs, rew, done, info = step_result + obs, rew, done, info = step_result # type: ignore raw_obs.append(obs) if np.any(done): reset_result = raw_env.reset(np.where(done)[0]) - obs = reset_result_to_obs(reset_result) + obs = reset_result_to_obs(reset_result) # type: ignore raw_obs.append(obs) step_result = train_env.step(action) if len(step_result) == 5: obs, rew, terminated, truncated, info = step_result done = np.logical_or(terminated, truncated) else: - obs, rew, done, info = step_result + obs, rew, done, info = step_result # type: ignore train_obs.append(obs) if np.any(done): reset_result = train_env.reset(np.where(done)[0]) - obs = reset_result_to_obs(reset_result) + obs = reset_result_to_obs(reset_result) # type: ignore train_obs.append(obs) ref_rms = RunningMeanStd() for ro, to in zip(raw_obs, train_obs, strict=True): @@ -310,7 +326,7 @@ def reset_result_to_obs(reset_result): assert np.allclose(ref_rms.mean, test_env.get_obs_rms().mean) assert np.allclose(ref_rms.var, test_env.get_obs_rms().var) reset_result = test_env.reset() - obs = reset_result_to_obs(reset_result) + obs = reset_result_to_obs(reset_result) # type: ignore test_obs = [obs] for action in action_list: step_result = test_env.step(action) @@ -318,11 +334,11 @@ def reset_result_to_obs(reset_result): obs, rew, terminated, truncated, info = step_result done = np.logical_or(terminated, truncated) else: - obs, rew, done, info = step_result + obs, rew, done, info = step_result # type: ignore test_obs.append(obs) if np.any(done): reset_result = test_env.reset(np.where(done)[0]) - obs = reset_result_to_obs(reset_result) + obs = reset_result_to_obs(reset_result) # type: ignore test_obs.append(obs) for ro, to in zip(raw_obs, test_obs, strict=True): no = (ro - ref_rms.mean) / np.sqrt(ref_rms.var + eps) @@ -349,16 +365,18 @@ def __init__(self) -> None: self.action_space = gym.spaces.Box(low=-1.0, high=2.0, shape=(4,), dtype=np.float32) self.observation_space = gym.spaces.Discrete(2) - def step(self, act): + def step(self, act: Any) -> tuple[Any, Literal[-1], Literal[False], Literal[True], dict]: return self.observation_space.sample(), -1, False, True, {} bsz = 10 action_per_branch = [4, 6, 10, 7] env = DummyEnv() + assert isinstance(env.action_space, gym.spaces.Box) original_act = env.action_space.high # convert continous to multidiscrete action space # with different action number per dimension env_m = ContinuousToDiscrete(env, action_per_branch) + assert isinstance(env_m.action_space, gym.spaces.MultiDiscrete) # check conversion is working properly for one action np.testing.assert_allclose(env_m.action(env_m.action_space.nvec - 1), original_act) # check conversion is working properly for a batch of actions @@ -369,8 +387,12 @@ def step(self, act): # convert multidiscrete with different action number per # dimension to discrete action space env_d = MultiDiscreteToDiscrete(env_m) + assert isinstance(env_d.action_space, gym.spaces.Discrete) # check conversion is working properly for one action - np.testing.assert_allclose(env_d.action(env_d.action_space.n - 1), env_m.action_space.nvec - 1) + np.testing.assert_allclose( + env_d.action(np.array(env_d.action_space.n - 1)), + env_m.action_space.nvec - 1, + ) # check conversion is working properly for a batch of actions np.testing.assert_allclose( env_d.action(np.array([env_d.action_space.n - 1] * bsz)), @@ -386,6 +408,7 @@ def step(self, act): assert truncated +# TODO: old gym envs are no longer supported! Replace by Ant-v4 and fix assoticiated tests @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_venv_wrapper_envpool() -> None: raw = envpool.make_gymnasium("Ant-v3", num_envs=4) @@ -404,9 +427,16 @@ def test_venv_wrapper_envpool_gym_reset_return_info() -> None: ) obs, info = env.reset() assert obs.shape[0] == num_envs - for _, v in info.items(): - if not isinstance(v, dict): - assert v.shape[0] == num_envs + # This is not actually unreachable b/c envpool does not return info in the right format + if isinstance(info, dict): # type: ignore[unreachable] + for _, v in info.items(): # type: ignore[unreachable] + if not isinstance(v, dict): + assert v.shape[0] == num_envs + else: + for _info in info: + for _, v in _info.items(): + if not isinstance(v, dict): + assert v.shape[0] == num_envs if __name__ == "__main__": diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index 651e77082..657100554 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -2,32 +2,37 @@ import copy from collections import Counter +from collections.abc import Callable, Iterator, Sequence +from typing import Any, cast import gymnasium as gym import numpy as np +import torch from gymnasium.spaces import Box from torch.utils.data import DataLoader, Dataset, DistributedSampler from tianshou.data import Batch, Collector +from tianshou.data.types import BatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.env import BaseVectorEnv, DummyVectorEnv, SubprocVectorEnv +from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type from tianshou.policy import BasePolicy class DummyDataset(Dataset): - def __init__(self, length) -> None: + def __init__(self, length: int) -> None: self.length = length self.episodes = [3 * i % 5 + 1 for i in range(self.length)] - def __getitem__(self, index): + def __getitem__(self, index: int) -> tuple[int, int]: assert 0 <= index < self.length return index, self.episodes[index] - def __len__(self): + def __len__(self) -> int: return self.length class FiniteEnv(gym.Env): - def __init__(self, dataset, num_replicas, rank) -> None: + def __init__(self, dataset: Dataset, num_replicas: int | None, rank: int | None) -> None: self.dataset = dataset self.num_replicas = num_replicas self.rank = rank @@ -36,9 +41,14 @@ def __init__(self, dataset, num_replicas, rank) -> None: sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None, ) - self.iterator = None - - def reset(self): + self.iterator: Iterator | None = None + + def reset( + self, + *, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[Any, dict[str, Any]]: if self.iterator is None: self.iterator = iter(self.loader) try: @@ -49,7 +59,7 @@ def reset(self): self.iterator = None return None, {} - def step(self, action): + def step(self, action: int) -> tuple[int, float, bool, bool, dict[str, Any]]: self.current_step += 1 assert self.current_step <= self.step_count return ( @@ -62,58 +72,64 @@ def step(self, action): class FiniteVectorEnv(BaseVectorEnv): - def __init__(self, env_fns, **kwargs) -> None: + def __init__(self, env_fns: Sequence[Callable[[], ENV_TYPE]], **kwargs: Any) -> None: super().__init__(env_fns, **kwargs) - self._alive_env_ids = set() + self._alive_env_ids: set[int] = set() self._reset_alive_envs() - self._default_obs = self._default_info = None + self._default_obs: np.ndarray | None = None + self._default_info: dict | None = None + self.tracker: MetricTracker - def _reset_alive_envs(self): + def _reset_alive_envs(self) -> None: if not self._alive_env_ids: # starting or running out self._alive_env_ids = set(range(self.env_num)) # to workaround with tianshou's buffer and batch - def _set_default_obs(self, obs): + def _set_default_obs(self, obs: np.ndarray) -> None: if obs is not None and self._default_obs is None: self._default_obs = copy.deepcopy(obs) - def _set_default_info(self, info): + def _set_default_info(self, info: dict) -> None: if info is not None and self._default_info is None: self._default_info = copy.deepcopy(info) - def _get_default_obs(self): + def _get_default_obs(self) -> np.ndarray | None: return copy.deepcopy(self._default_obs) - def _get_default_info(self): + def _get_default_info(self) -> dict | None: return copy.deepcopy(self._default_info) # END - def reset(self, env_id=None): + def reset( + self, + env_id: int | list[int] | np.ndarray | None = None, + **kwargs: Any, + ) -> tuple[np.ndarray, np.ndarray]: env_id = self._wrap_id(env_id) self._reset_alive_envs() # ask super to reset alive envs and remap to current index request_id = list(filter(lambda i: i in self._alive_env_ids, env_id)) - obs = [None] * len(env_id) - infos = [None] * len(env_id) + obs_list: list[np.ndarray | None] = [None] * len(env_id) + infos: list[dict | None] = [None] * len(env_id) id2idx = {i: k for k, i in enumerate(env_id)} if request_id: for k, o, info in zip(request_id, *super().reset(request_id), strict=True): - obs[id2idx[k]] = o + obs_list[id2idx[k]] = o infos[id2idx[k]] = info - for i, o in zip(env_id, obs, strict=True): + for i, o in zip(env_id, obs_list, strict=True): if o is None and i in self._alive_env_ids: self._alive_env_ids.remove(i) # fill empty observation with default(fake) observation - for o in obs: + for o in obs_list: self._set_default_obs(o) - for i in range(len(obs)): - if obs[i] is None: - obs[i] = self._get_default_obs() + for i in range(len(obs_list)): + if obs_list[i] is None: + obs_list[i] = self._get_default_obs() if infos[i] is None: infos[i] = self._get_default_info() @@ -121,26 +137,34 @@ def reset(self, env_id=None): self.reset() raise StopIteration - return np.stack(obs), np.array(infos) + obs_list = cast(list[np.ndarray], obs_list) + infos = cast(list[dict], infos) + + return np.stack(obs_list), np.array(infos) - def step(self, action, id=None): - id = self._wrap_id(id) - id2idx = {i: k for k, i in enumerate(id)} - request_id = list(filter(lambda i: i in self._alive_env_ids, id)) - result = [[None, 0.0, False, False, None] for _ in range(len(id))] + def step( + self, + action: np.ndarray | torch.Tensor | None, + id: int | list[int] | np.ndarray | None = None, + ) -> gym_new_venv_step_type: + ids: list[int] | np.ndarray = self._wrap_id(id) + id2idx = {i: k for k, i in enumerate(ids)} + request_id = list(filter(lambda i: i in self._alive_env_ids, ids)) + result: list[list] = [[None, 0.0, False, False, None] for _ in range(len(ids))] # ask super to step alive envs and remap to current index + assert action is not None if request_id: valid_act = np.stack([action[id2idx[i]] for i in request_id]) - for i, r in zip( + for i, (r_obs, r_reward, r_term, r_trunc, r_info) in zip( request_id, zip(*super().step(valid_act, request_id), strict=True), strict=True, ): - result[id2idx[i]] = r + result[id2idx[i]] = [r_obs, r_reward, r_term, r_trunc, r_info] # logging - for i, r in zip(id, result, strict=True): + for i, r in zip(ids, result, strict=True): if i in self._alive_env_ids: self.tracker.log(*r) @@ -153,7 +177,18 @@ def step(self, action, id=None): if result[i][-1] is None: result[i][-1] = self._get_default_info() - return list(map(np.stack, zip(*result, strict=True))) + obs_list, rew_list, term_list, trunc_list, info_list = zip(*result, strict=True) + try: + obs_stack = np.stack(obs_list) + except ValueError: # different len(obs) + obs_stack = np.array(obs_list, dtype=object) + return ( + obs_stack, + np.stack(rew_list), + np.stack(term_list), + np.stack(trunc_list), + np.stack(info_list), + ) class FiniteDummyVectorEnv(FiniteVectorEnv, DummyVectorEnv): @@ -168,23 +203,28 @@ class AnyPolicy(BasePolicy): def __init__(self) -> None: super().__init__(action_space=Box(-1, 1, (1,))) - def forward(self, batch, state=None): + def forward( + self, + batch: ObsBatchProtocol, + state: dict | BatchProtocol | np.ndarray | None = None, + **kwargs: Any, + ) -> Batch: return Batch(act=np.stack([1] * len(batch))) - def learn(self, batch): + def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> None: pass -def _finite_env_factory(dataset, num_replicas, rank): +def _finite_env_factory(dataset: Dataset, num_replicas: int, rank: int) -> Callable[[], FiniteEnv]: return lambda: FiniteEnv(dataset, num_replicas, rank) class MetricTracker: def __init__(self) -> None: - self.counter = Counter() - self.finished = set() + self.counter: Counter = Counter() + self.finished: set[int] = set() - def log(self, obs, rew, terminated, truncated, info): + def log(self, obs: Any, rew: float, terminated: bool, truncated: bool, info: dict) -> None: assert rew == 1.0 done = terminated or truncated index = info["sample"] @@ -193,7 +233,7 @@ def log(self, obs, rew, terminated, truncated, info): self.finished.add(index) self.counter[index] += 1 - def validate(self): + def validate(self) -> None: assert len(self.finished) == 100 for k, v in self.counter.items(): assert v == k * 3 % 5 + 1 diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 0c51f847c..7c3aacc07 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -4,7 +4,7 @@ import torch from torch.distributions import Categorical, Distribution, Independent, Normal -from tianshou.policy import PPOPolicy +from tianshou.policy import BasePolicy, PPOPolicy from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.continuous import ActorProb, Critic from tianshou.utils.net.discrete import Actor @@ -12,13 +12,15 @@ obs_shape = (5,) -def _to_hashable(x: np.ndarray | int): +def _to_hashable(x: np.ndarray | int) -> int | tuple[list]: return x if isinstance(x, int) else tuple(x.tolist()) @pytest.fixture(params=["continuous", "discrete"]) -def policy(request): +def policy(request: pytest.FixtureRequest) -> PPOPolicy: action_type = request.param + action_space: gym.spaces.Box | gym.spaces.Discrete + actor: Actor | ActorProb if action_type == "continuous": action_space = gym.spaces.Box(low=-1, high=1, shape=(3,)) actor = ActorProb( @@ -36,7 +38,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n), action_shape=action_space.n, ) - dist_fn = lambda logits: Categorical(logits=logits) + dist_fn = Categorical else: raise ValueError(f"Unknown action type: {action_type}") @@ -47,7 +49,8 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: actor_critic = ActorCritic(actor, critic) optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3) - policy: PPOPolicy = PPOPolicy( + policy: BasePolicy + policy = PPOPolicy( actor=actor, critic=critic, dist_fn=dist_fn, @@ -60,7 +63,7 @@ def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: class TestPolicyBasics: - def test_get_action(self, policy) -> None: + def test_get_action(self, policy: PPOPolicy) -> None: sample_obs = torch.randn(obs_shape) policy.deterministic_eval = False actions = [policy.compute_action(sample_obs) for _ in range(10)] diff --git a/test/base/test_returns.py b/test/base/test_returns.py index 2dbf47c29..23f50fb22 100644 --- a/test/base/test_returns.py +++ b/test/base/test_returns.py @@ -4,10 +4,11 @@ import torch from tianshou.data import Batch, ReplayBuffer, to_numpy +from tianshou.data.types import BatchWithReturnsProtocol from tianshou.policy import BasePolicy -def compute_episodic_return_base(batch, gamma): +def compute_episodic_return_base(batch: Batch, gamma: float) -> Batch: returns = np.zeros_like(batch.rew) last = 0 for i in reversed(range(len(batch.rew))): @@ -19,7 +20,7 @@ def compute_episodic_return_base(batch, gamma): return batch -def test_episodic_returns(size=2560) -> None: +def test_episodic_returns(size: int = 2560) -> None: fn = BasePolicy.compute_episodic_return buf = ReplayBuffer(20) batch = Batch( @@ -34,7 +35,7 @@ def test_episodic_returns(size=2560) -> None: }, ), ) - for b in batch: + for b in iter(batch): b.obs = b.act = 1 buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) @@ -46,7 +47,7 @@ def test_episodic_returns(size=2560) -> None: truncated=np.array([0, 0, 0, 0, 0, 0, 0.0]), rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), ) - for b in batch: + for b in iter(batch): b.obs = b.act = 1 buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) @@ -58,7 +59,7 @@ def test_episodic_returns(size=2560) -> None: truncated=np.array([0, 0, 0, 0, 0, 0, 0]), rew=np.array([7, 6, 1, 2, 3, 4, 5.0]), ) - for b in batch: + for b in iter(batch): b.obs = b.act = 1 buf.add(b) returns, _ = fn(batch, buf, buf.sample_indices(0), gamma=0.1, gae_lambda=1) @@ -118,7 +119,7 @@ def test_episodic_returns(size=2560) -> None: }, ), ) - for b in batch: + for b in iter(batch): b.obs = b.act = 1 buf.add(b) v = np.array([2.0, 3.0, 4, -1, 5.0, 6.0, 7, -2, 8.0, 9.0, 10, -3]) @@ -148,15 +149,15 @@ def test_episodic_returns(size=2560) -> None: truncated=np.zeros(size), rew=np.random.random(size), ) - for b in batch: + for b in iter(batch): b.obs = b.act = 1 buf.add(b) indices = buf.sample_indices(0) - def vanilla(): + def vanilla() -> Batch: return compute_episodic_return_base(batch, gamma=0.1) - def optimized(): + def optimized() -> tuple[np.ndarray, np.ndarray]: return fn(batch, buf, indices, gamma=0.1, gae_lambda=1.0) cnt = 3000 @@ -164,17 +165,22 @@ def optimized(): print("GAE optim ", timeit(optimized, setup=optimized, number=cnt)) -def target_q_fn(buffer, indices): +def target_q_fn(buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: # return the next reward indices = buffer.next(indices) return torch.tensor(-buffer.rew[indices], dtype=torch.float32) -def target_q_fn_multidim(buffer, indices): +def target_q_fn_multidim(buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: return target_q_fn(buffer, indices).unsqueeze(1).repeat(1, 51) -def compute_nstep_return_base(nstep, gamma, buffer, indices): +def compute_nstep_return_base( + nstep: int, + gamma: float, + buffer: ReplayBuffer, + indices: np.ndarray, +) -> np.ndarray: returns = np.zeros_like(indices, dtype=float) buf_len = len(buffer) for i in range(len(indices)): @@ -195,7 +201,7 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices): return returns -def test_nstep_returns(size=10000) -> None: +def test_nstep_returns(size: int = 10000) -> None: buf = ReplayBuffer(10) for i in range(12): buf.add( @@ -273,7 +279,7 @@ def test_nstep_returns(size=10000) -> None: assert np.allclose(returns_multidim, returns[:, np.newaxis]) -def test_nstep_returns_with_timelimit(size=10000) -> None: +def test_nstep_returns_with_timelimit(size: int = 10000) -> None: buf = ReplayBuffer(10) for i in range(12): buf.add( @@ -366,10 +372,10 @@ def test_nstep_returns_with_timelimit(size=10000) -> None: ) batch, indices = buf.sample(256) - def vanilla(): + def vanilla() -> np.ndarray: return compute_nstep_return_base(3, 0.1, buf, indices) - def optimized(): + def optimized() -> BatchWithReturnsProtocol: return BasePolicy.compute_nstep_return( batch, buf, diff --git a/test/base/test_stats.py b/test/base/test_stats.py index 537519287..9776374ba 100644 --- a/test/base/test_stats.py +++ b/test/base/test_stats.py @@ -13,7 +13,8 @@ class TestStats: @staticmethod def test_training_stats_wrapper() -> None: train_stats = TrainingStats(train_time=1.0) - train_stats.loss_field = 12 + + setattr(train_stats, "loss_field", 12) # noqa: B010 wrapped_train_stats = DummyTrainingStatsWrapper(train_stats, dummy_field=42) @@ -37,4 +38,12 @@ def test_training_stats_wrapper() -> None: # existing fields, wrapped and not-wrapped, can be mutated wrapped_train_stats.loss_field = 13 wrapped_train_stats.dummy_field = 43 + assert hasattr( + wrapped_train_stats.wrapped_stats, + "loss_field", + ), "Attribute `loss_field` not found in `wrapped_train_stats.wrapped_stats`." + assert hasattr( + wrapped_train_stats, + "loss_field", + ), "Attribute `loss_field` not found in `wrapped_train_stats`." assert wrapped_train_stats.wrapped_stats.loss_field == wrapped_train_stats.loss_field == 13 diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index e2de17e85..a17c3b513 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -73,14 +73,14 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, max_action=args.max_action, device=args.device).to( args.device, ) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 38ddbe8f0..5a522dedb 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -84,10 +84,10 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb(net, args.action_shape, unbounded=True, device=args.device).to(args.device) critic = Critic( - Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), + Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), device=args.device, ).to(args.device) actor_critic = ActorCritic(actor, critic) diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index a180e44bb..697b59e98 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -58,7 +57,7 @@ def get_args() -> argparse.Namespace: def test_redq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Box, env.action_space) + assert isinstance(env.action_space, gym.spaces.Box) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape @@ -80,7 +79,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( net, args.action_shape, @@ -94,8 +93,8 @@ def linear(x: int, y: int) -> nn.Module: return EnsembleLinear(args.ensemble_size, x, y) net_c = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index c347f75b0..fd5b15a9f 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -81,12 +81,12 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed + args.training_num) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb(net, args.action_shape, device=args.device, unbounded=True).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, @@ -94,8 +94,8 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, @@ -110,7 +110,7 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr) args.alpha = (target_entropy, log_alpha, alpha_optim) - policy: BasePolicy = SACPolicy( + policy: SACPolicy = SACPolicy( actor=actor, actor_optim=actor_optim, critic=critic1, @@ -176,7 +176,7 @@ def stop_fn(mean_rewards: float) -> bool: device=args.device, ).to(args.device) optim = torch.optim.Adam(il_actor.parameters(), lr=args.il_lr) - il_policy: BasePolicy = ImitationPolicy( + il_policy: ImitationPolicy = ImitationPolicy( actor=il_actor, optim=optim, action_space=env.action_space, diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index fb1e28a83..ea55da052 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -75,14 +75,14 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, max_action=args.max_action, device=args.device).to( args.device, ) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, @@ -90,8 +90,8 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: critic1 = Critic(net_c1, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 9de81283c..ae788d1cc 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -106,7 +106,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: BasePolicy = TRPOPolicy( + policy: TRPOPolicy = TRPOPolicy( actor=actor, critic=critic, optim=optim, diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index f51a8d75a..f60857ea4 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -84,12 +84,13 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: np.random.seed(args.seed) torch.manual_seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr) dist = torch.distributions.Categorical - policy: A2CPolicy = A2CPolicy( + policy: BasePolicy + policy = A2CPolicy( actor=actor, critic=critic, optim=optim, @@ -153,11 +154,11 @@ def stop_fn(mean_rewards: float) -> bool: # here we define an imitation collector with a trivial policy # if args.task == 'CartPole-v0': # env.spec.reward_threshold = 190 # lower the goal - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - net = Actor(net, args.action_shape, device=args.device).to(args.device) - optim = torch.optim.Adam(net.parameters(), lr=args.il_lr) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + optim = torch.optim.Adam(actor.parameters(), lr=args.il_lr) il_policy: ImitationPolicy = ImitationPolicy( - actor=net, + actor=actor, optim=optim, action_space=env.action_space, ) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 295c6b378..c3f6afe3a 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -52,12 +52,19 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) - args.state_shape = env.observation_space.shape or env.observation_space.n + if isinstance(env.observation_space, gym.spaces.Box): + args.state_shape = env.observation_space.shape + elif isinstance(env.observation_space, gym.spaces.Discrete): + args.state_shape = int(env.observation_space.n) + assert isinstance(env.action_space, gym.spaces.MultiDiscrete) args.num_branches = env.action_space.shape[0] if args.reward_threshold is None: default_reward_threshold = {"Pendulum-v0": -250, "Pendulum-v1": -250} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) + args.reward_threshold = default_reward_threshold.get( + args.task, + env.spec.reward_threshold if env.spec else None, + ) print("Observations shape:", args.state_shape) print("Num branches:", args.num_branches) @@ -96,7 +103,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: model=net, optim=optim, discount_factor=args.gamma, - action_space=env.action_space, + action_space=env.action_space, # type: ignore[arg-type] # TODO: should `BranchingDQNPolicy` support also `MultiDiscrete` action spaces? target_update_freq=args.target_update_freq, ) # collector @@ -145,7 +152,7 @@ def stop_fn(mean_rewards: float) -> bool: test_envs.seed(args.seed) test_collector.reset() collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) - print(collector_stats) + collector_stats.pprint_asdict() if __name__ == "__main__": diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index c9731ed3b..483aca9c6 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -2,7 +2,6 @@ import os import pickle import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -64,7 +63,7 @@ def get_args() -> argparse.Namespace: def test_c51(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape @@ -86,8 +85,8 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax=True, diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index de598e1d2..6c588839f 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -58,7 +57,7 @@ def get_args() -> argparse.Namespace: def test_dqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape @@ -81,14 +80,14 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: # Q_param = V_param = {"hidden_sizes": [128]} # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, # dueling=(Q_param, V_param), ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: BasePolicy = DQNPolicy( + policy: DQNPolicy = DQNPolicy( model=net, optim=optim, discount_factor=args.gamma, diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 03ece9bde..8bca5c131 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -51,7 +50,7 @@ def get_args() -> argparse.Namespace: def test_drqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index fa7a4ca4a..f1af574f7 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -64,7 +63,7 @@ def get_args() -> argparse.Namespace: def test_fqf(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: @@ -101,7 +100,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: optim = torch.optim.Adam(net.parameters(), lr=args.lr) fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim) fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr) - policy: BasePolicy = FQFPolicy( + policy: FQFPolicy = FQFPolicy( model=net, optim=optim, fraction_model=fraction_net, diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 1f75ab516..87e7398b9 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -64,7 +63,7 @@ def get_args() -> argparse.Namespace: def test_iqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) space_info = SpaceInfo.from_env(env) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 795086d1b..95db43c23 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -68,15 +68,15 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax=True, ).to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) dist_fn = torch.distributions.Categorical - policy: BasePolicy = PGPolicy( + policy: PGPolicy = PGPolicy( actor=net, optim=optim, dist_fn=dist_fn, diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 63ef55122..132cbea5a 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -80,7 +80,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor: nn.Module critic: nn.Module if torch.cuda.is_available(): diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index b3e42d5e5..879717a75 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -55,7 +54,7 @@ def get_args() -> argparse.Namespace: def test_qrdqn(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -81,8 +80,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax=False, diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index a38433ba4..c7035345e 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -2,7 +2,6 @@ import os import pickle import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -63,7 +62,7 @@ def get_args() -> argparse.Namespace: def test_rainbow(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -92,8 +91,8 @@ def noisy_linear(x: int, y: int) -> NoisyLinear: return NoisyLinear(x, y, args.noisy_std) net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax=True, diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index 9d5e27be6..b2f466f3d 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -54,7 +53,7 @@ def get_args() -> argparse.Namespace: def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -77,10 +76,10 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: # model obs_dim = space_info.observation_info.obs_dim action_dim = space_info.action_info.action_dim - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, softmax_output=False, device=args.device).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) - net_c1 = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net_c1 = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) critic1 = Critic(net_c1, last_size=action_dim, device=args.device).to(args.device) critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr) net_c2 = Net(obs_dim, hidden_sizes=args.hidden_sizes, device=args.device) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 2aa5b4e9f..9ca9c7055 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -1,7 +1,6 @@ import argparse import os import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -73,7 +72,7 @@ def get_args() -> argparse.Namespace: def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -98,8 +97,8 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: # Q_param = V_param = {"hidden_sizes": [128]} # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, # dueling=(Q_param, V_param), diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index 3095f7cc9..ebf93cd5a 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -99,7 +99,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) actor_critic = ActorCritic(actor, critic) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 2994b11dd..72742b785 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -122,8 +122,9 @@ def stop_fn(mean_rewards: float) -> bool: # Let's watch its performance! policy.eval() test_envs.seed(args.seed) - result = test_collector.collect(n_episode=args.test_num, render=args.render) - print(f"Final reward: {result.rew_mean}, length: {result.len_mean}") + test_collector.reset() + stats = test_collector.collect(n_episode=args.test_num, render=args.render) + stats.pprint_asdict() elif env.spec.reward_threshold: assert result.best_reward >= env.spec.reward_threshold diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 61450a932..e8411b2dd 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -1,7 +1,6 @@ import argparse import os import pickle -from typing import cast import gymnasium as gym import numpy as np @@ -61,7 +60,7 @@ def get_args() -> argparse.Namespace: def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: args = get_args() env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape @@ -85,8 +84,8 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: test_envs.seed(args.seed) # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax=False, diff --git a/test/offline/gather_pendulum_data.py b/test/offline/gather_pendulum_data.py index e7a76221a..bc46ce4da 100644 --- a/test/offline/gather_pendulum_data.py +++ b/test/offline/gather_pendulum_data.py @@ -91,7 +91,7 @@ def gather_data() -> VectorReplayBuffer: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = ActorProb( net, args.action_shape, @@ -100,8 +100,8 @@ def gather_data() -> VectorReplayBuffer: ).to(args.device) actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 425f70b25..1839d863a 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -114,8 +114,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr) net_c = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 5ce5b406d..1e31b1feb 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -3,7 +3,6 @@ import os import pickle import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -81,7 +80,7 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: else: buffer = gather_data() env = gym.make(args.task) - env.action_space = cast(gym.spaces.Box, env.action_space) + assert isinstance(env.action_space, gym.spaces.Box) space_info = SpaceInfo.from_env(env) @@ -110,8 +109,8 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: # model # actor network net_a = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ) @@ -126,8 +125,8 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: # critic network net_c = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 81fc899d4..77790808b 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -2,7 +2,6 @@ import os import pickle import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -57,7 +56,7 @@ def get_args() -> argparse.Namespace: def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: # envs env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape @@ -73,7 +72,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, args.hidden_sizes[0], device=args.device) + net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0], device=args.device) policy_net = Actor( net, args.action_shape, diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 3d8bb4c39..7323eac13 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -2,7 +2,6 @@ import os import pickle import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -54,7 +53,7 @@ def get_args() -> argparse.Namespace: def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: # envs env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape @@ -71,8 +70,8 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax=False, @@ -80,7 +79,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: ) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - policy: BasePolicy = DiscreteCQLPolicy( + policy: DiscreteCQLPolicy = DiscreteCQLPolicy( model=net, optim=optim, action_space=env.action_space, diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index beca5467f..b3cb64616 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -2,7 +2,6 @@ import os import pickle import pprint -from typing import cast import gymnasium as gym import numpy as np @@ -52,7 +51,7 @@ def get_args() -> argparse.Namespace: def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: # envs env = gym.make(args.task) - env.action_space = cast(gym.spaces.Discrete, env.action_space) + assert isinstance(env.action_space, gym.spaces.Discrete) space_info = SpaceInfo.from_env(env) args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape @@ -68,10 +67,10 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: torch.manual_seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, args.hidden_sizes[0], device=args.device) + net = Net(state_shape=args.state_shape, action_shape=args.hidden_sizes[0], device=args.device) actor = Actor( - net, - args.action_shape, + preprocess_net=net, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, softmax_output=False, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 68fab728f..256140c41 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -97,12 +97,17 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: train_envs.seed(args.seed) test_envs.seed(args.seed) # model - net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) - actor = ActorProb(net, args.action_shape, max_action=args.max_action, device=args.device).to( + net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) + actor = ActorProb( + preprocess_net=net, + action_shape=args.action_shape, + max_action=args.max_action, + device=args.device, + ).to( args.device, ) critic = Critic( - Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), + Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device), device=args.device, ).to(args.device) actor_critic = ActorCritic(actor, critic) @@ -115,7 +120,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: # discriminator disc_net = Critic( Net( - args.state_shape, + state_shape=args.state_shape, action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, activation=torch.nn.Tanh, @@ -137,7 +142,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: loc, scale = loc_scale return Independent(Normal(loc, scale), 1) - policy: BasePolicy = GAILPolicy( + policy: GAILPolicy = GAILPolicy( actor=actor, critic=critic, optim=optim, diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 961af2ab3..18778563c 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -114,15 +114,15 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: # critic network net_c1 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, ) net_c2 = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, concat=True, device=args.device, diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 990bf4694..7b3fb4dfc 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -8,7 +8,7 @@ from pettingzoo.butterfly import pistonball_v6 from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, InfoStats, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager @@ -68,7 +68,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def get_env(args: argparse.Namespace = get_args()): +def get_env(args: argparse.Namespace = get_args()) -> PettingZooEnv: return PettingZooEnv(pistonball_v6.env(continuous=False, n_pistons=args.n_pistons)) @@ -76,7 +76,7 @@ def get_agents( args: argparse.Namespace = get_args(), agents: list[BasePolicy] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[BasePolicy, list[torch.optim.Optimizer], list]: +) -> tuple[BasePolicy, list[torch.optim.Optimizer] | None, list]: env = get_env() observation_space = ( env.observation_space["observation"] @@ -91,8 +91,8 @@ def get_agents( for _ in range(args.n_pistons): # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ).to(args.device) @@ -116,7 +116,7 @@ def train_agent( args: argparse.Namespace = get_args(), agents: list[BasePolicy] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[dict, BasePolicy]: +) -> tuple[InfoStats, BasePolicy]: train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed @@ -154,7 +154,7 @@ def train_fn(epoch: int, env_step: int) -> None: def test_fn(epoch: int, env_step: int | None) -> None: [agent.set_eps(args.eps_test) for agent in policy.policies.values()] - def reward_metric(rews): + def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, 0] # trainer @@ -191,5 +191,4 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector(policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}") + result.pprint_asdict() diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 14b5aacca..54d606602 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -12,6 +12,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import BasePolicy, MultiAgentPolicyManager, PPOPolicy @@ -131,7 +132,7 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -def get_env(args: argparse.Namespace = get_args()): +def get_env(args: argparse.Namespace = get_args()) -> PettingZooEnv: return PettingZooEnv(pistonball_v6.env(continuous=True, n_pistons=args.n_pistons)) @@ -139,7 +140,7 @@ def get_agents( args: argparse.Namespace = get_args(), agents: list[BasePolicy] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[BasePolicy, list[torch.optim.Optimizer], list]: +) -> tuple[BasePolicy, list[torch.optim.Optimizer] | None, list]: env = get_env() observation_space = ( env.observation_space["observation"] @@ -186,10 +187,10 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: return Independent(Normal(loc, scale), 1) agent: PPOPolicy = PPOPolicy( - actor, - critic, - optim, - dist, + actor=actor, + critic=critic, + optim=optim, + dist_fn=dist, discount_factor=args.gamma, max_grad_norm=args.max_grad_norm, eps_clip=args.eps_clip, @@ -208,7 +209,12 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: agents.append(agent) optims.append(optim) - policy = MultiAgentPolicyManager(agents, env, action_scaling=True, action_bound_method="clip") + policy = MultiAgentPolicyManager( + policies=agents, + env=env, + action_scaling=True, + action_bound_method="clip", + ) return policy, optims, env.agents @@ -216,7 +222,7 @@ def train_agent( args: argparse.Namespace = get_args(), agents: list[BasePolicy] | None = None, optims: list[torch.optim.Optimizer] | None = None, -) -> tuple[dict, BasePolicy]: +) -> tuple[InfoStats, BasePolicy]: train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed @@ -248,7 +254,7 @@ def save_best_fn(policy: BasePolicy) -> None: def stop_fn(mean_rewards: float) -> bool: return False - def reward_metric(rews): + def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, 0] # trainer @@ -281,5 +287,4 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non policy.eval() collector = Collector(policy, env) collector_result = collector.collect(n_episode=1, render=args.render) - rews, lens = collector_result["rews"], collector_result["lens"] - print(f"Final reward: {rews[:, 0].mean()}, length: {lens.mean()}") + collector_result.pprint_asdict() diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index e1559b113..da580f358 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -10,6 +10,7 @@ from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data.stats import InfoStats from tianshou.env import DummyVectorEnv from tianshou.env.pettingzoo_env import PettingZooEnv from tianshou.policy import BasePolicy, DQNPolicy, MultiAgentPolicyManager, RandomPolicy @@ -18,7 +19,7 @@ from tianshou.utils.net.common import Net -def get_env(render_mode: str | None = None): +def get_env(render_mode: str | None = None) -> PettingZooEnv: return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode)) @@ -95,7 +96,7 @@ def get_agents( agent_learn: BasePolicy | None = None, agent_opponent: BasePolicy | None = None, optim: torch.optim.Optimizer | None = None, -) -> tuple[BasePolicy, torch.optim.Optimizer, list]: +) -> tuple[BasePolicy, torch.optim.Optimizer | None, list]: env = get_env() observation_space = ( env.observation_space["observation"] @@ -107,8 +108,8 @@ def get_agents( if agent_learn is None: # model net = Net( - args.state_shape, - args.action_shape, + state_shape=args.state_shape, + action_shape=args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device, ).to(args.device) @@ -145,7 +146,7 @@ def train_agent( agent_learn: BasePolicy | None = None, agent_opponent: BasePolicy | None = None, optim: torch.optim.Optimizer | None = None, -) -> tuple[dict, BasePolicy]: +) -> tuple[InfoStats, BasePolicy]: train_envs = DummyVectorEnv([get_env for _ in range(args.training_num)]) test_envs = DummyVectorEnv([get_env for _ in range(args.test_num)]) # seed @@ -193,7 +194,7 @@ def train_fn(epoch: int, env_step: int) -> None: def test_fn(epoch: int, env_step: int | None) -> None: policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) - def reward_metric(rews): + def reward_metric(rews: np.ndarray) -> np.ndarray: return rews[:, args.agent_id - 1] # trainer @@ -230,5 +231,4 @@ def watch( policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) collector = Collector(policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) - rews, lens = result["rews"], result["lens"] - print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") + result.pprint_asdict() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index ffa748882..5c7fa036e 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -3,6 +3,7 @@ from collections.abc import Collection, Iterable, Iterator, Sequence from copy import deepcopy from numbers import Number +from types import EllipsisType from typing import ( Any, Protocol, @@ -17,7 +18,8 @@ import numpy as np import torch -IndexType = slice | int | np.ndarray | list[int] +_SingleIndexType = slice | int | EllipsisType +IndexType = np.ndarray | _SingleIndexType | list[_SingleIndexType] | tuple[_SingleIndexType, ...] TBatch = TypeVar("TBatch", bound="BatchProtocol") arr_type = torch.Tensor | np.ndarray diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index 48dde374d..43a4db2e9 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -37,12 +37,17 @@ class SamplingConfig(ToStringMixin): an explanation of epoch semantics. """ - batch_size: int = 64 + batch_size: int | None = 64 """for off-policy algorithms, this is the number of environment steps/transitions to sample from the buffer for a gradient update; for on-policy algorithms, its use is algorithm-specific. On-policy algorithms use the full buffer that was collected in the preceding collection step but they may use this parameter to perform the gradient update using mini-batches of this size (causing the gradient to be less accurate, a form of regularization). + + ``batch_size=None`` means that the full buffer is used for the gradient update. This doesn't + make much sense for off-policy algorithms and is not recommended then. For on-policy or offline algorithms, + this means that the full buffer is used for the gradient update (no mini-batching), and + may make sense in some cases. """ num_train_envs: int = -1 diff --git a/tianshou/highlevel/module/actor.py b/tianshou/highlevel/module/actor.py index faaaa68d0..867ece17a 100644 --- a/tianshou/highlevel/module/actor.py +++ b/tianshou/highlevel/module/actor.py @@ -147,14 +147,14 @@ def __init__(self, hidden_sizes: Sequence[int], activation: ModuleType = nn.ReLU def create_module(self, envs: Environments, device: TDevice) -> BaseActor: net_a = Net( - envs.get_observation_shape(), + state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, device=device, ) return continuous.Actor( - net_a, - envs.get_action_shape(), + preprocess_net=net_a, + action_shape=envs.get_action_shape(), hidden_sizes=(), device=device, ).to(device) @@ -182,14 +182,14 @@ def __init__( def create_module(self, envs: Environments, device: TDevice) -> BaseActor: net_a = Net( - envs.get_observation_shape(), + state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, device=device, ) actor = continuous.ActorProb( - net_a, - envs.get_action_shape(), + preprocess_net=net_a, + action_shape=envs.get_action_shape(), unbounded=self.unbounded, device=device, conditioned_sigma=self.conditioned_sigma, @@ -216,7 +216,7 @@ def __init__( def create_module(self, envs: Environments, device: TDevice) -> BaseActor: net_a = Net( - envs.get_observation_shape(), + state_shape=envs.get_observation_shape(), hidden_sizes=self.hidden_sizes, activation=self.activation, device=device, diff --git a/tianshou/highlevel/module/critic.py b/tianshou/highlevel/module/critic.py index 6d3a7b107..f1984e4d7 100644 --- a/tianshou/highlevel/module/critic.py +++ b/tianshou/highlevel/module/critic.py @@ -120,7 +120,7 @@ def create_module( ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( - envs.get_observation_shape(), + state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, concat=use_action, @@ -146,7 +146,7 @@ def create_module( ) -> nn.Module: action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( - envs.get_observation_shape(), + state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, concat=use_action, @@ -275,7 +275,7 @@ def linear_layer(x: int, y: int) -> EnsembleLinear: action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( - envs.get_observation_shape(), + state_shape=envs.get_observation_shape(), action_shape=action_shape, hidden_sizes=self.hidden_sizes, concat=use_action, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index dabe24e75..eceee100f 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence -from typing import Any, TypeAlias, TypeVar, no_type_check +from typing import Any, Generic, TypeAlias, TypeVar, cast, no_type_check import numpy as np import torch @@ -140,20 +140,23 @@ def forward(self, obs: np.ndarray | torch.Tensor) -> torch.Tensor: return self.model(obs) -class NetBase(nn.Module, ABC): +TRecurrentState = TypeVar("TRecurrentState", bound=Any) + + +class NetBase(nn.Module, Generic[TRecurrentState], ABC): """Interface for NNs used in policies.""" @abstractmethod def forward( self, obs: np.ndarray | torch.Tensor, - state: Any = None, - **kwargs: Any, - ) -> tuple[torch.Tensor, Any]: + state: TRecurrentState | None = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, TRecurrentState | None]: pass -class Net(NetBase): +class Net(NetBase[Any]): """Wrapper of MLP to support more specific DRL usage. For advanced usage (how to customize the network), please refer to @@ -259,13 +262,13 @@ def forward( self, obs: np.ndarray | torch.Tensor, state: Any = None, - **kwargs: Any, + info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, Any]: """Mapping: obs -> flatten (inside MLP)-> logits. :param obs: :param state: unused and returned as is - :param kwargs: unused + :param info: unused """ logits = self.model(obs) batch_size = logits.shape[0] @@ -284,7 +287,7 @@ def forward( return logits, state -class Recurrent(NetBase): +class Recurrent(NetBase[RecurrentStateBatch]): """Simple Recurrent network based on LSTM. For advanced usage (how to customize the network), please refer to @@ -313,9 +316,9 @@ def __init__( def forward( self, obs: np.ndarray | torch.Tensor, - state: RecurrentStateBatch | dict[str, torch.Tensor] | None = None, - **kwargs: Any, - ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + state: RecurrentStateBatch | None = None, + info: dict[str, Any] | None = None, + ) -> tuple[torch.Tensor, RecurrentStateBatch]: """Mapping: obs -> flatten -> logits. In the evaluation mode, `obs` should be with shape ``[bsz, dim]``; in the @@ -324,7 +327,7 @@ def forward( :param obs: :param state: either None or a dict with keys 'hidden' and 'cell' - :param kwargs: unused + :param info: unused :return: predicted action, next state as dict with keys 'hidden' and 'cell' """ # Note: the original type of state is Batch but it might also be a dict @@ -357,10 +360,16 @@ def forward( ) obs = self.fc2(obs[:, -1]) # please ensure the first dim is batch size: [bsz, len, ...] - return obs, { - "hidden": hidden.transpose(0, 1).detach(), - "cell": cell.transpose(0, 1).detach(), - } + rnn_state_batch = cast( + RecurrentStateBatch, + Batch( + { + "hidden": hidden.transpose(0, 1).detach(), + "cell": cell.transpose(0, 1).detach(), + }, + ), + ) + return obs, rnn_state_batch class ActorCritic(nn.Module): @@ -439,7 +448,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class BranchingNet(NetBase): +# TODO: fix docstring +class BranchingNet(NetBase[Any]): """Branching dual Q network. Network for the BranchingDQNPolicy, it uses a common network module, a value module @@ -539,7 +549,7 @@ def forward( self, obs: np.ndarray | torch.Tensor, state: Any = None, - **kwargs: Any, + info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, Any]: """Mapping: obs -> model -> logits.""" common_out = self.common(obs)