forked from VainF/Torch-Pruning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_customized_layer.py
93 lines (74 loc) · 2.97 KB
/
test_customized_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
from typing import Sequence
############
# Customize your layer
#
class CustomizedLayer(nn.Module):
def __init__(self, in_dim):
super().__init__()
self.in_dim = in_dim
self.scale = nn.Parameter(torch.Tensor(self.in_dim))
self.bias = nn.Parameter(torch.Tensor(self.in_dim))
self.fc = nn.Linear(self.in_dim, self.in_dim)
def forward(self, x):
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()
x = torch.div(x, norm)
return self.fc(x * self.scale + self.bias)
def __repr__(self):
return "CustomizedLayer(in_dim=%d)"%(self.in_dim)
class FullyConnectedNet(nn.Module):
"""https://github.com/VainF/Torch-Pruning/issues/21"""
def __init__(self, input_size, num_classes, HIDDEN_UNITS):
super().__init__()
self.fc1 = nn.Linear(input_size, HIDDEN_UNITS)
self.customized_layer = CustomizedLayer(HIDDEN_UNITS)
self.fc2 = nn.Linear(HIDDEN_UNITS, num_classes)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.customized_layer(x)
y_hat = self.fc2(x)
return y_hat
############################
# Implement your pruning function for the customized layer
# You should implement the following class fucntions:
# 1. prune_out_channels
# 2. prune_in_channels
# 3. get_out_channels
# 4. get_in_channels
class MyPruner(tp.pruner.BasePruningFunc):
def prune_out_channels(self, layer: CustomizedLayer, idxs: Sequence[int]) -> nn.Module:
keep_idxs = list(set(range(layer.in_dim)) - set(idxs))
keep_idxs.sort()
layer.in_dim = layer.in_dim-len(idxs)
layer.scale = torch.nn.Parameter(layer.scale.data.clone()[keep_idxs])
layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs])
tp.prune_linear_in_channels(layer.fc, idxs)
tp.prune_linear_out_channels(layer.fc, idxs)
return layer
def get_out_channels(self, layer):
return self.in_dim
# identical functions
prune_in_channels = prune_out_channels
get_in_channels = get_out_channels
model = FullyConnectedNet(128, 10, 256)
DG = tp.DependencyGraph()
# 1. Register your customized layer
my_pruner = MyPruner()
DG.register_customized_layer(
CustomizedLayer,
my_pruner)
# 2. Build dependency graph
DG.build_dependency(model, example_inputs=torch.randn(1,128))
# 3. get a pruning group according to the dependency graph. idxs is the indices of pruned filters.
pruning_group = DG.get_pruning_group( model.fc1, tp.prune_linear_out_channels, idxs=[0, 1, 6] )
print(pruning_group)
# 4. execute this group (prune the model)
pruning_group.prune()
print("The pruned model:\n", model)
print("Output: ", model(torch.randn(1,128)).shape)
assert model.fc1.out_features==253 and model.customized_layer.in_dim==253 and model.fc2.in_features==253