-
Notifications
You must be signed in to change notification settings - Fork 160
/
test.lua
138 lines (111 loc) · 4.22 KB
/
test.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
-- you can easily test specific units like this:
-- th -lnn -e "nn.test{'LookupTable'}"
-- th -lnn -e "nn.test{'LookupTable', 'Add'}"
local mytester = torch.Tester()
local jac
local sjac
local precision = 1e-5
local expprecision = 1e-4
local stntest = {}
function stntest.AffineGridGeneratorBHWD_batch()
local nframes = torch.random(2,10)
local height = torch.random(2,5)
local width = torch.random(2,5)
local input = torch.zeros(nframes, 2, 3):uniform()
local module = nn.AffineGridGeneratorBHWD(height, width)
local err = jac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')
-- IO
local ferr,berr = jac.testIO(module,input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end
function stntest.AffineGridGeneratorBHWD_single()
local height = torch.random(2,5)
local width = torch.random(2,5)
local input = torch.zeros(2, 3):uniform()
local module = nn.AffineGridGeneratorBHWD(height, width)
local err = jac.testJacobian(module,input)
mytester:assertlt(err,precision, 'error on state ')
-- IO
local ferr,berr = jac.testIO(module,input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end
function stntest.BilinearSamplerBHWD_batch()
local nframes = torch.random(2,10)
local height = torch.random(1,5)
local width = torch.random(1,5)
local channels = torch.random(1,6)
local inputImages = torch.zeros(nframes, height, width, channels):uniform()
local grids = torch.zeros(nframes, height, width, 2):uniform()
local module = nn.BilinearSamplerBHWD()
-- test input images (first element of input table)
module._updateOutput = module.updateOutput
function module:updateOutput(input)
return self:_updateOutput({input, grids})
end
module._updateGradInput = module.updateGradInput
function module:updateGradInput(input, gradOutput)
self:_updateGradInput({input, grids}, gradOutput)
return self.gradInput[1]
end
local errImages = jac.testJacobian(module,inputImages)
mytester:assertlt(errImages,precision, 'error on state ')
-- test grids (second element of input table)
function module:updateOutput(input)
return self:_updateOutput({inputImages, input})
end
function module:updateGradInput(input, gradOutput)
self:_updateGradInput({inputImages, input}, gradOutput)
return self.gradInput[2]
end
local errGrids = jac.testJacobian(module,grids)
mytester:assertlt(errGrids,precision, 'error on state ')
end
function stntest.BilinearSamplerBHWD_single()
local height = torch.random(1,5)
local width = torch.random(1,5)
local channels = torch.random(1,6)
local inputImages = torch.zeros(height, width, channels):uniform()
local grids = torch.zeros(height, width, 2):uniform()
local module = nn.BilinearSamplerBHWD()
-- test input images (first element of input table)
module._updateOutput = module.updateOutput
function module:updateOutput(input)
return self:_updateOutput({input, grids})
end
module._updateGradInput = module.updateGradInput
function module:updateGradInput(input, gradOutput)
self:_updateGradInput({input, grids}, gradOutput)
return self.gradInput[1]
end
local errImages = jac.testJacobian(module,inputImages)
mytester:assertlt(errImages,precision, 'error on state ')
-- test grids (second element of input table)
function module:updateOutput(input)
return self:_updateOutput({inputImages, input})
end
function module:updateGradInput(input, gradOutput)
self:_updateGradInput({inputImages, input}, gradOutput)
return self.gradInput[2]
end
local errGrids = jac.testJacobian(module,grids)
mytester:assertlt(errGrids,precision, 'error on state ')
end
mytester:add(stntest)
if not nn then
require 'nn'
jac = nn.Jacobian
sjac = nn.SparseJacobian
mytester:run()
else
jac = nn.Jacobian
sjac = nn.SparseJacobian
function stn.test(tests)
-- randomize stuff
math.randomseed(os.time())
mytester:run(tests)
return mytester
end
end