Skip to content

Commit

Permalink
[rllib/tune] Fix durable trainable in trainer template, add release t…
Browse files Browse the repository at this point in the history
…est (#20422)
  • Loading branch information
krfricke authored and Alex committed Nov 20, 2021
1 parent b9de2a9 commit a16ea36
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 20 deletions.
19 changes: 19 additions & 0 deletions python/ray/tune/ray_trial_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding: utf-8
import copy
import inspect
from collections import deque
from functools import partial
import logging
Expand Down Expand Up @@ -384,6 +385,24 @@ def _setup_remote_runner(self, trial):
kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
kwargs["sync_function_tpl"] = trial.sync_function_tpl

# Throw a meaningful error if trainable does not use the
# new API
sig = inspect.signature(trial.get_trainable_cls())
try:
sig.bind_partial(**kwargs)
except Exception as e:
raise RuntimeError(
"Your trainable class does not accept a "
"`remote_checkpoint_dir` or `sync_function_tpl` argument "
"in its constructor, but you've passed a "
"`upload_dir` to your SyncConfig. Without accepting "
"these parameters and passing them to the base trainable "
"constructor in the init call, cloud checkpointing is "
"effectively disabled. To resolve this issue, add the "
"parameters to your trainable class constructor or "
"disable cloud checkpointing by setting `upload_dir=None`."
) from e

with self._change_working_directory(trial):
return full_actor_class.remote(**kwargs)

Expand Down
2 changes: 2 additions & 0 deletions release/.buildkite/build_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def __init__(self, name: str, retry: int = 0):
"aws_no_sync_down",
"aws_ssh_sync",
"aws_durable_upload",
# "aws_durable_upload_rllib_str",
# "aws_durable_upload_rllib_trainer",
"gcp_k8s_durable_upload",
],
"~/ray/release/tune_tests/scalability_tests/tune_tests.yaml": [
Expand Down
17 changes: 17 additions & 0 deletions release/tune_tests/cloud_tests/app_config_ml.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
base_image: "anyscale/ray-ml:nightly-py37-gpu"
env_vars: {}
debian_packages:
- curl

python:
pip_packages:
- pytest
- awscli
- gsutil
conda_packages: []

post_build_cmds:
- pip uninstall -y ray || true
# Install Ray
- pip3 install -U {{ env["RAY_WHEELS"] | default("ray") }}
- {{ env["RAY_WHEELS_SANITY_CHECK"] | default("echo No Ray wheels sanity check") }}
20 changes: 20 additions & 0 deletions release/tune_tests/cloud_tests/tune_cloud_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,26 @@
prepare: python wait_cluster.py 4 600
script: python workloads/run_cloud_test.py durable_upload --bucket s3://data-test-ilr/durable_upload

- name: aws_durable_upload_rllib_str
cluster:
app_config: app_config_ml.yaml
compute_template: tpl_aws_4x2.yaml

run:
timeout: 600
prepare: python wait_cluster.py 4 600
script: python workloads/run_cloud_test.py durable_upload --trainable rllib_str --bucket s3://data-test-ilr/durable_upload_rllib_str

- name: aws_durable_upload_rllib_trainer
cluster:
app_config: app_config_ml.yaml
compute_template: tpl_aws_4x2.yaml

run:
timeout: 600
prepare: python wait_cluster.py 4 600
script: python workloads/run_cloud_test.py durable_upload --trainable rllib_trainer --bucket s3://data-test-ilr/durable_upload_rllib_trainer

- name: aws_no_durable_upload
cluster:
app_config: app_config.yaml
Expand Down
71 changes: 55 additions & 16 deletions release/tune_tests/cloud_tests/workloads/_tune_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import ray
from ray import tune
from ray.rllib.agents import DefaultCallbacks
from ray.rllib.agents.ppo import PPOTrainer


def train(config, checkpoint_dir=None):
def fn_trainable(config, checkpoint_dir=None):
if checkpoint_dir:
with open(os.path.join(checkpoint_dir, "checkpoint.json"), "rt") as fp:
state = json.load(fp)
Expand All @@ -30,6 +32,16 @@ def train(config, checkpoint_dir=None):
internal_iter=state["internal_iter"])


class RLLibCallback(DefaultCallbacks):
def __init__(self):
super(RLLibCallback, self).__init__()
self.internal_iter = 0

def on_train_result(self, *, trainer, result: dict, **kwargs) -> None:
result["internal_iter"] = self.internal_iter
self.internal_iter += 1


class IndicatorCallback(tune.Callback):
def __init__(self, indicator_file):
self.indicator_file = indicator_file
Expand All @@ -39,35 +51,57 @@ def on_step_begin(self, iteration, trials, **info):
fp.write("1")


