Skip to content

Commit 6ea8f53

Browse files
authored
Sweeps (pytorch#75)
* add zip * add e2e test, add types * add test_titan_sweep.sh back, used as default command in script
1 parent a67f411 commit 6ea8f53

File tree

2 files changed

+227
-17
lines changed

2 files changed

+227
-17
lines changed

scripts/titan_sweep.py

+136-17
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,49 @@
44
#
55
"""
66
Parameter Sweep Script for W&B style configs
7+
Usage: python titan_sweep.py <json_config_file_or_json_string>
8+
79
Config format:
810
{
911
"command": "./submit.sh", # Executable to run
12+
"method": "grid", # Optional: "grid" (default) or "zip"
1013
"parameters": { # Single parameter set
1114
"param.name": {"values": [value1, value2]},
1215
"another.param": {"values": [value3, value4]}
1316
}
1417
}
1518
16-
Or
19+
This will generate the following combinations for grid method:
20+
./submit.sh --param.name value1 --another.param value3
21+
./submit.sh --param.name value1 --another.param value4
22+
./submit.sh --param.name value2 --another.param value3
23+
./submit.sh --param.name value2 --another.param value4
24+
25+
And for zip method:
26+
./submit.sh --param.name value1 --another.param value3
27+
./submit.sh --param.name value2 --another.param value4
28+
29+
30+
Or specify config with multiple parameter sets:
1731
1832
{
1933
"command": "./submit.sh",
34+
"method": "zip", # Optional: "grid" (default) or "zip"
2035
"parameters": [ # Multiple parameter sets
2136
{
2237
"param.name": {"values": [value1, value2]},
2338
"another.param": {"values": [value3, value4]}
2439
},
2540
{
41+
"project.name": {"values": ["my_project"]},
2642
"param.name": {"values": [value5, value6]},
2743
"different.param": {"values": [value7]}
2844
}
2945
]
3046
}
47+
48+
Special handling for single-value parameters:
49+
If any parameter has a single value, it is added to all combinations in both grid and zip modes.
3150
"""
3251

3352
import json
@@ -36,9 +55,29 @@
3655
import subprocess
3756
import sys
3857
from pathlib import Path
58+
from typing import Dict, List, Union, Iterator, Tuple, Any, Optional, Literal, TypedDict
59+
60+
SearchMethodOptions = Literal["grid", "zip"]
61+
62+
# Type aliases for better readability
63+
ParamConfig = Dict[str, Dict[Literal["values"], List[Any]]]
64+
CommandList = List[Tuple[str, Union[str, int, float, bool]]]
65+
3966

40-
DEFAULT_CONFIG = {
67+
class SweepConfig(TypedDict):
68+
command: str
69+
method: SearchMethodOptions
70+
parameters: List[ParamConfig]
71+
72+
73+
class ParamConfig(TypedDict):
74+
name: str
75+
values: List[str | float | int | bool]
76+
77+
78+
DEFAULT_CONFIG: Dict[str, Union[str, List[ParamConfig]]] = {
4179
"command": "tests/test_titan_sweep.sh",
80+
"method": "grid",
4281
"parameters": [
4382
{
4483
"training.batch_size": {"values": [4, 8]},
@@ -55,7 +94,7 @@
5594
}
5695

5796

58-
def load_config(config_source=None):
97+
def load_config(config_source: Optional[str] = None) -> Dict[str, Any]:
5998
if not config_source:
6099
return DEFAULT_CONFIG
61100
try:
@@ -68,9 +107,20 @@ def load_config(config_source=None):
68107
sys.exit(1)
69108

70109

71-
def validate_parameters(params):
110+
def validate_parameters(params: ParamConfig, method: str = "grid") -> None:
111+
"""
112+
Validate a single parameter configuration.
113+
114+
Args:
115+
params: Parameter configuration dictionary
116+
method: Combination method ("grid" or "zip")
117+
118+
Raises:
119+
ValueError: If parameters are invalid
120+
"""
72121
if not isinstance(params, dict):
73122
raise ValueError("Parameters must be a dictionary")
123+
74124
for param, param_config in params.items():
75125
if not isinstance(param, str) or "." not in param:
76126
raise ValueError(f"Invalid parameter name: {param}")
@@ -79,37 +129,106 @@ def validate_parameters(params):
79129
if not isinstance(param_config["values"], list):
80130
raise ValueError(f"Values must be a list for parameter: {param}")
81131

132+
if method == "zip":
133+
# Check length consistency excluding single-value parameters
134+
multi_value_params = {k: v for k, v in params.items() if len(v["values"]) > 1}
135+
if multi_value_params:
136+
lengths = [len(param["values"]) for param in multi_value_params.values()]
137+
if not all(length == lengths[0] for length in lengths):
138+
raise ValueError(
139+
f"For zip method, all multi-value parameters "
140+
f"must have {lengths[0]} values each"
141+
)
82142

83-
def validate_config(config):
143+
144+
def validate_config(config: Dict[str, Any]) -> None:
145+
"""
146+
Validate the complete configuration.
147+
148+
Args:
149+
config: Complete configuration dictionary
150+
151+
Raises:
152+
ValueError: If configuration is invalid
153+
"""
84154
if "command" not in config:
85155
raise ValueError("Missing 'command' in config")
86156
if not isinstance(config["command"], str):
87157
raise ValueError("'command' must be a string")
88158
if "parameters" not in config:
89159
raise ValueError("Missing 'parameters' in config")
90160

91-
# Handle both single dict and list of dicts
161+
method = config.get("method", "grid")
162+
if method not in ["grid", "zip"]:
163+
raise ValueError("'method' must be either 'grid' or 'zip'")
164+
92165
param_sets = config["parameters"]
93166
if isinstance(param_sets, dict):
94-
validate_parameters(param_sets)
167+
validate_parameters(param_sets, method)
95168
elif isinstance(param_sets, list):
96169
for param_set in param_sets:
97-
validate_parameters(param_set)
170+
validate_parameters(param_set, method)
98171
else:
99172
raise ValueError("'parameters' must be either dict or list of dicts")
100173

101174

102-
def generate_command_combinations(parameters):
103-
param_names = list(parameters.keys())
104-
param_values = [parameters[param]["values"] for param in param_names]
105-
for values in product(*param_values):
106-
yield [(param_names[i], val) for i, val in enumerate(values)]
175+
def generate_command_combinations(
176+
parameters: ParamConfig, method: SearchMethodOptions = "grid"
177+
) -> Iterator[CommandList]:
178+
"""
179+
Generate parameter combinations based on the specified method.
180+
181+
Args:
182+
parameters: Parameter configuration dictionary
183+
method: Combination method ("grid" or "zip")
184+
185+
Yields:
186+
List of parameter-value tuples for each combination
187+
"""
188+
# Get single-value parameters first
189+
single_values = [
190+
(k, v["values"][0]) for k, v in parameters.items() if len(v["values"]) == 1
191+
]
192+
193+
# Get multi-value parameters
194+
multi_value_params = {k: v for k, v in parameters.items() if len(v["values"]) > 1}
107195

196+
if method == "grid":
197+
if not multi_value_params:
198+
yield single_values
199+
return
108200

109-
def run_command(command, command_list):
201+
param_names = list(multi_value_params.keys())
202+
param_values = [multi_value_params[param]["values"] for param in param_names]
203+
for values in product(*param_values):
204+
yield single_values + [(param_names[i], val) for i, val in enumerate(values)]
205+
elif method == "zip":
206+
if not multi_value_params:
207+
yield single_values
208+
return
209+
210+
names = list(multi_value_params.keys())
211+
for values in zip(*(v["values"] for v in multi_value_params.values())):
212+
yield single_values + list(zip(names, values))
213+
else:
214+
raise ValueError(f"Invalid method: {method}")
215+
216+
217+
def run_command(command: str, command_list: CommandList) -> None:
218+
"""
219+
Execute the command with the given parameters.
220+
221+
Args:
222+
command: Base command to execute
223+
command_list: List of parameter-value tuples to append to command
224+
225+
Raises:
226+
subprocess.CalledProcessError: If command execution fails
227+
"""
110228
cmd = command.split()
111229
if isinstance(cmd, str):
112230
cmd = [command]
231+
113232
for param, value in command_list:
114233
if isinstance(value, bool):
115234
if value:
@@ -128,7 +247,7 @@ def run_command(command, command_list):
128247
print(f"Error: {e}")
129248

130249

131-
def main():
250+
def main() -> None:
132251
parser = argparse.ArgumentParser(description="Parameter sweep script")
133252
group = parser.add_mutually_exclusive_group()
134253
group.add_argument("config_file", nargs="?", help="Config file path")
@@ -141,14 +260,14 @@ def main():
141260
try:
142261
validate_config(config)
143262
command = config["command"]
263+
method = config.get("method", "grid")
144264

145-
# Handle both single dict and list of dicts
146265
param_sets = config["parameters"]
147266
if isinstance(param_sets, dict):
148267
param_sets = [param_sets]
149268

150269
for param_set in param_sets:
151-
for combo in generate_command_combinations(param_set):
270+
for combo in generate_command_combinations(param_set, method):
152271
run_command(command, combo)
153272
except Exception as e:
154273
print(f"Error: {e}")

tests/test_titan_sweep.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#!/usr/bin/env python3
2+
#
3+
# Copyright (c) 2025 Graphcore Ltd. All rights reserved.
4+
#
5+
from typing import List, Dict, Any
6+
import json
7+
import tempfile
8+
import subprocess
9+
import sys
10+
import os
11+
import pytest
12+
13+
14+
@pytest.fixture
15+
def echo_script(tmp_path) -> str:
16+
"""
17+
Create a temporary shell script that echoes its arguments.
18+
Args:
19+
tmp_path: Pytest-provided temporary directory path
20+
Returns:
21+
Path to the temporary shell script
22+
"""
23+
script_path = tmp_path / "echo_script.sh"
24+
script_content = """#!/bin/bash
25+
echo "$@"
26+
"""
27+
script_path.write_text(script_content)
28+
script_path.chmod(0o755)
29+
return str(script_path)
30+
31+
32+
def run_sweep_and_capture(config_dict: Dict[str, Any]) -> List[str]:
33+
"""Run the sweep script with given config and return output lines."""
34+
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as tf:
35+
json.dump(config_dict, tf)
36+
config_path = tf.name
37+
try:
38+
result = subprocess.run(
39+
[sys.executable, "scripts/titan_sweep.py", config_path],
40+
capture_output=True,
41+
text=True,
42+
)
43+
return result.stdout.strip().split("\n")
44+
finally:
45+
os.unlink(config_path)
46+
47+
48+
def test_grid_sweep(echo_script: str) -> None:
49+
"""Test grid method parameter sweep."""
50+
config = {
51+
"command": echo_script,
52+
"method": "grid",
53+
"parameters": {
54+
"wandb.project": {"values": ["my_project"]},
55+
"training.steps": {"values": [1000, 2000]},
56+
"training.batch_size": {"values": [16, 32, 64]},
57+
},
58+
}
59+
outputs = run_sweep_and_capture(config)
60+
expected_outputs = [
61+
"--wandb.project my_project --training.steps 1000 --training.batch_size 16",
62+
"--wandb.project my_project --training.steps 1000 --training.batch_size 32",
63+
"--wandb.project my_project --training.steps 1000 --training.batch_size 64",
64+
"--wandb.project my_project --training.steps 2000 --training.batch_size 16",
65+
"--wandb.project my_project --training.steps 2000 --training.batch_size 32",
66+
"--wandb.project my_project --training.steps 2000 --training.batch_size 64",
67+
]
68+
assert sorted(outputs) == sorted(expected_outputs)
69+
70+
71+
def test_zip_sweep(echo_script: str) -> None:
72+
"""Test zip method parameter sweep with a single-value parameter."""
73+
config = {
74+
"command": echo_script,
75+
"method": "zip",
76+
"parameters": {
77+
"wandb.project": {"values": ["my_project"]},
78+
"model.type": {"values": ["a", "b"]},
79+
"learning.rate": {"values": [0.1, 0.2]},
80+
},
81+
}
82+
outputs = run_sweep_and_capture(config)
83+
expected_outputs = [
84+
"--wandb.project my_project --model.type a --learning.rate 0.1",
85+
"--wandb.project my_project --model.type b --learning.rate 0.2",
86+
]
87+
assert sorted(outputs) == sorted(expected_outputs)
88+
89+
90+
if __name__ == "__main__":
91+
pytest.main([__file__])

0 commit comments

Comments
 (0)