Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
modeling downloads now returns the original network definition as wel…
Browse files Browse the repository at this point in the history
…l as other stats stored in a json file

moved snapshot function to train

moved get_snapshot code to tasks

removed extra whitespace
  • Loading branch information
Lucaszw committed Jul 7, 2016
1 parent c69d709 commit 190e020
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 58 deletions.
12 changes: 1 addition & 11 deletions digits/model/images/classification/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,7 @@ def job_type(self):
def download_files(self, epoch=-1):
task = self.train_task()

snapshot_filename = None
if epoch == -1 and len(task.snapshots):
epoch = task.snapshots[-1][1]
snapshot_filename = task.snapshots[-1][0]
else:
for f, e in task.snapshots:
if e == epoch:
snapshot_filename = f
break
if not snapshot_filename:
raise ValueError('Invalid epoch')
snapshot_filename = task.get_snapshot(epoch)

# get model files
model_files = task.get_model_files()
Expand Down
12 changes: 1 addition & 11 deletions digits/model/images/generic/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,7 @@ def job_type(self):
def download_files(self, epoch=-1):
task = self.train_task()

snapshot_filename = None
if epoch == -1 and len(task.snapshots):
epoch = task.snapshots[-1][1]
snapshot_filename = task.snapshots[-1][0]
else:
for f, e in task.snapshots:
if e == epoch:
snapshot_filename = f
break
if not snapshot_filename:
raise ValueError('Invalid epoch')
snapshot_filename = task.get_snapshot(epoch)

# get model files
model_files = task.get_model_files()
Expand Down
17 changes: 17 additions & 0 deletions digits/model/images/job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright (c) 2014-2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import json
import os

from ..job import ModelJob
from digits.utils import subclass, override

# NOTE: Increment this everytime the pickled object changes
PICKLE_VERSION = 1
Expand All @@ -17,3 +21,16 @@ def __init__(self, **kwargs):
super(ImageModelJob, self).__init__(**kwargs)
self.pickver_job_model_image = PICKLE_VERSION

@override
def get_job_stats_as_json_string(self,epoch=-1):
task = self.train_task()

stats = {
"job id": self.id(),
"creation time": self.status_history[0][1],
"username": self.username,
}

stats.update(task.get_task_stats(epoch))

return json.dumps(stats)
6 changes: 6 additions & 0 deletions digits/model/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def train_task(self):
"""Return the first TrainTask for this job"""
return [t for t in self.tasks if isinstance(t, tasks.TrainTask)][0]

def get_job_stats_as_json_string(self):
"""
Returns stats for a job as a json string
"""
return NotImplementedError()

