-
Notifications
You must be signed in to change notification settings - Fork 6.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Preparatory PR for multi-agent multi-GPU learner (alpha-star style) #03 #21652
Changes from 1 commit
a9f3098
5fe33ee
2f2c546
21efe03
05802c9
6616496
59c8a33
2312bcd
b9b8e98
3a1b7f0
3ad6097
e3c9222
fb01568
526e0e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -228,10 +228,18 @@ def validate_config(self, config: TrainerConfigDict) -> None: | |
"`NoFilter` for ES!") | ||
|
||
@override(Trainer) | ||
def _init(self, config, env_creator): | ||
def setup(self, config): | ||
# Call super's validation method. | ||
self.validate_config(config) | ||
|
||
# Generate `self.env_creator` callable to create an env instance. | ||
self._get_env_creator_from_env_id(self._env_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we change the name of this function to
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the catch. I'll check. ... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
# Generate the local env. | ||
env_context = EnvContext(config["env_config"] or {}, worker_index=0) | ||
env = env_creator(env_context) | ||
env = self.env_creator(env_context) | ||
|
||
self.callbacks = config.get("callbacks")() | ||
|
||
self._policy_class = get_policy_class(config) | ||
self.policy = self._policy_class( | ||
obs_space=env.observation_space, | ||
|
@@ -247,8 +255,8 @@ def _init(self, config, env_creator): | |
|
||
# Create the actors. | ||
logger.info("Creating actors.") | ||
self._workers = [ | ||
Worker.remote(config, {}, env_creator, noise_id, idx + 1) | ||
self.workers = [ | ||
Worker.remote(config, {}, self.env_creator, noise_id, idx + 1) | ||
for idx in range(config["num_workers"]) | ||
] | ||
|
||
|
@@ -333,7 +341,7 @@ def step_attempt(self): | |
# Now sync the filters | ||
FilterManager.synchronize({ | ||
DEFAULT_POLICY_ID: self.policy.observation_filter | ||
}, self._workers) | ||
}, self.workers) | ||
|
||
info = { | ||
"weights_norm": np.square(theta).sum(), | ||
|
@@ -375,7 +383,7 @@ def _sync_weights_to_workers(self, *, worker_set=None, workers=None): | |
@override(Trainer) | ||
def cleanup(self): | ||
# workaround for https://github.com/ray-project/ray/issues/1516 | ||
for w in self._workers: | ||
for w in self.workers: | ||
w.__ray_terminate__.remote() | ||
|
||
def _collect_results(self, theta_id, min_episodes, min_timesteps): | ||
|
@@ -386,7 +394,7 @@ def _collect_results(self, theta_id, min_episodes, min_timesteps): | |
"Collected {} episodes {} timesteps so far this iter".format( | ||
num_episodes, num_timesteps)) | ||
rollout_ids = [ | ||
worker.do_rollouts.remote(theta_id) for worker in self._workers | ||
worker.do_rollouts.remote(theta_id) for worker in self.workers | ||
] | ||
# Get the results of the rollouts. | ||
for result in ray.get(rollout_ids): | ||
|
@@ -413,4 +421,4 @@ def __setstate__(self, state): | |
self.policy.observation_filter = state["filter"] | ||
FilterManager.synchronize({ | ||
DEFAULT_POLICY_ID: self.policy.observation_filter | ||
}, self._workers) | ||
}, self.workers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for whatever reason this function isn't available on master, but I also didn't find its definition in this pr diff, but only when checking out this branch. Strange :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be a mistake by me when splitting my local branch (which contained more changes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.