Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Merge pull request #891 from Lucaszw/downloadedModelInformation
Browse files Browse the repository at this point in the history
Get stats & original network def from "Download Model" button
  • Loading branch information
lukeyeager authored Jul 12, 2016
2 parents 34d443f + 8e64591 commit d0b407a
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 61 deletions.
12 changes: 1 addition & 11 deletions digits/model/images/classification/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,7 @@ def job_type(self):
def download_files(self, epoch=-1):
task = self.train_task()

snapshot_filename = None
if epoch == -1 and len(task.snapshots):
epoch = task.snapshots[-1][1]
snapshot_filename = task.snapshots[-1][0]
else:
for f, e in task.snapshots:
if e == epoch:
snapshot_filename = f
break
if not snapshot_filename:
raise ValueError('Invalid epoch')
snapshot_filename = task.get_snapshot(epoch)

# get model files
model_files = task.get_model_files()
Expand Down
14 changes: 14 additions & 0 deletions digits/model/images/classification/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,11 @@ def test_clone(self):
content2.pop('id')
content1.pop('directory')
content2.pop('directory')
content1.pop('creation time')
content2.pop('creation time')
content1.pop('job id')
content2.pop('job id')

assert (content1 == content2), 'job content does not match'

job1 = digits.webapp.scheduler.get_job(job1_id)
Expand All @@ -522,6 +527,15 @@ def test_save(self):
job = digits.webapp.scheduler.get_job(self.model_id)
assert job.save(), 'Job failed to save'

def test_get_snapshot(self):
job = digits.webapp.scheduler.get_job(self.model_id)
task = job.train_task()
f = task.get_snapshot(-1)

assert f, "Failed to load snapshot"
filename = task.get_snapshot_filename(-1)
assert filename, "Failed to get filename"

def test_download(self):
for extension in ['tar', 'zip', 'tar.gz', 'tar.bz2']:
yield self.check_download, extension
Expand Down
12 changes: 1 addition & 11 deletions digits/model/images/generic/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,7 @@ def job_type(self):
def download_files(self, epoch=-1):
task = self.train_task()

snapshot_filename = None
if epoch == -1 and len(task.snapshots):
epoch = task.snapshots[-1][1]
snapshot_filename = task.snapshots[-1][0]
else:
for f, e in task.snapshots:
if e == epoch:
snapshot_filename = f
break
if not snapshot_filename:
raise ValueError('Invalid epoch')
snapshot_filename = task.get_snapshot(epoch)

# get model files
model_files = task.get_model_files()
Expand Down
15 changes: 14 additions & 1 deletion digits/model/images/generic/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,11 @@ def test_clone(self):
content2.pop('id')
content1.pop('directory')
content2.pop('directory')
content1.pop('creation time')
content2.pop('creation time')
content1.pop('job id')
content2.pop('job id')

assert (content1 == content2), 'job content does not match'

job1 = digits.webapp.scheduler.get_job(job1_id)
Expand All @@ -498,6 +503,15 @@ def test_save(self):
job = digits.webapp.scheduler.get_job(self.model_id)
assert job.save(), 'Job failed to save'

def test_get_snapshot(self):
job = digits.webapp.scheduler.get_job(self.model_id)
task = job.train_task()
f = task.get_snapshot(-1)

assert f, "Failed to load snapshot"
filename = task.get_snapshot_filename(-1)
assert filename, "Failed to get filename"

