diff --git a/digits/model/forms.py b/digits/model/forms.py index 3f160e007..210513813 100644 --- a/digits/model/forms.py +++ b/digits/model/forms.py @@ -120,10 +120,10 @@ def validate_py_ext(form, field): tooltip = "If you provide a random seed, then back-to-back runs with the same model and dataset should give identical results." ) - batch_size = utils.forms.IntegerField('Batch size', + batch_size = utils.forms.MultiIntegerField('Batch size', validators = [ - validators.NumberRange(min=1), - validators.Optional(), + utils.forms.MultiNumberRange(min=1), + utils.forms.MultiOptional(), ], tooltip = "How many images to process at once. If blank, values are used from the network definition." ) @@ -153,10 +153,10 @@ def validate_solver_type(form, field): ### Learning rate - learning_rate = utils.forms.FloatField('Base Learning Rate', + learning_rate = utils.forms.MultiFloatField('Base Learning Rate', default = 0.01, validators = [ - validators.NumberRange(min=0), + utils.forms.MultiNumberRange(min=0), ], tooltip = "Affects how quickly the network learns. If you are getting NaN for your loss, you probably need to lower this value." ) diff --git a/digits/model/images/classification/test_views.py b/digits/model/images/classification/test_views.py index 2bb2421f1..b81680496 100644 --- a/digits/model/images/classification/test_views.py +++ b/digits/model/images/classification/test_views.py @@ -202,7 +202,11 @@ def create_model(cls, network=None, **kwargs): if rv.status_code != 200: print json.loads(rv.data) raise RuntimeError('Model creation failed with %s' % rv.status_code) - return json.loads(rv.data)['id'] + data = json.loads(rv.data) + if 'jobs' in data.keys(): + return [j['id'] for j in data['jobs']] + else: + return data['id'] # expect a redirect if not 300 <= rv.status_code <= 310: @@ -1123,3 +1127,15 @@ def test_python_layer(self): assert rv.status_code == 200, 'json load failed with %s' % rv.status_code content = json.loads(rv.data) assert len(content['snapshots']), 'should have at least snapshot' + +class TestSweepCreation(BaseViewsTestWithDataset): + FRAMEWORK = 'caffe' + """ + Model creation tests + """ + def test_sweep(self): + job_ids = self.create_model(json=True, learning_rate='[0.01, 0.02]', batch_size='[8, 10]') + for job_id in job_ids: + assert self.model_wait_completion(job_id) == 'Done', 'create failed' + assert self.delete_model(job_id) == 200, 'delete failed' + assert not self.model_exists(job_id), 'model exists after delete' diff --git a/digits/model/images/classification/views.py b/digits/model/images/classification/views.py index 58561158a..1b12e7fc1 100644 --- a/digits/model/images/classification/views.py +++ b/digits/model/images/classification/views.py @@ -118,150 +118,182 @@ def create(): raise werkzeug.exceptions.BadRequest( 'Unknown dataset job_id "%s"' % form.dataset.data) - job = None - try: - job = ImageClassificationModelJob( - username = utils.auth.get_username(), - name = form.model_name.data, - dataset_id = datasetJob.id(), - ) - # get handle to framework object - fw = frameworks.get_framework_by_id(form.framework.data) - - pretrained_model = None - if form.method.data == 'standard': - found = False - - # can we find it in standard networks? - network_desc = fw.get_standard_network_desc(form.standard_networks.data) - if network_desc: - found = True - network = fw.get_network_from_desc(network_desc) - - if not found: - raise werkzeug.exceptions.BadRequest( - 'Unknown standard model "%s"' % form.standard_networks.data) - elif form.method.data == 'previous': - old_job = scheduler.get_job(form.previous_networks.data) - if not old_job: + # sweeps will be a list of the the permutations of swept fields + # Get swept learning_rate + sweeps = [{'learning_rate': v} for v in form.learning_rate.data] + add_learning_rate = len(form.learning_rate.data) > 1 + + # Add swept batch_size + sweeps = [dict(s.items() + [('batch_size', bs)]) for bs in form.batch_size.data for s in sweeps[:]] + add_batch_size = len(form.batch_size.data) > 1 + n_jobs = len(sweeps) + + jobs = [] + for sweep in sweeps: + # Populate the form with swept data to be used in saving and + # launching jobs. + form.learning_rate.data = sweep['learning_rate'] + form.batch_size.data = sweep['batch_size'] + + # Augment Job Name + extra = '' + if add_learning_rate: + extra += ' learning_rate:%s' % str(form.learning_rate.data[0]) + if add_batch_size: + extra += ' batch_size:%d' % form.batch_size.data[0] + + job = None + try: + job = ImageClassificationModelJob( + username = utils.auth.get_username(), + name = form.model_name.data + extra, + dataset_id = datasetJob.id(), + ) + # get handle to framework object + fw = frameworks.get_framework_by_id(form.framework.data) + + pretrained_model = None + if form.method.data == 'standard': + found = False + + # can we find it in standard networks? + network_desc = fw.get_standard_network_desc(form.standard_networks.data) + if network_desc: + found = True + network = fw.get_network_from_desc(network_desc) + + if not found: + raise werkzeug.exceptions.BadRequest( + 'Unknown standard model "%s"' % form.standard_networks.data) + elif form.method.data == 'previous': + old_job = scheduler.get_job(form.previous_networks.data) + if not old_job: + raise werkzeug.exceptions.BadRequest( + 'Job not found: %s' % form.previous_networks.data) + + use_same_dataset = (old_job.dataset_id == job.dataset_id) + network = fw.get_network_from_previous(old_job.train_task().network, use_same_dataset) + + for choice in form.previous_networks.choices: + if choice[0] == form.previous_networks.data: + epoch = float(flask.request.form['%s-snapshot' % form.previous_networks.data]) + if epoch == 0: + pass + elif epoch == -1: + pretrained_model = old_job.train_task().pretrained_model + else: + for filename, e in old_job.train_task().snapshots: + if e == epoch: + pretrained_model = filename + break + + if pretrained_model is None: + raise werkzeug.exceptions.BadRequest( + "For the job %s, selected pretrained_model for epoch %d is invalid!" + % (form.previous_networks.data, epoch)) + if not (os.path.exists(pretrained_model)): + raise werkzeug.exceptions.BadRequest( + "Pretrained_model for the selected epoch doesn't exists. May be deleted by another user/process. Please restart the server to load the correct pretrained_model details") + break + + elif form.method.data == 'custom': + network = fw.get_network_from_desc(form.custom_network.data) + pretrained_model = form.custom_network_snapshot.data.strip() + else: raise werkzeug.exceptions.BadRequest( - 'Job not found: %s' % form.previous_networks.data) - - use_same_dataset = (old_job.dataset_id == job.dataset_id) - network = fw.get_network_from_previous(old_job.train_task().network, use_same_dataset) - - for choice in form.previous_networks.choices: - if choice[0] == form.previous_networks.data: - epoch = float(flask.request.form['%s-snapshot' % form.previous_networks.data]) - if epoch == 0: - pass - elif epoch == -1: - pretrained_model = old_job.train_task().pretrained_model - else: - for filename, e in old_job.train_task().snapshots: - if e == epoch: - pretrained_model = filename - break - - if pretrained_model is None: - raise werkzeug.exceptions.BadRequest( - "For the job %s, selected pretrained_model for epoch %d is invalid!" - % (form.previous_networks.data, epoch)) - if not (os.path.exists(pretrained_model)): - raise werkzeug.exceptions.BadRequest( - "Pretrained_model for the selected epoch doesn't exists. May be deleted by another user/process. Please restart the server to load the correct pretrained_model details") - break - - elif form.method.data == 'custom': - network = fw.get_network_from_desc(form.custom_network.data) - pretrained_model = form.custom_network_snapshot.data.strip() - else: - raise werkzeug.exceptions.BadRequest( - 'Unrecognized method: "%s"' % form.method.data) - - policy = {'policy': form.lr_policy.data} - if form.lr_policy.data == 'fixed': - pass - elif form.lr_policy.data == 'step': - policy['stepsize'] = form.lr_step_size.data - policy['gamma'] = form.lr_step_gamma.data - elif form.lr_policy.data == 'multistep': - policy['stepvalue'] = form.lr_multistep_values.data - policy['gamma'] = form.lr_multistep_gamma.data - elif form.lr_policy.data == 'exp': - policy['gamma'] = form.lr_exp_gamma.data - elif form.lr_policy.data == 'inv': - policy['gamma'] = form.lr_inv_gamma.data - policy['power'] = form.lr_inv_power.data - elif form.lr_policy.data == 'poly': - policy['power'] = form.lr_poly_power.data - elif form.lr_policy.data == 'sigmoid': - policy['stepsize'] = form.lr_sigmoid_step.data - policy['gamma'] = form.lr_sigmoid_gamma.data - else: - raise werkzeug.exceptions.BadRequest( - 'Invalid learning rate policy') - - if config_value('caffe_root')['multi_gpu']: - if form.select_gpus.data: - selected_gpus = [str(gpu) for gpu in form.select_gpus.data] - gpu_count = None - elif form.select_gpu_count.data: - gpu_count = form.select_gpu_count.data - selected_gpus = None + 'Unrecognized method: "%s"' % form.method.data) + + policy = {'policy': form.lr_policy.data} + if form.lr_policy.data == 'fixed': + pass + elif form.lr_policy.data == 'step': + policy['stepsize'] = form.lr_step_size.data + policy['gamma'] = form.lr_step_gamma.data + elif form.lr_policy.data == 'multistep': + policy['stepvalue'] = form.lr_multistep_values.data + policy['gamma'] = form.lr_multistep_gamma.data + elif form.lr_policy.data == 'exp': + policy['gamma'] = form.lr_exp_gamma.data + elif form.lr_policy.data == 'inv': + policy['gamma'] = form.lr_inv_gamma.data + policy['power'] = form.lr_inv_power.data + elif form.lr_policy.data == 'poly': + policy['power'] = form.lr_poly_power.data + elif form.lr_policy.data == 'sigmoid': + policy['stepsize'] = form.lr_sigmoid_step.data + policy['gamma'] = form.lr_sigmoid_gamma.data else: - gpu_count = 1 - selected_gpus = None - else: - if form.select_gpu.data == 'next': - gpu_count = 1 - selected_gpus = None + raise werkzeug.exceptions.BadRequest( + 'Invalid learning rate policy') + + if config_value('caffe_root')['multi_gpu']: + if form.select_gpus.data: + selected_gpus = [str(gpu) for gpu in form.select_gpus.data] + gpu_count = None + elif form.select_gpu_count.data: + gpu_count = form.select_gpu_count.data + selected_gpus = None + else: + gpu_count = 1 + selected_gpus = None else: - selected_gpus = [str(form.select_gpu.data)] - gpu_count = None - - # Python Layer File may be on the server or copied from the client. - fs.copy_python_layer_file( - bool(form.python_layer_from_client.data), - job.dir(), - (flask.request.files[form.python_layer_client_file.name] - if form.python_layer_client_file.name in flask.request.files - else ''), form.python_layer_server_file.data) - - job.tasks.append(fw.create_train_task( - job_dir = job.dir(), - dataset = datasetJob, - train_epochs = form.train_epochs.data, - snapshot_interval = form.snapshot_interval.data, - learning_rate = form.learning_rate.data, - lr_policy = policy, - gpu_count = gpu_count, - selected_gpus = selected_gpus, - batch_size = form.batch_size.data, - val_interval = form.val_interval.data, - pretrained_model= pretrained_model, - crop_size = form.crop_size.data, - use_mean = form.use_mean.data, - network = network, - random_seed = form.random_seed.data, - solver_type = form.solver_type.data, - shuffle = form.shuffle.data, + if form.select_gpu.data == 'next': + gpu_count = 1 + selected_gpus = None + else: + selected_gpus = [str(form.select_gpu.data)] + gpu_count = None + + # Python Layer File may be on the server or copied from the client. + fs.copy_python_layer_file( + bool(form.python_layer_from_client.data), + job.dir(), + (flask.request.files[form.python_layer_client_file.name] + if form.python_layer_client_file.name in flask.request.files + else ''), form.python_layer_server_file.data) + + job.tasks.append(fw.create_train_task( + job_dir = job.dir(), + dataset = datasetJob, + train_epochs = form.train_epochs.data, + snapshot_interval = form.snapshot_interval.data, + learning_rate = form.learning_rate.data[0], + lr_policy = policy, + gpu_count = gpu_count, + selected_gpus = selected_gpus, + batch_size = form.batch_size.data[0], + val_interval = form.val_interval.data, + pretrained_model= pretrained_model, + crop_size = form.crop_size.data, + use_mean = form.use_mean.data, + network = network, + random_seed = form.random_seed.data, + solver_type = form.solver_type.data, + shuffle = form.shuffle.data, + ) ) - ) - ## Save form data with the job so we can easily clone it later. - save_form_to_job(job, form) + ## Save form data with the job so we can easily clone it later. + save_form_to_job(job, form) - scheduler.add_job(job) - if request_wants_json(): - return flask.jsonify(job.json_dict()) - else: - return flask.redirect(flask.url_for('digits.model.views.show', job_id=job.id())) + jobs.append(job) + scheduler.add_job(job) + if n_jobs == 1: + if request_wants_json(): + return flask.jsonify(job.json_dict()) + else: + return flask.redirect(flask.url_for('digits.model.views.show', job_id=job.id())) + + except: + if job: + scheduler.delete_job(job) + raise + + if request_wants_json(): + return flask.jsonify(jobs=[job.json_dict() for job in jobs]) - except: - if job: - scheduler.delete_job(job) - raise + # If there are multiple jobs launched, go to the home page. + return flask.redirect('/') def show(job): """ diff --git a/digits/model/images/generic/test_views.py b/digits/model/images/generic/test_views.py index 82068365a..3aa3816e6 100644 --- a/digits/model/images/generic/test_views.py +++ b/digits/model/images/generic/test_views.py @@ -194,7 +194,11 @@ def create_model(cls, learning_rate=None, **kwargs): if rv.status_code != 200: print json.loads(rv.data) raise RuntimeError('Model creation failed with %s' % rv.status_code) - return json.loads(rv.data)['id'] + data = json.loads(rv.data) + if 'jobs' in data.keys(): + return [j['id'] for j in data['jobs']] + else: + return data['id'] # expect a redirect if not 300 <= rv.status_code <= 310: @@ -893,3 +897,15 @@ def test_infer_one_json(self): output = np.array(data['outputs']['output'][0]) assert output.shape == (1, self.CROP_SIZE, self.CROP_SIZE), \ 'shape mismatch: %s' % str(output.shape) + +class TestSweepCreation(BaseViewsTestWithDataset): + FRAMEWORK = 'caffe' + """ + Model creation tests + """ + def test_sweep(self): + job_ids = self.create_model(json=True, learning_rate='[0.01, 0.02]', batch_size='[8, 10]') + for job_id in job_ids: + assert self.model_wait_completion(job_id) == 'Done', 'create failed' + assert self.delete_model(job_id) == 200, 'delete failed' + assert not self.model_exists(job_id), 'model exists after delete' diff --git a/digits/model/images/generic/views.py b/digits/model/images/generic/views.py index 040d16bb3..4a994208f 100644 --- a/digits/model/images/generic/views.py +++ b/digits/model/images/generic/views.py @@ -83,137 +83,169 @@ def create(): raise werkzeug.exceptions.BadRequest( 'Unknown dataset job_id "%s"' % form.dataset.data) - job = None - try: - job = GenericImageModelJob( - username = utils.auth.get_username(), - name = form.model_name.data, - dataset_id = datasetJob.id(), - ) - - # get framework (hard-coded to caffe for now) - fw = frameworks.get_framework_by_id(form.framework.data) + # sweeps will be a list of the the permutations of swept fields + # Get swept learning_rate + sweeps = [{'learning_rate': v} for v in form.learning_rate.data] + add_learning_rate = len(form.learning_rate.data) > 1 + + # Add swept batch_size + sweeps = [dict(s.items() + [('batch_size', bs)]) for bs in form.batch_size.data for s in sweeps[:]] + add_batch_size = len(form.batch_size.data) > 1 + n_jobs = len(sweeps) + + jobs = [] + for sweep in sweeps: + # Populate the form with swept data to be used in saving and + # launching jobs. + form.learning_rate.data = sweep['learning_rate'] + form.batch_size.data = sweep['batch_size'] + + # Augment Job Name + extra = '' + if add_learning_rate: + extra += ' learning_rate:%s' % str(form.learning_rate.data[0]) + if add_batch_size: + extra += ' batch_size:%d' % form.batch_size.data[0] + + job = None + try: + job = GenericImageModelJob( + username = utils.auth.get_username(), + name = form.model_name.data + extra, + dataset_id = datasetJob.id(), + ) - pretrained_model = None - #if form.method.data == 'standard': - if form.method.data == 'previous': - old_job = scheduler.get_job(form.previous_networks.data) - if not old_job: + # get framework (hard-coded to caffe for now) + fw = frameworks.get_framework_by_id(form.framework.data) + + pretrained_model = None + #if form.method.data == 'standard': + if form.method.data == 'previous': + old_job = scheduler.get_job(form.previous_networks.data) + if not old_job: + raise werkzeug.exceptions.BadRequest( + 'Job not found: %s' % form.previous_networks.data) + + use_same_dataset = (old_job.dataset_id == job.dataset_id) + network = fw.get_network_from_previous(old_job.train_task().network, use_same_dataset) + + for choice in form.previous_networks.choices: + if choice[0] == form.previous_networks.data: + epoch = float(flask.request.form['%s-snapshot' % form.previous_networks.data]) + if epoch == 0: + pass + elif epoch == -1: + pretrained_model = old_job.train_task().pretrained_model + else: + for filename, e in old_job.train_task().snapshots: + if e == epoch: + pretrained_model = filename + break + + if pretrained_model is None: + raise werkzeug.exceptions.BadRequest( + "For the job %s, selected pretrained_model for epoch %d is invalid!" + % (form.previous_networks.data, epoch)) + if not (os.path.exists(pretrained_model)): + raise werkzeug.exceptions.BadRequest( + "Pretrained_model for the selected epoch doesn't exists. May be deleted by another user/process. Please restart the server to load the correct pretrained_model details") + break + + elif form.method.data == 'custom': + network = fw.get_network_from_desc(form.custom_network.data) + pretrained_model = form.custom_network_snapshot.data.strip() + else: raise werkzeug.exceptions.BadRequest( - 'Job not found: %s' % form.previous_networks.data) - - use_same_dataset = (old_job.dataset_id == job.dataset_id) - network = fw.get_network_from_previous(old_job.train_task().network, use_same_dataset) - - for choice in form.previous_networks.choices: - if choice[0] == form.previous_networks.data: - epoch = float(flask.request.form['%s-snapshot' % form.previous_networks.data]) - if epoch == 0: - pass - elif epoch == -1: - pretrained_model = old_job.train_task().pretrained_model - else: - for filename, e in old_job.train_task().snapshots: - if e == epoch: - pretrained_model = filename - break - - if pretrained_model is None: - raise werkzeug.exceptions.BadRequest( - "For the job %s, selected pretrained_model for epoch %d is invalid!" - % (form.previous_networks.data, epoch)) - if not (os.path.exists(pretrained_model)): - raise werkzeug.exceptions.BadRequest( - "Pretrained_model for the selected epoch doesn't exists. May be deleted by another user/process. Please restart the server to load the correct pretrained_model details") - break - - elif form.method.data == 'custom': - network = fw.get_network_from_desc(form.custom_network.data) - pretrained_model = form.custom_network_snapshot.data.strip() - else: - raise werkzeug.exceptions.BadRequest( - 'Unrecognized method: "%s"' % form.method.data) - - policy = {'policy': form.lr_policy.data} - if form.lr_policy.data == 'fixed': - pass - elif form.lr_policy.data == 'step': - policy['stepsize'] = form.lr_step_size.data - policy['gamma'] = form.lr_step_gamma.data - elif form.lr_policy.data == 'multistep': - policy['stepvalue'] = form.lr_multistep_values.data - policy['gamma'] = form.lr_multistep_gamma.data - elif form.lr_policy.data == 'exp': - policy['gamma'] = form.lr_exp_gamma.data - elif form.lr_policy.data == 'inv': - policy['gamma'] = form.lr_inv_gamma.data - policy['power'] = form.lr_inv_power.data - elif form.lr_policy.data == 'poly': - policy['power'] = form.lr_poly_power.data - elif form.lr_policy.data == 'sigmoid': - policy['stepsize'] = form.lr_sigmoid_step.data - policy['gamma'] = form.lr_sigmoid_gamma.data - else: - raise werkzeug.exceptions.BadRequest( - 'Invalid learning rate policy') - - if config_value('caffe_root')['multi_gpu']: - if form.select_gpu_count.data: - gpu_count = form.select_gpu_count.data - selected_gpus = None + 'Unrecognized method: "%s"' % form.method.data) + + policy = {'policy': form.lr_policy.data} + if form.lr_policy.data == 'fixed': + pass + elif form.lr_policy.data == 'step': + policy['stepsize'] = form.lr_step_size.data + policy['gamma'] = form.lr_step_gamma.data + elif form.lr_policy.data == 'multistep': + policy['stepvalue'] = form.lr_multistep_values.data + policy['gamma'] = form.lr_multistep_gamma.data + elif form.lr_policy.data == 'exp': + policy['gamma'] = form.lr_exp_gamma.data + elif form.lr_policy.data == 'inv': + policy['gamma'] = form.lr_inv_gamma.data + policy['power'] = form.lr_inv_power.data + elif form.lr_policy.data == 'poly': + policy['power'] = form.lr_poly_power.data + elif form.lr_policy.data == 'sigmoid': + policy['stepsize'] = form.lr_sigmoid_step.data + policy['gamma'] = form.lr_sigmoid_gamma.data else: - selected_gpus = [str(gpu) for gpu in form.select_gpus.data] - gpu_count = None - else: - if form.select_gpu.data == 'next': - gpu_count = 1 - selected_gpus = None + raise werkzeug.exceptions.BadRequest( + 'Invalid learning rate policy') + + if config_value('caffe_root')['multi_gpu']: + if form.select_gpu_count.data: + gpu_count = form.select_gpu_count.data + selected_gpus = None + else: + selected_gpus = [str(gpu) for gpu in form.select_gpus.data] + gpu_count = None else: - selected_gpus = [str(form.select_gpu.data)] - gpu_count = None - - # Python Layer File may be on the server or copied from the client. - fs.copy_python_layer_file( - bool(form.python_layer_from_client.data), - job.dir(), - (flask.request.files[form.python_layer_client_file.name] - if form.python_layer_client_file.name in flask.request.files - else ''), form.python_layer_server_file.data) - - job.tasks.append(fw.create_train_task( - job_dir = job.dir(), - dataset = datasetJob, - train_epochs = form.train_epochs.data, - snapshot_interval = form.snapshot_interval.data, - learning_rate = form.learning_rate.data, - lr_policy = policy, - gpu_count = gpu_count, - selected_gpus = selected_gpus, - batch_size = form.batch_size.data, - val_interval = form.val_interval.data, - pretrained_model= pretrained_model, - crop_size = form.crop_size.data, - use_mean = form.use_mean.data, - network = network, - random_seed = form.random_seed.data, - solver_type = form.solver_type.data, - shuffle = form.shuffle.data, + if form.select_gpu.data == 'next': + gpu_count = 1 + selected_gpus = None + else: + selected_gpus = [str(form.select_gpu.data)] + gpu_count = None + + # Python Layer File may be on the server or copied from the client. + fs.copy_python_layer_file( + bool(form.python_layer_from_client.data), + job.dir(), + (flask.request.files[form.python_layer_client_file.name] + if form.python_layer_client_file.name in flask.request.files + else ''), form.python_layer_server_file.data) + + job.tasks.append(fw.create_train_task( + job_dir = job.dir(), + dataset = datasetJob, + train_epochs = form.train_epochs.data, + snapshot_interval = form.snapshot_interval.data, + learning_rate = form.learning_rate.data[0], + lr_policy = policy, + gpu_count = gpu_count, + selected_gpus = selected_gpus, + batch_size = form.batch_size.data[0], + val_interval = form.val_interval.data, + pretrained_model= pretrained_model, + crop_size = form.crop_size.data, + use_mean = form.use_mean.data, + network = network, + random_seed = form.random_seed.data, + solver_type = form.solver_type.data, + shuffle = form.shuffle.data, + ) ) - ) - ## Save form data with the job so we can easily clone it later. - save_form_to_job(job, form) + ## Save form data with the job so we can easily clone it later. + save_form_to_job(job, form) - scheduler.add_job(job) - if request_wants_json(): - return flask.jsonify(job.json_dict()) - else: - return flask.redirect(flask.url_for('digits.model.views.show', job_id=job.id())) + jobs.append(job) + scheduler.add_job(job) + if n_jobs == 1: + if request_wants_json(): + return flask.jsonify(job.json_dict()) + else: + return flask.redirect(flask.url_for('digits.model.views.show', job_id=job.id())) + + except: + if job: + scheduler.delete_job(job) + raise + + if request_wants_json(): + return flask.jsonify(jobs=[job.json_dict() for job in jobs]) - except: - if job: - scheduler.delete_job(job) - raise + # If there are multiple jobs launched, go to the home page. + return flask.redirect('/') def show(job): """ diff --git a/digits/model/views.py b/digits/model/views.py index c381a0296..f01013708 100644 --- a/digits/model/views.py +++ b/digits/model/views.py @@ -160,7 +160,8 @@ def visualize_lr(): Returns a JSON object of data used to create the learning rate graph """ policy = flask.request.form['lr_policy'] - lr = float(flask.request.form['learning_rate']) + # There may be multiple lrs if the learning_rate is swept + lrs = map(float, flask.request.form['learning_rate'].split(',')) if policy == 'fixed': pass elif policy == 'step': @@ -183,26 +184,29 @@ def visualize_lr(): else: raise werkzeug.exceptions.BadRequest('Invalid policy') - data = ['Learning Rate'] - for i in xrange(101): - if policy == 'fixed': - data.append(lr) - elif policy == 'step': - data.append(lr * math.pow(gamma, math.floor(float(i)/step))) - elif policy == 'multistep': - if current_step < len(steps) and i >= steps[current_step]: - current_step += 1 - data.append(lr * math.pow(gamma, current_step)) - elif policy == 'exp': - data.append(lr * math.pow(gamma, i)) - elif policy == 'inv': - data.append(lr * math.pow(1.0 + gamma * i, -power)) - elif policy == 'poly': - data.append(lr * math.pow(1.0 - float(i)/100, power)) - elif policy == 'sigmoid': - data.append(lr / (1.0 + math.exp(gamma * (i - step)))) - - return json.dumps({'data': {'columns': [data]}}) + datalist = [] + for j, lr in enumerate(lrs): + data = ['Learning Rate %d' % j] + for i in xrange(101): + if policy == 'fixed': + data.append(lr) + elif policy == 'step': + data.append(lr * math.pow(gamma, math.floor(float(i)/step))) + elif policy == 'multistep': + if current_step < len(steps) and i >= steps[current_step]: + current_step += 1 + data.append(lr * math.pow(gamma, current_step)) + elif policy == 'exp': + data.append(lr * math.pow(gamma, i)) + elif policy == 'inv': + data.append(lr * math.pow(1.0 + gamma * i, -power)) + elif policy == 'poly': + data.append(lr * math.pow(1.0 - float(i)/100, power)) + elif policy == 'sigmoid': + data.append(lr / (1.0 + math.exp(gamma * (i - step)))) + datalist.append(data) + + return json.dumps({'data': {'columns': datalist}}) @blueprint.route('//download', methods=['GET', 'POST'], diff --git a/digits/templates/models/images/classification/new.html b/digits/templates/models/images/classification/new.html index a28faab60..8ef574f1d 100644 --- a/digits/templates/models/images/classification/new.html +++ b/digits/templates/models/images/classification/new.html @@ -152,6 +152,9 @@

Solver Options

{{form.batch_size.label}} {{form.batch_size.tooltip}} + + {{form.batch_size.small_text}} + {{form.batch_size(class='form-control', placeholder='[network defaults]')}}
@@ -162,6 +165,9 @@

Solver Options

{{form.learning_rate.label}} {{form.learning_rate.tooltip}} + + {{form.learning_rate.small_text}} + {{form.learning_rate(class='form-control learning-rate-option')}}
diff --git a/digits/templates/models/images/generic/new.html b/digits/templates/models/images/generic/new.html index 795135950..f88e99676 100644 --- a/digits/templates/models/images/generic/new.html +++ b/digits/templates/models/images/generic/new.html @@ -152,6 +152,9 @@

Solver Options

{{form.batch_size.label}} {{form.batch_size.tooltip}} + + {{form.batch_size.small_text}} + {{form.batch_size(class='form-control', placeholder='[network defaults]')}}
@@ -162,6 +165,9 @@

Solver Options

{{form.learning_rate.label}} {{form.learning_rate.tooltip}} + + {{form.learning_rate.small_text}} + {{form.learning_rate(class='form-control learning-rate-option')}}
diff --git a/digits/utils/forms.py b/digits/utils/forms.py index 2c8786457..a6b5d09a0 100644 --- a/digits/utils/forms.py +++ b/digits/utils/forms.py @@ -5,6 +5,7 @@ import wtforms from wtforms import SubmitField from wtforms import validators +from wtforms.compat import string_types from digits.utils.routing import get_request_arg @@ -225,6 +226,166 @@ def __init__(self, label='', validators=None, tooltip='', explanation_file = '', self.tooltip = Tooltip(self.id, self.short_name, tooltip) self.explanation = Explanation(self.id, self.short_name, explanation_file) +class MultiIntegerField(wtforms.Field): + """ + A text field, except all input is coerced to one of more integers. + Erroneous input is ignored and will not be accepted as a value. + """ + widget = wtforms.widgets.TextInput() + + def is_int(self, v): + try: + v = int(v) + return True + except: + return False + + def __init__(self, label='', validators=None, tooltip='', explanation_file = '', **kwargs): + super(MultiIntegerField, self).__init__(label, validators, **kwargs) + self.tooltip = Tooltip(self.id, self.short_name, tooltip + ' (accepts comma separated list)') + self.explanation = Explanation(self.id, self.short_name, explanation_file) + self.small_text = 'multiples allowed' + + def __setattr__(self, name, value): + if name == 'data': + if not isinstance(value, (list, tuple)): + value = [value] + value = [int(x) for x in value if self.is_int(x)] + if len(value) == 0: + value = [None] + self.__dict__[name] = value + + def _value(self): + if self.raw_data: + return self.raw_data[0] + return ','.join([str(x) for x in self.data if self.is_int(x)]) + + def process_formdata(self, valuelist): + if valuelist: + try: + valuelist[0] = valuelist[0].replace('[', '') + valuelist[0] = valuelist[0].replace(']', '') + valuelist[0] = valuelist[0].split(',') + self.data = [int(float(datum)) for datum in valuelist[0]] + except ValueError: + self.data = [None] + raise ValueError(self.gettext('Not a valid integer value')) + +class MultiFloatField(wtforms.Field): + """ + A text field, except all input is coerced to one of more floats. + Erroneous input is ignored and will not be accepted as a value. + """ + widget = wtforms.widgets.TextInput() + + def is_float(self, v): + try: + v = float(v) + return True + except: + return False + + def __init__(self, label='', validators=None, tooltip='', explanation_file = '', **kwargs): + super(MultiFloatField, self).__init__(label, validators, **kwargs) + self.tooltip = Tooltip(self.id, self.short_name, tooltip + ' (accepts comma separated list)') + self.explanation = Explanation(self.id, self.short_name, explanation_file) + self.small_text = 'multiples allowed' + + def __setattr__(self, name, value): + if name == 'data': + if not isinstance(value, (list, tuple)): + value = [value] + value = [float(x) for x in value if self.is_float(x)] + if len(value) == 0: + value = [None] + self.__dict__[name] = value + + def _value(self): + if self.raw_data: + return self.raw_data[0] + return ','.join([str(x) for x in self.data if self.is_float(x)]) + + def process_formdata(self, valuelist): + if valuelist: + try: + valuelist[0] = valuelist[0].replace('[', '') + valuelist[0] = valuelist[0].replace(']', '') + valuelist[0] = valuelist[0].split(',') + self.data = [float(datum) for datum in valuelist[0]] + except ValueError: + self.data = [None] + raise ValueError(self.gettext('Not a valid float value')) + + def data_array(self): + if isinstance(self.data, (list, tuple)): + return self.data + else: + return [self.data] + +class MultiNumberRange(object): + """ + Validates that a number is of a minimum and/or maximum value, inclusive. + This will work with any comparable number type, such as floats and + decimals, not just integers. + + :param min: + The minimum required value of the number. If not provided, minimum + value will not be checked. + :param max: + The maximum value of the number. If not provided, maximum value + will not be checked. + :param message: + Error message to raise in case of a validation error. Can be + interpolated using `%(min)s` and `%(max)s` if desired. Useful defaults + are provided depending on the existence of min and max. + """ + def __init__(self, min=None, max=None, message=None): + self.min = min + self.max = max + self.message = message + + def __call__(self, form, field): + fdata = field.data if isinstance(field.data, (list, tuple)) else [field.data] + for data in fdata: + if data is None or (self.min is not None and data < self.min) or \ + (self.max is not None and data > self.max): + message = self.message + if message is None: + # we use %(min)s interpolation to support floats, None, and + # Decimals without throwing a formatting exception. + if self.max is None: + message = field.gettext('Number %(data)s must be at least %(min)s.') + elif self.min is None: + message = field.gettext('Number %(data)s must be at most %(max)s.') + else: + message = field.gettext('Number %(data)s must be between %(min)s and %(max)s.') + + raise validators.ValidationError(message % dict(data=data, min=self.min, max=self.max)) + +class MultiOptional(object): + """ + Allows empty input and stops the validation chain from continuing. + + If input is empty, also removes prior errors (such as processing errors) + from the field. + + :param strip_whitespace: + If True (the default) also stop the validation chain on input which + consists of only whitespace. + """ + field_flags = ('optional', ) + + def __init__(self, strip_whitespace=True): + if strip_whitespace: + self.string_check = lambda s: s.strip() + else: + self.string_check = lambda s: s + + def __call__(self, form, field): + if not field.raw_data or isinstance(field.raw_data[0], string_types) and not self.string_check(field.raw_data[0]): + field.errors[:] = [] + raise validators.StopValidation() + ## Used to save data to populate forms when cloning def add_warning(form, warning): if not hasattr(form, 'warnings'): @@ -245,7 +406,8 @@ def iterate_over_form(job, form, function, prefix = ['form'], indent = ''): whitelist_fields = [ 'BooleanField', 'FloatField', 'HiddenField', 'IntegerField', 'RadioField', 'SelectField', 'SelectMultipleField', - 'StringField', 'TextAreaField', 'TextField'] + 'StringField', 'TextAreaField', 'TextField', + 'MultiIntegerField', 'MultiFloatField'] blacklist_fields = ['FileField', 'SubmitField'] @@ -278,7 +440,6 @@ def set_data(job, form, key, value): if isinstance(value, basestring): value = '\'' + value + '\'' - # print '\'' + key + '\': ' + str(value) +',' return False ## function to pass to iterate_over_form to get data from job