Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High-Level API #970

Merged
merged 100 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 94 commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
42fc181
Add dev dependencies jsonargparse and docstring_parser
opcode81 Sep 20, 2023
a54aade
Addition of dataclasses based config for scripts, major refactoring
Jul 26, 2023
16ed5fd
Initial high-level interfaces, demonstrated in mujoco_ppo_hl
opcode81 Sep 19, 2023
25c6bbd
Ignore D106: Missing docstring in public nested class
opcode81 Sep 20, 2023
2a1cc6b
Enable ruff setting ignore-init-module-imports
opcode81 Sep 20, 2023
316eb3c
Add SAC high-level interface
opcode81 Sep 20, 2023
997b520
Refactoring, dropping package config
opcode81 Sep 20, 2023
adc3240
Remove LoggerConfig
opcode81 Sep 20, 2023
d26b8cb
Use experiment-specific config in mujoco_sac_hl, adding auto-alpha
opcode81 Sep 20, 2023
8ec4200
Move RLSamplingConfig to separate module config, fixing cyclic import
opcode81 Sep 20, 2023
3fd60f9
Unify PPO configuration objects, use experiment-specific configuration
opcode81 Sep 20, 2023
4d53d34
Ignore Ruff rule RET505, because it sacrifices visual discernability
opcode81 Sep 25, 2023
37dc07e
Add high-level experiment builder interface
opcode81 Sep 21, 2023
367778d
Improve high-level policy parametrisation
opcode81 Sep 25, 2023
6a73938
WandbLogger: Use less restrictive type annotation for config
opcode81 Sep 26, 2023
e993425
Add high-level API support for TD3
opcode81 Sep 26, 2023
38cf982
Disable Ruff rule D205 (blank-line-after-summary)
opcode81 Sep 26, 2023
d4e604b
Move parameter transformation directly into parameter objects,
opcode81 Sep 26, 2023
5bcf514
Add alternative functional interface for environment creation
Sep 27, 2023
78b6dd1
Adapt class naming scheme
opcode81 Sep 27, 2023
acd89fa
Remove parameter transformers from config object state,
opcode81 Sep 27, 2023
cd79cf8
Add A2C high-level API
opcode81 Sep 28, 2023
e0e7349
Add base class BaseActor with method get_preprocess_net for high-leve…
opcode81 Sep 28, 2023
6b6d9ea
Add support for discrete PPO
opcode81 Sep 28, 2023
2671580
Add DDPG high-level API and MuJoCo example
opcode81 Oct 3, 2023
de70147
Add string module from sensAI
opcode81 Oct 3, 2023
ce26e25
Handle ruff complaints in string module
opcode81 Oct 3, 2023
58bd20f
Add logging module
opcode81 Oct 3, 2023
9f0a410
Log full experiment configuration, adding string representations to r…
opcode81 Oct 3, 2023
8f67c2e
Disable numba DEBUG logs
opcode81 Oct 3, 2023
358978c
Add ToStringMixin to further high-level parameter classes
opcode81 Oct 5, 2023
1cba589
Add DQN support in high-level API
opcode81 Oct 5, 2023
b54fcd1
Change high-level DQN interface to expect an actor instead of a critic,
opcode81 Oct 5, 2023
50ac385
Add some basic tests for high-level experiment builder API
opcode81 Oct 5, 2023
d269063
Remove 'RL' prefix from class names
opcode81 Oct 6, 2023
837ff13
Reorder ExperimentBuilder args (EnvFactory first)
opcode81 Oct 6, 2023
a8dc75f
ExperimentBuilder: Allow experiment_config and sampling_config to be …
opcode81 Oct 6, 2023
1243894
Add DistributionFunctionFactory subclasses for discrete/continuous de…
opcode81 Oct 6, 2023
7ed6c1d
Remove obsolete module highlevel.utils
opcode81 Oct 9, 2023
e671632
Make mypy ignore copied util modules string & logging
opcode81 Oct 9, 2023
a161a9c
Improve type annotations, fix type issues and add checks
opcode81 Oct 9, 2023
22dfc4e
Fix type annotations of dist_fn
opcode81 Oct 9, 2023
4e93c12
Remove obsolete configuration files
opcode81 Oct 9, 2023
6bb3abb
Support PG/Reinforce in high-level API
opcode81 Oct 10, 2023
1bb52a6
Simplify critic/agent with optimizer generation
opcode81 Oct 10, 2023
a8ea680
Fix ruff type comparison complaint
opcode81 Oct 10, 2023
73a6d15
Log Environments
opcode81 Oct 10, 2023
383a4a6
Support NPG in high-level API and add example mujoco_npg_hl
opcode81 Oct 10, 2023
7af836b
Support TRPO in high-level API and add example mujoco_trpo_hl
opcode81 Oct 10, 2023
17ef4dd
Support REDQ in high-level API
opcode81 Oct 10, 2023
305b30a
Simplify parameter transformers by applying ParamTransformerChangeValue
opcode81 Oct 10, 2023
799beb7
Support discrete SAC in high-level API
opcode81 Oct 10, 2023
c7d0b6b
Simplify agent factories by making better use of base classes
opcode81 Oct 10, 2023
213e08a
Add method get_output_dim to BaseActor
opcode81 Oct 11, 2023
a8a367c
Support IQN in high-level API
opcode81 Oct 11, 2023
686fd55
Extend tests, fixing some default behaviour
opcode81 Oct 11, 2023
ee3813b
Ignore temp scripts and temp folder
opcode81 Oct 12, 2023
f6d4977
Reify policy persistence, introducing Wold representation
opcode81 Oct 11, 2023
3691ed2
Support obs_rms persistence for MuJoCo by adding a general mechanism
opcode81 Oct 12, 2023
ba80329
Add FileLoggerContext
opcode81 Oct 12, 2023
76e8702
Improve persistence handling
opcode81 Oct 12, 2023
023b33c
Make mypy happy
opcode81 Oct 13, 2023
3bba192
Add experiment result
opcode81 Oct 13, 2023
fc695a5
Use logging to report trainer epoch status
opcode81 Oct 13, 2023
90eaacb
PolicyWrapperFactory: Remove unnecessary input type variable
opcode81 Oct 16, 2023
97e21b5
Remove obsolete mixin, improve class names
opcode81 Oct 16, 2023
4b270ea
Add documentation, improve structure of 'module' package
opcode81 Oct 16, 2023
8304878
Add generalised DQN network representation, adding specialised class …
opcode81 Oct 16, 2023
ae48506
DQNExperimentBuilder: Use IntermediateModuleFactory instead of ActorF…
opcode81 Oct 16, 2023
d84e936
Apply centrally defined callbacks
opcode81 Oct 16, 2023
e63d8d4
Use ToStringMixin in dataclasses to detect recurring objects in large…
opcode81 Oct 17, 2023
ff451f8
Add documentation to parameters, improve factorisation
opcode81 Oct 17, 2023
c7d0cbb
Experiment: Fix return type annotation, remove unused type arguments
opcode81 Oct 18, 2023
80b1b1f
World.restore_path: Add value check
opcode81 Oct 18, 2023
ed06ab7
Handle obs_norm setting in MuJoCo envs
opcode81 Oct 18, 2023
41bd463
Allow to configure activation function in default networks
opcode81 Oct 18, 2023
9c5ee55
Merge remote-tracking branch 'origin/master' into feat/high-level-api
opcode81 Oct 18, 2023
cc6f016
miniblock: Fix type annotation of linear_layer
opcode81 Oct 18, 2023
193be9a
Add 'stdout' to spelling dictionary
opcode81 Oct 18, 2023
bbfad01
Improve docstrings
opcode81 Oct 18, 2023
89ce40e
Docs: Add tianshou.highlevel to docs build via auto-generated .rst files
opcode81 Oct 18, 2023
6cbee18
Change interface of EnvFactory to ensure that configuration
opcode81 Oct 18, 2023
7437131
Fix tianshou.highlevel depending on jsonargparse
opcode81 Oct 19, 2023
f7f2064
ExperimentConfig: Improve docstrings, remove obsolete item 'render'
opcode81 Oct 20, 2023
b5a8915
Revert to simplified environment factory, removing unnecessary config…
opcode81 Oct 24, 2023
58466eb
Keep all ExperimentBuilder tests in one place
opcode81 Oct 24, 2023
dd4a0eb
Fix: Add MujocoEnvObsRmsPersistence only if obs_norm is enabled
opcode81 Oct 24, 2023
96298ea
Add convenient construction mechanisms for Environments
opcode81 Oct 24, 2023
da2194e
Force kwargs in PolicyWrapperFactoryIntrinsicCuriosity init
opcode81 Oct 26, 2023
3cd6dcc
BaseTrainer: Remove info on default values from docstrings
opcode81 Oct 26, 2023
d684dae
Change default number of environments (train=#CPUs, test=1)
opcode81 Oct 26, 2023
c613557
Apply datetime_tag() in high-level examples
opcode81 Oct 26, 2023
86cca8f
Add comment explaining use of _logFormat
opcode81 Oct 26, 2023
a3dbe90
Allow to configure the policy persistence mode, adding a new mode
opcode81 Oct 26, 2023
5952993
Add option to disable file logging
opcode81 Oct 27, 2023
fdb0eba
Depend on sensAI instead of copying its utils (logging, string)
opcode81 Oct 27, 2023
5c8d57a
Fix index error in call to _with_critic_factory_default
opcode81 Nov 6, 2023
7e6d3d6
Rename class ActorCriticModuleOpt -> ActorCriticOpt
opcode81 Nov 6, 2023
ac672f6
Add docstring for ActorFactoryTransientStorageDecorator
opcode81 Nov 6, 2023
dae4000
Revert "Depend on sensAI instead of copying its utils (logging, string)"
opcode81 Nov 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,8 @@ wandb/
videos/

# might be needed for IDE plugins that can't read ruff config
.flake8
.flake8

# temporary scripts (for ad-hoc testing), temp folder
/temp
/temp*.py
2 changes: 2 additions & 0 deletions docs/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# auto-generated content
/api/tianshou.highlevel
121 changes: 121 additions & 0 deletions docs/autogen_rst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import logging
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
import os
import shutil
from pathlib import Path

log = logging.getLogger(os.path.basename(__file__))


def module_template(module_qualname: str):
module_name = module_qualname.split(".")[-1]
title = module_name.replace("_", r"\_")
return f"""{title}
{"="*len(title)}

.. automodule:: {module_qualname}
:members:
:undoc-members:
"""


def package_template(package_qualname: str):
package_name = package_qualname.split(".")[-1]
title = package_name.replace("_", r"\_")
return f"""{title}
{"="*len(title)}

.. automodule:: {package_qualname}
:members:
:undoc-members:

.. toctree::
:glob:

{package_name}/*
"""


def indexTemplate(package_name):
title = package_name
return f"""{title}
{"="*len(title)}

.. automodule:: {package_name}
:members:
:undoc-members:

.. toctree::
:glob:

*
"""


def write_to_file(content: str, path: str):
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as f:
f.write(content)
os.chmod(path, 0o666)


def make_rst(src_root, rst_root, clean=False, overwrite=False, package_prefix=""):
"""Creates/updates documentation in form of rst files for modules and packages.
Does not delete any existing rst files. Thus, rst files for packages or modules that have been removed or renamed
should be deleted by hand.

This method should be executed from the project's top-level directory

:param src_root: path to library base directory, typically "src/<library_name>"
:param clean: whether to completely clean the target directory beforehand, removing any existing .rst files
:param overwrite: whether to overwrite existing rst files. This should be used with caution as it will delete
all manual changes to documentation files
:package_prefix: a prefix to prepend to each module (for the case where the src_root is not the base package),
which, if not empty, should end with a "."
:return:
"""
rst_root = os.path.abspath(rst_root)

if clean and os.path.isdir(rst_root):
shutil.rmtree(rst_root)

base_package_name = package_prefix + os.path.basename(src_root)
write_to_file(indexTemplate(base_package_name), os.path.join(rst_root, "index.rst"))

for root, dirnames, filenames in os.walk(src_root):
if os.path.basename(root).startswith("_"):
continue
base_package_relpath = os.path.relpath(root, start=src_root)
base_package_qualname = package_prefix + os.path.relpath(
root,
start=os.path.dirname(src_root),
).replace(os.path.sep, ".")

for dirname in dirnames:
if not dirname.startswith("_"):
package_qualname = f"{base_package_qualname}.{dirname}"
package_rst_path = os.path.join(rst_root, base_package_relpath, f"{dirname}.rst")
log.info(f"Writing package documentation to {package_rst_path}")
write_to_file(package_template(package_qualname), package_rst_path)

for filename in filenames:
base_name, ext = os.path.splitext(filename)
if ext == ".py" and not filename.startswith("_"):
module_qualname = f"{base_package_qualname}.{filename[:-3]}"

module_rst_path = os.path.join(rst_root, base_package_relpath, f"{base_name}.rst")
if os.path.exists(module_rst_path) and not overwrite:
log.debug(f"{module_rst_path} already exists, skipping it")

log.info(f"Writing module documentation to {module_rst_path}")
write_to_file(module_template(module_qualname), module_rst_path)


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
docs_root = Path(__file__).parent
make_rst(
docs_root / ".." / "tianshou" / "highlevel",
docs_root / "api" / "tianshou.highlevel",
clean=True,
package_prefix="tianshou.",
)
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Tianshou is still under development, you can also check out the documents in sta
api/tianshou.trainer
api/tianshou.exploration
api/tianshou.utils
api/tianshou.highlevel/index


.. toctree::
Expand Down
8 changes: 8 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,11 @@ params
inplace
deepcopy
Gaussian
stdout
parallelization
minibatch
minibatches
MLP
backpropagation
dataclass
superset
33 changes: 33 additions & 0 deletions examples/atari/atari_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainingContext,
)
from tianshou.policy import DQNPolicy


