-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Conversation
assert len(weights) == 1, f'Bad weights: {[w.name for w in wrapper.layer.weights]}' | ||
weight_list.append(tf.math.abs(weights[0].read_value())) | ||
all_bn_weights = tf.concat(weight_list, 0) | ||
k = int(all_bn_weights.shape[0] * (1 - pruner.wrappers[0].config['sparsity'])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use shape[0]
here which represents the first dimension of all_bn_weights
? If we want to pick (1-sparsity) max values from weight, shouldn't we use count of all elements in all_bn_weights
instead of shape[0]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current implementation is correct since this pruner only prunes BatchNormalization layer....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree it looks strange, but that's how PyTorch version works.
What's your suggestion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is because SlimPruner is a channel level pruner, so by default, we use shape[0]
. Maybe can refactor if we support more kind of channel scaling factor, but for now, it is good enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, SlimPruner is for pruning BatchNorm operator whose number of weight parameter is equal to channel number. In this case, all weight parameters are in one dimension and pruning each of them equal to pruning convolution channel. So current implementation is actually enough.
|
||
assert weight is not None | ||
if wrapper.masks.get(weight_name) is not None: | ||
weight *= wrapper.masks[weight_name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't fully understand why multiply mask here. Maybe it is used to prune layer step by step.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, slim is iterative pruning
@@ -3,15 +3,15 @@ | |||
from nni.compression.tensorflow import Pruner | |||
|
|||
__all__ = [ | |||
'OneshotPruner', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can keep this for user who wants to write their own pruner?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we export it than it's a public API and we need to maintain backward compatibility.
I think it's an internal helper and don't want to make it a formal API.
If a user really want it, they should copy the code.
assert len(weights) == 1, f'Bad weights: {[w.name for w in wrapper.layer.weights]}' | ||
weight_list.append(tf.math.abs(weights[0].read_value())) | ||
all_bn_weights = tf.concat(weight_list, 0) | ||
k = int(all_bn_weights.shape[0] * (1 - pruner.wrappers[0].config['sparsity'])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This place is a little different from pytorch implementation and may lead to different thresholds with the same weight and sparsity.
k = int(all_bn_weights.shape[0] * pruner.config_list[0]['sparsity']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The topk API is different between PyTorch and TensorFlow. If we don't use 1 -
the logic will be much more complicated and hard to read.
Maybe we can alter the rounding behavior to make them more similar, but I guess the original implementation does not care it at all...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, maybe alter the rounding behavior is better. Or each layer may gap at most two channels be pruned between pytorch and tensorflow, then I'm afraid the final effect may have a bigger gap.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can also use -tf.nn.top_k(-weights,k,sorted=False).values.numpy()
to get the minimum k values.
|
||
assert weight is not None | ||
if wrapper.masks.get(weight_name) is not None: | ||
weight *= wrapper.masks[weight_name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, slim is iterative pruning
assert len(weights) == 1, f'Bad weights: {[w.name for w in wrapper.layer.weights]}' | ||
weight_list.append(tf.math.abs(weights[0].read_value())) | ||
all_bn_weights = tf.concat(weight_list, 0) | ||
k = int(all_bn_weights.shape[0] * (1 - pruner.wrappers[0].config['sparsity'])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, SlimPruner is for pruning BatchNorm operator whose number of weight parameter is equal to channel number. In this case, all weight parameters are in one dimension and pruning each of them equal to pruning convolution channel. So current implementation is actually enough.
assert len(weights) == 1, f'Bad weights: {[w.name for w in wrapper.layer.weights]}' | ||
weight_list.append(tf.math.abs(weights[0].read_value())) | ||
all_bn_weights = tf.concat(weight_list, 0) | ||
k = int(all_bn_weights.shape[0] * (1 - pruner.wrappers[0].config['sparsity'])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can also use -tf.nn.top_k(-weights,k,sorted=False).values.numpy()
to get the minimum k values.
|
||
filters = weight.shape[0] | ||
num_prune = int(filters * sparsity) | ||
if filters >= 2 and num_prune >= 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it really contain some corner cases? Why not casting weight by self.global_threshold
directly instead of doing such branch identification.
Please resolve conflicts. |
No description provided.