forked from tensorflow/benchmarks
-
Notifications
You must be signed in to change notification settings - Fork 3
/
datasets.py
184 lines (141 loc) · 5.33 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmark dataset utilities.
"""
from abc import abstractmethod
import pickle as cPickle
import os
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python.platform import gfile
import preprocessing
IMAGENET_NUM_TRAIN_IMAGES = 1281167
IMAGENET_NUM_VAL_IMAGES = 50000
def create_dataset(data_dir, data_name):
"""Create a Dataset instance based on data_dir and data_name."""
supported_datasets = {
'synthetic': SyntheticData,
'imagenet': ImagenetData,
'cifar10': Cifar10Data,
}
if not data_dir:
data_name = 'synthetic'
if data_name is None:
for supported_name in supported_datasets:
if supported_name in data_dir:
data_name = supported_name
break
if data_name is None:
raise ValueError('Could not identify name of dataset. '
'Please specify with --data_name option.')
if data_name not in supported_datasets:
raise ValueError('Unknown dataset. Must be one of %s', ', '.join(
[key for key in sorted(supported_datasets.keys())]))
return supported_datasets[data_name](data_dir)
class Dataset(object):
"""Abstract class for cnn benchmarks dataset."""
def __init__(self, name, height=None, width=None, depth=None, data_dir=None,
queue_runner_required=False, num_classes=1000):
self.name = name
self.height = height
self.width = width
self.depth = depth or 3
self.data_dir = data_dir
self._queue_runner_required = queue_runner_required
self._num_classes = num_classes
def tf_record_pattern(self, subset):
return os.path.join(self.data_dir, '%s-*-of-*' % subset)
def reader(self):
return tf.TFRecordReader()
@property
def num_classes(self):
return self._num_classes
@num_classes.setter
def num_classes(self, val):
self._num_classes = val
@abstractmethod
def num_examples_per_epoch(self, subset):
pass
def __str__(self):
return self.name
def get_image_preprocessor(self):
return None
def queue_runner_required(self):
return self._queue_runner_required
def use_synthetic_gpu_images(self):
return False
class ImagenetData(Dataset):
"""Configuration for Imagenet dataset."""
def __init__(self, data_dir=None):
if data_dir is None:
raise ValueError('Data directory not specified')
super(ImagenetData, self).__init__('imagenet', 300, 300, data_dir=data_dir)
def num_examples_per_epoch(self, subset='train'):
if subset == 'train':
return IMAGENET_NUM_TRAIN_IMAGES
elif subset == 'validation':
return IMAGENET_NUM_VAL_IMAGES
else:
raise ValueError('Invalid data subset "%s"' % subset)
def get_image_preprocessor(self):
return preprocessing.RecordInputImagePreprocessor
class SyntheticData(Dataset):
"""Configuration for synthetic dataset."""
def __init__(self, unused_data_dir):
super(SyntheticData, self).__init__('synthetic')
def get_image_preprocessor(self):
return preprocessing.SyntheticImagePreprocessor
def use_synthetic_gpu_images(self):
return True
class Cifar10Data(Dataset):
"""Configuration for cifar 10 dataset.
It will mount all the input images to memory.
"""
def __init__(self, data_dir=None):
if data_dir is None:
raise ValueError('Data directory not specified')
super(Cifar10Data, self).__init__('cifar10', 32, 32, data_dir=data_dir,
queue_runner_required=True,
num_classes=10)
def read_data_files(self, subset='train'):
"""Reads from data file and return images and labels in a numpy array."""
if subset == 'train':
filenames = [os.path.join(self.data_dir, 'data_batch_%d' % i)
for i in xrange(1, 6)]
elif subset == 'validation':
filenames = [os.path.join(self.data_dir, 'test_batch')]
else:
raise ValueError('Invalid data subset "%s"' % subset)
inputs = []
for filename in filenames:
with gfile.Open(filename, 'r') as f:
inputs.append(cPickle.load(f))
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
all_images = np.concatenate(
[each_input['data'] for each_input in inputs]).astype(np.float32)
all_labels = np.concatenate(
[each_input['labels'] for each_input in inputs])
return all_images, all_labels
def num_examples_per_epoch(self, subset='train'):
if subset == 'train':
return 50000
elif subset == 'validation':
return 10000
else:
raise ValueError('Invalid data subset "%s"' % subset)
def get_image_preprocessor(self):
return preprocessing.RecordInputImagePreprocessor