def test_download(self):
for extension in ['tar', 'zip', 'tar.gz', 'tar.bz2']:
yield self.check_download, extension
Expand Down Expand Up @@ -1050,4 +1064,3 @@ class TestAllInOneNetwork(BaseTestCreation, BaseTestCreated):
exclude { stage: "deploy" }
}
"""

17 changes: 17 additions & 0 deletions digits/model/images/job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) 2014-2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import os
import datetime
from ..job import ModelJob
from digits.utils import subclass, override

# NOTE: Increment this everytime the pickled object changes
PICKLE_VERSION = 1
Expand All @@ -17,3 +20,17 @@ def __init__(self, **kwargs):
super(ImageModelJob, self).__init__(**kwargs)
self.pickver_job_model_image = PICKLE_VERSION

@override
def json_dict(self, verbose=False, epoch=-1):
d = super(ImageModelJob, self).json_dict(verbose)
task = self.train_task()
creation_time = str(datetime.datetime.fromtimestamp(self.status_history[0][1]))

d.update({
"job id": self.id(),
"creation time": creation_time,
"username": self.username,
})

d.update(task.get_task_stats(epoch))
return d
67 changes: 54 additions & 13 deletions digits/model/tasks/caffe_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

# Constants
CAFFE_SOLVER_FILE = 'solver.prototxt'
CAFFE_ORIGINAL_FILE = 'original.prototxt'
CAFFE_TRAIN_VAL_FILE = 'train_val.prototxt'
CAFFE_SNAPSHOT_PREFIX = 'snapshot'
CAFFE_DEPLOY_FILE = 'deploy.prototxt'
Expand Down Expand Up @@ -83,11 +84,16 @@ def __init__(self, **kwargs):
self.solver = None

self.solver_file = CAFFE_SOLVER_FILE
self.original_file = CAFFE_ORIGINAL_FILE
self.train_val_file = CAFFE_TRAIN_VAL_FILE
self.snapshot_prefix = CAFFE_SNAPSHOT_PREFIX
self.deploy_file = CAFFE_DEPLOY_FILE
self.log_file = self.CAFFE_LOG

self.digits_version = digits.__version__
self.caffe_version = config_value('caffe_root')['ver_str']
self.caffe_flavor = config_value('caffe_root')['flavor']

def __getstate__(self):
state = super(CaffeTrainTask, self).__getstate__()

Expand Down Expand Up @@ -228,6 +234,10 @@ def save_files_classification(self):
"""
Save solver, train_val and deploy files to disk
"""
# Save the origin network to file:
with open(self.path(self.original_file), 'w') as outfile:
text_format.PrintMessage(self.network, outfile)

network = cleanedUpClassificationNetwork(self.network, len(self.get_labels()))
data_layers, train_val_layers, deploy_layers = filterLayersByState(network)

Expand Down Expand Up @@ -523,6 +533,10 @@ def save_files_generic(self):

assert train_feature_db_path is not None, 'Training images are required'

# Save the origin network to file:
with open(self.path(self.original_file), 'w') as outfile:
text_format.PrintMessage(self.network, outfile)

### Split up train_val and deploy layers

network = cleanedUpGenericNetwork(self.network)
Expand Down Expand Up @@ -1030,6 +1044,41 @@ def after_runtime_error(self):

### TrainTask overrides

@override
def get_task_stats(self,epoch=-1):
"""
return a dictionary of task statistics
"""

loc, mean_file = os.path.split(self.dataset.get_mean_file())

stats = {
"image dimensions": self.dataset.get_feature_dims(),
"mean file": mean_file,
"snapshot file": self.get_snapshot_filename(epoch),
"solver file": self.solver_file,
"train_val file": self.train_val_file,
"deploy file": self.deploy_file,
"framework": "caffe"
}

# These attributes only available in more recent jobs:
if hasattr(self,"original_file"):
stats.update({
"caffe flavor": self.caffe_flavor,
"caffe version": self.caffe_version,
"network file": self.original_file,
"digits version": self.digits_version
})

if hasattr(self.dataset,"resize_mode"):
stats.update({"image resize mode": self.dataset.resize_mode})

if hasattr(self.dataset,"labels_file"):
stats.update({"labels file": self.dataset.labels_file})

return stats

@override
def detect_snapshots(self):
self.snapshots = []
Expand Down Expand Up @@ -1316,18 +1365,7 @@ def get_net(self, epoch=None, gpu=-1):
if not self.has_model():
return False

file_to_load = None

if not epoch:
epoch = self.snapshots[-1][1]
file_to_load = self.snapshots[-1][0]
else:
for snapshot_file, snapshot_epoch in self.snapshots:
if snapshot_epoch == epoch:
file_to_load = snapshot_file
break
if file_to_load is None:
raise Exception('snapshot not found for epoch "%s"' % epoch)
file_to_load = self.get_snapshot(epoch)

# check if already loaded
if self.loaded_snapshot_file and self.loaded_snapshot_file == file_to_load \
Expand Down Expand Up @@ -1421,11 +1459,14 @@ def get_model_files(self):
"""
return paths to model files
"""
return {
model_files = {
"Solver": self.solver_file,
"Network (train/val)": self.train_val_file,
"Network (deploy)": self.deploy_file
}
if hasattr(self,"original_file"):
model_files.update({"Network (original)": self.original_file})
return model_files

@override
def get_network_desc(self):
Expand Down
36 changes: 22 additions & 14 deletions digits/model/tasks/torch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(self, **kwargs):
self.snapshot_prefix = TORCH_SNAPSHOT_PREFIX
self.log_file = self.TORCH_LOG

self.digits_version = digits.__version__

def __getstate__(self):
state = super(TorchTrainTask, self).__getstate__()

Expand Down Expand Up @@ -926,23 +928,29 @@ def get_network_desc(self):
desc = infile.read()
return desc

def get_snapshot(self, epoch):
@override
def get_task_stats(self,epoch=-1):
"""
return snapshot file for specified epoch
return a dictionary of task statistics
"""
file_to_load = None

if not epoch:
epoch = self.snapshots[-1][1]
file_to_load = self.snapshots[-1][0]
else:
for snapshot_file, snapshot_epoch in self.snapshots:
if snapshot_epoch == epoch:
file_to_load = snapshot_file
break
if file_to_load is None:
raise Exception('snapshot not found for epoch "%s"' % epoch)
loc, mean_file = os.path.split(self.dataset.get_mean_file())

stats = {
"image dimensions": self.dataset.get_feature_dims(),
"mean file": mean_file,
"snapshot file": self.get_snapshot_filename(epoch),
"model file": self.model_file,
"framework": "torch"
}

if hasattr(self,"digits_version"):
stats.update({"digits version": self.digits_version})

return file_to_load
if hasattr(self.dataset,"resize_mode"):
stats.update({"image resize mode": self.dataset.resize_mode})

if hasattr(self.dataset,"labels_file"):
stats.update({"labels file": self.dataset.labels_file})

return stats
34 changes: 34 additions & 0 deletions digits/model/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,35 @@ def infer_many(self, data, model_epoch=None):
"""
return None

def get_snapshot(self, epoch=-1):
"""
return snapshot file for specified epoch
"""
snapshot_filename = None

if len(self.snapshots) == 0:
return "no snapshots"

if epoch == -1 or not epoch:
epoch = self.snapshots[-1][1]
snapshot_filename = self.snapshots[-1][0]
else:
for f, e in self.snapshots:
if e == epoch:
snapshot_filename = f
break
if not snapshot_filename:
raise ValueError('Invalid epoch')

return snapshot_filename

def get_snapshot_filename(self,epoch=-1):
"""
Return the filename for the specified epoch
"""
path, name = os.path.split(self.get_snapshot(epoch))
return name

def get_labels(self):
"""
Read labels from labels_file and return them in a list
Expand Down Expand Up @@ -547,3 +576,8 @@ def get_network_desc(self):
"""
raise NotImplementedError()

def get_task_stats(self,epoch=-1):
"""
return a dictionary of task statistics
"""
raise NotImplementedError()
Loading

0 comments on commit d0b407a

Please sign in to comment.