class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest):
def __init__(self, eps_test: float):
self.eps_test = eps_test

def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
policy.set_eps(self.eps_test)


class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
def __init__(self, eps_train: float, eps_train_final: float):
self.eps_train = eps_train
self.eps_train_final = eps_train_final

def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
logger = context.logger
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
else:
eps = self.eps_train_final
policy.set_eps(eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})
105 changes: 105 additions & 0 deletions examples/atari/atari_dqn_hl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#!/usr/bin/env python3

import os

from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
)
from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
DQNExperimentBuilder,
ExperimentConfig,
)
from tianshou.highlevel.params.policy_params import DQNParams
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
eps_test: float = 0.005,
eps_train: float = 1.0,
eps_train_final: float = 0.05,
buffer_size: int = 100000,
lr: float = 0.0001,
gamma: float = 0.99,
n_step: int = 3,
target_update_freq: int = 500,
epoch: int = 100,
step_per_epoch: int = 100000,
step_per_collect: int = 10,
update_per_step: float = 0.1,
batch_size: int = 32,
training_num: int = 10,
test_num: int = 10,
frames_stack: int = 4,
save_buffer_name: str | None = None, # TODO support?
icm_lr_scale: float = 0.0,
icm_reward_scale: float = 0.01,
icm_forward_loss_weight: float = 0.2,
):
log_name = os.path.join(task, "dqn", str(experiment_config.seed), datetime_tag())

sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
update_per_step=update_per_step,
repeat_per_collect=None,
replay_buffer_stack_num=frames_stack,
replay_buffer_ignore_obs_next=True,
replay_buffer_save_only_last_obs=True,
)

env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)

builder = (
DQNExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_dqn_params(
DQNParams(
discount_factor=gamma,
estimation_step=n_step,
lr=lr,
target_update_freq=target_update_freq,
),
)
.with_model_factory(IntermediateModuleFactoryAtariDQN())
.with_trainer_epoch_callback_train(
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(
PolicyWrapperFactoryIntrinsicCuriosity(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to use kwargs explicitly in most methods that have more than say 3-4 parameters. Makes it easier to understand the logic and prevents difficult-to-debug mistakes when the order is accidentally switched by the user. You could use (*, <args>) in the constructor of PolicyWrapperFactoryIntrinsicCuriosity and other places where it's appropriate to ensure that this is happening - similar to how I've done it for the policies in the last PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, how about using an alias type IntrinsicCuriousity for better readability?

Copy link
Collaborator Author

@opcode81 opcode81 Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, forced kwargs can improve readability (and potentially avoid errors) here. Added in commit da2194e.

As far as the shorter aliases are concerned, I am still not entirely convinced they are needed. I would never use them personally, because they're just not easily discovered via auto-completion. And you lose the semantics. What do the others think?

feature_net_factory=IntermediateModuleFactoryAtariDQNFeatures(),
hidden_sizes=[512],
lr=lr,
lr_scale=icm_lr_scale,
reward_scale=icm_reward_scale,
forward_loss_weight=icm_forward_loss_weight,
),
)

experiment = builder.build()
experiment.run(log_name)


if __name__ == "__main__":
logging.run_cli(main)
97 changes: 97 additions & 0 deletions examples/atari/atari_iqn_hl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env python3

import os
from collections.abc import Sequence

from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
)
from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
IQNExperimentBuilder,
)
from tianshou.highlevel.params.policy_params import IQNParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag


