-
Notifications
You must be signed in to change notification settings - Fork 25
/
tests.py
121 lines (87 loc) · 3.64 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import unittest
import torchfile
import os.path
import sys
import numpy as np
unicode_type = str if sys.version_info > (3,) else unicode
def make_filename(fn):
TEST_FILE_DIRECTORY = 'testfiles_x86_64'
return os.path.join(TEST_FILE_DIRECTORY, fn)
def load(fn, **kwargs):
return torchfile.load(make_filename(fn), **kwargs)
class TestBasics(unittest.TestCase):
def test_dict(self):
obj = load('hello=123.t7')
self.assertEqual(dict(obj), {b'hello': 123})
def test_custom_class(self):
obj = load('custom_class.t7')
self.assertEqual(obj.torch_typename(), b"Blah")
def test_classnames_never_decoded(self):
obj = load('custom_class.t7', utf8_decode_strings=True)
self.assertNotIsInstance(obj.torch_typename(), unicode_type)
obj = load('custom_class.t7', utf8_decode_strings=False)
self.assertNotIsInstance(obj.torch_typename(), unicode_type)
def test_basic_tensors(self):
f64 = load('doubletensor.t7')
self.assertTrue((f64 == np.array([[1, 2, 3, ], [4, 5, 6.9]],
dtype=np.float64)).all())
f32 = load('floattensor.t7')
self.assertAlmostEqual(f32.sum(), 12.97241666913, delta=1e-5)
def test_function(self):
func_with_upvals = load('function_upvals.t7')
self.assertIsInstance(func_with_upvals, torchfile.LuaFunction)
def test_dict_accessors(self):
obj = load('hello=123.t7',
use_int_heuristic=True,
utf8_decode_strings=True)
self.assertIsInstance(obj['hello'], int)
self.assertIsInstance(obj.hello, int)
obj = load('hello=123.t7',
use_int_heuristic=True,
utf8_decode_strings=False)
self.assertIsInstance(obj[b'hello'], int)
self.assertIsInstance(obj.hello, int)
class TestRecursiveObjects(unittest.TestCase):
def test_recursive_class(self):
obj = load('recursive_class.t7')
self.assertEqual(obj.a, obj)
def test_recursive_table(self):
obj = load('recursive_kv_table.t7')
# both the key and value point to itself:
key, = obj.keys()
self.assertEqual(key, obj)
self.assertEqual(obj[key], obj)
class TestTDS(unittest.TestCase):
def test_hash(self):
obj = load('tds_hash.t7')
self.assertEqual(len(obj), 3)
self.assertEqual(obj[1], 2)
self.assertEqual(obj[10], 11)
def test_vec(self):
# Should not be affected by list heuristic at all
vec = load('tds_vec.t7', use_list_heuristic=False)
self.assertEqual(vec, [123, 456])
class TestHeuristics(unittest.TestCase):
def test_list_heuristic(self):
obj = load('list_table.t7', use_list_heuristic=True)
self.assertEqual(obj, [b'hello', b'world', b'third item', 123])
obj = load('list_table.t7',
use_list_heuristic=False,
use_int_heuristic=True)
self.assertEqual(
dict(obj),
{1: b'hello', 2: b'world', 3: b'third item', 4: 123})
def test_int_heuristic(self):
obj = load('hello=123.t7', use_int_heuristic=True)
self.assertIsInstance(obj[b'hello'], int)
obj = load('hello=123.t7', use_int_heuristic=False)
self.assertNotIsInstance(obj[b'hello'], int)
obj = load('list_table.t7',
use_list_heuristic=False,
use_int_heuristic=False)
self.assertEqual(
dict(obj),
{1: b'hello', 2: b'world', 3: b'third item', 4: 123})
self.assertNotIsInstance(list(obj.keys())[0], int)
if __name__ == '__main__':
unittest.main()