Skip to content

Commit

Permalink
Merge pull request #310 from HollyLee2000/master
Browse files Browse the repository at this point in the history
fix nan and inf bug in sparse learning
  • Loading branch information
VainF committed Dec 14, 2023
2 parents ae9cc41 + 5e3d525 commit 8ee06e6
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions torch_pruning/pruner/algorithms/batchnorm_scale_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def regularize(self, model, reg=None, bias=False):
m.weight.grad.data.add_(reg*torch.sign(m.weight.data))
else:
for group in self._groups:
group_l2norm_sq = self._l2_imp(group)
if group_l2norm_sq is None:
group_l2norm_sq = self._l2_imp(group) + 1e-9 # + 1e-9 to avoid inf
if group_l2norm_sq is None or torch.any(torch.isnan(group_l2norm_sq)): # avoid nan
continue
gamma = reg * (1 / group_l2norm_sq.sqrt())

Expand Down
2 changes: 2 additions & 0 deletions torch_pruning/pruner/algorithms/group_norm_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def regularize(self, model, alpha=2**4, bias=False):
for i, group in enumerate(self._groups):
ch_groups = self._get_channel_groups(group)
imp = self.estimate_importance(group).sqrt()
if torch.any(torch.isnan(imp)): # avoid nan
continue
gamma = alpha**((imp.max() - imp) / (imp.max() - imp.min()))

# Update Gradient
Expand Down
2 changes: 1 addition & 1 deletion torch_pruning/pruner/algorithms/growing_reg_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def step(self, interactive=False):
def regularize(self, model, bias=False):
for i, group in enumerate(self._groups):
group_l2norm_sq = self.estimate_importance(group)
if group_l2norm_sq is None:
if group_l2norm_sq is None or torch.any(torch.isnan(group_l2norm_sq)): # avoid nan
continue
gamma = self.group_reg[group]
for k, (dep, idxs) in enumerate(group):
Expand Down

0 comments on commit 8ee06e6

Please sign in to comment.