Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

call evaluation_hooks in Evaluator #122

Merged
76 changes: 54 additions & 22 deletions pfrl/experiments/evaluation_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,39 @@
class EvaluationHook(object, metaclass=ABCMeta):
"""Hook function that will be called after evaluation.

This class is for clarifying the interface required for EvaluationHook functions.
You don't need to inherit this class to define your own hooks. Any callable that
accepts (env, agent, evaluator, step, eval_score) as arguments can be used as an
evaluation hook.

Note that:
- ``step`` is the current training step, not the number of evaluations so far.
- ``train_agent_async`` DOES NOT support EvaluationHook.
Every evaluation hook function must inherit this class.

Attributes:
support_train_agent (bool):
Set to ``True`` if the hook can be used in
pfrl.experiments.train_agent.train_agent_with_evaluation.
support_train_agent_batch (bool):
Set to ``True`` if the hook can be used in
pfrl.experiments.train_agent_batch.train_agent_batch_with_evaluation.
support_train_agent_async (bool):
Set to ``True`` if the hook can be used in
pfrl.experiments.train_agent_async.train_agent_async.
"""

support_train_agent = False
support_train_agent_batch = False
support_train_agent_async = False

@abstractmethod
def __call__(self, env, agent, evaluator, step, eval_score):
def __call__(self, env, agent, evaluator, step, eval_stats, agent_stats, env_stats):
"""Call the hook.

Args:
env: Environment.
agent: Agent.
evaluator: Evaluator.
step: Current timestep.
eval_score: Evaluation score at t=`step`.
step: Current timestep. (Not the number of evaluations so far)
eval_stats (dict): Last evaluation stats from
pfrl.experiments.evaluator.eval_performance().
agent_stats (List of pairs): Last agent stats from
agent.get_statistics().
env_stats: Last environment stats from
env.get_statistics().
"""
raise NotImplementedError

Expand All @@ -42,26 +55,45 @@ class OptunaPrunerHook(EvaluationHook):
Optuna regards trials which raise `optuna.TrialPruned` as unpromissed and
prune them at the early stages of the training.

Note that:
- ``step`` is the current training step, not the number of evaluations so far.
- ``train_agent_async`` DOES NOT support EvaluationHook.
- This hook stops trial by raising an exception, but re-raise error among process
is not straight forward.
Note that this hook does not support
pfrl.experiments.train_agent_async.train_agent_async.
Optuna detects pruning signal by `optuna.TrialPruned` exception, but async training
mode doesn't re-raise subprocess' exceptions. (See: pfrl.utils.async_.py)

Args:
trial (optuna.Trial): Current trial.
Raises:
optuna.TrialPruned: Raise when the trial should be pruned immediately.
Note that you don't need to care about this exception since Optuna will
catch `optuna.TrialPruned` and stop the trial properly.
"""

support_train_agent = True
support_train_agent_batch = True
support_train_agent_async = False # unsupported

def __init__(self, trial):
if not _optuna_available:
raise RuntimeError("OptunaPrunerHook requires optuna installed.")
self.trial = trial

def __call__(self, env, agent, evaluator, step, eval_score):
self.trial.report(eval_score, step)
def __call__(self, env, agent, evaluator, step, eval_stats, agent_stats, env_stats):
"""Call the hook.

Args:
env: Environment.
agent: Agent.
evaluator: Evaluator.
step: Current timestep. (Not the number of evaluations so far)
eval_stats (dict): Last evaluation stats from
pfrl.experiments.evaluator.eval_performance().
agent_stats (List of pairs): Last agent stats from
agent.get_statistics().
env_stats: Last environment stats from
env.get_statistics().

