4
4
#
5
5
"""
6
6
Parameter Sweep Script for W&B style configs
7
+ Usage: python titan_sweep.py <json_config_file_or_json_string>
8
+
7
9
Config format:
8
10
{
9
11
"command": "./submit.sh", # Executable to run
12
+ "method": "grid", # Optional: "grid" (default) or "zip"
10
13
"parameters": { # Single parameter set
11
14
"param.name": {"values": [value1, value2]},
12
15
"another.param": {"values": [value3, value4]}
13
16
}
14
17
}
15
18
16
- Or
19
+ This will generate the following combinations for grid method:
20
+ ./submit.sh --param.name value1 --another.param value3
21
+ ./submit.sh --param.name value1 --another.param value4
22
+ ./submit.sh --param.name value2 --another.param value3
23
+ ./submit.sh --param.name value2 --another.param value4
24
+
25
+ And for zip method:
26
+ ./submit.sh --param.name value1 --another.param value3
27
+ ./submit.sh --param.name value2 --another.param value4
28
+
29
+
30
+ Or specify config with multiple parameter sets:
17
31
18
32
{
19
33
"command": "./submit.sh",
34
+ "method": "zip", # Optional: "grid" (default) or "zip"
20
35
"parameters": [ # Multiple parameter sets
21
36
{
22
37
"param.name": {"values": [value1, value2]},
23
38
"another.param": {"values": [value3, value4]}
24
39
},
25
40
{
41
+ "project.name": {"values": ["my_project"]},
26
42
"param.name": {"values": [value5, value6]},
27
43
"different.param": {"values": [value7]}
28
44
}
29
45
]
30
46
}
47
+
48
+ Special handling for single-value parameters:
49
+ If any parameter has a single value, it is added to all combinations in both grid and zip modes.
31
50
"""
32
51
33
52
import json
36
55
import subprocess
37
56
import sys
38
57
from pathlib import Path
58
+ from typing import Dict , List , Union , Iterator , Tuple , Any , Optional , Literal , TypedDict
59
+
60
+ SearchMethodOptions = Literal ["grid" , "zip" ]
61
+
62
+ # Type aliases for better readability
63
+ ParamConfig = Dict [str , Dict [Literal ["values" ], List [Any ]]]
64
+ CommandList = List [Tuple [str , Union [str , int , float , bool ]]]
65
+
39
66
40
- DEFAULT_CONFIG = {
67
+ class SweepConfig (TypedDict ):
68
+ command : str
69
+ method : SearchMethodOptions
70
+ parameters : List [ParamConfig ]
71
+
72
+
73
+ class ParamConfig (TypedDict ):
74
+ name : str
75
+ values : List [str | float | int | bool ]
76
+
77
+
78
+ DEFAULT_CONFIG : Dict [str , Union [str , List [ParamConfig ]]] = {
41
79
"command" : "tests/test_titan_sweep.sh" ,
80
+ "method" : "grid" ,
42
81
"parameters" : [
43
82
{
44
83
"training.batch_size" : {"values" : [4 , 8 ]},
55
94
}
56
95
57
96
58
- def load_config (config_source = None ):
97
+ def load_config (config_source : Optional [ str ] = None ) -> Dict [ str , Any ] :
59
98
if not config_source :
60
99
return DEFAULT_CONFIG
61
100
try :
@@ -68,9 +107,20 @@ def load_config(config_source=None):
68
107
sys .exit (1 )
69
108
70
109
71
- def validate_parameters (params ):
110
+ def validate_parameters (params : ParamConfig , method : str = "grid" ) -> None :
111
+ """
112
+ Validate a single parameter configuration.
113
+
114
+ Args:
115
+ params: Parameter configuration dictionary
116
+ method: Combination method ("grid" or "zip")
117
+
118
+ Raises:
119
+ ValueError: If parameters are invalid
120
+ """
72
121
if not isinstance (params , dict ):
73
122
raise ValueError ("Parameters must be a dictionary" )
123
+
74
124
for param , param_config in params .items ():
75
125
if not isinstance (param , str ) or "." not in param :
76
126
raise ValueError (f"Invalid parameter name: { param } " )
@@ -79,37 +129,106 @@ def validate_parameters(params):
79
129
if not isinstance (param_config ["values" ], list ):
80
130
raise ValueError (f"Values must be a list for parameter: { param } " )
81
131
132
+ if method == "zip" :
133
+ # Check length consistency excluding single-value parameters
134
+ multi_value_params = {k : v for k , v in params .items () if len (v ["values" ]) > 1 }
135
+ if multi_value_params :
136
+ lengths = [len (param ["values" ]) for param in multi_value_params .values ()]
137
+ if not all (length == lengths [0 ] for length in lengths ):
138
+ raise ValueError (
139
+ f"For zip method, all multi-value parameters "
140
+ f"must have { lengths [0 ]} values each"
141
+ )
82
142
83
- def validate_config (config ):
143
+
144
+ def validate_config (config : Dict [str , Any ]) -> None :
145
+ """
146
+ Validate the complete configuration.
147
+
148
+ Args:
149
+ config: Complete configuration dictionary
150
+
151
+ Raises:
152
+ ValueError: If configuration is invalid
153
+ """
84
154
if "command" not in config :
85
155
raise ValueError ("Missing 'command' in config" )
86
156
if not isinstance (config ["command" ], str ):
87
157
raise ValueError ("'command' must be a string" )
88
158
if "parameters" not in config :
89
159
raise ValueError ("Missing 'parameters' in config" )
90
160
91
- # Handle both single dict and list of dicts
161
+ method = config .get ("method" , "grid" )
162
+ if method not in ["grid" , "zip" ]:
163
+ raise ValueError ("'method' must be either 'grid' or 'zip'" )
164
+
92
165
param_sets = config ["parameters" ]
93
166
if isinstance (param_sets , dict ):
94
- validate_parameters (param_sets )
167
+ validate_parameters (param_sets , method )
95
168
elif isinstance (param_sets , list ):
96
169
for param_set in param_sets :
97
- validate_parameters (param_set )
170
+ validate_parameters (param_set , method )
98
171
else :
99
172
raise ValueError ("'parameters' must be either dict or list of dicts" )
100
173
101
174
102
- def generate_command_combinations (parameters ):
103
- param_names = list (parameters .keys ())
104
- param_values = [parameters [param ]["values" ] for param in param_names ]
105
- for values in product (* param_values ):
106
- yield [(param_names [i ], val ) for i , val in enumerate (values )]
175
+ def generate_command_combinations (
176
+ parameters : ParamConfig , method : SearchMethodOptions = "grid"
177
+ ) -> Iterator [CommandList ]:
178
+ """
179
+ Generate parameter combinations based on the specified method.
180
+
181
+ Args:
182
+ parameters: Parameter configuration dictionary
183
+ method: Combination method ("grid" or "zip")
184
+
185
+ Yields:
186
+ List of parameter-value tuples for each combination
187
+ """
188
+ # Get single-value parameters first
189
+ single_values = [
190
+ (k , v ["values" ][0 ]) for k , v in parameters .items () if len (v ["values" ]) == 1
191
+ ]
192
+
193
+ # Get multi-value parameters
194
+ multi_value_params = {k : v for k , v in parameters .items () if len (v ["values" ]) > 1 }
107
195
196
+ if method == "grid" :
197
+ if not multi_value_params :
198
+ yield single_values
199
+ return
108
200
109
- def run_command (command , command_list ):
201
+ param_names = list (multi_value_params .keys ())
202
+ param_values = [multi_value_params [param ]["values" ] for param in param_names ]
203
+ for values in product (* param_values ):
204
+ yield single_values + [(param_names [i ], val ) for i , val in enumerate (values )]
205
+ elif method == "zip" :
206
+ if not multi_value_params :
207
+ yield single_values
208
+ return
209
+
210
+ names = list (multi_value_params .keys ())
211
+ for values in zip (* (v ["values" ] for v in multi_value_params .values ())):
212
+ yield single_values + list (zip (names , values ))
213
+ else :
214
+ raise ValueError (f"Invalid method: { method } " )
215
+
216
+
217
+ def run_command (command : str , command_list : CommandList ) -> None :
218
+ """
219
+ Execute the command with the given parameters.
220
+
221
+ Args:
222
+ command: Base command to execute
223
+ command_list: List of parameter-value tuples to append to command
224
+
225
+ Raises:
226
+ subprocess.CalledProcessError: If command execution fails
227
+ """
110
228
cmd = command .split ()
111
229
if isinstance (cmd , str ):
112
230
cmd = [command ]
231
+
113
232
for param , value in command_list :
114
233
if isinstance (value , bool ):
115
234
if value :
@@ -128,7 +247,7 @@ def run_command(command, command_list):
128
247
print (f"Error: { e } " )
129
248
130
249
131
- def main ():
250
+ def main () -> None :
132
251
parser = argparse .ArgumentParser (description = "Parameter sweep script" )
133
252
group = parser .add_mutually_exclusive_group ()
134
253
group .add_argument ("config_file" , nargs = "?" , help = "Config file path" )
@@ -141,14 +260,14 @@ def main():
141
260
try :
142
261
validate_config (config )
143
262
command = config ["command" ]
263
+ method = config .get ("method" , "grid" )
144
264
145
- # Handle both single dict and list of dicts
146
265
param_sets = config ["parameters" ]
147
266
if isinstance (param_sets , dict ):
148
267
param_sets = [param_sets ]
149
268
150
269
for param_set in param_sets :
151
- for combo in generate_command_combinations (param_set ):
270
+ for combo in generate_command_combinations (param_set , method ):
152
271
run_command (command , combo )
153
272
except Exception as e :
154
273
print (f"Error: { e } " )
0 commit comments