Skip to content

Commit

Permalink
Fix mypy issues in tests and examples (thu-ml#1077)
Browse files Browse the repository at this point in the history
Closes thu-ml#952 

- `SamplingConfig` supports `batch_size=None`. thu-ml#1077
- tests and examples are covered by `mypy`. thu-ml#1077
- `NetBase` is more used, stricter typing by making it generic. thu-ml#1077
- `utils.net.common.Recurrent` now receives and returns a
`RecurrentStateBatch` instead of a dict. thu-ml#1077

---------

Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
  • Loading branch information
2 people authored and ZhengLi1314 committed Apr 15, 2024
1 parent a70b532 commit 3dc5426
Show file tree
Hide file tree
Showing 106 changed files with 1,265 additions and 903 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ repos:
pass_filenames: false
- id: mypy
name: mypy
entry: poetry run mypy tianshou
entry: poetry run mypy tianshou examples test
# filenames should not be passed as they would collide with the config in pyproject.toml
pass_filenames: false
files: '^tianshou(/[^/]*)*/[^/]*\.py$'
Expand Down
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- `Collector`s can now be closed, and their reset is more granular. #1063
- Trainers can control whether collectors should be reset prior to training. #1063
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
- `SamplingConfig` supports `batch_size=None`. #1077

### Internal Improvements
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
Expand All @@ -20,6 +21,8 @@ instead of just `nn.Module`. #1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032
- Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032
- Exception no longer raised on `len` of empty `Batch`. #1084
- tests and examples are covered by `mypy`. #1077
- `NetBase` is more used, stricter typing by making it generic. #1077

### Breaking Changes

Expand All @@ -30,10 +33,10 @@ expicitly or pass `reset_before_collect=True` . #1063
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
continuous and discrete cases. #1032
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077

### Tests
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081


Started after v1.0.0

15 changes: 7 additions & 8 deletions docs/02_notebooks/L0_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
},
{
"cell_type": "code",
"outputs": [],
"source": [
"# !pip install tianshou gym"
],
"execution_count": null,
"metadata": {
"collapsed": false
},
"execution_count": 0
"outputs": [],
"source": [
"# !pip install tianshou gym"
]
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -71,7 +71,7 @@
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import BasePolicy, PPOPolicy\n",
"from tianshou.policy import PPOPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
"from tianshou.utils.net.common import ActorCritic, Net\n",
"from tianshou.utils.net.discrete import Actor, Critic\n",
Expand Down Expand Up @@ -106,8 +106,7 @@
"\n",
"# PPO policy\n",
"dist = torch.distributions.Categorical\n",
"policy: BasePolicy\n",
"policy = PPOPolicy(\n",
"policy: PPOPolicy = PPOPolicy(\n",
" actor=actor,\n",
" critic=critic,\n",
" optim=optim,\n",
Expand Down
5 changes: 2 additions & 3 deletions docs/02_notebooks/L5_Collector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import BasePolicy, PGPolicy\n",
"from tianshou.policy import PGPolicy\n",
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor"
]
Expand All @@ -87,8 +87,7 @@
"actor = Actor(net, env.action_space.n)\n",
"optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n",
"\n",
"policy: BasePolicy\n",
"policy = PGPolicy(\n",
"policy: PGPolicy = PGPolicy(\n",
" actor=actor,\n",
" optim=optim,\n",
" dist_fn=torch.distributions.Categorical,\n",
Expand Down
5 changes: 2 additions & 3 deletions docs/02_notebooks/L6_Trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import BasePolicy, PGPolicy\n",
"from tianshou.policy import PGPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor"
Expand Down Expand Up @@ -110,9 +110,8 @@
"actor = Actor(net, env.action_space.n)\n",
"optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n",
"\n",
"policy: BasePolicy\n",
"# We choose to use REINFORCE algorithm, also known as Policy Gradient\n",
"policy = PGPolicy(\n",
"policy: PGPolicy = PGPolicy(\n",
" actor=actor,\n",
" optim=optim,\n",
" dist_fn=torch.distributions.Categorical,\n",
Expand Down
5 changes: 2 additions & 3 deletions docs/02_notebooks/L7_Experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import BasePolicy, PPOPolicy\n",
"from tianshou.policy import PPOPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
"from tianshou.utils.net.common import ActorCritic, Net\n",
"from tianshou.utils.net.discrete import Actor, Critic\n",
Expand Down Expand Up @@ -164,8 +164,7 @@
"outputs": [],
"source": [
"dist = torch.distributions.Categorical\n",
"policy: BasePolicy\n",
"policy = PPOPolicy(\n",
"policy: PPOPolicy = PPOPolicy(\n",
" actor=actor,\n",
" critic=critic,\n",
" optim=optim,\n",
Expand Down
6 changes: 3 additions & 3 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from atari_network import C51
from atari_wrapper import make_atari_env

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import C51Policy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
Expand Down Expand Up @@ -122,6 +122,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
log_path = os.path.join(args.logdir, log_name)

# logger
logger_factory = LoggerFactoryDefault()
if args.logger == "wandb":
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
Expand Down Expand Up @@ -182,8 +183,7 @@ def watch() -> None:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
result.pprint_asdict()

if args.watch:
watch()
Expand Down
9 changes: 5 additions & 4 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import DQNPolicy
from tianshou.policy.base import BasePolicy
from tianshou.policy.modelbased.icm import ICMPolicy
Expand Down Expand Up @@ -104,7 +104,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy: DQNPolicy = DQNPolicy(
policy: DQNPolicy | ICMPolicy
policy = DQNPolicy(
model=net,
optim=optim,
action_space=env.action_space,
Expand Down Expand Up @@ -157,6 +158,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
log_path = os.path.join(args.logdir, log_name)

# logger
logger_factory = LoggerFactoryDefault()
if args.logger == "wandb":
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
Expand Down Expand Up @@ -223,8 +225,7 @@ def watch() -> None:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
result.pprint_asdict()

if args.watch:
watch()
Expand Down
6 changes: 3 additions & 3 deletions examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import FQFPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
Expand Down Expand Up @@ -135,6 +135,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
log_path = os.path.join(args.logdir, log_name)

# logger
logger_factory = LoggerFactoryDefault()
if args.logger == "wandb":
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
Expand Down Expand Up @@ -195,8 +196,7 @@ def watch() -> None:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
result.pprint_asdict()

if args.watch:
watch()
Expand Down
6 changes: 3 additions & 3 deletions examples/atari/atari_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import IQNPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
Expand Down Expand Up @@ -132,6 +132,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
log_path = os.path.join(args.logdir, log_name)

# logger
logger_factory = LoggerFactoryDefault()
if args.logger == "wandb":
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
Expand Down Expand Up @@ -192,8 +193,7 @@ def watch() -> None:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
result.pprint_asdict()

if args.watch:
watch()
Expand Down
Loading

0 comments on commit 3dc5426

Please sign in to comment.