Skip to content

Commit

Permalink
Merge pull request #39 from BonsaiAI/adgudime/async_sac
Browse files Browse the repository at this point in the history
Adgudime/async sac
  • Loading branch information
AdityaGudimella authored Sep 21, 2020
2 parents 9bc8e18 + 2d260d6 commit ffb5fe0
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 14 deletions.
27 changes: 17 additions & 10 deletions rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ray.rllib.execution.replay_ops import StoreToReplayBuffer, Replay
from ray.rllib.execution.train_ops import UpdateTargetNetwork
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.execution.replay_buffer import ReplayActor
from ray.rllib.execution.replay_buffer import ReplayActor, VanillaReplayActor
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.actors import create_colocated

Expand Down Expand Up @@ -88,15 +88,21 @@ def __call__(self, item: ("ActorHandle", SampleBatchType)):
def apex_execution_plan(workers: WorkerSet, config: dict):
# Create a number of replay buffer actors.
num_replay_buffer_shards = config["optimizer"]["num_replay_buffer_shards"]
replay_actors = create_colocated(ReplayActor, [
replay_actor_cls = ReplayActor if config[
"prioritized_replay"] else VanillaReplayActor
replay_actors = create_colocated(
replay_actor_cls,
[
num_replay_buffer_shards,
config["learning_starts"],
config["buffer_size"],
config["train_batch_size"],
config["prioritized_replay_alpha"],
config["prioritized_replay_beta"],
config["prioritized_replay_eps"],
],
num_replay_buffer_shards,
config["learning_starts"],
config["buffer_size"],
config["train_batch_size"],
config["prioritized_replay_alpha"],
config["prioritized_replay_beta"],
config["prioritized_replay_eps"],
], num_replay_buffer_shards)
)

# Start the learner thread.
learner_thread = LearnerThread(workers.local_worker())
Expand All @@ -105,7 +111,8 @@ def apex_execution_plan(workers: WorkerSet, config: dict):
# Update experience priorities post learning.
def update_prio_and_stats(item: ("ActorHandle", dict, int)):
actor, prio_dict, count = item
actor.update_priorities.remote(prio_dict)
if config["prioritized_replay"]:
actor.update_priorities.remote(prio_dict)
metrics = _get_shared_metrics()
# Manually update the steps trained counter since the learner thread
# is executing outside the pipeline.
Expand Down
5 changes: 4 additions & 1 deletion rllib/agents/sac/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
# __sphinx_doc_end__
# yapf: enable


ApexSACTrainer = SACTrainer.with_updates(
name="APEX_SAC", default_config=APEX_SAC_DEFAULT_CONFIG, execution_plan=apex_execution_plan
name="APEX_SAC",
default_config=APEX_SAC_DEFAULT_CONFIG,
execution_plan=apex_execution_plan,
)
105 changes: 105 additions & 0 deletions rllib/execution/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,4 +416,109 @@ def stats(self, debug=False):
return stat


# Visible for testing.
_local_vanilla_replay_buffer = None


class LocalVanillaReplayBuffer(LocalReplayBuffer):
"""A replay buffer shard.
Ray actors are single-threaded, so for scalability multiple replay actors
may be created to increase parallelism."""

def __init__(
self,
num_shards,
learning_starts,
buffer_size,
replay_batch_size,
prioritized_replay_alpha=0.6,
prioritized_replay_beta=0.4,
prioritized_replay_eps=1e-6,
multiagent_sync_replay=False,
):
self.replay_starts = learning_starts // num_shards
self.buffer_size = buffer_size // num_shards
self.replay_batch_size = replay_batch_size
self.prioritized_replay_beta = prioritized_replay_beta
self.prioritized_replay_eps = prioritized_replay_eps
self.multiagent_sync_replay = multiagent_sync_replay

def gen_replay():
while True:
yield self.replay()

ParallelIteratorWorker.__init__(self, gen_replay, False)

def new_buffer():
return ReplayBuffer(self.buffer_size)

self.replay_buffers = collections.defaultdict(new_buffer)

# Metrics
self.add_batch_timer = TimerStat()
self.replay_timer = TimerStat()
self.update_priorities_timer = TimerStat()
self.num_added = 0

# Make externally accessible for testing.
global _local_vanilla_replay_buffer
_local_vanilla_replay_buffer = self
# If set, return this instead of the usual data for testing.
self._fake_batch = None

@staticmethod
def get_instance_for_testing():
global _local_vanilla_replay_buffer
return _local_vanilla_replay_buffer

def replay(self):
if self._fake_batch:
fake_batch = SampleBatch(self._fake_batch)
return MultiAgentBatch({DEFAULT_POLICY_ID: fake_batch}, fake_batch.count)

if self.num_added < self.replay_starts:
return None

with self.replay_timer:
samples = {}
idxes = None
for policy_id, replay_buffer in self.replay_buffers.items():
if self.multiagent_sync_replay:
if idxes is None:
idxes = replay_buffer.sample_idxes(self.replay_batch_size)
else:
idxes = replay_buffer.sample_idxes(self.replay_batch_size)
(
obses_t,
actions,
rewards,
obses_tp1,
dones,
) = replay_buffer.sample_with_idxes(idxes)
samples[policy_id] = SampleBatch(
{
"obs": obses_t,
"actions": actions,
"rewards": rewards,
"new_obs": obses_tp1,
"dones": dones,
}
)
return MultiAgentBatch(samples, self.replay_batch_size)

def stats(self, debug=False):
stat = {
"add_batch_time_ms": round(1000 * self.add_batch_timer.mean, 3),
"replay_time_ms": round(1000 * self.replay_timer.mean, 3),
}
for policy_id, replay_buffer in self.replay_buffers.items():
stat.update(
{"policy_{}".format(policy_id): replay_buffer.stats(debug=debug)}
)
return stat


VanillaReplayActor = ray.remote(num_cpus=0)(LocalVanillaReplayBuffer)

ReplayActor = ray.remote(num_cpus=0)(LocalReplayBuffer)
15 changes: 13 additions & 2 deletions rllib/tests/agents/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,20 @@ def astuple(self):
config_updates={
"num_workers": 8,
"exploration_config": {"type": "StochasticSampling"},
"no_done_at_end": True,
"prioritized_replay": False,
"no_done_at_end": False,
},
n_iter=200,
threshold=-350.,
),
TestAgentParams.for_pendulum(
algorithm=ContinuousActionSpaceAlgorithm.APEX_SAC,
config_updates={
"num_workers": 8,
"exploration_config": {"type": "StochasticSampling"},
"prioritized_replay": True,
"no_done_at_end": True
},
# TODO: Delete next line before landing PR
n_iter=200,
threshold=-350.,
),
Expand Down
2 changes: 1 addition & 1 deletion rllib/tests/agents/test_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_monotonically_improving_algorithms_can_converge_with_different_framewor
"""
learnt = False
episode_reward_mean = -float("inf")
for i in range(n_iter):
for _ in range(n_iter):
results = trainer.train()
episode_reward_mean = results["episode_reward_mean"]
if episode_reward_mean >= threshold:
Expand Down

0 comments on commit ffb5fe0

Please sign in to comment.