Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add TensorFlow slim pruner #3614

Merged
merged 4 commits into from
May 27, 2021
Merged

Conversation

liuzhe-lz
Copy link
Contributor

No description provided.

@ultmaster ultmaster requested review from J-shang and linbinskn May 14, 2021 08:34
assert len(weights) == 1, f'Bad weights: {[w.name for w in wrapper.layer.weights]}'
weight_list.append(tf.math.abs(weights[0].read_value()))
all_bn_weights = tf.concat(weight_list, 0)
k = int(all_bn_weights.shape[0] * (1 - pruner.wrappers[0].config['sparsity']))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use shape[0] here which represents the first dimension of all_bn_weights? If we want to pick (1-sparsity) max values from weight, shouldn't we use count of all elements in all_bn_weights instead of shape[0]?

Copy link
Contributor

@linbinskn linbinskn May 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current implementation is correct since this pruner only prunes BatchNormalization layer....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it looks strange, but that's how PyTorch version works.
What's your suggestion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is because SlimPruner is a channel level pruner, so by default, we use shape[0]. Maybe can refactor if we support more kind of channel scaling factor, but for now, it is good enough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, SlimPruner is for pruning BatchNorm operator whose number of weight parameter is equal to channel number. In this case, all weight parameters are in one dimension and pruning each of them equal to pruning convolution channel. So current implementation is actually enough.


assert weight is not None
if wrapper.masks.get(weight_name) is not None:
weight *= wrapper.masks[weight_name]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't fully understand why multiply mask here. Maybe it is used to prune layer step by step.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, slim is iterative pruning

@@ -3,15 +3,15 @@
from nni.compression.tensorflow import Pruner

__all__ = [
'OneshotPruner',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can keep this for user who wants to write their own pruner?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we export it than it's a public API and we need to maintain backward compatibility.
I think it's an internal helper and don't want to make it a formal API.
If a user really want it, they should copy the code.

assert len(weights) == 1, f'Bad weights: {[w.name for w in wrapper.layer.weights]}'
weight_list.append(tf.math.abs(weights[0].read_value()))
all_bn_weights = tf.concat(weight_list, 0)
k = int(all_bn_weights.shape[0] * (1 - pruner.wrappers[0].config['sparsity']))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This place is a little different from pytorch implementation and may lead to different thresholds with the same weight and sparsity.

k = int(all_bn_weights.shape[0] * pruner.config_list[0]['sparsity'])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The topk API is different between PyTorch and TensorFlow. If we don't use 1 - the logic will be much more complicated and hard to read.
Maybe we can alter the rounding behavior to make them more similar, but I guess the original implementation does not care it at all...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, maybe alter the rounding behavior is better. Or each layer may gap at most two channels be pruned between pytorch and tensorflow, then I'm afraid the final effect may have a bigger gap.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can also use -tf.nn.top_k(-weights,k,sorted=False).values.numpy() to get the minimum k values.


assert weight is not None
if wrapper.masks.get(weight_name) is not None:
weight *= wrapper.masks[weight_name]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, slim is iterative pruning

assert len(weights) == 1, f'Bad weights: {[w.name for w in wrapper.layer.weights]}'
weight_list.append(tf.math.abs(weights[0].read_value()))
all_bn_weights = tf.concat(weight_list, 0)
k = int(all_bn_weights.shape[0] * (1 - pruner.wrappers[0].config['sparsity']))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, SlimPruner is for pruning BatchNorm operator whose number of weight parameter is equal to channel number. In this case, all weight parameters are in one dimension and pruning each of them equal to pruning convolution channel. So current implementation is actually enough.

assert len(weights) == 1, f'Bad weights: {[w.name for w in wrapper.layer.weights]}'
weight_list.append(tf.math.abs(weights[0].read_value()))
all_bn_weights = tf.concat(weight_list, 0)
k = int(all_bn_weights.shape[0] * (1 - pruner.wrappers[0].config['sparsity']))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can also use -tf.nn.top_k(-weights,k,sorted=False).values.numpy() to get the minimum k values.


filters = weight.shape[0]
num_prune = int(filters * sparsity)
if filters >= 2 and num_prune >= 1:
Copy link
Contributor

@linbinskn linbinskn May 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it really contain some corner cases? Why not casting weight by self.global_threshold directly instead of doing such branch identification.

@linbinskn
Copy link
Contributor

Please resolve conflicts.

@ultmaster ultmaster merged commit e349b44 into microsoft:master May 27, 2021
@liuzhe-lz liuzhe-lz deleted the slim-pruner-tf branch June 17, 2021 03:27
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support one tensorflow algorithm for model compression
4 participants