Skip to content

Commit 003eac8

Browse files
Support 'name:cnt' accelerators spec in YAML (#396)
* Support 'name:cnt' accelerators spec in YAML * Fixes #373: 'sky start/down' should error out
1 parent fa91016 commit 003eac8

15 files changed

+104
-63
lines changed

README.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ sky launch -c mycluster hello_sky.yaml
1414
```yaml
1515
# hello_sky.yaml
1616
resources:
17-
accelerators:
18-
K80:4
17+
accelerators: K80:4
1918

2019
setup: |
2120
# Typical use: pip install -r requirements.txt

docs/source/getting-started/quickstart.rst

+9-11
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,20 @@ requiring an NVIDIA Tesla K80 GPU on AWS. (More example yaml files can be found
3737
# hello_sky.yaml
3838
3939
resources:
40-
# Optional; if left out, pick from the available clouds.
41-
cloud: aws
40+
# Optional; if left out, pick from the available clouds.
41+
cloud: aws
4242
43-
# Get more GPUs with
44-
# accelerators:
45-
# K80: 8
46-
accelerators: K80
43+
# Get 1 K80 GPU. Use <name>:<n> to get more (e.g., "K80:8").
44+
accelerators: K80
4745
4846
setup: |
49-
# Typical use: pip install -r requirements.txt
50-
echo "running setup"
47+
# Typical use: pip install -r requirements.txt
48+
echo "running setup"
5149
5250
run: |
53-
# Typical use: make use of resources, such as running training.
54-
echo "hello sky!"
55-
conda env list
51+
# Typical use: make use of resources, such as running training.
52+
echo "hello sky!"
53+
conda env list
5654
5755
Sky handles selecting an appropriate VM based on user-specified resource
5856
constraints, launching the cluster on an appropriate cloud provider, and

docs/source/getting-started/tutorial.rst

+17-19
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,27 @@ and run command:
1414
name: huggingface
1515
1616
resources:
17-
accelerators:
18-
V100: 4
17+
accelerators: V100:4
1918
2019
setup: |
21-
git clone https://github.com/huggingface/transformers/
22-
cd transformers
23-
pip3 install .
24-
cd examples/pytorch/text-classification
25-
pip3 install -r requirements.txt
20+
git clone https://github.com/huggingface/transformers/
21+
cd transformers
22+
pip3 install .
23+
cd examples/pytorch/text-classification
24+
pip3 install -r requirements.txt
2625
2726
run: |
28-
cd transformers/examples/pytorch/text-classification
29-
python3 run_glue.py \
30-
--model_name_or_path bert-base-cased \
31-
--dataset_name imdb \
32-
--do_train \
33-
--max_seq_length 128 \
34-
--per_device_train_batch_size 32 \
35-
--learning_rate 2e-5 \
36-
--max_steps 50 \
37-
--output_dir /tmp/imdb/ --overwrite_output_dir \
38-
--fp16
27+
cd transformers/examples/pytorch/text-classification
28+
python3 run_glue.py \
29+
--model_name_or_path bert-base-cased \
30+
--dataset_name imdb \
31+
--do_train \
32+
--max_seq_length 128 \
33+
--per_device_train_batch_size 32 \
34+
--learning_rate 2e-5 \
35+
--max_steps 50 \
36+
--output_dir /tmp/imdb/ --overwrite_output_dir \
37+
--fp16
3938
4039
4140
We can launch training by running:
@@ -93,4 +92,3 @@ If we wish to view the output for each run after it has completed we can use:
9392
9493
$ # Cancel job job3 (ID: 3)
9594
$ sky cancel lm-cluster 3
96-

docs/source/reference/interactive-nodes.rst

+1-3
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ By default, interactive clusters are a single node. If you require a cluster wit
6969
7070
num_nodes: 16
7171
resources:
72-
accelerators:
73-
V100: 8
72+
accelerators: V100:8
7473
7574
.. code-block:: console
7675
@@ -81,4 +80,3 @@ To log in to the head node:
8180
.. code-block:: console
8281
8382
$ ssh my-cluster
84-

docs/source/reference/yaml-spec.rst

+4-3
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ describe all fields available.
2323
resources:
2424
cloud: aws # A cloud (optional) can be specified, if desired.
2525
26-
# Accelerator requirements (optional) can be specified, use sky show-gpus
26+
# Accelerator requirements (optional) can be specified, use `sky show-gpus`
2727
# to view available accelerator configurations.
28-
accelerators:
29-
V100: 4 # Specify the accelerator type and the count per node.
28+
# This specifies the accelerator type and the count per node. Format:
29+
# <name>:<cnt> or <name> (short for a count of 1).
30+
accelerators: V100:4
3031
3132
# Accelerator arguments (optional) provides additional metadata for some
3233
# accelerators, such as the TensorFlow version for TPUs.

examples/huggingface_glue_imdb_app.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ name: huggingface
33
resources:
44
accelerators: V100
55
# The above is a shorthand for <name>: <count=1>. Same as:
6-
# accelerators:
7-
# V100: 1
6+
# accelerators: V100:1
87

98
# The setup command. Will be run under the working directory.
109
setup: |

examples/job_queue/job.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
name: job
1010

1111
resources:
12-
accelerators:
13-
K80: 0.5
12+
accelerators: K80:0.5
1413

1514
setup: |
1615
echo "running setup"

examples/job_queue/job_gpu.yaml

+1-3
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
name: job
1010

1111
resources:
12-
accelerators:
13-
K80: 0.5
12+
accelerators: K80:0.5
1413

1514
# setup: |
1615
# conda create -n test python=3.7 -y
@@ -25,4 +24,3 @@ run: |
2524
echo "started"
2625
python -u -c "import torch; a = torch.randn(10000, 10000).cuda(); b = torch.randn(10000, 10000).cuda(); [print((a @ b).sum()) for _ in range(10000000000)]"
2726
echo "ended"
28-

examples/job_queue/job_multinode.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
name: job_multinode
1111

1212
resources:
13-
accelerators:
14-
K80: 0.5
13+
accelerators: K80:0.5
1514

1615
num_nodes: 2
1716

examples/many_gpu_vms.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ name: many_gpu_vms
22

33
resources:
44
cloud: aws
5-
accelerators:
6-
V100: 8
5+
accelerators: V100:8
76
# use_spot: true
87

98
num_nodes: 16

examples/ray_tune_app.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
resources:
22
cloud: aws
3-
accelerators:
4-
V100: 1
3+
accelerators: V100
54

65
num_nodes: 2
76

examples/resnet_distributed_torch.yaml

+1-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ name: resnet-distributed-app
22

33

44
resources:
5-
accelerators:
6-
V100: 1
5+
accelerators: V100
76

87
num_nodes: 2
98

sky/cli.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,7 @@ def stop(
904904

905905

906906
@cli.command(cls=_DocumentedCodeCommand)
907-
@click.argument('clusters', nargs=-1, required=False)
907+
@click.argument('clusters', nargs=-1, required=True)
908908
def start(clusters: Tuple[str]):
909909
"""Restart cluster(s).
910910
@@ -1038,10 +1038,7 @@ def down(
10381038
sky down -a
10391039
10401040
"""
1041-
names = clusters
1042-
if not all and not names:
1043-
return
1044-
_terminate_or_stop_clusters(names, apply_to_all=all, terminate=True)
1041+
_terminate_or_stop_clusters(clusters, apply_to_all=all, terminate=True)
10451042

10461043

10471044
def _terminate_or_stop_clusters(names: Tuple[str], apply_to_all: Optional[bool],

sky/resources.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@ class Resources:
2121
Examples:
2222
2323
# Fully specified cloud and instance type (is_launchable() is True).
24-
sky.Resources(clouds.AWS(), 'p3.2xlarge'),
25-
sky.Resources(clouds.GCP(), 'n1-standard-16'),
24+
sky.Resources(clouds.AWS(), 'p3.2xlarge')
25+
sky.Resources(clouds.GCP(), 'n1-standard-16')
2626
sky.Resources(clouds.GCP(), 'n1-standard-8', 'V100')
2727
2828
# Specifying required resources; Sky decides the cloud/instance type.
29-
sky.Resources(accelerators='V100'),
30-
sky.Resources(clouds.GCP(), accelerators={'V100': 1}),
29+
# The below are equivalent:
30+
sky.Resources(accelerators='V100')
31+
sky.Resources(accelerators='V100:1')
32+
sky.Resources(accelerators={'V100': 1})
3133
3234
# TODO:
3335
sky.Resources(requests={'mem': '16g', 'cpu': 8})
@@ -48,7 +50,22 @@ def __init__(
4850
'If instance_type is specified, must specify the cloud'
4951
if accelerators is not None:
5052
if isinstance(accelerators, str): # Convert to Dict[str, int].
51-
accelerators = {accelerators: 1}
53+
if ':' not in accelerators:
54+
accelerators = {accelerators: 1}
55+
else:
56+
splits = accelerators.split(':')
57+
parse_error = ('The "accelerators" field as a str '
58+
'should be <name> or <name>:<cnt>. '
59+
f'Found: {accelerators!r}')
60+
if len(splits) != 2:
61+
raise ValueError(parse_error)
62+
try:
63+
accelerators = {splits[0]: int(splits[1])}
64+
except ValueError:
65+
try:
66+
accelerators = {splits[0]: float(splits[1])}
67+
except ValueError:
68+
raise ValueError(parse_error) from None
5269
assert len(accelerators) == 1, accelerators
5370

5471
acc, _ = list(accelerators.items())[0]

tests/test_optimizer_dryruns.py

+41
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
import pytest
2+
import tempfile
3+
import textwrap
24

35
import sky
46
from sky import clouds
57
from sky import exceptions
68

79

10+
def _test_parse_accelerators(spec, expected_accelerators):
11+
with tempfile.NamedTemporaryFile('w') as f:
12+
f.write(spec)
13+
f.flush()
14+
with sky.Dag():
15+
task = sky.Task.from_yaml(f.name)
16+
assert list(task.resources)[0].accelerators == expected_accelerators
17+
18+
819
# Monkey-patching is required because in the test environment, no cloud is
920
# enabled. The optimizer checks the environment to find enabled clouds, and
1021
# only generates plans within these clouds. The tests assume that all three
@@ -133,3 +144,33 @@ def test_instance_type_matches_accelerators(monkeypatch):
133144
sky.Resources(sky.AWS(),
134145
instance_type='p3.16xlarge',
135146
accelerators={'V100': 1}))
147+
148+
149+
def test_parse_accelerators_from_yaml():
150+
spec = textwrap.dedent("""\
151+
resources:
152+
accelerators: V100""")
153+
_test_parse_accelerators(spec, {'V100': 1})
154+
155+
spec = textwrap.dedent("""\
156+
resources:
157+
accelerators: V100:4""")
158+
_test_parse_accelerators(spec, {'V100': 4})
159+
160+
spec = textwrap.dedent("""\
161+
resources:
162+
accelerators: V100:0.5""")
163+
_test_parse_accelerators(spec, {'V100': 0.5})
164+
165+
spec = textwrap.dedent("""\
166+
resources:
167+
accelerators: \"V100: 0.5\"""")
168+
_test_parse_accelerators(spec, {'V100': 0.5})
169+
170+
# Invalid.
171+
spec = textwrap.dedent("""\
172+
resources:
173+
accelerators: \"V100: expected_a_float_here\"""")
174+
with pytest.raises(ValueError) as e:
175+
_test_parse_accelerators(spec, None)
176+
assert 'The "accelerators" field as a str ' in str(e.value)

0 commit comments

Comments
 (0)