forked from facebookarchive/fb.resnet.torch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathextract-features.lua
76 lines (60 loc) · 1.8 KB
/
extract-features.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- extracts features from an image using a trained model
--
require 'torch'
require 'paths'
if #arg < 2 then
io.stderr:write('Usage: th extract-features.lua [MODEL] [FILE]...\n')
os.exit(1)
end
for _, f in ipairs(arg) do
if not paths.filep(f) then
io.stderr:write('file not found: ' .. f .. '\n')
os.exit(1)
end
end
require 'cudnn'
require 'cunn'
require 'image'
local t = require '../datasets/transforms'
-- Load the model
local model = torch.load(arg[1])
-- Remove the fully connected layer
assert(torch.type(model:get(#model.modules)) == 'nn.Linear')
model:remove(#model.modules)
-- Evaluate mode
model:evaluate()
-- The model was trained with this input normalization
local meanstd = {
mean = { 0.485, 0.456, 0.406 },
std = { 0.229, 0.224, 0.225 },
}
local transform = t.Compose{
t.Scale(256),
t.ColorNormalize(meanstd),
t.CenterCrop(224),
}
local features
for i=2,#arg do
-- load the image as a RGB float tensor with values 0..1
local img = image.load(arg[i], 3, 'float')
-- Scale, normalize, and crop the image
img = transform(img)
-- View as mini-batch of size 1
img = img:view(1, table.unpack(img:size():totable()))
-- Get the output of the layer before the (removed) fully connected layer
local output = model:forward(img:cuda()):squeeze(1)
if not features then
features = torch.FloatTensor(#arg - 1, output:size(1)):zero()
end
features[i - 1]:copy(output)
end
torch.save('features.t7', features)
print('saved features to features.t7')