-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprune_utils.py
87 lines (73 loc) · 2.85 KB
/
prune_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from pruners import *
from models import masked_layers
from tqdm import tqdm
import torch
import numpy as np
def load_pruner(method):
prune_methods = {
'rand' : Rand,
'mag' : Mag,
'snip' : SNIP,
'grasp': GraSP,
'synflow' : SynFlow,
}
return prune_methods[method]
def masks(module):
r"""Returns an iterator over modules masks, yielding the mask.
"""
for name, buf in module.named_buffers():
if "mask" in name:
yield buf
def prunable(module, batchnorm):
r"""Returns boolean whether a module is prunable.
"""
isprunable = isinstance(module, (masked_layers.Linear, masked_layers.Conv2d))
if batchnorm:
isprunable |= isinstance(module, (masked_layers.BatchNorm1d, masked_layers.BatchNorm2d))
return isprunable
def masked_parameters(model, bias=False, batchnorm=False):
r"""Returns an iterator over models prunable parameters, yielding both the
mask and parameter tensors.
"""
for module in filter(lambda p: prunable(p, batchnorm), model.modules()):
for mask, param in zip(masks(module), module.parameters(recurse=False)):
if param is not module.bias or bias is True:
yield mask, param
def prune_loop(model, loss, pruner, dataloader, device, sparsity, schedule, scope, epochs,
reinitialize=False, train_mode=False, shuffle=False, invert=False, return_stats=False, set_pruned_params_to_zero=False):
r"""Applies score mask loop iteratively to a final sparsity level.
"""
# Set model to train or eval mode
model.train()
if not train_mode:
model.eval()
# Prune model
for epoch in tqdm(range(epochs)):
pruner.score(model, loss, dataloader, device)
if schedule == 'exponential':
sparse = sparsity**((epoch + 1) / epochs)
elif schedule == 'linear':
sparse = 1.0 - (1.0 - sparsity)*((epoch + 1) / epochs)
# Invert scores
if invert:
pruner.invert()
pruner.mask(sparse, scope)
# Reainitialize weights
if reinitialize:
model._initialize_weights()
# Shuffle masks
if shuffle:
pruner.shuffle()
# Set pruned params to zero
if set_pruned_params_to_zero:
pruner.set_zeros()
# Confirm sparsity level
remaining_params, total_params = pruner.stats()
print('Prune at {}% sparsity level: Total Params: {}; Remaining Params: {}'.format(round((1 - sparsity)*100), total_params, int(remaining_params)))
if np.abs(remaining_params - total_params*sparsity) >= 30:
print("ERROR: {} prunable parameters remaining, expected {}".format(remaining_params, total_params*sparsity))
quit()
if return_stats:
remaining_params, total_params = pruner.meta_stats()
return remaining_params, total_params
return int(remaining_params), total_params