-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_provider.py
73 lines (61 loc) · 2.35 KB
/
data_provider.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from slim.datasets import dataset_factory as datasets
slim = tf.contrib.slim
def provide_data(split_name, batch_size, dataset_dir,
dataset_name='imagenet', num_readers=1, num_threads=1,
patch_size=128):
"""Provides batches of image data for compression.
Args:
split_name: Either 'train' or 'validation'.
batch_size: The number of images in each batch.
dataset_dir: The directory where the data can be found. If `None`, use
default.
dataset_name: Name of the dataset.
num_readers: Number of dataset readers.
num_threads: Number of prefetching threads.
patch_size: Size of the path to extract from the image.
Returns:
images: A `Tensor` of size [batch_size, patch_size, patch_size, channels]
"""
randomize = split_name == 'train'
dataset = datasets.get_dataset(
dataset_name, split_name, dataset_dir=dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
common_queue_capacity=5 * batch_size,
common_queue_min=batch_size,
shuffle=randomize)
[image] = provider.get(['image'])
# Sample a patch of fixed size.
patch = tf.image.resize_image_with_crop_or_pad(image, patch_size, patch_size)
patch.shape.assert_is_compatible_with([patch_size, patch_size, 3])
# Preprocess the images. Make the range lie in a strictly smaller range than
# [-1, 1], so that network outputs aren't forced to the extreme ranges.
patch = (tf.to_float(patch) - 128.0) / 142.0
if randomize:
image_batch = tf.train.shuffle_batch(
[patch],
batch_size=batch_size,
num_threads=num_threads,
capacity=5 * batch_size,
min_after_dequeue=batch_size)
else:
image_batch = tf.train.batch(
[patch],
batch_size=batch_size,
num_threads=1, # no threads so it's deterministic
capacity=5 * batch_size)
return image_batch
def float_image_to_uint8(image):
"""Convert float image in ~[-0.9, 0.9) to [0, 255] uint8.
Args:
image: An image tensor. Values should be in [-0.9, 0.9).
Returns:
Input image cast to uint8 and with integer values in [0, 255].
"""
image = (image * 142.0) + 128.0
return tf.cast(image, tf.uint8)