-
Notifications
You must be signed in to change notification settings - Fork 6
/
data_loader.py
32 lines (24 loc) · 1.1 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from __future__ import print_function
import tensorflow as tf
import os
import numpy as np
class ImageLoader(object):
def __init__(self, path, batch_size=128, begin=0, end=10000):
self.path = path
self.batch_size = batch_size
self._read_img_names(begin, end)
self.batches_per_epoch = int(np.floor(len(self.img_paths) / batch_size))
self.img_paths = tf.convert_to_tensor(self.img_paths, dtype=tf.string)
# create dataset
data = tf.data.Dataset.from_tensor_slices(self.img_paths)
data = data.map(self._parse) #, num_parallel_calls=4)
self.data = data.batch(batch_size)
def _read_img_names(self, begin, end):
_, __, imgs = next(os.walk(self.path))
self.img_paths = list(map(lambda x: os.path.join(self.path, x), imgs[begin:end]))
def _parse(self, filename):
img_raw = tf.read_file(filename)
img_decoded = tf.image.decode_image(img_raw, channels=3)
img_croped = tf.random_crop(img_decoded, [32, 32, 3])
img_flipped = tf.image.random_flip_left_right(img_croped)
return img_flipped