Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Preparatory PR for multi-agent multi-GPU learner (alpha-star style) #03 #21652

Merged
2 changes: 1 addition & 1 deletion rllib/agents/a3c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
A3C_CONFIG,
{
"rollout_fragment_length": 20,
"min_iter_time_s": 10,
"min_time_s_per_reporting": 10,
"sample_async": False,

# A2C supports microbatching, in which we accumulate gradients over
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/a3c/a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
"entropy_coeff": 0.01,
# Entropy coefficient schedule
"entropy_coeff_schedule": None,
# Min time per iteration
"min_iter_time_s": 5,
# Min time per reporting
"min_time_s_per_reporting": 5,
# Workers sample async. Note that this increases the effective
# rollout_fragment_length by up to 5x due to async buffering of batches.
"sample_async": True,
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/a3c/tests/test_a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_a2c_compilation(self):
trainer.stop()

def test_a2c_exec_impl(ray_start_regular):
config = {"min_iter_time_s": 0}
config = {"min_time_s_per_reporting": 0}
for _ in framework_iterator(config):
trainer = a3c.A2CTrainer(env="CartPole-v0", config=config)
results = trainer.train()
Expand All @@ -46,7 +46,7 @@ def test_a2c_exec_impl(ray_start_regular):

def test_a2c_exec_impl_microbatch(ray_start_regular):
config = {
"min_iter_time_s": 0,
"min_time_s_per_reporting": 0,
"microbatch_size": 10,
}
for _ in framework_iterator(config):
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/a3c/tests/test_a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_a3c_entropy_coeff_schedule(self):
config["timesteps_per_iteration"] = 20
# 0 metrics reporting delay, this makes sure timestep,
# which entropy coeff depends on, is updated after each worker rollout.
config["min_iter_time_s"] = 0
config["min_time_s_per_reporting"] = 0
# Initial lr, doesn't really matter because of the schedule below.
config["entropy_coeff"] = 0.01
schedule = [
Expand Down
15 changes: 11 additions & 4 deletions rllib/agents/ars/ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,17 @@ def validate_config(self, config: TrainerConfigDict) -> None:
"`NoFilter` for ARS!")

@override(Trainer)
def _init(self, config, env_creator):
def setup(self, config):
# Validate our config dict.
self.validate_config(config)

# Generate `self.env_creator` callable to create an env instance.
self._get_env_creator_from_env_id(self._env_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for whatever reason this function isn't available on master, but I also didn't find its definition in this pr diff, but only when checking out this branch. Strange :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be a mistake by me when splitting my local branch (which contained more changes).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

# Generate the local env.
env_context = EnvContext(config["env_config"] or {}, worker_index=0)
env = env_creator(env_context)
env = self.env_creator(env_context)

self.callbacks = config.get("callbacks")()

self._policy_class = get_policy_class(config)
self.policy = self._policy_class(env.observation_space,
Expand All @@ -250,7 +257,7 @@ def _init(self, config, env_creator):
# Create the actors.
logger.info("Creating actors.")
self.workers = [
Worker.remote(config, env_creator, noise_id, idx + 1)
Worker.remote(config, self.env_creator, noise_id, idx + 1)
for idx in range(config["num_workers"])
]

Expand Down Expand Up @@ -375,7 +382,7 @@ def compute_single_action(self, observation, *args, **kwargs):
return action[0], [], {}
return action[0]

@Deprecated(new="compute_single_action", error=False)
@Deprecated(new="compute_single_action", error=True)
def compute_action(self, observation, *args, **kwargs):
return self.compute_single_action(observation, *args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion rllib/agents/ars/tests/test_ars.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def test_ars_compilation(self):
config["model"]["fcnet_hiddens"] = [10]
config["model"]["fcnet_activation"] = None
config["noise_size"] = 2500000
# Test eval workers ("normal" Trainer eval WorkerSet, unusual for ARS).
# Test eval workers ("normal" WorkerSet, unlike ARS' list of
# RolloutWorkers used for collecting train batches).
config["evaluation_interval"] = 1
config["evaluation_num_workers"] = 1

Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ddpg/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"target_network_update_freq": 500000,
"timesteps_per_iteration": 25000,
"worker_side_prioritization": True,
"min_iter_time_s": 30,
"min_time_s_per_reporting": 30,
},
_allow_unknown_configs=True,
)
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@
"num_workers": 0,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,
})
# __sphinx_doc_end__
# yapf: enable
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ddpg/tests/test_apex_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_apex_ddpg_compilation_and_per_worker_epsilon_values(self):
config["num_workers"] = 2
config["prioritized_replay"] = True
config["timesteps_per_iteration"] = 100
config["min_iter_time_s"] = 1
config["min_time_s_per_reporting"] = 1
config["learning_starts"] = 0
config["optimizer"]["num_replay_buffer_shards"] = 1
num_iterations = 1
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ddpg/tests/test_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_ddpg_loss_function(self):
config["actor_hiddens"] = [10]
config["critic_hiddens"] = [10]
# Make sure, timing differences do not affect trainer.train().
config["min_iter_time_s"] = 0
config["min_time_s_per_reporting"] = 0
config["timesteps_per_iteration"] = 100

