Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward and backward propagation in NTT stage #4

Open
tnn2018 opened this issue Sep 13, 2021 · 3 comments
Open

Forward and backward propagation in NTT stage #4

tnn2018 opened this issue Sep 13, 2021 · 3 comments

Comments

@tnn2018
Copy link

tnn2018 commented Sep 13, 2021

Hello!
Impressive work!
I'd like to know the meaning of parameters in the function: nt_transfer_step (ntt.py):
masked_g = grad(self.nt_transfer_loss)(student_net_params, masks, teacher_net_params, x, nn_density_level)

I noticed that in your reply to another issue (Pruned weights in NTK #3 ), the gradients of all parameters are updated; however,in the paper:
"Here, the backpropagated gradients flow through the fixed mask m and thus the masked-out parameters are not updated"
it makes me confused.
Look forward to your reply !
Best wishes!

@liutianlin0121
Copy link
Collaborator

Thanks for your interest and this keen observation. As mentioned in the paper, I intended to backprop through the mask. However, as mentioned in another issue you linked, the gradients might not be sparse — I am curious about this but I haven’t checked this myself (perhaps it only happens in global pruning? I am not sure). May I ask you to verify this phenomenon once more? It will be great if you could train a layerwise sparse net using NTT, and check if the layerwise gradient is indeed sparse. If not sparse, then this may hint to a glitch that I made in the implementation. But I would assume that masking or not masking the gradient will not significantly change the final NTT loss.

@tnn2018
Copy link
Author

tnn2018 commented Sep 14, 2021

I have observed the experimental results before, and the model is indeed sparse. Is it possible that the gradient is updated, but the parameters are not updated?
I understand your concern, but can you explain the parameters meaning of the above gradient first? There are still some difficulties in implementation, and I don't fully understand it. Thank you very much!

@liutianlin0121
Copy link
Collaborator

liutianlin0121 commented Sep 14, 2021

Sure! The line

masked_g = grad(self.nt_transfer_loss)(student_net_params, masks, teacher_net_params, x, nn_density_level)

computes the gradient of the NTT loss self.nt_transfer_loss (implemented here) with respect to its first argument (the student_net_params variable), and then evaluated at all arguments (those are student_net_params, masks, teacher_net_params, x, and nn_density_level).

Some more detailed comments:

  • Why is that the gradient is taken with respect to the first argument? This is due to the api of the jax.grad here. By default, argnums=0, meaning that the grad is only taken with respect to the first argument.
  • What are the meaning of parameters student_net_params, masks, teacher_net_params, x, and nn_density_level? They are the arguments of the NTT loss self.nt_transfer_loss (implemented here). student_net_params is the parameters of the student network, masks are the masks of the student network, teacher_net_params is the parameters of the teacher network, x is input features, and nn_density_level is the desired parameter density level of the student network. Notably, the student_net_params is still dense in this stage, and its sparsity is induced by masks. Mathematically, the NTT loss is defined in Equation (14) of our paper.

As we see in the code, I still believe that the gradient is indeed backprop through the mask (because the gradient is only taken w.r.t the first argument, the parameters of the student network). Let me know if this helps and if there are further questions!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants