-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcritical.lua
130 lines (88 loc) · 3.36 KB
/
critical.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
--1 choose two feature vectors from the train points and do binary serach on them to find the critical point between them. This includes passing the midpoint to featureTolabel function each time and getting the soft and hard lables of the points using that function.
require 'image'
require 'cudnn'
require 'cunn'
local c = require 'trepl.colorize'
function featureTolabel(featureVector)
-- load the model
local model_path = "logs/vgg/trainedModel.net"
local model = torch.load(model_path)
-- print(model)
-- model definition should set numInputDims
-- hacking around it for the moment
local view = model:findModules('nn.View')
if #view > 0 then
view[1].numInputDims = 3
end
--print(model)
local model2 = model:get(54)
model2:add(nn.SoftMax())
model2:cuda()
model2:evaluate()
local softLabels_feature = model2:forward(featureVector:view(1,512))
softLabels_feature = torch.reshape(softLabels_feature, 10, 1)
local max = torch.max(softLabels_feature, 1)
local hardLabel_feature = 1
for i = 1, 10 do
if torch.all(torch.eq(softLabels_feature[i], max)) then
hardLabel_feature = i
end
end
local output = {softLabels_feature, hardLabel_feature}
return output
end
------------------------------------------------------------------------------------------------
opt = lapp[[
--trainSize (default 100) size of training set
]]
if #arg < 1 then
io.stderr:write('Usage: th ciritcsl.lua [Size of training set]...\n')
os.exit(1)
end
model_path = opt.model
points = torch.load('trainFeature_learntLabels.dat')
-- points is a table. Each row of the table has three components: featureTensor, softLabels, hardLabel. Use points[i][1] to get feature vector of the ith training point
length = opt.trainSize
print('length', length)
criticalPoints = {}
criticalSoftLabels = {}
output = {}
maxIterations = 10
k = 0
print(c.blue '==>' ..' calculating critical points ')
for i = 1, length do
for j = i+1, length do
print(k)
feature_x = points[i][1]:clone()
feature_y = points[j][1]:clone()
hardlabel_x = points[i][3]
hardlabel_y = points[j][3]
softlabel_x = points[i][2]:clone()
softlabel_y = points[j][2]:clone()
if hardlabel_x ~= hardlabel_y then
--print(hardlabel_x, hardlabel_y)
k = k+1
iterationsNum = 0
while ( hardlabel_x ~= hardlabel_y and iterationsNum < maxIterations ) do
tmp = feature_x + feature_y
feature_mid = tmp:clone()
feature_mid:cmul(torch.Tensor(512):fill(.5):cuda())
softlabel_mid = featureTolabel(feature_mid)[1]:clone()
hardlabel_mid = featureTolabel(feature_mid)[2]
-- the output of featureTolabel is two dimensional. The first dimension is the soft label and the second dimension is the hard label for the feature vector. The hard label is just the index with maximum value in soft label.
if hardlabel_x ~= hardlabel_mid then
feature_y = feature_mid:clone()
else
feature_x = feature_mid:clone()
end
iterationsNum = iterationsNum + 1
end
criticalPoints[k] = feature_mid:clone()
criticalSoftLabels[k] = featureTolabel(feature_mid)[1]
table.insert(output, {criticalPoints[k], criticalSoftLabels[k]})
end
end
end
print(c.blue '==>' ..' saving fature vectors of critical points')
torch.save ('criticalPoints_feature.dat', output)
print('finish saving')