Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Get stats & original network def from "Download Model" button #891

Merged
merged 1 commit into from
Jul 12, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions digits/model/images/classification/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,7 @@ def job_type(self):
def download_files(self, epoch=-1):
task = self.train_task()

snapshot_filename = None
if epoch == -1 and len(task.snapshots):
epoch = task.snapshots[-1][1]
snapshot_filename = task.snapshots[-1][0]
else:
for f, e in task.snapshots:
if e == epoch:
snapshot_filename = f
break
if not snapshot_filename:
raise ValueError('Invalid epoch')
snapshot_filename = task.get_snapshot(epoch)

# get model files
model_files = task.get_model_files()
Expand Down
14 changes: 14 additions & 0 deletions digits/model/images/classification/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,11 @@ def test_clone(self):
content2.pop('id')
content1.pop('directory')
content2.pop('directory')
content1.pop('creation time')
content2.pop('creation time')
content1.pop('job id')
content2.pop('job id')

assert (content1 == content2), 'job content does not match'

job1 = digits.webapp.scheduler.get_job(job1_id)
Expand All @@ -522,6 +527,15 @@ def test_save(self):
job = digits.webapp.scheduler.get_job(self.model_id)
assert job.save(), 'Job failed to save'

def test_get_snapshot(self):
job = digits.webapp.scheduler.get_job(self.model_id)
task = job.train_task()
f = task.get_snapshot(-1)

assert f, "Failed to load snapshot"
filename = task.get_snapshot_filename(-1)
assert filename, "Failed to get filename"

def test_download(self):
for extension in ['tar', 'zip', 'tar.gz', 'tar.bz2']:
yield self.check_download, extension
Expand Down
12 changes: 1 addition & 11 deletions digits/model/images/generic/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,7 @@ def job_type(self):
def download_files(self, epoch=-1):
task = self.train_task()

snapshot_filename = None
if epoch == -1 and len(task.snapshots):
epoch = task.snapshots[-1][1]
snapshot_filename = task.snapshots[-1][0]
else:
for f, e in task.snapshots:
if e == epoch:
snapshot_filename = f
break
if not snapshot_filename:
raise ValueError('Invalid epoch')
snapshot_filename = task.get_snapshot(epoch)

# get model files
model_files = task.get_model_files()
Expand Down
15 changes: 14 additions & 1 deletion digits/model/images/generic/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,11 @@ def test_clone(self):
content2.pop('id')
content1.pop('directory')
content2.pop('directory')
content1.pop('creation time')
content2.pop('creation time')
content1.pop('job id')
content2.pop('job id')

assert (content1 == content2), 'job content does not match'

job1 = digits.webapp.scheduler.get_job(job1_id)
Expand All @@ -498,6 +503,15 @@ def test_save(self):
job = digits.webapp.scheduler.get_job(self.model_id)
assert job.save(), 'Job failed to save'

def test_get_snapshot(self):
job = digits.webapp.scheduler.get_job(self.model_id)
task = job.train_task()
f = task.get_snapshot(-1)

assert f, "Failed to load snapshot"
filename = task.get_snapshot_filename(-1)
assert filename, "Failed to get filename"