Raises:
optuna.TrialPruned: Raise when the trial should be pruned immediately.
Note that you don't need to care about this exception since Optuna will
catch `optuna.TrialPruned` and stop the trial properly.
"""
score = eval_stats["mean"]
self.trial.report(score, step)
if self.trial.should_prune():
raise optuna.TrialPruned()
32 changes: 32 additions & 0 deletions pfrl/experiments/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ class Evaluator(object):
outdir (str): Path to a directory to save things.
max_episode_len (int): Maximum length of episodes used in evaluations.
step_offset (int): Offset of steps used to schedule evaluations.
evaluation_hooks (Sequence): Sequence of
pfrl.experiments.evaluation_hooks.EvaluationHook objects. They are
called after each evaluation.
save_best_so_far_agent (bool): If set to True, after each evaluation,
if the score (= mean of returns in evaluation episodes) exceeds
the best-so-far score, the current agent is saved.
Expand All @@ -411,6 +414,7 @@ def __init__(
outdir,
max_episode_len=None,
step_offset=0,
evaluation_hooks=(),
save_best_so_far_agent=True,
logger=None,
use_tensorboard=False,
Expand All @@ -432,6 +436,7 @@ def __init__(
self.max_episode_len = max_episode_len
self.step_offset = step_offset
self.prev_eval_t = self.step_offset - self.step_offset % self.eval_interval
self.evaluation_hooks = evaluation_hooks
self.save_best_so_far_agent = save_best_so_far_agent
self.logger = logger or logging.getLogger(__name__)
self.env_get_stats = getattr(self.env, "get_statistics", lambda: [])
Expand Down Expand Up @@ -480,6 +485,17 @@ def evaluate_and_update_max_score(self, t, episodes):
if self.use_tensorboard:
record_tb_stats(self.tb_writer, agent_stats, eval_stats, env_stats, t)

for hook in self.evaluation_hooks:
hook(
env=self.env,
agent=self.agent,
evaluator=self,
step=t,
eval_stats=eval_stats,
agent_stats=agent_stats,
env_stats=env_stats,
)

if mean > self.max_score:
self.logger.info("The best score is updated %s -> %s", self.max_score, mean)
self.max_score = mean
Expand All @@ -505,6 +521,9 @@ class AsyncEvaluator(object):
outdir (str): Path to a directory to save things.
max_episode_len (int): Maximum length of episodes used in evaluations.
step_offset (int): Offset of steps used to schedule evaluations.
evaluation_hooks (Sequence): Sequence of
pfrl.experiments.evaluation_hooks.EvaluationHook objects. They are
called after each evaluation.
save_best_so_far_agent (bool): If set to True, after each evaluation,
if the score (= mean return of evaluation episodes) exceeds
the best-so-far score, the current agent is saved.
Expand All @@ -518,6 +537,7 @@ def __init__(
outdir,
max_episode_len=None,
step_offset=0,
evaluation_hooks=(),
save_best_so_far_agent=True,
logger=None,
):
Expand All @@ -533,6 +553,7 @@ def __init__(
self.outdir = outdir
self.max_episode_len = max_episode_len
self.step_offset = step_offset
self.evaluation_hooks = evaluation_hooks
self.save_best_so_far_agent = save_best_so_far_agent
self.logger = logger or logging.getLogger(__name__)

Expand Down Expand Up @@ -595,6 +616,17 @@ def evaluate_and_update_max_score(self, t, episodes, env, agent):
if self.record_tb_stats_queue is not None:
self.record_tb_stats_queue.put([agent_stats, eval_stats, env_stats, t])

for hook in self.evaluation_hooks:
hook(
env=env,
agent=agent,
evaluator=self,
step=t,
eval_stats=eval_stats,
agent_stats=agent_stats,
env_stats=env_stats,
)

with self._max_score.get_lock():
if mean > self._max_score.value:
self.logger.info(
Expand Down
17 changes: 10 additions & 7 deletions pfrl/experiments/train_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def train_agent(
evaluator=None,
successful_score=None,
step_hooks=(),
evaluation_hooks=(),
logger=None,
):

Expand Down Expand Up @@ -84,8 +83,6 @@ def train_agent(
eval_stats = dict(stats)
eval_stats["eval_score"] = eval_score
eval_stats_history.append(eval_stats)
for hook in evaluation_hooks:
hook(env, agent, evaluator, t, eval_score)
if (
successful_score is not None
and evaluator.max_score >= successful_score
Expand Down Expand Up @@ -153,9 +150,9 @@ def train_agent_with_evaluation(
step_hooks (Sequence): Sequence of callable objects that accepts
(env, agent, step) as arguments. They are called every step.
See pfrl.experiments.hooks.
evaluation_hooks (Sequence): Sequence of callable objects that accepts
(env, agent, evaluator, step, eval_score) as arguments. They are
called every evaluation. See pfrl.experiments.evaluation_hooks.
evaluation_hooks (Sequence): Sequence of
pfrl.experiments.evaluation_hooks.EvaluationHook objects. They are
called after each evaluation.
save_best_so_far_agent (bool): If set to True, after each evaluation
phase, if the score (= mean return of evaluation episodes) exceeds
the best-so-far score, the current agent is saved.
Expand All @@ -168,6 +165,12 @@ def train_agent_with_evaluation(

logger = logger or logging.getLogger(__name__)

for hook in evaluation_hooks:
if not hook.support_train_agent:
raise ValueError(
"{} does not support train_agent_with_evaluation().".format(hook)
)

os.makedirs(outdir, exist_ok=True)

if eval_env is None:
Expand All @@ -185,6 +188,7 @@ def train_agent_with_evaluation(
max_episode_len=eval_max_episode_len,
env=eval_env,
step_offset=step_offset,
evaluation_hooks=evaluation_hooks,
save_best_so_far_agent=save_best_so_far_agent,
use_tensorboard=use_tensorboard,
logger=logger,
Expand All @@ -201,7 +205,6 @@ def train_agent_with_evaluation(
evaluator=evaluator,
successful_score=successful_score,
step_hooks=step_hooks,
evaluation_hooks=evaluation_hooks,
logger=logger,
)

Expand Down
9 changes: 9 additions & 0 deletions pfrl/experiments/train_agent_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def train_agent_async(
agent=None,
make_agent=None,
global_step_hooks=[],
evaluation_hooks=(),
save_best_so_far_agent=True,
use_tensorboard=False,
logger=None,
Expand Down Expand Up @@ -193,6 +194,9 @@ def train_agent_async(
global_step_hooks (list): List of callable objects that accepts
(env, agent, step) as arguments. They are called every global
step. See pfrl.experiments.hooks.
evaluation_hooks (Sequence): Sequence of
pfrl.experiments.evaluation_hooks.EvaluationHook objects. They are
called after each evaluation.
save_best_so_far_agent (bool): If set to True, after each evaluation,
if the score (= mean return of evaluation episodes) exceeds
the best-so-far score, the current agent is saved.
Expand All @@ -213,6 +217,10 @@ def train_agent_async(

logger = logger or logging.getLogger(__name__)

for hook in evaluation_hooks:
if not hook.support_train_agent_async:
raise ValueError("{} does not support train_agent_async().".format(hook))

# Prevent numpy from using multiple threads
os.environ["OMP_NUM_THREADS"] = "1"

Expand Down Expand Up @@ -252,6 +260,7 @@ def train_agent_async(
outdir=outdir,
max_episode_len=max_episode_len,
step_offset=step_offset,
evaluation_hooks=evaluation_hooks,
save_best_so_far_agent=save_best_so_far_agent,
logger=logger,
)
Expand Down
20 changes: 10 additions & 10 deletions pfrl/experiments/train_agent_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def train_agent_batch(
evaluator=None,
successful_score=None,
step_hooks=(),
evaluation_hooks=(),
return_window_size=100,
logger=None,
):
Expand All @@ -41,9 +40,6 @@ def train_agent_batch(
step_hooks (Sequence): Sequence of callable objects that accepts
(env, agent, step) as arguments. They are called every step.
See pfrl.experiments.hooks.
evaluation_hooks (Sequence): Sequence of callable objects that accepts
(env, agent, evaluator, step, eval_score) as arguments. They are
called every evaluation. See pfrl.experiments.evaluation_hooks.
logger (logging.Logger): Logger used in this function.
Returns:
List of evaluation episode stats dict.
Expand Down Expand Up @@ -130,8 +126,6 @@ def train_agent_batch(
eval_stats = dict(agent.get_statistics())
eval_stats["eval_score"] = eval_score
eval_stats_history.append(eval_stats)
for hook in evaluation_hooks:
hook(env, agent, evaluator, t, eval_score)
if (
successful_score is not None
and evaluator.max_score >= successful_score
Expand Down Expand Up @@ -206,9 +200,9 @@ def train_agent_batch_with_evaluation(
step_hooks (Sequence): Sequence of callable objects that accepts
(env, agent, step) as arguments. They are called every step.
See pfrl.experiments.hooks.
evaluation_hooks (Sequence): Sequence of callable objects that accepts
(env, agent, evaluator, step, eval_score) as arguments. They are
called every evaluation. See pfrl.experiments.evaluation_hooks.
evaluation_hooks (Sequence): Sequence of
pfrl.experiments.evaluation_hooks.EvaluationHook objects. They are
called after each evaluation.
save_best_so_far_agent (bool): If set to True, after each evaluation,
if the score (= mean return of evaluation episodes) exceeds
the best-so-far score, the current agent is saved.
Expand All @@ -221,6 +215,12 @@ def train_agent_batch_with_evaluation(

logger = logger or logging.getLogger(__name__)

for hook in evaluation_hooks:
if not hook.support_train_agent_batch:
raise ValueError(
"{} does not support train_agent_batch_with_evaluation().".format(hook)
)

os.makedirs(outdir, exist_ok=True)

if eval_env is None:
Expand All @@ -238,6 +238,7 @@ def train_agent_batch_with_evaluation(
max_episode_len=eval_max_episode_len,
env=eval_env,
step_offset=step_offset,
evaluation_hooks=evaluation_hooks,
save_best_so_far_agent=save_best_so_far_agent,
use_tensorboard=use_tensorboard,
logger=logger,
Expand All @@ -256,7 +257,6 @@ def train_agent_batch_with_evaluation(
return_window_size=return_window_size,
log_interval=log_interval,
step_hooks=step_hooks,
evaluation_hooks=evaluation_hooks,
logger=logger,
)

Expand Down
21 changes: 14 additions & 7 deletions tests/experiments_tests/test_evaluation_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ def test_dont_prune(self):
agent = Mock()
evaluator = Mock()
step = 42
eval_score = 3.14
eval_stats = {"mean": 3.14}
agent_stats = [("dummy", 2.7)]
env_stats = []

optuna_pruner_hook(env, agent, evaluator, step, eval_score)
optuna_pruner_hook(
env, agent, evaluator, step, eval_stats, agent_stats, env_stats
)

trial.report.assert_called_once_with(eval_score, step)
trial.report.assert_called_once_with(eval_stats["mean"], step)

def test_should_prune(self):
trial = Mock()
Expand All @@ -31,10 +35,13 @@ def test_should_prune(self):
agent = Mock()
evaluator = Mock()
step = 42
eval_score = 3.14
eval_stats = {"mean": 3.14}
agent_stats = [("dummy", 2.7)]
env_stats = []

with self.assertRaises(optuna.TrialPruned):
optuna_pruner_hook(env, agent, evaluator, step, eval_score)
optuna_pruner_hook(
env, agent, evaluator, step, eval_stats, agent_stats, env_stats
)

trial.report.assert_called()
trial.report.assert_called_once_with(eval_score, step)
trial.report.assert_called_once_with(eval_stats["mean"], step)
Loading