-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathIMDBAddImageNet.m
129 lines (112 loc) · 4.25 KB
/
IMDBAddImageNet.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
function [ imdb, select ] = IMDBAddImageNet( imdb, path, task_num, varargin )
% IMDBADDIMAGENET Add to imdb the ImageNet (ILSVRC 2012) dataset.
% Input:
% PATH struct generated by GETPATH()
% TASK_NUM the path number (see CNN_CUSTOMTRAIN) that this dataset corresponds to
% Options:
% See code comments
%
% Authors: Zhizhong Li
%
% See the COPYING file.
opts.partial = 0; % for >0 partial, e.g. 0.3, only include that much portion of # samples.
opts.trainval = [1 2]; % 1 for train, 2 for val. By default include train+val.
opts.randstream = []; % use randstream if provided
opts = vl_argparse(opts, varargin);
n_synset = numel(path.files_imgtrain_synsets);
if ~isfield(imdb, 'images')
imdb.images.name = [];
imdb.images.label = [];
imdb.images.set = [];
imdb.images.task = [];
end
% for training
if ismember(1, opts.trainval)
if exist(path.path_imgIMDB.train, 'file')
imgimdb = load(path.path_imgIMDB.train);
name = imgimdb.name;
label = imgimdb.label;
set = ones(numel(label), 1);
else
name = cell(n_synset,1);
label = cell(n_synset,1);
set = cell(n_synset,1);
% dir through all directories
for i = 1:n_synset
synset_name = path.files_imgtrain_synsets{i};
images_file = dir(fullfile(path.path_imgtrain, synset_name, '*.JPEG'));
name{i} = strcat(synset_name, '/', {images_file.name});
n_imgs = numel(name{i});
label{i} = ones(n_imgs,1) * i;
set{i} = ones(n_imgs,1);
end
% concat
name = reshape([name{:}], [], 1);
label = num2cell(cell2mat(label));
set = cell2mat(set);
save(path.path_imgIMDB.train, 'name', 'label');
end
% only using part of it
if opts.partial
if isempty(opts.randstream)
select.train = randperm(numel(name));
else
select.train = randperm(opts.randstream, numel(name));
end
select.train = select.train(1:ceil(numel(name) * opts.partial));
name = name(select.train);
label = label(select.train);
set = set(select.train);
else
select.train = 1:numel(name);
end
else
name = {}; label = {}; set = [];
end
% for validation...
if ismember(2, opts.trainval)
if exist(path.path_imgIMDB.val, 'file')
imgimdb = load(path.path_imgIMDB.val);
val_name = imgimdb.val_name;
val_label = imgimdb.val_label;
val_set = 2 * ones(numel(val_label), 1);
else
% translate the official metadata order (i.e. validation label) to the
% dir result order (i.e. Alexnet's order)
meta = load(path.file_imgmeta);
wnids = cell2mat(cellfun(@(x) str2double(x(2:end)), {meta.synsets.WNID}, 'UniformOutput', false));
val_officiallabel = load(path.file_imgvalgt);
val_wnid = mat2cell(wnids(val_officiallabel)',ones(size(val_officiallabel)));
dirwnid = cellfun(@(x) str2double(x(2:end)), path.files_imgtrain_synsets, 'UniformOutput', false);
wniddict = containers.Map(dirwnid, 1:numel(dirwnid));
val_label = values(wniddict, val_wnid); % no cell2mat
% then get the file names
val_name = dir(fullfile(path.path_imgval, '*.JPEG'));
val_name = strcat(path.path_imgval_relative_to_imgtrain, {val_name.name}');
assert(numel(val_name) == numel(val_label));
val_set = 2 * ones(numel(val_label), 1);
save(path.path_imgIMDB.val, 'val_name', 'val_label');
end
% only using part of it
if opts.partial
if isempty(opts.randstream)
select.val = randperm(numel(name));
else
select.val = randperm(opts.randstream, numel(name));
end
select.val = select.val(1:ceil(numel(name) * opts.partial));
val_name = val_name(select.val);
val_label = val_label(select.val);
val_set = val_set(select.val);
else
select.val = 1:numel(name);
end
name = [ name; val_name ];
label = [ label; val_label ];
set = [ set; val_set ];
end
imdb.images.name = [ imdb.images.name; name ];
imdb.images.label = [ imdb.images.label; label ];
imdb.images.set = [ imdb.images.set; set ];
imdb.images.task = [ imdb.images.task;
task_num * ones(size(set,1), 1) ];