def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
eps_test: float = 0.005,
eps_train: float = 1.0,
eps_train_final: float = 0.05,
buffer_size: int = 100000,
lr: float = 0.0001,
gamma: float = 0.99,
sample_size: int = 32,
online_sample_size: int = 8,
target_sample_size: int = 8,
num_cosines: int = 64,
hidden_sizes: Sequence[int] = (512,),
n_step: int = 3,
target_update_freq: int = 500,
epoch: int = 100,
step_per_epoch: int = 100000,
step_per_collect: int = 10,
update_per_step: float = 0.1,
batch_size: int = 32,
training_num: int = 10,
test_num: int = 10,
frames_stack: int = 4,
save_buffer_name: str | None = None, # TODO support?
):
log_name = os.path.join(task, "iqn", str(experiment_config.seed), datetime_tag())

sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
update_per_step=update_per_step,
repeat_per_collect=None,
replay_buffer_stack_num=frames_stack,
replay_buffer_ignore_obs_next=True,
replay_buffer_save_only_last_obs=True,
)

env_factory = AtariEnvFactory(task, experiment_config.seed, frames_stack, scale=scale_obs)

experiment = (
IQNExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_iqn_params(
IQNParams(
discount_factor=gamma,
estimation_step=n_step,
lr=lr,
sample_size=sample_size,
online_sample_size=online_sample_size,
target_update_freq=target_update_freq,
target_sample_size=target_sample_size,
hidden_sizes=hidden_sizes,
num_cosines=num_cosines,
),
)
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
.with_trainer_epoch_callback_train(
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task))
.build()
)
experiment.run(log_name)


if __name__ == "__main__":
logging.run_cli(main)
Loading