Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scale gradient for backward pass #521

Closed
enpasos opened this issue Jan 13, 2021 · 6 comments · Fixed by #548
Closed

scale gradient for backward pass #521

enpasos opened this issue Jan 13, 2021 · 6 comments · Fixed by #548
Labels
question Further information is requested

Comments

@enpasos
Copy link
Contributor

enpasos commented Jan 13, 2021

Question or maybe Enhancement

I'm missing a feature to scale the gradient for backward pass (as e.g. used in MuZero) ... something like
tensor * scale + stop_gradient(tensor) * (1 - scale)
I'm not sure if the feature is missing or I'm simply not seeing the proper way how to do it.

Workaround

I worked around it by adding an additional forward pass, keeping the tensor as outputs and putting them in on training forward as "stop_gradient(tensor)"-inputs. This works functionally, but comes at the cost of

  1. memory consumption on the training device (rare on my gpu)
  2. lower performance
  3. higher complexity
@enpasos enpasos added the question Further information is requested label Jan 13, 2021
@roywei
Copy link
Contributor

roywei commented Jan 16, 2021

You can use block.getParameters() to get the parameters and getArray to get the param value, then getGradient to access the gradient. You can do inplace update on the gradient value. (e.g. grad.muli(scale))

Not sure if this is what you want, if not, please provide some python code in TF, PyTorch or MXNet so we can take a look.
Thanks!

@enpasos
Copy link
Contributor Author

enpasos commented Jan 16, 2021

Thank you very much for your reply.
I think the methods you mentioned are useful for some use cases.
For use cases where the concerned node on the graph is passed many times I do not see a clever way where the methods you mentioned lead to a simple solution.

I would like to give some more information about the use case I am looking at:

MuZero use case: Java implementation of MuZero based on DJL (MXNet as Framework).

Need: The MuZero paper comes with Python-Pseudocode (see inside the suplimentary data). The pseudocode uses this function

def scale_gradient(tensor: Any, scale):
    """Scales the gradient for the backward pass."""
    return tensor * scale + tf.stop_gradient(tensor) * (1 - scale)

to scale down the error backpropagation from the recurrently called dynamic function.

Support in the frameworks
In tensorflow I see the function stop_gradient on the python api.
As I am using MXNet I searched for the support in MXNet and found this.

@enpasos
Copy link
Contributor Author

enpasos commented Jan 17, 2021

I think I found the function in the MXNet-Python API: BlockGrad

@enpasos
Copy link
Contributor Author

enpasos commented Jan 17, 2021

It would be great to have it in Java, too.

@enpasos
Copy link
Contributor Author

enpasos commented Jan 17, 2021

I'll test this

    public static NDArray stopGradient(NDArray in) {
        MxNDManager manager = (MxNDManager)in.getManager();
        MxOpParams params = new MxOpParams();
        return manager.invoke("stop_gradient", in, params);
    }

@enpasos
Copy link
Contributor Author

enpasos commented Jan 18, 2021

The stopGradient works well for me: I could remove my workaround and therefore gained gpu memory ... enough to double the batchsize.

As it is a general functionality (e.g. used in MuZero) it would be very useful to add the functionality on the Java API, too. e.g. in the NDArray interface and its implementations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants