Skip to content

Commit

Permalink
Redesign experience recording logic (#453)
Browse files Browse the repository at this point in the history
* Two not important fix

* Temp draft. Prepare to WFH

* Done

* Lint

* Lint
  • Loading branch information
lihuoran authored Jan 10, 2022
1 parent dc2b9ab commit 0bdc230
Show file tree
Hide file tree
Showing 16 changed files with 186 additions and 153 deletions.
47 changes: 40 additions & 7 deletions maro/rl_v3/learning/env_sampler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations

import collections
from abc import ABCMeta, abstractmethod
from collections import defaultdict, deque
from dataclasses import dataclass
Expand Down Expand Up @@ -122,8 +125,7 @@ def switch_to_eval_mode(self) -> None:

@dataclass
class CacheElement:
"""
The data structure used to store a cached value during experience collection.
"""The data structure used to store a cached value during experience collection.
"""
tick: int
state: np.ndarray
Expand All @@ -133,15 +135,47 @@ class CacheElement:


@dataclass
class ExpElement(CacheElement):
"""
Stores the complete information for a tick. ExpElement is an extension of CacheElement.
class ExpElement:
"""Stores the complete information for a tick. ExpElement is an extension of CacheElement.
"""
tick: int
state: np.ndarray
agent_state_dict: Dict[Any, np.ndarray]
action_dict: Dict[Any, np.ndarray]
reward_dict: Dict[Any, float]
terminal_dict: Dict[Any, bool]
next_state: Optional[np.ndarray]
next_agent_state_dict: Optional[Dict[Any, np.ndarray]]

@property
def agent_names(self) -> list:
return sorted(self.agent_state_dict.keys())

@property
def num_agents(self) -> int:
return len(self.agent_state_dict)

def split_contents(self, agent2trainer: Dict[Any, str]) -> Dict[str, ExpElement]:
ret = collections.defaultdict(lambda: ExpElement(
tick=self.tick,
state=self.state,
agent_state_dict={},
action_dict={},
reward_dict={},
terminal_dict={},
next_state=self.next_state,
next_agent_state_dict=None if self.next_agent_state_dict is None else {}
))
for agent_name in self.agent_names:
trainer_name = agent2trainer[agent_name]
ret[trainer_name].agent_state_dict[agent_name] = self.agent_state_dict[agent_name]
ret[trainer_name].action_dict[agent_name] = self.action_dict[agent_name]
ret[trainer_name].reward_dict[agent_name] = self.reward_dict[agent_name]
ret[trainer_name].terminal_dict[agent_name] = self.terminal_dict[agent_name]
if self.next_agent_state_dict is not None and agent_name in self.next_agent_state_dict:
ret[trainer_name].next_agent_state_dict[agent_name] = self.next_agent_state_dict[agent_name]
return ret


class AbsEnvSampler(object, metaclass=ABCMeta):
"""
Expand Down Expand Up @@ -314,7 +348,6 @@ def sample( # TODO: check logic with qiuyang
state=cache_element.state,
agent_state_dict=cache_element.agent_state_dict,
action_dict=cache_element.action_dict,
env_action_dict=cache_element.env_action_dict,
reward_dict=reward_dict,
terminal_dict={}, # Will be processed later in `_post_polish_experiences()`
next_state=next_state,
Expand Down Expand Up @@ -343,7 +376,7 @@ def _post_polish_experiences(self, experiences: List[ExpElement]) -> List[ExpEle
for key, value in latest_agent_state_dict.items():
if key not in experiences[i].next_agent_state_dict:
experiences[i].next_agent_state_dict[key] = value
latest_agent_state_dict.update(experiences[i].next_agent_state_dict)
latest_agent_state_dict.update(experiences[i].agent_state_dict)
return experiences

def set_policy_states(self, policy_state_dict: Dict[str, object]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion maro/rl_v3/model/multi_q_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from maro.rl_v3.utils import match_shape, SHAPE_CHECK_FLAG
from maro.rl_v3.utils import SHAPE_CHECK_FLAG, match_shape

from .abs_net import AbsNet

Expand Down
1 change: 0 additions & 1 deletion maro/rl_v3/tmp_example_multi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from .env_sampler import CIMEnvSampler
from .policies import policy_creator, trainer_creator


if __name__ == "__main__":
run_workflow_centralized_mode(
get_env_sampler_func=lambda: CIMEnvSampler(
Expand Down
34 changes: 29 additions & 5 deletions maro/rl_v3/training/algorithms/ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
# Licensed under the MIT license.

import asyncio
import collections
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional

import numpy as np
import torch

from maro.rl_v3.learning import ExpElement
from maro.rl_v3.model import VNet
from maro.rl_v3.policy import DiscretePolicyGradient
from maro.rl_v3.training import AbsTrainOps, FIFOReplayMemory, SingleTrainer, TrainerParams
Expand Down Expand Up @@ -196,24 +198,46 @@ class DiscreteActorCritic(SingleTrainer):
def __init__(self, name: str, params: DiscreteActorCriticParams) -> None:
super(DiscreteActorCritic, self).__init__(name, params)
self._params = params
self._ops_params = {}
self._ops_name = f"{self._name}.ops"

self._replay_memory_dict = {}

def build(self) -> None:
self._ops_params = {
"get_policy_func": self._get_policy_func,
**self._params.extract_ops_params(),
}

self._ops = self.get_ops(self._ops_name)
self._replay_memory = FIFOReplayMemory(

self._replay_memory_dict = collections.defaultdict(lambda: FIFOReplayMemory(
capacity=self._params.replay_memory_capacity,
state_dim=self._ops.policy_state_dim,
action_dim=self._ops.policy_action_dim
)
))

def record(self, exp_element: ExpElement) -> None:
for agent_name in exp_element.agent_names:
memory = self._replay_memory_dict[agent_name]
transition_batch = TransitionBatch(
states=np.expand_dims(exp_element.agent_state_dict[agent_name], axis=0),
actions=np.expand_dims(exp_element.action_dict[agent_name], axis=0),
rewards=np.array([exp_element.reward_dict[agent_name]]),
terminals=np.array([exp_element.terminal_dict[agent_name]]),
next_states=np.expand_dims(
exp_element.next_agent_state_dict.get(agent_name, exp_element.agent_state_dict[agent_name]),
axis=0,
),
)
memory.put(transition_batch)

def _get_local_ops_by_name(self, ops_name: str) -> AbsTrainOps:
return DiscreteActorCriticOps(**self._ops_params)

def _get_batch(self, agent_name: str, batch_size: int = None) -> TransitionBatch:
return self._replay_memory_dict[agent_name].sample(batch_size if batch_size is not None else self._batch_size)

async def train_step(self):
await asyncio.gather(self._ops.set_batch(self._get_batch()))
await asyncio.gather(self._ops.update(self._params.grad_iters))
for agent_name in self._replay_memory_dict:
await asyncio.gather(self._ops.set_batch(self._get_batch(agent_name)))
await asyncio.gather(self._ops.update(self._params.grad_iters))
2 changes: 1 addition & 1 deletion maro/rl_v3/training/algorithms/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from maro.rl_v3.model import QNet
from maro.rl_v3.policy import ContinuousRLPolicy
from maro.rl_v3.training import AbsTrainOps, RandomReplayMemory, SingleTrainer, TrainerParams
from maro.rl_v3.utils import CoroutineWrapper, TransitionBatch, ndarray_to_tensor
from maro.rl_v3.utils import TransitionBatch, ndarray_to_tensor
from maro.utils import clone


Expand Down
24 changes: 23 additions & 1 deletion maro/rl_v3/training/algorithms/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# Licensed under the MIT license.

from dataclasses import dataclass
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Optional

import numpy as np
import torch

from maro.rl_v3.learning import ExpElement
from maro.rl_v3.policy import ValueBasedPolicy
from maro.rl_v3.training import AbsTrainOps, RandomReplayMemory, SingleTrainer, TrainerParams
from maro.rl_v3.utils import TransitionBatch, ndarray_to_tensor
Expand Down Expand Up @@ -140,9 +142,12 @@ class DQN(SingleTrainer):
def __init__(self, name: str, params: DQNParams) -> None:
super(DQN, self).__init__(name, params)
self._params = params
self._ops_params = {}
self._q_net_version = self._target_q_net_version = 0
self._ops_name = f"{self._name}.ops"

self._replay_memory: Optional[RandomReplayMemory] = None

def build(self) -> None:
self._ops_params = {
"get_policy_func": self._get_policy_func,
Expand All @@ -157,9 +162,26 @@ def build(self) -> None:
random_overwrite=self._params.random_overwrite
)

def record(self, exp_element: ExpElement) -> None:
for agent_name in exp_element.agent_names:
transition_batch = TransitionBatch(
states=np.expand_dims(exp_element.agent_state_dict[agent_name], axis=0),
actions=np.expand_dims(exp_element.action_dict[agent_name], axis=0),
rewards=np.array([exp_element.reward_dict[agent_name]]),
terminals=np.array([exp_element.terminal_dict[agent_name]]),
next_states=np.expand_dims(
exp_element.next_agent_state_dict.get(agent_name, exp_element.agent_state_dict[agent_name]),
axis=0,
),
)
self._replay_memory.put(transition_batch)

def _get_local_ops_by_name(self, ops_name: str) -> AbsTrainOps:
return DQNOps(**self._ops_params)

def _get_batch(self, batch_size: int = None) -> TransitionBatch:
return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size)

async def train_step(self) -> None:
for _ in range(self._params.num_epochs):
await self._ops.set_batch(self._get_batch())
Expand Down
49 changes: 47 additions & 2 deletions maro/rl_v3/training/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@

import asyncio
from dataclasses import dataclass
from typing import Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch

from maro.rl_v3.learning import ExpElement
from maro.rl_v3.model import MultiQNet
from maro.rl_v3.policy import DiscretePolicyGradient
from maro.rl_v3.training import AbsTrainOps, MultiTrainer, RandomMultiReplayMemory, TrainerParams
from maro.rl_v3.utils import MultiTransitionBatch, ndarray_to_tensor
from maro.rl_v3.utils import CoroutineWrapper, MultiTransitionBatch, RemoteObj, ndarray_to_tensor
from maro.utils import clone


Expand Down Expand Up @@ -266,6 +267,13 @@ def __init__(self, name: str, params: DiscreteMADDPGParams) -> None:
self._policy_version = self._target_policy_version = 0
self._shared_critic_ops_name = f"{self._name}.shared_critic_ops"

self._actor_ops_list = []
self._critic_ops: Union[RemoteObj, CoroutineWrapper, None] = None

self._policy2agent: Dict[str, Any] = {}

self._replay_memory: Optional[RandomMultiReplayMemory] = None

def _improve(self, batch: MultiTransitionBatch) -> None:
for ops in self._actor_ops_list:
ops.set_batch(batch)
Expand Down Expand Up @@ -314,6 +322,43 @@ def build(self) -> None:
agent_states_dims=[ops.policy_state_dim for ops in self._actor_ops_list]
)

assert len(self._agent2policy.keys()) == len(self._agent2policy.values()) # agent <=> policy
self._policy2agent = {policy_name: agent_name for agent_name, policy_name in self._agent2policy.items()}

def record(self, exp_element: ExpElement) -> None:
assert exp_element.num_agents == len(self._agent2policy.keys())

actions = []
rewards = []
agent_states = []
terminals = []
next_agent_states = []
for policy_name in self._policy_names:
agent_name = self._policy2agent[policy_name]
actions.append(np.expand_dims(exp_element.action_dict[agent_name], axis=0))
rewards.append(np.array([exp_element.reward_dict[agent_name]]))
agent_states.append(np.expand_dims(exp_element.agent_state_dict[agent_name], axis=0))
terminals.append(exp_element.terminal_dict[agent_name])
next_agent_states.append(np.expand_dims(
exp_element.next_agent_state_dict.get(agent_name, exp_element.agent_state_dict[agent_name]), axis=0
))

transition_batch = MultiTransitionBatch(
states=np.expand_dims(exp_element.state, axis=0),
actions=actions,
rewards=rewards,
next_states=np.expand_dims(
exp_element.next_state if exp_element.next_state is not None else exp_element.state, axis=0
),
agent_states=agent_states,
next_agent_states=next_agent_states,
terminals=np.array(terminals)
)
self._replay_memory.put(transition_batch)

def _get_batch(self, batch_size: int = None) -> MultiTransitionBatch:
return self._replay_memory.sample(batch_size if batch_size is not None else self._batch_size)

def _get_local_ops_by_name(self, ops_name: str) -> AbsTrainOps:
if ops_name == self._shared_critic_ops_name:
params = {
Expand Down
4 changes: 1 addition & 3 deletions maro/rl_v3/training/replay_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def put(self, transition_batch: TransitionBatch) -> None:

self._put_by_indexes(self._get_put_indexes(batch_size), transition_batch)

def _put_by_indexes(self, indexes: np.ndarray, transition_batch: TransitionBatch):
def _put_by_indexes(self, indexes: np.ndarray, transition_batch: TransitionBatch) -> None:
self._states[indexes] = transition_batch.states
self._actions[indexes] = transition_batch.actions
self._rewards[indexes] = transition_batch.rewards
Expand All @@ -164,7 +164,6 @@ def sample_by_indexes(self, indexes: np.ndarray) -> TransitionBatch:
assert all([0 <= idx < self._capacity for idx in indexes])

return TransitionBatch(
policy_name='',
states=self._states[indexes],
actions=self._actions[indexes],
rewards=self._rewards[indexes],
Expand Down Expand Up @@ -291,7 +290,6 @@ def sample_by_indexes(self, indexes: np.ndarray) -> MultiTransitionBatch:
assert all([0 <= idx < self._capacity for idx in indexes])

return MultiTransitionBatch(
policy_names=[],
states=self._states[indexes],
actions=[action[indexes] for action in self._actions],
rewards=[reward[indexes] for reward in self._rewards],
Expand Down
Loading

0 comments on commit 0bdc230

Please sign in to comment.