-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathBatchDatsetReader.py
88 lines (75 loc) · 3.07 KB
/
BatchDatsetReader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Code ideas from https://github.com/Newmu/dcgan and tensorflow mnist dataset reader
"""
import numpy as np
from scipy import misc
from skimage import color
class BatchDatset:
files = []
images = []
image_options = {}
batch_offset = 0
epochs_completed = 0
def __init__(self, records_list, image_options={}):
"""
Intialize a generic file reader with batching for list of files
:param records_list: list of files to read -
:param image_options: A dictionary of options for modifying the output image
Available options:
resize = True/ False
resize_size = #size of output image
color=LAB, RGB, HSV
"""
print("Initializing Batch Dataset Reader...")
print(image_options)
self.files = records_list
self.image_options = image_options
self._read_images()
def _read_images(self):
self.images = np.array([self._transform(filename) for filename in self.files])
print (self.images.shape)
def _transform(self, filename):
try:
image = misc.imread(filename)
if len(image.shape) < 3: # make sure images are of shape(h,w,3)
image = np.array([image for i in range(3)])
if self.image_options.get("resize", False) and self.image_options["resize"]:
resize_size = int(self.image_options["resize_size"])
resize_image = misc.imresize(image,
[resize_size, resize_size])
else:
resize_image = image
if self.image_options.get("color", False):
option = self.image_options['color']
if option == "LAB":
resize_image = color.rgb2lab(resize_image)
elif option == "HSV":
resize_image = color.rgb2hsv(resize_image)
except:
print ("Error reading file: %s of shape %s" % (filename, str(image.shape)))
raise
return np.array(resize_image)
def get_records(self):
return self.images
def reset_batch_offset(self, offset=0):
self.batch_offset = offset
def next_batch(self, batch_size):
start = self.batch_offset
self.batch_offset += batch_size
if self.batch_offset > self.images.shape[0]:
# Finished epoch
self.epochs_completed += 1
print("****************** Epochs completed: " + str(self.epochs_completed) + "******************")
# Shuffle the data
perm = np.arange(self.images.shape[0])
np.random.shuffle(perm)
# Start next epoch
start = 0
self.batch_offset = batch_size
end = self.batch_offset
images = self.images[start:end]
return np.expand_dims(images[:, :, :, 0], axis=3), images
def get_random_batch(self, batch_size):
indexes = np.random.randint(0, self.images.shape[0], size=[batch_size]).tolist()
images = self.images[indexes]
return np.expand_dims(images[:, :, :, 0], axis=3), images