Skip to content

Commit

Permalink
[Bug]Autoslim different checkpoints have the same size (#193)
Browse files Browse the repository at this point in the history
* fix: split autoslim different checkpoint has equal model size

* chore: pre-commit

* chore: pre-commit

Co-authored-by: Lance(Yongle) Wang <lance.wang@vastaitech.com>
  • Loading branch information
Hiwyl and Lance(Yongle) Wang authored Jul 5, 2022
1 parent 3cc359e commit 1abad08
Showing 1 changed file with 30 additions and 21 deletions.
51 changes: 30 additions & 21 deletions mmrazor/models/pruners/structure_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,11 @@ def prepare_from_supernet(self, supernet):
for name, module in supernet.model.named_modules():
if isinstance(module, nn.GroupNorm):
min_required_version = '1.6.0'
assert digit_version(torch.__version__) >= digit_version(
min_required_version
), f'Requires pytorch>={min_required_version} to auto-trace' \
f'GroupNorm correctly.'
assert digit_version(
torch.__version__
) >= digit_version(min_required_version), (
f'Requires pytorch>={min_required_version} to auto-trace'
f'GroupNorm correctly.')
if hasattr(module, 'weight'):
# trace shared modules
module.cnt = 0
Expand Down Expand Up @@ -407,13 +408,14 @@ def make_same_out_channel_groups(self, node2parents, name2module):
same_in_channel_groups, same_out_channel_groups = {}, {}
for node_name, parents_name in node2parents.items():
parser = self.find_make_group_parser(node_name, name2module)
idx, same_in_channel_groups, same_out_channel_groups = \
parser(self,
node_name=node_name,
parents_name=parents_name,
group_idx=idx,
same_in_channel_groups=same_in_channel_groups,
same_out_channel_groups=same_out_channel_groups)
idx, same_in_channel_groups, same_out_channel_groups = parser(
self,
node_name=node_name,
parents_name=parents_name,
group_idx=idx,
same_in_channel_groups=same_in_channel_groups,
same_out_channel_groups=same_out_channel_groups,
)

groups = dict()
idx = 0
Expand Down Expand Up @@ -455,23 +457,29 @@ def add_pruning_attrs(self, module):
if isinstance(module, nn.Conv2d):
module.register_buffer(
'in_mask',
module.weight.new_ones((1, module.in_channels, 1, 1), ))
module.weight.new_ones((1, module.in_channels, 1, 1), ),
)
module.register_buffer(
'out_mask',
module.weight.new_ones((1, module.out_channels, 1, 1), ))
module.weight.new_ones((1, module.out_channels, 1, 1), ),
)
module.forward = self.modify_conv_forward(module)
if isinstance(module, nn.Linear):
module.register_buffer(
'in_mask', module.weight.new_ones((1, module.in_features), ))
'in_mask',
module.weight.new_ones((1, module.in_features), ),
)
module.register_buffer(
'out_mask', module.weight.new_ones((1, module.out_features), ))
'out_mask',
module.weight.new_ones((1, module.out_features), ),
)
module.forward = self.modify_fc_forward(module)
if (isinstance(module, _BatchNorm)
or isinstance(module, _InstanceNorm)
or isinstance(module, GroupNorm)):
if isinstance(module, _BatchNorm) or isinstance(
module, _InstanceNorm) or isinstance(module, GroupNorm):
module.register_buffer(
'out_mask',
module.weight.new_ones((1, len(module.weight), 1, 1), ))
module.weight.new_ones((1, len(module.weight), 1, 1), ),
)

def find_node_parents(self, paths):
"""Find the parent node of a node.
Expand Down Expand Up @@ -565,11 +573,12 @@ def deploy_subnet(self, supernet, channel_cfg):
if getattr(module, 'groups', in_channels) > 1:
module.groups = in_channels

module.weight = nn.Parameter(temp_weight.data)
module.weight = nn.Parameter(temp_weight.data.clone())
module.weight.requires_grad = requires_grad

if hasattr(module, 'bias') and module.bias is not None:
module.bias = nn.Parameter(module.bias.data[:out_channels])
module.bias = nn.Parameter(
module.bias.data[:out_channels].clone())
module.bias.requires_grad = requires_grad

if hasattr(module, 'running_mean'):
Expand Down

0 comments on commit 1abad08

Please sign in to comment.