map_ = {
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/dqn/apex.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
"timesteps_per_iteration": 25000,
"exploration_config": {"type": "PerWorkerEpsilonGreedy"},
"worker_side_prioritization": True,
"min_iter_time_s": 30,
"min_time_s_per_reporting": 30,
# If set, this will fix the ratio of replayed from a buffer and learned
# on timesteps to sampled from an environment and stored in the replay
# buffer timesteps. Otherwise, replay will proceed as fast as possible.
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/dqn/simple_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@
# to increase if your environment is particularly slow to sample, or if
# you"re using the Async or Ape-X optimizers.
"num_workers": 0,
# Prevent iterations from going lower than this time span.
"min_iter_time_s": 1,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,
})
# __sphinx_doc_end__
# yapf: enable
Expand Down
6 changes: 3 additions & 3 deletions rllib/agents/dqn/tests/test_apex_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_apex_zero_workers(self):
config["learning_starts"] = 1000
config["prioritized_replay"] = True
config["timesteps_per_iteration"] = 100
config["min_iter_time_s"] = 1
config["min_time_s_per_reporting"] = 1
config["optimizer"]["num_replay_buffer_shards"] = 1
for _ in framework_iterator(config):
trainer = apex.ApexTrainer(config=config, env="CartPole-v0")
Expand All @@ -41,7 +41,7 @@ def test_apex_dqn_compilation_and_per_worker_epsilon_values(self):
config["learning_starts"] = 1000
config["prioritized_replay"] = True
config["timesteps_per_iteration"] = 100
config["min_iter_time_s"] = 1
config["min_time_s_per_reporting"] = 1
config["optimizer"]["num_replay_buffer_shards"] = 1

for _ in framework_iterator(config, with_eager_tracing=True):
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_apex_lr_schedule(self):
config["timesteps_per_iteration"] = 10
# 0 metrics reporting delay, this makes sure timestep,
# which lr depends on, is updated after each worker rollout.
config["min_iter_time_s"] = 0
config["min_time_s_per_reporting"] = 0
config["optimizer"]["num_replay_buffer_shards"] = 1
# This makes sure learning schedule is checked every 10 timesteps.
config["optimizer"]["max_weight_sync_delay"] = 10
Expand Down
24 changes: 16 additions & 8 deletions rllib/agents/es/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,18 @@ def validate_config(self, config: TrainerConfigDict) -> None:
"`NoFilter` for ES!")

@override(Trainer)
def _init(self, config, env_creator):
def setup(self, config):
# Call super's validation method.
self.validate_config(config)

