-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathmodel_hook.py
106 lines (83 loc) · 3.66 KB
/
model_hook.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# -*- coding: utf-8 -*-
import time
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
from module_madd import compute_module_madd
class CModelHook(object):
def __init__(self, model, input_size):
assert isinstance(model, nn.Module)
assert isinstance(input_size, (list, tuple))
self._model = model
self._input_size = input_size
self._origin_call = dict() # sub module call hook
self._hook_model()
x = torch.rand(1, *self._input_size) # add module duration time
self._model.eval()
self._model(x)
@staticmethod
def _register_buffer(module):
assert isinstance(module, nn.Module)
if len(list(module.children())) > 0:
return
module.register_buffer('input_shape', torch.zeros(3).int())
module.register_buffer('output_shape', torch.zeros(3).int())
module.register_buffer('parameter_quantity', torch.zeros(1).int())
module.register_buffer('inference_memory', torch.zeros(1).long())
module.register_buffer('MAdd', torch.zeros(1).long())
module.register_buffer('duration', torch.zeros(1).float())
def _sub_module_call_hook(self):
def wrap_call(module, *input, **kwargs):
assert module.__class__ in self._origin_call
start = time.time()
output = self._origin_call[module.__class__](module, *input, **kwargs)
end = time.time()
module.duration = torch.from_numpy(
np.array([end - start], dtype=np.float32))
module.input_shape = torch.from_numpy(
np.array(input[0].size()[1:], dtype=np.int32))
module.output_shape = torch.from_numpy(
np.array(output.size()[1:], dtype=np.int32))
parameter_quantity = 0
# iterate through parameters and count num params
for name, p in module._parameters.items():
parameter_quantity += (0 if p is None else torch.numel(p.data))
module.parameter_quantity = torch.from_numpy(
np.array([parameter_quantity], dtype=np.long))
inference_memory = 1
for s in output.size()[1:]:
inference_memory *= s
# memory += parameters_number # exclude parameter memory
inference_memory = inference_memory * 4 / (1024 ** 2) # shown as MB unit
module.inference_memory = torch.from_numpy(
np.array([inference_memory], dtype=np.float32))
if len(input) == 1:
madd = compute_module_madd(module, input[0], output)
elif len(input) > 1:
madd = compute_module_madd(module, input, output)
else: # error
madd = 0
module.MAdd = torch.from_numpy(
np.array([madd], dtype=np.int64))
return output
for module in self._model.modules():
if len(list(module.children())) == 0 and module.__class__ not in self._origin_call:
self._origin_call[module.__class__] = module.__class__.__call__
module.__class__.__call__ = wrap_call
def _hook_model(self):
self._model.apply(self._register_buffer)
self._sub_module_call_hook()
@staticmethod
def _retrieve_leaf_modules(model):
leaf_modules = []
for name, m in model.named_modules():
if len(list(m.children())) == 0:
leaf_modules.append((name, m))
return leaf_modules
def retrieve_leaf_modules(self):
return OrderedDict(self._retrieve_leaf_modules(self._model))
def main():
pass
if __name__ == "__main__":
main()