-
Notifications
You must be signed in to change notification settings - Fork 0
/
assemble_data.py
96 lines (86 loc) · 3.48 KB
/
assemble_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python
"""
Form a subset of the Flickr Style data, download images to dirname, and write
Caffe ImagesDataLayer training file.
"""
import os
import urllib
import hashlib
import argparse
import numpy as np
import pandas as pd
from skimage import io
import multiprocessing
# Flickr returns a special image if the request is unavailable.
MISSING_IMAGE_SHA1 = '6a92790b1c2a301c6e7ddef645dca1f53ea97ac2'
example_dirname = os.path.abspath(os.path.dirname(__file__))
training_dirname = os.path.join('/home/imatge/datasets/flickr_style')
def download_image(args_tuple):
"For use with multiprocessing map. Returns filename on fail."
try:
url, filename = args_tuple
if not os.path.exists(filename):
urllib.urlretrieve(url, filename)
with open(filename) as f:
assert hashlib.sha1(f.read()).hexdigest() != MISSING_IMAGE_SHA1
test_read_image = io.imread(filename)
return True
except KeyboardInterrupt:
raise Exception() # multiprocessing doesn't catch keyboard exceptions
except:
return False
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Download a subset of Flickr Style to a directory')
parser.add_argument(
'-s', '--seed', type=int, default=0,
help="random seed")
parser.add_argument(
'-i', '--images', type=int, default=-1,
help="number of images to use (-1 for all [default])",
)
parser.add_argument(
'-w', '--workers', type=int, default=-1,
help="num workers used to download images. -x uses (all - x) cores [-1 default]."
)
parser.add_argument(
'-l', '--labels', type=int, default=0,
help="if set to a positive value, only sample images from the first number of labels."
)
args = parser.parse_args()
np.random.seed(args.seed)
# Read data, shuffle order, and subsample.
csv_filename = os.path.join(example_dirname, 'flickr_style.csv.gz')
df = pd.read_csv(csv_filename, index_col=0, compression='gzip')
df = df.iloc[np.random.permutation(df.shape[0])]
if args.labels > 0:
df = df.loc[df['label'] < args.labels]
if args.images > 0 and args.images < df.shape[0]:
df = df.iloc[:args.images]
# Make directory for images and get local filenames.
if training_dirname is None:
training_dirname = os.path.join('/home/imatge/datasets/flickr_style')
images_dirname = os.path.join(training_dirname, 'images')
if not os.path.exists(images_dirname):
os.makedirs(images_dirname)
df['image_filename'] = [
os.path.join(images_dirname, _.split('/')[-1]) for _ in df['image_url']
]
# Download images.
num_workers = args.workers
if num_workers <= 0:
num_workers = multiprocessing.cpu_count() + num_workers
print('Downloading {} images with {} workers...'.format(
df.shape[0], num_workers))
pool = multiprocessing.Pool(processes=num_workers)
map_args = zip(df['image_url'], df['image_filename'])
results = pool.map(download_image, map_args)
# Only keep rows with valid images, and write out training file lists.
df = df[results]
for split in ['train', 'test']:
split_df = df[df['_split'] == split]
filename = os.path.join(training_dirname, '{}.txt'.format(split))
split_df[['image_filename', 'label']].to_csv(
filename, sep=' ', header=None, index=None)
print('Writing train/val for {} successfully downloaded images.'.format(
df.shape[0]))