Skip to content

Commit

Permalink
Refactoring/mypy issues test (#1017)
Browse files Browse the repository at this point in the history
Improves typing in examples and tests, towards mypy passing there.

Introduces the SpaceInfo utility
  • Loading branch information
dantp-ai authored Feb 6, 2024
1 parent 4756ee8 commit eb0215c
Show file tree
Hide file tree
Showing 108 changed files with 1,775 additions and 1,272 deletions.
44 changes: 21 additions & 23 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
import torch
from atari_network import C51
from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import C51Policy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger, WandbLogger


def get_args():
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=0)
Expand Down Expand Up @@ -66,7 +66,7 @@ def get_args():
return parser.parse_args()


def test_c51(args=get_args()):
def test_c51(args: argparse.Namespace = get_args()) -> None:
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
Expand All @@ -87,7 +87,7 @@ def test_c51(args=get_args()):
net = C51(*args.state_shape, args.action_shape, args.num_atoms, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = C51Policy(
policy: C51Policy = C51Policy(
model=net,
optim=optim,
discount_factor=args.gamma,
Expand Down Expand Up @@ -123,21 +123,19 @@ def test_c51(args=get_args()):

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)

def save_best_fn(policy):
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
else:
logger_factory.logger_type = "tensorboard"

logger = logger_factory.create_logger(
log_dir=log_path,
experiment_name=log_name,
run_id=args.resume_id,
config_dict=vars(args),
)

def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards: float) -> bool:
Expand All @@ -147,7 +145,7 @@ def stop_fn(mean_rewards: float) -> bool:
return mean_rewards >= 20
return False

def train_fn(epoch, env_step):
def train_fn(epoch: int, env_step: int) -> None:
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final)
Expand All @@ -157,11 +155,11 @@ def train_fn(epoch, env_step):
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

def test_fn(epoch, env_step):
def test_fn(epoch: int, env_step: int | None) -> None:
policy.set_eps(args.eps_test)

# watch agent's performance
def watch():
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
Expand Down
44 changes: 21 additions & 23 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
import torch
from atari_network import DQN
from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import DQNPolicy
from tianshou.policy.base import BasePolicy
from tianshou.policy.modelbased.icm import ICMPolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.discrete import IntrinsicCuriosityModule


def get_args():
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=0)
Expand Down Expand Up @@ -83,7 +83,7 @@ def get_args():
return parser.parse_args()


def test_dqn(args=get_args()):
def test_dqn(args: argparse.Namespace = get_args()) -> None:
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
Expand All @@ -104,7 +104,7 @@ def test_dqn(args=get_args()):
net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = DQNPolicy(
policy: DQNPolicy = DQNPolicy(
model=net,
optim=optim,
action_space=env.action_space,
Expand Down Expand Up @@ -158,21 +158,19 @@ def test_dqn(args=get_args()):

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
else:
logger_factory.logger_type = "tensorboard"

logger = logger_factory.create_logger(
log_dir=log_path,
experiment_name=log_name,
run_id=args.resume_id,
config_dict=vars(args),
)

def save_best_fn(policy):
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards: float) -> bool:
Expand All @@ -182,7 +180,7 @@ def stop_fn(mean_rewards: float) -> bool:
return mean_rewards >= 20
return False

def train_fn(epoch, env_step):
def train_fn(epoch: int, env_step: int) -> None:
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final)
Expand All @@ -192,17 +190,17 @@ def train_fn(epoch, env_step):
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

def test_fn(epoch, env_step):
def test_fn(epoch: int, env_step: int | None) -> None:
policy.set_eps(args.eps_test)

def save_checkpoint_fn(epoch, env_step, gradient_step):
def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, f"checkpoint_{epoch}.pth")
torch.save({"model": policy.state_dict()}, ckpt_path)
return ckpt_path

# watch agent's performance
def watch():
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
Expand Down
6 changes: 4 additions & 2 deletions examples/atari/atari_dqn_hl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import functools
import os

from examples.atari.atari_network import (
Expand Down Expand Up @@ -48,7 +49,7 @@ def main(
icm_lr_scale: float = 0.0,
icm_reward_scale: float = 0.01,
icm_forward_loss_weight: float = 0.2,
):
) -> None:
log_name = os.path.join(task, "dqn", str(experiment_config.seed), datetime_tag())

sampling_config = SamplingConfig(
Expand Down Expand Up @@ -102,4 +103,5 @@ def main(


if __name__ == "__main__":
logging.run_cli(main)
run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig())
logging.run_cli(run_with_default_config)
44 changes: 21 additions & 23 deletions examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@
import torch
from atari_network import DQN
from atari_wrapper import make_atari_env
from torch.utils.tensorboard import SummaryWriter

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import FQFPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.discrete import FractionProposalNetwork, FullQuantileFunction


def get_args():
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=3128)
Expand Down Expand Up @@ -69,7 +69,7 @@ def get_args():
return parser.parse_args()


def test_fqf(args=get_args()):
def test_fqf(args: argparse.Namespace = get_args()) -> None:
env, train_envs, test_envs = make_atari_env(
args.task,
args.seed,
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_fqf(args=get_args()):
fraction_net = FractionProposalNetwork(args.num_fractions, net.input_dim)
fraction_optim = torch.optim.RMSprop(fraction_net.parameters(), lr=args.fraction_lr)
# define policy
policy = FQFPolicy(
policy: FQFPolicy = FQFPolicy(
model=net,
optim=optim,
fraction_model=fraction_net,
Expand Down Expand Up @@ -136,21 +136,19 @@ def test_fqf(args=get_args()):

# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)

def save_best_fn(policy):
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
else:
logger_factory.logger_type = "tensorboard"

logger = logger_factory.create_logger(
log_dir=log_path,
experiment_name=log_name,
run_id=args.resume_id,
config_dict=vars(args),
)

def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards: float) -> bool:
Expand All @@ -160,7 +158,7 @@ def stop_fn(mean_rewards: float) -> bool:
return mean_rewards >= 20
return False

def train_fn(epoch, env_step):
def train_fn(epoch: int, env_step: int) -> None:
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * (args.eps_train - args.eps_train_final)
Expand All @@ -170,11 +168,11 @@ def train_fn(epoch, env_step):
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

def test_fn(epoch, env_step):
def test_fn(epoch: int, env_step: int | None) -> None:
policy.set_eps(args.eps_test)

# watch agent's performance
def watch():
def watch() -> None:
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
Expand Down
Loading

0 comments on commit eb0215c

Please sign in to comment.