def test_download(self):
for extension in ['tar', 'zip', 'tar.gz', 'tar.bz2']:
yield self.check_download, extension
Expand Down Expand Up @@ -1050,4 +1064,3 @@ class TestAllInOneNetwork(BaseTestCreation, BaseTestCreated):
exclude { stage: "deploy" }
}
"""

17 changes: 17 additions & 0 deletions digits/model/images/job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright (c) 2014-2016, NVIDIA CORPORATION. All rights reserved.
from __future__ import absolute_import

import os
import datetime
from ..job import ModelJob
from digits.utils import subclass, override

# NOTE: Increment this everytime the pickled object changes
PICKLE_VERSION = 1
Expand All @@ -17,3 +20,17 @@ def __init__(self, **kwargs):
super(ImageModelJob, self).__init__(**kwargs)
self.pickver_job_model_image = PICKLE_VERSION

@override
def json_dict(self, verbose=False, epoch=-1):
d = super(ImageModelJob, self).json_dict(verbose)
task = self.train_task()
creation_time = str(datetime.datetime.fromtimestamp(self.status_history[0][1]))

d.update({
"job id": self.id(),
"creation time": creation_time,
"username": self.username,
})

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

d.update(task.get_task_stats(epoch))
return d
67 changes: 54 additions & 13 deletions digits/model/tasks/caffe_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

# Constants
CAFFE_SOLVER_FILE = 'solver.prototxt'
CAFFE_ORIGINAL_FILE = 'original.prototxt'
CAFFE_TRAIN_VAL_FILE = 'train_val.prototxt'
CAFFE_SNAPSHOT_PREFIX = 'snapshot'
CAFFE_DEPLOY_FILE = 'deploy.prototxt'
Expand Down Expand Up @@ -83,11 +84,16 @@ def __init__(self, **kwargs):
self.solver = None

self.solver_file = CAFFE_SOLVER_FILE
self.original_file = CAFFE_ORIGINAL_FILE
self.train_val_file = CAFFE_TRAIN_VAL_FILE
self.snapshot_prefix = CAFFE_SNAPSHOT_PREFIX
self.deploy_file = CAFFE_DEPLOY_FILE
self.log_file = self.CAFFE_LOG

self.digits_version = digits.__version__
self.caffe_version = config_value('caffe_root')['ver_str']
self.caffe_flavor = config_value('caffe_root')['flavor']

def __getstate__(self):
state = super(CaffeTrainTask, self).__getstate__()

Expand Down Expand Up @@ -228,6 +234,10 @@ def save_files_classification(self):
"""
Save solver, train_val and deploy files to disk
"""
# Save the origin network to file:
with open(self.path(self.original_file), 'w') as outfile:
text_format.PrintMessage(self.network, outfile)

network = cleanedUpClassificationNetwork(self.network, len(self.get_labels()))
data_layers, train_val_layers, deploy_layers = filterLayersByState(network)

Expand Down Expand Up @@ -523,6 +533,10 @@ def save_files_generic(self):

assert train_feature_db_path is not None, 'Training images are required'

# Save the origin network to file:
with open(self.path(self.original_file), 'w') as outfile:
text_format.PrintMessage(self.network, outfile)

### Split up train_val and deploy layers

network = cleanedUpGenericNetwork(self.network)
Expand Down Expand Up @@ -1030,6 +1044,41 @@ def after_runtime_error(self):

### TrainTask overrides

@override
def get_task_stats(self,epoch=-1):
"""
return a dictionary of task statistics
"""

loc, mean_file = os.path.split(self.dataset.get_mean_file())

stats = {
"image dimensions": self.dataset.get_feature_dims(),
"mean file": mean_file,
"snapshot file": self.get_snapshot_filename(epoch),
"solver file": self.solver_file,
"train_val file": self.train_val_file,
"deploy file": self.deploy_file,
"framework": "caffe"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add "caffe_version" and "caffe_flavor" here? Or I guess "caffe version" and "caffe flavor" since you've gone with spaces.

}

# These attributes only available in more recent jobs:
if hasattr(self,"original_file"):
stats.update({
"caffe flavor": self.caffe_flavor,
"caffe version": self.caffe_version,
"network file": self.original_file,
"digits version": self.digits_version
})

if hasattr(self.dataset,"resize_mode"):
stats.update({"image resize mode": self.dataset.resize_mode})

if hasattr(self.dataset,"labels_file"):
stats.update({"labels file": self.dataset.labels_file})

return stats

@override
def detect_snapshots(self):
self.snapshots = []
Expand Down Expand Up @@ -1316,18 +1365,7 @@ def get_net(self, epoch=None, gpu=-1):
if not self.has_model():
return False

file_to_load = None

if not epoch:
epoch = self.snapshots[-1][1]
file_to_load = self.snapshots[-1][0]
else:
for snapshot_file, snapshot_epoch in self.snapshots:
if snapshot_epoch == epoch:
file_to_load = snapshot_file
break
if file_to_load is None:
raise Exception('snapshot not found for epoch "%s"' % epoch)
file_to_load = self.get_snapshot(epoch)

# check if already loaded
if self.loaded_snapshot_file and self.loaded_snapshot_file == file_to_load \
Expand Down Expand Up @@ -1421,11 +1459,14 @@ def get_model_files(self):
"""
return paths to model files
"""
return {
model_files = {
"Solver": self.solver_file,
"Network (train/val)": self.train_val_file,
"Network (deploy)": self.deploy_file
}
if hasattr(self,"original_file"):
model_files.update({"Network (original)": self.original_file})
return model_files

@override
def get_network_desc(self):
Expand Down
36 changes: 22 additions & 14 deletions digits/model/tasks/torch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(self, **kwargs):
self.snapshot_prefix = TORCH_SNAPSHOT_PREFIX
self.log_file = self.TORCH_LOG

self.digits_version = digits.__version__

def __getstate__(self):
state = super(TorchTrainTask, self).__getstate__()

Expand Down Expand Up @@ -926,23 +928,29 @@ def get_network_desc(self):
desc = infile.read()
return desc

def get_snapshot(self, epoch):
@override
def get_task_stats(self,epoch=-1):
"""
return snapshot file for specified epoch
return a dictionary of task statistics
"""
file_to_load = None

if not epoch:
epoch = self.snapshots[-1][1]
file_to_load = self.snapshots[-1][0]
else:
for snapshot_file, snapshot_epoch in self.snapshots:
if snapshot_epoch == epoch:
file_to_load = snapshot_file
break
if file_to_load is None:
raise Exception('snapshot not found for epoch "%s"' % epoch)
loc, mean_file = os.path.split(self.dataset.get_mean_file())

stats = {
"image dimensions": self.dataset.get_feature_dims(),
"mean file": mean_file,
"snapshot file": self.get_snapshot_filename(epoch),
"model file": self.model_file,
"framework": "torch"
}

if hasattr(self,"digits_version"):
stats.update({"digits version": self.digits_version})

return file_to_load
if hasattr(self.dataset,"resize_mode"):
stats.update({"image resize mode": self.dataset.resize_mode})

if hasattr(self.dataset,"labels_file"):
stats.update({"labels file": self.dataset.labels_file})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A very picky reviewer might argue that this method looks for the most part like the CaffeTrainTask one and there is an opportunity to factor out some code by moving the common bits up in the class hierarchy into TrainTask. It's totally fine this way though :-)


return stats
34 changes: 34 additions & 0 deletions digits/model/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,35 @@ def infer_many(self, data, model_epoch=None):
"""
return None

def get_snapshot(self, epoch=-1):
"""
return snapshot file for specified epoch
"""
snapshot_filename = None

if len(self.snapshots) == 0:
return "no snapshots"

if epoch == -1 or not epoch:
epoch = self.snapshots[-1][1]
snapshot_filename = self.snapshots[-1][0]
else:
for f, e in self.snapshots:
if e == epoch:
snapshot_filename = f
break
if not snapshot_filename:
raise ValueError('Invalid epoch')

return snapshot_filename

def get_snapshot_filename(self,epoch=-1):
"""
Return the filename for the specified epoch
"""
path, name = os.path.split(self.get_snapshot(epoch))
return name

def get_labels(self):
"""
Read labels from labels_file and return them in a list
Expand Down Expand Up @@ -547,3 +576,8 @@ def get_network_desc(self):
"""
raise NotImplementedError()

def get_task_stats(self,epoch=-1):
"""
return a dictionary of task statistics
"""
raise NotImplementedError()
Loading