Skip to content

Commit

Permalink
Add type annotations to funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
dantp-ai committed Jan 14, 2024
1 parent bed23bd commit af63276
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 25 deletions.
14 changes: 7 additions & 7 deletions test/offline/test_discrete_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import DiscreteBCQPolicy
from tianshou.policy import BasePolicy, DiscreteBCQPolicy
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net
Expand All @@ -22,7 +22,7 @@
from test.offline.gather_cartpole_data import expert_file_name, gather_data


def get_args():
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--reward-threshold", type=float, default=None)
Expand Down Expand Up @@ -52,7 +52,7 @@ def get_args():
return parser.parse_known_args()[0]


def test_discrete_bcq(args=get_args()):
def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
# envs
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
Expand Down Expand Up @@ -111,13 +111,13 @@ def test_discrete_bcq(args=get_args()):
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer, save_interval=args.save_interval)

def save_best_fn(policy):
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
def stop_fn(mean_rewards: float) -> bool:
return mean_rewards >= args.reward_threshold

def save_checkpoint_fn(epoch, env_step, gradient_step):
def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
# Example: saving by epoch num
Expand Down Expand Up @@ -170,7 +170,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}")


def test_discrete_bcq_resume(args=get_args()):
def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None:
args.resume = True
test_discrete_bcq(args)

Expand Down
10 changes: 5 additions & 5 deletions test/offline/test_discrete_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import DiscreteCQLPolicy
from tianshou.policy import BasePolicy, DiscreteCQLPolicy
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
Expand All @@ -21,7 +21,7 @@
from test.offline.gather_cartpole_data import expert_file_name, gather_data


def get_args():
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--reward-threshold", type=float, default=None)
Expand Down Expand Up @@ -49,7 +49,7 @@ def get_args():
return parser.parse_known_args()[0]


def test_discrete_cql(args=get_args()):
def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
# envs
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
Expand Down Expand Up @@ -100,10 +100,10 @@ def test_discrete_cql(args=get_args()):
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

def save_best_fn(policy):
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
def stop_fn(mean_rewards: float) -> bool:
return mean_rewards >= args.reward_threshold

result = OfflineTrainer(
Expand Down
10 changes: 5 additions & 5 deletions test/offline/test_discrete_crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import DiscreteCRRPolicy
from tianshou.policy import BasePolicy, DiscreteCRRPolicy
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net
Expand All @@ -22,7 +22,7 @@
from test.offline.gather_cartpole_data import expert_file_name, gather_data


def get_args():
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--reward-threshold", type=float, default=None)
Expand All @@ -47,7 +47,7 @@ def get_args():
return parser.parse_known_args()[0]


def test_discrete_crr(args=get_args()):
def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
# envs
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
Expand Down Expand Up @@ -103,10 +103,10 @@ def test_discrete_crr(args=get_args()):
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)

def save_best_fn(policy):
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
def stop_fn(mean_rewards: float) -> bool:
return mean_rewards >= args.reward_threshold

result = OfflineTrainer(
Expand Down
16 changes: 8 additions & 8 deletions test/offline/test_gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import gymnasium as gym
import numpy as np
import torch
from torch.distributions import Independent, Normal
from torch.distributions import Distribution, Independent, Normal
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import GAILPolicy
from tianshou.policy import BasePolicy, GAILPolicy
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net
Expand All @@ -23,7 +23,7 @@
from test.offline.gather_pendulum_data import expert_file_name, gather_data


def get_args():
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Pendulum-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
Expand Down Expand Up @@ -65,7 +65,7 @@ def get_args():
return parser.parse_known_args()[0]


def test_gail(args=get_args()):
def test_gail(args: argparse.Namespace = get_args()) -> None:
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
Expand Down Expand Up @@ -128,7 +128,7 @@ def test_gail(args=get_args()):

# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits):
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)

policy = GAILPolicy(
Expand Down Expand Up @@ -165,13 +165,13 @@ def dist(*logits):
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer, save_interval=args.save_interval)

def save_best_fn(policy):
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))

def stop_fn(mean_rewards):
def stop_fn(mean_rewards: float) -> bool:
return mean_rewards >= args.reward_threshold

def save_checkpoint_fn(epoch, env_step, gradient_step):
def save_checkpoint_fn(epoch: int, env_step: int, gradient_step: int) -> str:
# see also: https://pytorch.org/tutorials/beginner/saving_loading_models.html
ckpt_path = os.path.join(log_path, "checkpoint.pth")
# Example: saving by epoch num
Expand Down

0 comments on commit af63276

Please sign in to comment.