Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Add parameter sweep to learning rate and batch size #708

Merged
merged 2 commits into from
Apr 27, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions digits/model/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def validate_py_ext(form, field):
tooltip = "If you provide a random seed, then back-to-back runs with the same model and dataset should give identical results."
)

batch_size = utils.forms.IntegerField('Batch size',
batch_size = utils.forms.MultiIntegerField('Batch size',
validators = [
validators.NumberRange(min=1),
validators.Optional(),
utils.forms.MultiNumberRange(min=1),
utils.forms.MultiOptional(),
],
tooltip = "How many images to process at once. If blank, values are used from the network definition."
)
Expand Down Expand Up @@ -153,10 +153,10 @@ def validate_solver_type(form, field):

### Learning rate

learning_rate = utils.forms.FloatField('Base Learning Rate',
learning_rate = utils.forms.MultiFloatField('Base Learning Rate',
default = 0.01,
validators = [
validators.NumberRange(min=0),
utils.forms.MultiNumberRange(min=0),
],
tooltip = "Affects how quickly the network learns. If you are getting NaN for your loss, you probably need to lower this value."
)
Expand Down
18 changes: 17 additions & 1 deletion digits/model/images/classification/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,11 @@ def create_model(cls, network=None, **kwargs):
if rv.status_code != 200:
print json.loads(rv.data)
raise RuntimeError('Model creation failed with %s' % rv.status_code)
return json.loads(rv.data)['id']
data = json.loads(rv.data)
if 'jobs' in data.keys():
return [j['id'] for j in data['jobs']]
else:
return data['id']

# expect a redirect
if not 300 <= rv.status_code <= 310:
Expand Down Expand Up @@ -1123,3 +1127,15 @@ def test_python_layer(self):
assert rv.status_code == 200, 'json load failed with %s' % rv.status_code
content = json.loads(rv.data)
assert len(content['snapshots']), 'should have at least snapshot'

class TestSweepCreation(BaseViewsTestWithDataset):
FRAMEWORK = 'caffe'
"""
Model creation tests
"""
def test_sweep(self):
job_ids = self.create_model(json=True, learning_rate='[0.01, 0.02]', batch_size='[8, 10]')
for job_id in job_ids:
assert self.model_wait_completion(job_id) == 'Done', 'create failed'
assert self.delete_model(job_id) == 200, 'delete failed'
assert not self.model_exists(job_id), 'model exists after delete'
306 changes: 169 additions & 137 deletions digits/model/images/classification/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,150 +118,182 @@ def create():
raise werkzeug.exceptions.BadRequest(
'Unknown dataset job_id "%s"' % form.dataset.data)

job = None
try:
job = ImageClassificationModelJob(
username = utils.auth.get_username(),
name = form.model_name.data,
dataset_id = datasetJob.id(),
)
# get handle to framework object
fw = frameworks.get_framework_by_id(form.framework.data)

pretrained_model = None
if form.method.data == 'standard':
found = False

# can we find it in standard networks?
network_desc = fw.get_standard_network_desc(form.standard_networks.data)
if network_desc:
found = True
network = fw.get_network_from_desc(network_desc)

if not found:
raise werkzeug.exceptions.BadRequest(
'Unknown standard model "%s"' % form.standard_networks.data)
elif form.method.data == 'previous':
old_job = scheduler.get_job(form.previous_networks.data)
if not old_job:
# sweeps will be a list of the the permutations of swept fields
# Get swept learning_rate
sweeps = [{'learning_rate': v} for v in form.learning_rate.data]
add_learning_rate = len(form.learning_rate.data) > 1

# Add swept batch_size
sweeps = [dict(s.items() + [('batch_size', bs)]) for bs in form.batch_size.data for s in sweeps[:]]
add_batch_size = len(form.batch_size.data) > 1
n_jobs = len(sweeps)

jobs = []
for sweep in sweeps:
# Populate the form with swept data to be used in saving and
# launching jobs.
form.learning_rate.data = sweep['learning_rate']
form.batch_size.data = sweep['batch_size']

# Augment Job Name
extra = ''
if add_learning_rate:
extra += ' learning_rate:%s' % str(form.learning_rate.data[0])
if add_batch_size:
extra += ' batch_size:%d' % form.batch_size.data[0]

job = None
try:
job = ImageClassificationModelJob(
username = utils.auth.get_username(),
name = form.model_name.data + extra,
dataset_id = datasetJob.id(),
)
# get handle to framework object
fw = frameworks.get_framework_by_id(form.framework.data)

pretrained_model = None
if form.method.data == 'standard':
found = False

# can we find it in standard networks?
network_desc = fw.get_standard_network_desc(form.standard_networks.data)
if network_desc:
found = True
network = fw.get_network_from_desc(network_desc)

if not found:
raise werkzeug.exceptions.BadRequest(
'Unknown standard model "%s"' % form.standard_networks.data)
elif form.method.data == 'previous':
old_job = scheduler.get_job(form.previous_networks.data)
if not old_job:
raise werkzeug.exceptions.BadRequest(
'Job not found: %s' % form.previous_networks.data)

use_same_dataset = (old_job.dataset_id == job.dataset_id)
network = fw.get_network_from_previous(old_job.train_task().network, use_same_dataset)

for choice in form.previous_networks.choices:
if choice[0] == form.previous_networks.data:
epoch = float(flask.request.form['%s-snapshot' % form.previous_networks.data])
if epoch == 0:
pass
elif epoch == -1:
pretrained_model = old_job.train_task().pretrained_model
else:
for filename, e in old_job.train_task().snapshots:
if e == epoch:
pretrained_model = filename
break

if pretrained_model is None:
raise werkzeug.exceptions.BadRequest(
"For the job %s, selected pretrained_model for epoch %d is invalid!"
% (form.previous_networks.data, epoch))
if not (os.path.exists(pretrained_model)):
raise werkzeug.exceptions.BadRequest(
"Pretrained_model for the selected epoch doesn't exists. May be deleted by another user/process. Please restart the server to load the correct pretrained_model details")
break

elif form.method.data == 'custom':
network = fw.get_network_from_desc(form.custom_network.data)
pretrained_model = form.custom_network_snapshot.data.strip()
else:
raise werkzeug.exceptions.BadRequest(
'Job not found: %s' % form.previous_networks.data)

use_same_dataset = (old_job.dataset_id == job.dataset_id)
network = fw.get_network_from_previous(old_job.train_task().network, use_same_dataset)

for choice in form.previous_networks.choices:
if choice[0] == form.previous_networks.data:
epoch = float(flask.request.form['%s-snapshot' % form.previous_networks.data])
if epoch == 0:
pass
elif epoch == -1:
pretrained_model = old_job.train_task().pretrained_model
else:
for filename, e in old_job.train_task().snapshots:
if e == epoch:
pretrained_model = filename
break

if pretrained_model is None:
raise werkzeug.exceptions.BadRequest(
"For the job %s, selected pretrained_model for epoch %d is invalid!"
% (form.previous_networks.data, epoch))
if not (os.path.exists(pretrained_model)):
raise werkzeug.exceptions.BadRequest(
"Pretrained_model for the selected epoch doesn't exists. May be deleted by another user/process. Please restart the server to load the correct pretrained_model details")
break

elif form.method.data == 'custom':
network = fw.get_network_from_desc(form.custom_network.data)
pretrained_model = form.custom_network_snapshot.data.strip()
else:
raise werkzeug.exceptions.BadRequest(
'Unrecognized method: "%s"' % form.method.data)

policy = {'policy': form.lr_policy.data}
if form.lr_policy.data == 'fixed':
pass
elif form.lr_policy.data == 'step':
policy['stepsize'] = form.lr_step_size.data
policy['gamma'] = form.lr_step_gamma.data
elif form.lr_policy.data == 'multistep':
policy['stepvalue'] = form.lr_multistep_values.data
policy['gamma'] = form.lr_multistep_gamma.data
elif form.lr_policy.data == 'exp':
policy['gamma'] = form.lr_exp_gamma.data
elif form.lr_policy.data == 'inv':
policy['gamma'] = form.lr_inv_gamma.data
policy['power'] = form.lr_inv_power.data
elif form.lr_policy.data == 'poly':
policy['power'] = form.lr_poly_power.data
elif form.lr_policy.data == 'sigmoid':
policy['stepsize'] = form.lr_sigmoid_step.data
policy['gamma'] = form.lr_sigmoid_gamma.data
else:
raise werkzeug.exceptions.BadRequest(
'Invalid learning rate policy')

if config_value('caffe_root')['multi_gpu']:
if form.select_gpus.data:
selected_gpus = [str(gpu) for gpu in form.select_gpus.data]
gpu_count = None
elif form.select_gpu_count.data:
gpu_count = form.select_gpu_count.data
selected_gpus = None
'Unrecognized method: "%s"' % form.method.data)

policy = {'policy': form.lr_policy.data}
if form.lr_policy.data == 'fixed':
pass
elif form.lr_policy.data == 'step':
policy['stepsize'] = form.lr_step_size.data
policy['gamma'] = form.lr_step_gamma.data
elif form.lr_policy.data == 'multistep':
policy['stepvalue'] = form.lr_multistep_values.data
policy['gamma'] = form.lr_multistep_gamma.data
elif form.lr_policy.data == 'exp':
policy['gamma'] = form.lr_exp_gamma.data
elif form.lr_policy.data == 'inv':
policy['gamma'] = form.lr_inv_gamma.data
policy['power'] = form.lr_inv_power.data
elif form.lr_policy.data == 'poly':
policy['power'] = form.lr_poly_power.data
elif form.lr_policy.data == 'sigmoid':
policy['stepsize'] = form.lr_sigmoid_step.data
policy['gamma'] = form.lr_sigmoid_gamma.data
else:
gpu_count = 1
selected_gpus = None
else:
if form.select_gpu.data == 'next':
gpu_count = 1
selected_gpus = None
raise werkzeug.exceptions.BadRequest(
'Invalid learning rate policy')

if config_value('caffe_root')['multi_gpu']:
if form.select_gpus.data:
selected_gpus = [str(gpu) for gpu in form.select_gpus.data]
gpu_count = None
elif form.select_gpu_count.data:
gpu_count = form.select_gpu_count.data
selected_gpus = None
else:
gpu_count = 1
selected_gpus = None
else:
selected_gpus = [str(form.select_gpu.data)]
gpu_count = None

# Python Layer File may be on the server or copied from the client.
fs.copy_python_layer_file(
bool(form.python_layer_from_client.data),
job.dir(),
(flask.request.files[form.python_layer_client_file.name]
if form.python_layer_client_file.name in flask.request.files
else ''), form.python_layer_server_file.data)

job.tasks.append(fw.create_train_task(
job_dir = job.dir(),
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
learning_rate = form.learning_rate.data,
lr_policy = policy,
gpu_count = gpu_count,
selected_gpus = selected_gpus,
batch_size = form.batch_size.data,
val_interval = form.val_interval.data,
pretrained_model= pretrained_model,
crop_size = form.crop_size.data,
use_mean = form.use_mean.data,
network = network,
random_seed = form.random_seed.data,
solver_type = form.solver_type.data,
shuffle = form.shuffle.data,
if form.select_gpu.data == 'next':
gpu_count = 1
selected_gpus = None
else:
selected_gpus = [str(form.select_gpu.data)]
gpu_count = None

# Python Layer File may be on the server or copied from the client.
fs.copy_python_layer_file(
bool(form.python_layer_from_client.data),
job.dir(),
(flask.request.files[form.python_layer_client_file.name]
if form.python_layer_client_file.name in flask.request.files
else ''), form.python_layer_server_file.data)

job.tasks.append(fw.create_train_task(
job_dir = job.dir(),
dataset = datasetJob,
train_epochs = form.train_epochs.data,
snapshot_interval = form.snapshot_interval.data,
learning_rate = form.learning_rate.data[0],
lr_policy = policy,
gpu_count = gpu_count,
selected_gpus = selected_gpus,
batch_size = form.batch_size.data[0],
val_interval = form.val_interval.data,
pretrained_model= pretrained_model,
crop_size = form.crop_size.data,
use_mean = form.use_mean.data,
network = network,
random_seed = form.random_seed.data,
solver_type = form.solver_type.data,
shuffle = form.shuffle.data,
)
)
)

## Save form data with the job so we can easily clone it later.
save_form_to_job(job, form)
## Save form data with the job so we can easily clone it later.
save_form_to_job(job, form)

scheduler.add_job(job)
if request_wants_json():
return flask.jsonify(job.json_dict())
else:
return flask.redirect(flask.url_for('digits.model.views.show', job_id=job.id()))
jobs.append(job)
scheduler.add_job(job)
if n_jobs == 1:
if request_wants_json():
return flask.jsonify(job.json_dict())
else:
return flask.redirect(flask.url_for('digits.model.views.show', job_id=job.id()))

except:
if job:
scheduler.delete_job(job)
raise

if request_wants_json():
return flask.jsonify(jobs=[job.json_dict() for job in jobs])

except:
if job:
scheduler.delete_job(job)
raise
# If there are multiple jobs launched, go to the home page.
return flask.redirect('/')

def show(job):
"""
Expand Down
Loading