Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RL renaming v2 #476

Merged
merged 6 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/apidoc/maro.rl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ maro.rl.training.trainer
:undoc-members:
:show-inheritance:

maro.rl.training.trainer_manager
maro.rl.training.training_manager
--------------------------------------------------------------------------------

.. automodule:: maro.rl.training.trainer_manager
.. automodule:: maro.rl.training.training_manager
:members:
:undoc-members:
:show-inheritance:
Expand Down
22 changes: 12 additions & 10 deletions docs/source/key_components/rl_toolkit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ The above code snippet creates a ``ValueBasedPolicy`` object. Let's pay attentio
``q_net`` accepts a ``DiscreteQNet`` object, and it serves as the core part of a ``ValueBasedPolicy`` object. In
other words, ``q_net`` defines the model structure of the Q-network in the value-based policy, and further determines
the policy's behavior. ``DiscreteQNet`` is an abstract class, and ``MyQNet`` is a user-defined implementation
of ``DiscreteQNet``. It can be a simple MLP, a multihead transformer, or any other structure that the user wants.
of ``DiscreteQNet``. It can be a simple MLP, a multi-head transformer, or any other structure that the user wants.

MARO provides a set of abstractions of basic & commonly used PyTorch models like ``DiscereteQNet``, which enables
users to implement their own deep learning models in a handy way. They are:
Expand All @@ -131,22 +131,24 @@ The way to use these models is exactly the same as the way to use the policy mod

.. _trainer:

Trainer
Algorithm (Trainer)
-------

When introducing policies, we mentioned that policies cannot train themselves. Instead, they have to be trained
by external trainers. In MARO, a trainer represents an RL algorithm, such as DQN, actor-critic,
and so on. Trainers take interaction experiences and store them in the internal memory, and then use the experiences
by external algorithms, which are also called trainers.
In MARO, a trainer represents an RL algorithm, such as DQN, actor-critic,
and so on. These two concepts are equivalent in the MARO context.
Trainers take interaction experiences and store them in the internal memory, and then use the experiences
in the memory to train the policies. Like ``RLPolicy``, trainers are also concrete classes, which means they could
be used by configuring parameters. Currently, we have 4 trainers (algorithms) in MARO:

- ``DiscreteActorCritic``: Actor-critic algorithm for policies that generate discrete actions.
- ``DDPG``: DDPG algorithm for policies that generate continuous actions.
- ``DQN``: DQN algorithm for policies that generate discrete actions.
- ``DiscreteMADDPG``: MADDPG algorithm for policies that generate discrete actions.
- ``DiscreteActorCriticTrainer``: Actor-critic algorithm for policies that generate discrete actions.
- ``DDPGTrainer``: DDPG algorithm for policies that generate continuous actions.
- ``DQNTrainer``: DQN algorithm for policies that generate discrete actions.
- ``DiscreteMADDPGTrainer``: MADDPG algorithm for policies that generate discrete actions.

Each trainer has a corresponding ``Param`` class to manage all related parameters. For example,
``DiscreteActorCriticParams`` contains all parameters used in ``DiscreteActorCritic``:
``DiscreteActorCriticParams`` contains all parameters used in ``DiscreteActorCriticTrainer``:

.. code-block:: python

Expand All @@ -164,7 +166,7 @@ An example of creating an actor-critic trainer:

.. code-block:: python

DiscreteActorCritic(
DiscreteActorCriticTrainer(
name='ac',
params=DiscreteActorCriticParams(
device="cpu",
Expand Down
6 changes: 3 additions & 3 deletions examples/rl/cim/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from maro.rl.model import DiscretePolicyNet, FullyConnected, VNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import DiscreteActorCritic, DiscreteActorCriticParams
from maro.rl.training.algorithms import DiscreteActorCriticTrainer, DiscreteActorCriticParams


actor_net_conf = {
Expand Down Expand Up @@ -116,8 +116,8 @@ def get_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGrad
return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num))


def get_ac(state_dim: int, name: str) -> DiscreteActorCritic:
return DiscreteActorCritic(
def get_ac(state_dim: int, name: str) -> DiscreteActorCriticTrainer:
return DiscreteActorCriticTrainer(
name=name,
params=DiscreteActorCriticParams(
device="cpu",
Expand Down
6 changes: 3 additions & 3 deletions examples/rl/cim/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQN, DQNParams
from maro.rl.training.algorithms import DQNTrainer, DQNParams

q_net_conf = {
"hidden_dims": [256, 128, 64, 32],
Expand Down Expand Up @@ -79,8 +79,8 @@ def get_policy(state_dim: int, action_num: int, name: str) -> ValueBasedPolicy:
)


def get_dqn(name: str) -> DQN:
return DQN(
def get_dqn(name: str) -> DQNTrainer:
return DQNTrainer(
name=name,
params=DQNParams(
device="cpu",
Expand Down
6 changes: 3 additions & 3 deletions examples/rl/cim/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from maro.rl.model import DiscretePolicyNet, FullyConnected, MultiQNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import DiscreteMADDPG, DiscreteMADDPGParams
from maro.rl.training.algorithms import DiscreteMADDPGTrainer, DiscreteMADDPGParams


actor_net_conf = {
Expand Down Expand Up @@ -122,8 +122,8 @@ def get_policy(state_dim: int, action_num: int, name: str) -> DiscretePolicyGrad
return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num))


def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMADDPG:
return DiscreteMADDPG(
def get_maddpg(state_dim: int, action_dims: List[int], name: str) -> DiscreteMADDPGTrainer:
return DiscreteMADDPGTrainer(
name=name,
params=DiscreteMADDPGParams(
device="cpu",
Expand Down
10 changes: 5 additions & 5 deletions examples/rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import os

from maro.rl.training import TrainerManager
from maro.rl.training import TrainingManager
from maro.rl.workflows.scenario import Scenario
from maro.utils import LoggerV2

Expand Down Expand Up @@ -32,7 +32,7 @@
eval_point_index = 0

env_sampler = scenario.get_env_sampler(policy_creator)
trainer_manager = TrainerManager(policy_creator, trainer_creator, agent2policy, logger=logger)
training_manager = TrainingManager(policy_creator, trainer_creator, agent2policy, logger=logger)

# main loop
for ep in range(1, NUM_EPISODES + 1):
Expand All @@ -48,11 +48,11 @@
scenario.post_collect(result["info"], ep, segment)

logger.info(f"Roll-out completed for episode {ep}. Training started...")
trainer_manager.record_experiences(experiences)
trainer_manager.train()
training_manager.record_experiences(experiences)
training_manager.train_step()
if CHECKPOINT_PATH and ep % CHECKPOINT_INTERVAL == 0:
pth = os.path.join(CHECKPOINT_PATH, str(ep))
trainer_manager.save(pth)
training_manager.save(pth)
logger.info(f"All trainer states saved under {pth}")
segment += 1

Expand Down
6 changes: 3 additions & 3 deletions examples/rl/vm_scheduling/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from maro.rl.model import DiscretePolicyNet, FullyConnected, VNet
from maro.rl.policy import DiscretePolicyGradient
from maro.rl.training.algorithms import DiscreteActorCritic, DiscreteActorCriticParams
from maro.rl.training.algorithms import DiscreteActorCriticTrainer, DiscreteActorCriticParams


actor_net_conf = {
Expand Down Expand Up @@ -123,8 +123,8 @@ def get_policy(state_dim: int, action_num: int, num_features: int, name: str) ->
return DiscretePolicyGradient(name=name, policy_net=MyActorNet(state_dim, action_num, num_features))


def get_ac(state_dim: int, num_features: int, name: str) -> DiscreteActorCritic:
return DiscreteActorCritic(
def get_ac(state_dim: int, num_features: int, name: str) -> DiscreteActorCriticTrainer:
return DiscreteActorCriticTrainer(
name=name,
params=DiscreteActorCriticParams(
device="cpu",
Expand Down
6 changes: 3 additions & 3 deletions examples/rl/vm_scheduling/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from maro.rl.exploration import MultiLinearExplorationScheduler
from maro.rl.model import DiscreteQNet, FullyConnected
from maro.rl.policy import ValueBasedPolicy
from maro.rl.training.algorithms import DQN, DQNParams
from maro.rl.training.algorithms import DQNTrainer, DQNParams


q_net_conf = {
Expand Down Expand Up @@ -100,8 +100,8 @@ def get_policy(state_dim: int, action_num: int, num_features: int, name: str) ->
)


def get_dqn(name: str) -> DQN:
return DQN(
def get_dqn(name: str) -> DQNTrainer:
return DQNTrainer(
name=name,
params=DQNParams(
device="cpu",
Expand Down
4 changes: 2 additions & 2 deletions maro/rl/distributed/abs_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from zmq.eventloop.zmqstream import ZMQStream

from maro.rl.utils.common import get_ip_address_by_hostname, string_to_bytes
from maro.utils import DummyLogger, Logger
from maro.utils import DummyLogger, LoggerV2


class AbsWorker(object):
Expand All @@ -28,7 +28,7 @@ def __init__(
idx: int,
producer_host: str,
producer_port: int,
logger: Logger = None,
logger: LoggerV2 = None,
) -> None:
super(AbsWorker, self).__init__()

Expand Down
2 changes: 1 addition & 1 deletion maro/rl/policy/abs_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def exploit(self) -> None:
self._is_exploring = False

@abstractmethod
def step(self, loss: torch.Tensor) -> None:
def train_step(self, loss: torch.Tensor) -> None:
"""Run a training step to update the policy according to the given loss.

Args:
Expand Down
2 changes: 1 addition & 1 deletion maro/rl/policy/continuous_rl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _post_check(self, states: torch.Tensor, actions: torch.Tensor) -> bool:
def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
return self._policy_net.get_actions(states, exploring)

def step(self, loss: torch.Tensor) -> None:
def train_step(self, loss: torch.Tensor) -> None:
self._policy_net.step(loss)

def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]:
Expand Down
4 changes: 2 additions & 2 deletions maro/rl/policy/discrete_rl_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tens
actions = ndarray_to_tensor(actions, self._device)
return actions.unsqueeze(1) # [B, 1]

def step(self, loss: torch.Tensor) -> None:
def train_step(self, loss: torch.Tensor) -> None:
return self._q_net.step(loss)

def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -226,7 +226,7 @@ def policy_net(self) -> DiscretePolicyNet:
def _get_actions_impl(self, states: torch.Tensor, exploring: bool) -> torch.Tensor:
return self._policy_net.get_actions(states, exploring)

def step(self, loss: torch.Tensor) -> None:
def train_step(self, loss: torch.Tensor) -> None:
self._policy_net.step(loss)

def get_gradients(self, loss: torch.Tensor) -> Dict[str, torch.Tensor]:
Expand Down
10 changes: 5 additions & 5 deletions maro/rl/rollout/batch_env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from zmq import Context, Poller

from maro.rl.utils.common import bytes_to_pyobj, get_own_ip_address, pyobj_to_bytes
from maro.utils import DummyLogger, Logger
from maro.utils import DummyLogger, LoggerV2

from .env_sampler import ExpElement

Expand All @@ -19,10 +19,10 @@ class ParallelTaskController(object):

Args:
port (int, default=20000): Network port the controller uses to talk to the remote workers.
logger (Logger, default=None): Optional logger for logging key events.
logger (LoggerV2, default=None): Optional logger for logging key events.
"""

def __init__(self, port: int = 20000, logger: Logger = None) -> None:
def __init__(self, port: int = 20000, logger: LoggerV2 = None) -> None:
self._ip = get_own_ip_address()
self._context = Context.instance()

Expand Down Expand Up @@ -117,7 +117,7 @@ class BatchEnvSampler:
are received in T seconds, it will allow an additional T * grace_factor seconds to collect the remaining
results.
eval_parallelism (int, default=None): Parallelism for policy evaluation on remote workers.
logger (Logger, default=None): Optional logger for logging key events.
logger (LoggerV2, default=None): Optional logger for logging key events.
"""

def __init__(
Expand All @@ -127,7 +127,7 @@ def __init__(
min_env_samples: int = None,
grace_factor: float = None,
eval_parallelism: int = None,
logger: Logger = None,
logger: LoggerV2 = None,
) -> None:
super(BatchEnvSampler, self).__init__()
self._logger = logger if logger else DummyLogger()
Expand Down
6 changes: 3 additions & 3 deletions maro/rl/rollout/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from maro.rl.distributed import AbsWorker
from maro.rl.utils.common import bytes_to_pyobj, pyobj_to_bytes
from maro.utils import Logger
from maro.utils import LoggerV2

from .env_sampler import AbsEnvSampler

Expand All @@ -20,7 +20,7 @@ class RolloutWorker(AbsWorker):
for roll-out purposes.
producer_host (str): IP address of the parallel task controller host to connect to.
producer_port (int, default=20000): Port of the parallel task controller host to connect to.
logger (Logger, default=None): The logger of the workflow.
logger (LoggerV2, default=None): The logger of the workflow.
"""

def __init__(
Expand All @@ -29,7 +29,7 @@ def __init__(
env_sampler_creator: Callable[[], AbsEnvSampler],
producer_host: str,
producer_port: int = 20000,
logger: Logger = None,
logger: LoggerV2 = None,
) -> None:
super(RolloutWorker, self).__init__(
idx=idx, producer_host=producer_host, producer_port=producer_port, logger=logger,
Expand Down
8 changes: 4 additions & 4 deletions maro/rl/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from .proxy import TrainingProxy
from .replay_memory import FIFOMultiReplayMemory, FIFOReplayMemory, RandomMultiReplayMemory, RandomReplayMemory
from .train_ops import AbsTrainOps, RemoteOps, remote
from .trainer import AbsTrainer, MultiTrainer, SingleTrainer, TrainerParams
from .trainer_manager import TrainerManager
from .trainer import AbsTrainer, MultiAgentTrainer, SingleAgentTrainer, TrainerParams
from .training_manager import TrainingManager
from .worker import TrainOpsWorker

__all__ = [
"TrainingProxy",
"FIFOMultiReplayMemory", "FIFOReplayMemory", "RandomMultiReplayMemory", "RandomReplayMemory",
"AbsTrainOps", "RemoteOps", "remote",
"AbsTrainer", "MultiTrainer", "SingleTrainer", "TrainerParams",
"TrainerManager",
"AbsTrainer", "MultiAgentTrainer", "SingleAgentTrainer", "TrainerParams",
"TrainingManager",
"TrainOpsWorker",
]
16 changes: 8 additions & 8 deletions maro/rl/training/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .ac import DiscreteActorCritic, DiscreteActorCriticParams
from .ddpg import DDPG, DDPGParams
from .dqn import DQN, DQNParams
from .maddpg import DiscreteMADDPG, DiscreteMADDPGParams
from .ac import DiscreteActorCriticParams, DiscreteActorCriticTrainer
from .ddpg import DDPGParams, DDPGTrainer
from .dqn import DQNParams, DQNTrainer
from .maddpg import DiscreteMADDPGParams, DiscreteMADDPGTrainer

__all__ = [
"DiscreteActorCritic", "DiscreteActorCriticParams",
"DDPG", "DDPGParams",
"DQN", "DQNParams",
"DiscreteMADDPG", "DiscreteMADDPGParams",
"DiscreteActorCriticTrainer", "DiscreteActorCriticParams",
"DDPGTrainer", "DDPGParams",
"DQNTrainer", "DQNParams",
"DiscreteMADDPGTrainer", "DiscreteMADDPGParams",
]
Loading