def download_files(self):
"""
Returns a list of tuples: [(path, filename)...]
Expand Down
64 changes: 51 additions & 13 deletions digits/model/tasks/caffe_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

# Constants
CAFFE_SOLVER_FILE = 'solver.prototxt'
CAFFE_ORIGINAL_FILE = 'original.prototxt'
CAFFE_TRAIN_VAL_FILE = 'train_val.prototxt'
CAFFE_SNAPSHOT_PREFIX = 'snapshot'
CAFFE_DEPLOY_FILE = 'deploy.prototxt'
Expand Down Expand Up @@ -83,11 +84,16 @@ def __init__(self, **kwargs):
self.solver = None

self.solver_file = CAFFE_SOLVER_FILE
self.original_file = CAFFE_ORIGINAL_FILE
self.train_val_file = CAFFE_TRAIN_VAL_FILE
self.snapshot_prefix = CAFFE_SNAPSHOT_PREFIX
self.deploy_file = CAFFE_DEPLOY_FILE
self.log_file = self.CAFFE_LOG

self.digits_version = digits.__version__
self.caffe_version = config_value('caffe_root')['ver_str']
self.caffe_flavor = config_value('caffe_root')['flavor']

def __getstate__(self):
state = super(CaffeTrainTask, self).__getstate__()

Expand Down Expand Up @@ -228,6 +234,10 @@ def save_files_classification(self):
"""
Save solver, train_val and deploy files to disk
"""
# Save the origin network to file:
with open(self.path(self.original_file), 'w') as outfile:
text_format.PrintMessage(self.network, outfile)

network = cleanedUpClassificationNetwork(self.network, len(self.get_labels()))
data_layers, train_val_layers, deploy_layers = filterLayersByState(network)

Expand Down Expand Up @@ -523,6 +533,10 @@ def save_files_generic(self):

assert train_feature_db_path is not None, 'Training images are required'

# Save the origin network to file:
with open(self.path(self.original_file), 'w') as outfile:
text_format.PrintMessage(self.network, outfile)

### Split up train_val and deploy layers

network = cleanedUpGenericNetwork(self.network)
Expand Down Expand Up @@ -1030,6 +1044,38 @@ def after_runtime_error(self):

### TrainTask overrides

@override
def get_task_stats(self,epoch=-1):
"""
return a dictionary of task statistics
"""

loc, mean_file = os.path.split(self.dataset.get_mean_file())

stats = {
"image dimensions": self.dataset.get_feature_dims(),
"mean file": mean_file,
"snapshot file": self.get_snapshot_filename(epoch),
"solver file": self.solver_file,
"train_val file": self.train_val_file,
"deploy file": self.deploy_file,
"framework": "caffe"
}

if hasattr(self,"original_file"):
stats.update({"original file": self.original_file})

if hasattr(self,"digits_version"):
stats.update({"digits version": self.digits_version})

if hasattr(self.dataset,"resize_mode"):
stats.update({"image resize mode": self.dataset.resize_mode})

if hasattr(self.dataset,"labels_file"):
stats.update({"labels file": self.dataset.labels_file})

return stats

@override
def detect_snapshots(self):
self.snapshots = []
Expand Down Expand Up @@ -1316,18 +1362,7 @@ def get_net(self, epoch=None, gpu=-1):
if not self.has_model():
return False

file_to_load = None

if not epoch:
epoch = self.snapshots[-1][1]
file_to_load = self.snapshots[-1][0]
else:
for snapshot_file, snapshot_epoch in self.snapshots:
if snapshot_epoch == epoch:
file_to_load = snapshot_file
break
if file_to_load is None:
raise Exception('snapshot not found for epoch "%s"' % epoch)
file_to_load = self.get_snapshot(epoch)

# check if already loaded
if self.loaded_snapshot_file and self.loaded_snapshot_file == file_to_load \
Expand Down Expand Up @@ -1421,11 +1456,14 @@ def get_model_files(self):
"""
return paths to model files
"""
return {
model_files = {
"Solver": self.solver_file,
"Network (train/val)": self.train_val_file,
"Network (deploy)": self.deploy_file
}
if hasattr(self,"original_file"):
model_files.update({"Network (original)": self.original_file})
return model_files

@override
def get_network_desc(self):
Expand Down
36 changes: 22 additions & 14 deletions digits/model/tasks/torch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(self, **kwargs):
self.snapshot_prefix = TORCH_SNAPSHOT_PREFIX
self.log_file = self.TORCH_LOG

self.digits_version = digits.__version__

def __getstate__(self):
state = super(TorchTrainTask, self).__getstate__()

Expand Down Expand Up @@ -926,23 +928,29 @@ def get_network_desc(self):
desc = infile.read()
return desc

def get_snapshot(self, epoch):
@override
def get_task_stats(self,epoch=-1):
"""
return snapshot file for specified epoch
return a dictionary of task statistics
"""
file_to_load = None

if not epoch:
epoch = self.snapshots[-1][1]
file_to_load = self.snapshots[-1][0]
else:
for snapshot_file, snapshot_epoch in self.snapshots:
if snapshot_epoch == epoch:
file_to_load = snapshot_file
break
if file_to_load is None:
raise Exception('snapshot not found for epoch "%s"' % epoch)
loc, mean_file = os.path.split(self.dataset.get_mean_file())

stats = {
"image dimensions": self.dataset.get_feature_dims(),
"mean file": mean_file,
"snapshot file": self.get_snapshot_filename(epoch),
"model file": self.model_file,
"framework": "torch"
}

if hasattr(self,"digits_version"):
stats.update({"digits version": self.digits_version})

return file_to_load
if hasattr(self.dataset,"resize_mode"):
stats.update({"image resize mode": self.dataset.resize_mode})

if hasattr(self.dataset,"labels_file"):
stats.update({"labels file": self.dataset.labels_file})

return stats
30 changes: 30 additions & 0 deletions digits/model/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,31 @@ def infer_many(self, data, model_epoch=None):
"""
return None

