Skip to content

Commit

Permalink
Merge branch 'thu-ml:master' into refactoring/mypy-issues-test
Browse files Browse the repository at this point in the history
  • Loading branch information
dantp-ai authored Jan 11, 2024
2 parents fce2482 + 94665ac commit 0b018a1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 22 additions & 16 deletions tianshou/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections import defaultdict, deque
from collections.abc import Callable
from dataclasses import asdict
from typing import Any

import numpy as np
import tqdm
Expand Down Expand Up @@ -312,7 +311,14 @@ def __next__(self) -> EpochStats:
while t.n < t.total and not self.stop_fn_flag:
train_stat: CollectStatsBase
if self.train_collector is not None:
pbar_data_dict, train_stat, self.stop_fn_flag = self.train_step()
train_stat, self.stop_fn_flag = self.train_step()
pbar_data_dict = {
"env_step": str(self.env_step),
"rew": f"{self.last_rew:.2f}",
"len": str(int(self.last_len)),
"n/ep": str(train_stat.n_collected_episodes),
"n/st": str(train_stat.n_collected_steps),
}
t.update(train_stat.n_collected_steps)
if self.stop_fn_flag:
t.set_postfix(**pbar_data_dict)
Expand All @@ -322,13 +328,12 @@ def __next__(self) -> EpochStats:
assert self.buffer, "No train_collector or buffer specified"
train_stat = CollectStatsBase(
n_collected_episodes=len(self.buffer),
n_collected_steps=int(self._gradient_step),
)
t.update()

update_stat = self.policy_update_fn(train_stat)
pbar_data_dict = set_numerical_fields_to_precision(pbar_data_dict)
pbar_data_dict["gradient_step"] = self._gradient_step
pbar_data_dict["gradient_step"] = str(self._gradient_step)

t.set_postfix(**pbar_data_dict)

Expand Down Expand Up @@ -413,11 +418,19 @@ def test_step(self) -> tuple[CollectStats, bool]:

return test_stat, stop_fn_flag

def train_step(self) -> tuple[dict[str, Any], CollectStats, bool]:
"""Perform one training step."""
def train_step(self) -> tuple[CollectStats, bool]:
"""Perform one training step.
If test_in_train and stop_fn are set, will compute the stop_fn on the mean return of the training data.
Then, if the stop_fn is True there, will collect test data also compute the stop_fn of the mean return
on it.
Finally, if the latter is also True, will set should_stop_training to True.
:return: A tuple of the training stats and a boolean indicating whether to stop training.
"""
assert self.episode_per_test is not None
assert self.train_collector is not None
stop_fn_flag = False
should_stop_training = False
if self.train_fn:
self.train_fn(self.epoch, self.env_step)
result = self.train_collector.collect(
Expand All @@ -439,13 +452,6 @@ def train_step(self) -> tuple[dict[str, Any], CollectStats, bool]:

self.logger.log_train_data(asdict(result), self.env_step)

data = {
"env_step": str(self.env_step),
"rew": f"{self.last_rew:.2f}",
"len": str(int(self.last_len)),
"n/ep": str(result.n_collected_episodes),
"n/st": str(result.n_collected_steps),
}
if (
result.n_collected_episodes > 0
and self.test_in_train
Expand All @@ -464,12 +470,12 @@ def train_step(self) -> tuple[dict[str, Any], CollectStats, bool]:
)
assert test_result.returns_stat is not None # for mypy
if self.stop_fn(test_result.returns_stat.mean):
stop_fn_flag = True
should_stop_training = True
self.best_reward = test_result.returns_stat.mean
self.best_reward_std = test_result.returns_stat.std
else:
self.policy.train()
return data, result, stop_fn_flag
return result, should_stop_training

# TODO: move moving average computation and logging into its own logger
# TODO: maybe think about a command line logger instead of always printing data dict
Expand Down

0 comments on commit 0b018a1

Please sign in to comment.