-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathimg_helpers.py
54 lines (42 loc) · 1.66 KB
/
img_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from PIL import Image
from glob import glob
import tensorflow as tf
def load_dataset(data_path, batch_size, scale_size, split=None, is_grayscale=False, seed=None):
dataset_name = os.path.basename(data_path)
if dataset_name in ['CelebA'] and split:
data_path = os.path.join(data_path, 'splits', split)
else:
# is_grayscale = True
raise Exception('[!] Caution! Unknown dataset name.')
paths = []
tf_decode = tf.image.decode_jpeg
for ext in ["jpg", "png"]:
paths = glob("{}/*.{}".format(data_path, ext))
if ext == 'png':
tf_decode = tf.image.decode_png
if len(paths) != 0:
break
with Image.open(paths[0]) as img:
w, h = img.size
shape = [h, w, 3]
filename_queue = tf.train.string_input_producer(list(paths), shuffle=False, seed=seed)
reader = tf.WholeFileReader()
filename, data = reader.read(filename_queue)
image = tf_decode(data, channels=3)
if is_grayscale:
image = tf.image.rgb_to_grayscale(image)
shape = [h, w, 1]
image.set_shape(shape)
min_after_dequeue = 5000
capacity = min_after_dequeue + 3 * batch_size
queue = tf.train.shuffle_batch(
[image], batch_size=batch_size,
num_threads=4, capacity=capacity,
min_after_dequeue=min_after_dequeue, name='synthetic_inputs')
if dataset_name in ['CelebA']:
queue = tf.image.crop_to_bounding_box(queue, 50, 25, 128, 128)
queue = tf.image.resize_nearest_neighbor(queue, [scale_size, scale_size])
else:
queue = tf.image.resize_nearest_neighbor(queue, [scale_size, scale_size])
return tf.to_float(queue)