-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain_functions.lua
99 lines (85 loc) · 3.37 KB
/
train_functions.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
log('Loading Train Functions ... ')
function train()
config.testing = false
local batchSize = config.batchSize;
local animFeatures = GetAnimationFeatures(model.animationNN);
for iter=1,config.nIter do
---- load one batch
tt = iter
local tic= os.clock()
local imgFeatures, TrTarget = GetAUniformImageBatch(batchSize, {
viewpoint = true,
test = false,
spline = false,
})
local TrInput = {imgFeatures,animFeatures};
local toc = os.clock() - tic;
log('loading time :' .. tostring(toc))
-------- train the network--------------
model.learningRate = model:LearningRateComp(iter);
local acc, loss = model:TrainOneBatch(TrInput,TrTarget);
if (iter % 10) == 0 then
local tic = os.clock()
collectgarbage();
local toc = os.clock() - tic;
print("garbage collection :", toc)
end
if (iter % config.nDisplay) == 0 then
log(('Iter = %d | Train Accuracy = %f | Train Loss = %f\n'):format(iter,acc,loss));
end
if (iter % config.nEval) == 0 then
local TeInput, TeTarget = GetAUniformImageBatch(batchSize, {
viewpoint = true,
test = true,
spline = false,
});
local acc, loss = model:EvaluateOneBatch(TeInput,TeTarget);
log(('Testing ---------> Iter = %d | Test Accuracy = %f | Test Loss = %f\n'):format(iter,acc,loss));
end
if (iter % config.saveModelIter) == 0 then
local fileName = 'Model_iter_' .. iter .. '.t7';
log('Saving NN model in ----> ' .. paths.concat(config.logDirectory, fileName) .. '\n');
model:SaveModel(paths.concat(config.logDirectory, fileName));
end
end
end
---------------------------------------------------------
function test()
config.testing = true
----------------------------
local batchSize = config.batchSize;
local meanAcc = 0;
local sumFrameAcc = 0;
local sumFramables = 0;
local per_class_cum = torch.Tensor(config.nCategories, 2):fill(0)
local all_predictions
for iter=1,config.nIter do
tt = iter
---- load one batch
local tic= os.clock()
local TeInput, TeTarget = GetAnImageBatch(batchSize, {
viewpoint = true,
test = true,
deterministic = true,
spline = false,
});
local toc = os.clock() - tic;
log('loading time :' .. tostring(toc))
if (iter % 10) == 0 then
local tic = os.clock()
collectgarbage();
local toc = os.clock() - tic;
print("garbage collection :", toc)
end
local acc, loss, per_class, predicts, frames = model:EvaluateOneBatch(TeInput,TeTarget);
meanAcc = ((iter -1)* meanAcc + acc)/ iter;
per_class_cum = per_class_cum + per_class
log(('Iter = %d | Current Test Accuracy = %f | Average Test Accuracy = %f\n'):format(iter,acc,meanAcc));
local predictions = torch.cat(TeTarget, predicts, 2)
if not all_predictions then
all_predictions = predictions
else
all_predictions = torch.cat(all_predictions, predictions, 1)
end
end
end