-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathDataProvider2D.lua
executable file
·126 lines (91 loc) · 3.72 KB
/
DataProvider2D.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
require 'nn'
require 'torchx'
require 'imtools'
local DataProvider = {
data = nil, -- Size of isotropic multivariate Gaussian Z
labels = nil,
opts = nil
}
function DataProvider.create(image_dir, opts)
local save_dir = opts.parent_dir
local data_file = save_dir .. '/data.t7'
local labels_file = save_dir .. '/labels.t7'
local data_info_file = save_dir .. '/data_info.t7'
local data
local labels
local data_info
local data = {}
if paths.filep(data_file) then
print('Loading data from ' .. save_dir)
data = torch.load(data_file)
print('Done')
else
local c = 0
local images, image_paths, classes = {}, {}, {}
for dir in paths.iterdirs(image_dir) do
print('Loading images from ' .. image_dir .. '/' .. dir)
local images_tmp, image_paths_tmp = imtools.load_img(image_dir .. '/' .. dir .. '/', 'png', opts.image_sub_size)
for i = 1,#image_paths_tmp do
c = c+1
images[c] = nn.Unsqueeze(1):forward(images_tmp[i])
image_paths[c] = image_paths_tmp[i]
tokens = utils.split(image_paths_tmp[i], '/')
classes[c] = tokens[#tokens-1]
end
end
images = torch.concat(images,1)
images = torch.FloatTensor(images:size()):copy(images)
classes, labels = utils.unique(classes)
local nImgs = images:size()[1]
local nClasses = torch.max(labels)
local one_hot = torch.zeros(nImgs, nClasses):long()
for i = 1,nImgs do
one_hot[{i,labels[i]}] = 1
end
-- save 5% of the data for testing
local rand_inds = torch.randperm(nImgs):long()
local nTest = torch.round(nImgs/20)
data.train = {}
data.train.inds = rand_inds[{{nTest+1,-1}}]
data.train.images = images:index(1, data.train.inds)
data.train.labels = one_hot:index(1, data.train.inds)
data.test = {}
data.test.inds = rand_inds[{{1,nTest}}]
data.test.images = images:index(1, data.test.inds)
data.test.labels = one_hot:index(1, data.test.inds)
data.image_paths = image_paths
data.classes = classes
paths.mkdir(save_dir)
torch.save(data_file, data)
end
local self = data
self.opts = {}
-- shallow copy these options
self.opts.channel_inds_in = opts.channel_inds_in:clone() or torch.LongTensor{1}
self.opts.channel_inds_out = opts.channel_inds_out:clone() or torch.LongTensor{1}
self.opts.rotate = opts.rotate or false
setmetatable(self, { __index = DataProvider })
return self
end
function DataProvider:getImages(indices, train_or_test)
local images_in = self[train_or_test].images:index(1, indices):index(2, self.opts.channel_inds_in):clone()
local images_out = self[train_or_test].images:index(1, indices):index(2, self.opts.channel_inds_out):clone()
if self.opts.rotate and torch.rand(1)[1] < 0.01 then
for i = 1,images_in:size()[1] do
rad = (torch.rand(1)*2*math.pi)[1]
flip = torch.rand(1)[1]>0.5
if flip then
images_in[i] = image.hflip(images_in[i])
images_out[i] = image.hflip(images_out[i])
end
images_in[i] = image.rotate(images_in[i], rad)
images_out[i] = image.rotate(images_out[i], rad)
end
end
return images_in, images_out
end
function DataProvider:getLabels(indices, train_or_test)
local labels_in = self[train_or_test].labels:index(1, indices):typeAs(self.test.labels)
return labels_in
end
return DataProvider