def run_tune(
no_syncer: bool,
upload_dir: Optional[str] = None,
experiment_name: str = "cloud_test",
indicator_file: str = "/tmp/tune_cloud_indicator",
):
num_cpus_per_trial = int(os.environ.get("TUNE_NUM_CPUS_PER_TRIAL", "2"))
def run_tune(no_syncer: bool,
upload_dir: Optional[str] = None,
experiment_name: str = "cloud_test",
indicator_file: str = "/tmp/tune_cloud_indicator",
trainable: str = "function",
num_cpus_per_trial: int = 2):
if trainable == "function":
train = fn_trainable
config = {
"max_iterations": 30,
"sleep_time": 5,
"checkpoint_freq": 2,
"score_multiplied": tune.randint(0, 100),
}
kwargs = {"resources_per_trial": {"cpu": num_cpus_per_trial}}
elif trainable == "rllib_str" or trainable == "rllib_trainer":
if trainable == "rllib_str":
train = "PPO"
else:
train = PPOTrainer

config = {
"env": "CartPole-v1",
"num_workers": 1,
"num_envs_per_worker": 1,
"callbacks": RLLibCallback
}
kwargs = {
"stop": {
"training_iteration": 10
},
}
else:
raise RuntimeError(f"Unknown trainable: {trainable}")

tune.run(
train,
name=experiment_name,
resume="AUTO",
num_samples=4,
config={
"max_iterations": 30,
"sleep_time": 5,
"checkpoint_freq": 2,
"score_multiplied": tune.randint(0, 100),
},
config=config,
sync_config=tune.SyncConfig(
syncer="auto" if not no_syncer else None,
upload_dir=upload_dir,
sync_on_checkpoint=True,
sync_period=0.5,
),
keep_checkpoints_num=2,
resources_per_trial={"cpu": num_cpus_per_trial},
callbacks=[IndicatorCallback(indicator_file=indicator_file)],
verbose=2)
verbose=2,
**kwargs)


if __name__ == "__main__":
Expand All @@ -88,11 +122,16 @@ def run_tune(

args = parser.parse_args()

trainable = str(os.environ.get("TUNE_TRAINABLE", "function"))
num_cpus_per_trial = int(os.environ.get("TUNE_NUM_CPUS_PER_TRIAL", "2"))

run_kwargs = dict(
no_syncer=args.no_syncer or False,
upload_dir=args.upload_dir or None,
experiment_name=args.experiment_name or "cloud_test",
indicator_file=args.indicator_file,
trainable=trainable,
num_cpus_per_trial=num_cpus_per_trial,
)

if not ray.is_initialized:
Expand Down
8 changes: 6 additions & 2 deletions release/tune_tests/cloud_tests/workloads/run_cloud_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,7 @@ def after_experiments():

parser.add_argument(
"variant", choices=["no_sync_down", "ssh_sync", "durable_upload"])
parser.add_argument("--trainable", type=str, default="function")
parser.add_argument("--bucket", type=str, default=None)
parser.add_argument(
"--cpus-per-trial", required=False, default=2, type=int)
Expand Down Expand Up @@ -1092,6 +1093,7 @@ def after_experiments():
"/tmp/release_test_out.json")

def _run_test(variant: str,
trainable: str = "function",
bucket: str = "",
cpus_per_trial: int = 2,
overwrite_tune_script: Optional[str] = None):
Expand All @@ -1100,6 +1102,7 @@ def _run_test(variant: str,
f"node {ray.util.get_node_ip_address()} with "
f"{cpus_per_trial} CPUs per trial.")

os.environ["TUNE_TRAINABLE"] = str(trainable)
os.environ["TUNE_NUM_CPUS_PER_TRIAL"] = str(cpus_per_trial)

if overwrite_tune_script:
Expand All @@ -1123,7 +1126,8 @@ def _run_test(variant: str,

if not uses_ray_client:
print("This test will *not* use Ray client.")
_run_test(args.variant, args.bucket, args.cpus_per_trial)
_run_test(args.variant, args.trainable, args.bucket,
args.cpus_per_trial)
else:
print("This test will run using Ray client.")

Expand All @@ -1146,7 +1150,7 @@ def _get_head_ip():
_run_test_remote = ray.remote(
resources={f"node:{ip}": 0.01}, num_cpus=0)(_run_test)
ray.get(
_run_test_remote.remote(args.variant, args.bucket,
_run_test_remote.remote(args.variant, args.trainable, args.bucket,
args.cpus_per_trial, remote_tune_script))

print(f"Fetching remote release test result file: {release_test_out}")
Expand Down
11 changes: 9 additions & 2 deletions rllib/agents/trainer_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ray.rllib.utils.annotations import override, DeveloperAPI
from ray.rllib.utils.typing import EnvConfigDict, EnvType, \
PartialTrainerConfigDict, ResultDict, TrainerConfigDict
from ray.tune.logger import Logger

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -92,8 +93,14 @@ class trainer_cls(base):
_default_config = default_config or COMMON_CONFIG
_policy_class = default_policy

def __init__(self, config=None, env=None, logger_creator=None):
Trainer.__init__(self, config, env, logger_creator)
def __init__(self,
config: TrainerConfigDict = None,
env: Union[str, EnvType, None] = None,
logger_creator: Callable[[], Logger] = None,
remote_checkpoint_dir: Optional[str] = None,
sync_function_tpl: Optional[str] = None):
Trainer.__init__(self, config, env, logger_creator,
remote_checkpoint_dir, sync_function_tpl)

@override(base)
def setup(self, config: PartialTrainerConfigDict):
Expand Down

0 comments on commit a16ea36

Please sign in to comment.