# Generate `self.env_creator` callable to create an env instance.
self._get_env_creator_from_env_id(self._env_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we change the name of this function to _set_env_creator_from_env_id. This function sets the self.env_creator variable, but it doesn't return anything. That, or we could return the env creator ourselves and set the attribute ourselves:

self.env_creator = self._get_env_creator_from_env_id(self._env_id)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the catch. I'll check. ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

# Generate the local env.
env_context = EnvContext(config["env_config"] or {}, worker_index=0)
env = env_creator(env_context)
env = self.env_creator(env_context)

self.callbacks = config.get("callbacks")()

self._policy_class = get_policy_class(config)
self.policy = self._policy_class(
obs_space=env.observation_space,
Expand All @@ -247,8 +255,8 @@ def _init(self, config, env_creator):

# Create the actors.
logger.info("Creating actors.")
self._workers = [
Worker.remote(config, {}, env_creator, noise_id, idx + 1)
self.workers = [
Worker.remote(config, {}, self.env_creator, noise_id, idx + 1)
for idx in range(config["num_workers"])
]

Expand Down Expand Up @@ -333,7 +341,7 @@ def step_attempt(self):
# Now sync the filters
FilterManager.synchronize({
DEFAULT_POLICY_ID: self.policy.observation_filter
}, self._workers)
}, self.workers)

info = {
"weights_norm": np.square(theta).sum(),
Expand Down Expand Up @@ -375,7 +383,7 @@ def _sync_weights_to_workers(self, *, worker_set=None, workers=None):
@override(Trainer)
def cleanup(self):
# workaround for https://github.com/ray-project/ray/issues/1516
for w in self._workers:
for w in self.workers:
w.__ray_terminate__.remote()

def _collect_results(self, theta_id, min_episodes, min_timesteps):
Expand All @@ -386,7 +394,7 @@ def _collect_results(self, theta_id, min_episodes, min_timesteps):
"Collected {} episodes {} timesteps so far this iter".format(
num_episodes, num_timesteps))
rollout_ids = [
worker.do_rollouts.remote(theta_id) for worker in self._workers
worker.do_rollouts.remote(theta_id) for worker in self.workers
]
# Get the results of the rollouts.
for result in ray.get(rollout_ids):
Expand All @@ -413,4 +421,4 @@ def __setstate__(self, state):
self.policy.observation_filter = state["filter"]
FilterManager.synchronize({
DEFAULT_POLICY_ID: self.policy.observation_filter
}, self._workers)
}, self.workers)
2 changes: 1 addition & 1 deletion rllib/agents/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
#
"rollout_fragment_length": 50,
"train_batch_size": 500,
"min_iter_time_s": 10,
"min_time_s_per_reporting": 10,
"num_workers": 2,
# Number of GPUs the learner should use.
"num_gpus": 1,
Expand Down
7 changes: 6 additions & 1 deletion rllib/agents/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ def get_default_config(cls) -> TrainerConfigDict:
def default_resource_request(cls, config):
return None

def _init(self, config, env_creator):
@override(Trainer)
def setup(self, config):
# Call super's setup().
super().setup(config)

# Add needed properties.
self.info = None
self.restored = False

Expand Down
6 changes: 4 additions & 2 deletions rllib/agents/ppo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
# == IMPALA optimizer params (see documentation in impala.py) ==
"rollout_fragment_length": 50,
"train_batch_size": 500,
"min_iter_time_s": 10,
"min_time_s_per_reporting": 10,
"num_workers": 2,
"num_gpus": 0,
"num_multi_gpu_tower_stacks": 1,
Expand Down Expand Up @@ -132,5 +132,7 @@ def get_default_policy_class(self, config: PartialTrainerConfigDict) -> \
from ray.rllib.agents.ppo.appo_torch_policy import \
AsyncPPOTorchPolicy
return AsyncPPOTorchPolicy
else:
elif config["framework"] == "tf":
return AsyncPPOTFPolicy
elif config["framework"] in ["tf2", "tfe"]:
return AsyncPPOTFPolicy.as_eager()
2 changes: 1 addition & 1 deletion rllib/agents/ppo/tests/test_appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_appo_entropy_coeff_schedule(self):
config["timesteps_per_iteration"] = 20
# 0 metrics reporting delay, this makes sure timestep,
# which entropy coeff depends on, is updated after each worker rollout.
config["min_iter_time_s"] = 0
config["min_time_s_per_reporting"] = 0
# Initial lr, doesn't really matter because of the schedule below.
config["entropy_coeff"] = 0.01
schedule = [
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/qmix/qmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@
"num_workers": 0,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,

# === Model ===
"model": {
Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@
"num_cpus_per_worker": 1,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span.
"min_iter_time_s": 1,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,

# Whether the loss should be calculated deterministically (w/o the
# stochastic action sampling step). True only useful for cont. actions and
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/sac/tests/test_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_sac_loss_function(self):
config["Q_model"]["fcnet_hiddens"] = [10]
config["policy_model"]["fcnet_hiddens"] = [10]
# Make sure, timing differences do not affect trainer.train().
config["min_iter_time_s"] = 0
config["min_time_s_per_reporting"] = 0
# Test SAC with Simplex action space.
config["env_config"] = {"simplex_actions": True}

Expand Down
4 changes: 2 additions & 2 deletions rllib/agents/slateq/slateq.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@
"num_workers": 0,
# Whether to compute priorities on workers.
"worker_side_prioritization": False,
# Prevent iterations from going lower than this time span
"min_iter_time_s": 1,
# Prevent reporting frequency from going lower than this time span.
"min_time_s_per_reporting": 1,

# === SlateQ specific options ===
# Learning method used by the slateq policy. Choose from: RANDOM,
Expand Down
Loading