def get_snapshot(self, epoch=-1):
"""
return snapshot file for specified epoch
"""
snapshot_filename = None
if epoch == -1 or not epoch and len(self.snapshots):
epoch = self.snapshots[-1][1]
snapshot_filename = self.snapshots[-1][0]
else:
for f, e in self.snapshots:
if e == epoch:
snapshot_filename = f
break
if not snapshot_filename:
raise ValueError('Invalid epoch')

return snapshot_filename

def get_snapshot_filename(self,epoch=-1):
"""
Return the filename for the specified epoch
"""
path, name = os.path.split(self.get_snapshot(epoch))
return name

def get_labels(self):
"""
Read labels from labels_file and return them in a list
Expand Down Expand Up @@ -547,3 +572,8 @@ def get_network_desc(self):
"""
raise NotImplementedError()

def get_task_stats(self,epoch=-1):
"""
return a dictionary of task statistics
"""
raise NotImplementedError()
28 changes: 19 additions & 9 deletions digits/model/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import math
import tarfile
import tempfile
import zipfile

import flask
Expand Down Expand Up @@ -176,7 +177,9 @@ def download(job_id, extension):
"""
Return a tarball of all files required to run the model
"""

job = scheduler.get_job(job_id)

if job is None:
raise werkzeug.exceptions.NotFound('Job not found')

Expand All @@ -189,18 +192,20 @@ def download(job_id, extension):
elif 'snapshot_epoch' in flask.request.form:
epoch = float(flask.request.form['snapshot_epoch'])

# Write the stats of the job to json,
# and store in tempfile (for archive)
stats = job.get_job_stats_as_json_string(epoch)
temp = tempfile.NamedTemporaryFile()
temp.write(stats)
temp.seek(0)

task = job.train_task()

snapshot_filename = None
if epoch == -1 and len(task.snapshots):
epoch = task.snapshots[-1][1]
snapshot_filename = task.snapshots[-1][0]
else:
for f, e in task.snapshots:
if e == epoch:
snapshot_filename = f
break
if not snapshot_filename:

try:
snapshot_filename = task.get_snapshot(epoch)
except:
raise werkzeug.exceptions.BadRequest('Invalid epoch')

b = io.BytesIO()
Expand All @@ -214,13 +219,18 @@ def download(job_id, extension):
with tarfile.open(fileobj=b, mode='w:%s' % mode) as tf:
for path, name in job.download_files(epoch):
tf.add(path, arcname=name)
tf.add(temp.name,arcname="stats.json")
elif extension in ['zip']:
with zipfile.ZipFile(b, 'w') as zf:
for path, name in job.download_files(epoch):
zf.write(path, arcname=name)
zf.write(temp.name,arcname="stats.json")
else:
raise werkzeug.exceptions.BadRequest('Invalid extension')

# Close and delete temporary file
temp.close()

response = flask.make_response(b.getvalue())
response.headers['Content-Disposition'] = 'attachment; filename=%s_epoch_%s.%s' % (job.id(), epoch, extension)
return response
Expand Down

0 comments on commit 190e020

Please sign in to comment.