Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pruning Tutorial Not Working Correctly #1054

Closed
glenn-jocher opened this issue Jul 5, 2020 · 6 comments
Closed

Pruning Tutorial Not Working Correctly #1054

glenn-jocher opened this issue Jul 5, 2020 · 6 comments

Comments

@glenn-jocher
Copy link

glenn-jocher commented Jul 5, 2020

Hello! I tried to reproduce the pruning tutorial (https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) with our YOLOv5 repo https://github.com/ultralytics/yolov5 and ran into problems. Code to reproduce:

import torch
import torch.nn.utils.prune as prune

model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)

prune.global_unstructured(model.parameters, pruning_method=prune.L1Unstructured, amount=0.3)

Output:

Traceback (most recent call last):
  File "/Users/glennjocher/opt/anaconda3/envs/env1/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-25-63574f84c909>", line 1, in <module>
    prune.global_unstructured(model.parameters, pruning_method=prune.L1Unstructured, amount=0.3)
  File "/Users/glennjocher/opt/anaconda3/envs/env1/lib/python3.7/site-packages/torch/nn/utils/prune.py", line 1016, in global_unstructured
    assert isinstance(parameters, Iterable)
AssertionError

Also tried passing model.parameters() with parenthesis:

prune.global_unstructured(model.parameters(), pruning_method=prune.L1Unstructured, amount=0.3)

Output:

Traceback (most recent call last):
  File "/Users/glennjocher/Library/Application Support/JetBrains/PyCharmCE2020.1/scratches/prune.py", line 6, in <module>
    prune.global_unstructured(model.parameters(), pruning_method=prune.L1Unstructured, amount=0.3)
  File "/Users/glennjocher/opt/anaconda3/envs/env1/lib/python3.7/site-packages/torch/nn/utils/prune.py", line 1019, in global_unstructured
    t = torch.nn.utils.parameters_to_vector([getattr(*p) for p in parameters])
  File "/Users/glennjocher/opt/anaconda3/envs/env1/lib/python3.7/site-packages/torch/nn/utils/prune.py", line 1019, in <listcomp>
    t = torch.nn.utils.parameters_to_vector([getattr(*p) for p in parameters])
TypeError: getattr expected at most 3 arguments, got 32

The model itself is fine. I can inspect it and see the shape of the parameters, for example:

[x.shape for x in model.parameters()][:10]
   ...: 
Out[26]: 
[torch.Size([32, 12, 3, 3]),
 torch.Size([32]),
 torch.Size([32]),
 torch.Size([64, 32, 3, 3]),
 torch.Size([64]),
 torch.Size([64]),
 torch.Size([32, 64, 1, 1]),
 torch.Size([32]),
 torch.Size([32]),
 torch.Size([32, 64, 1, 1])]
@mickypaganini
Copy link
Contributor

mickypaganini commented Jul 14, 2020

That's not how global_unstructured works. Please see the docs or the tutorial.

Specifically, parameters_to_prune contains tuples of (Module, param_string_name)

@mickypaganini
Copy link
Contributor

I'm not familiar with your architecture, so you'll have to decide which parameters it makes sense to pool together and compare via global magnitude-based pruning; but let's assume, just for the sake of this simple example, that you only want to consider the convolutional layers identified by the logic of my if-statement below [if those aren't the weights you care about, please feel free to modify that logic as you wish].

Now, those layers happen to come with two parameters: "weight" and "bias". Let's say you are interested in the weights [if you care about the biases too, feel free to add them in as well in the parameters_to_prune]. Alright, how do we tell global_unstructured to prune those weights in a global manner? We do so by constructing parameters_to_prune as requested by that function [again, see docs and tutorial linked above].

parameter_to_prune = [
    (v, "weight") 
    for k, v in dict(model.named_modules()).items()
    if ((len(list(v.children())) == 0) and (k.endswith('conv')))
]

# now you can use global_unstructured pruning
prune.global_unstructured(parameter_to_prune, pruning_method=prune.L1Unstructured, amount=0.3)

To check that that succeeded, you can now look at the global sparsity across those layers, which should be 30%, as well as the individual per-layer sparsity:

# global sparsity
nparams = 0
pruned = 0
for k, v in dict(model.named_modules()).items():
    if ((len(list(v.children())) == 0) and (k.endswith('conv'))):
        nparams += v.weight.nelement()
        pruned += torch.sum(v.weight == 0)
print('Global sparsity across the pruned layers: {:.2f}%'.format( 100. * pruned / float(nparams)))
# ^^ should be 30%

# local sparsity
for k, v in dict(model.named_modules()).items():
    if ((len(list(v.children())) == 0) and (k.endswith('conv'))):
        print(
            "Sparsity in {}: {:.2f}%".format(
                k,
                100. * float(torch.sum(v.weight == 0))
                / float(v.weight.nelement())
            )
        )
# ^^ will be different for each layer

@glenn-jocher
Copy link
Author

glenn-jocher commented Jul 14, 2020

@mickypaganini thank you, I've updated our tutorial ultralytics/yolov5#304 with your comments. I got this to work finally, and we are seeing the correct global sparsity after pruning. Some people were confused that model size was not correspondingly reduced on saving, but I explained to them this was natural as only elements in the matrices are being pruned, not entire channels.

@mickypaganini
Copy link
Contributor

@glenn-jocher yep, we are just inserting zeros into the tensors in the right places (whether elements, in unstructured pruning, or channels, in structured pruning), but the size of those tensors isn't changing, unless you represent them in coordinate representation (see here for an example that uses .to_sparse(): #605 (comment))

@soumendukrg
Copy link

Have anyone tried structured pruning and thinning on YOLOv5 models, which can actually delete filters/channels and reduce MAC operations and thereby reduce inference time?

@PraveenMNaik
Copy link

@glenn-jocher , I tried with pruning the model,,,, but i am unable to fing the pruned model after the procedure... where can i find the pruned model (.pt)...pls help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants