-
Notifications
You must be signed in to change notification settings - Fork 85
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] differentiable model construction #742
base: main
Are you sure you want to change the base?
Conversation
not sure we'd want to do it from JSON. The whole point of JSON is it implies static. If you want to build a differentiable model from python -- that's a different thing to explore. Likely, what you want to do is create a |
What benefit do you get from a differentiable model apart from being able to dynamically substitute in the array values? |
consider
this will allow you compute the gradient |
This pull request introduces 1 alert when merging 5e57400 into 60488cd - view on LGTM.com new alerts:
|
@kratsg A mwe of this failing:
In a more complicated setting (taking gradients), it failed with a different error when trying to concatenate 1-d tensors in Line 214 in 5e57400
List concatenation by |
should be fixed now. I also checked validity of results
so things are at least consistent (checked in numpy -- the scipy.optimize.minimize hates jax for some reason in the diffable branch when you pass in jax arrays instead for inits/bounds) |
This pull request introduces 1 alert when merging 3c8aec8 into 5fcf95b - view on LGTM.com new alerts:
|
Thanks for the fix, that works great! Now, we're dealing with the same errors I encountered when I forked this originally: A MWE for this use case, similar to the above:
Define a model that depends on a data-altering parameter, and return an eval of
We can then try to take the gradient of
This throws:
So theoretically, overriding the default backend should have worked, but numpy is still being used here. Any thoughts on why this behaviour is occuring? I never figured it out. |
The following code words for me now: # gotta patch stuff
import sys
from unittest.mock import patch
# let's get started
import pyhf
jax_backend = pyhf.tensor.jax_backend(precision='64b')
pyhf.set_backend(jax_backend)
import jax.numpy as jnp
@patch('pyhf.default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.interpolators.code0'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.interpolators.code1'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.interpolators.code2'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.interpolators.code4'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.interpolators.code4p'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.modifiers.shapefactor'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.modifiers.shapesys'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.modifiers.staterror'], 'default_backend', new=jax_backend)
def from_spec(param):
... |
For the jax issue that i temporarily patched, I filed the report here jax-ml/jax#3919 |
This pull request introduces 1 alert when merging 81ee47a into 5fcf95b - view on LGTM.com new alerts:
|
...which has already been fixed by the blazing fast @jakevdp! :) This just missed being in JAX Line 10 in 5fcf95b
the next time they release. |
Awesome!! Thanks guys (and @jakevdp), that was the quickest turnaround for an external bug I’ve ever seen... |
Description
The JSON structure from which the model is constructed is essentially a collection of array data. If literally read from JSON, it is static data, however that array data might also be the output of other programs and one would like to have a fully differentiable chain of operations throughout the JSON-like structure (@phinate is looking into this more closely)
This PR makes especially
_create_nominal_and_modifiers
differentiable. Perhaps for now this might not be merged in directly, but will just provide @phinate with a functioning fork until we figure out how to do this properly / more cleanly.In this branch this works: