Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] differentiable model construction #742

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

lukasheinrich
Copy link
Contributor

@lukasheinrich lukasheinrich commented Jan 18, 2020

Description

The JSON structure from which the model is constructed is essentially a collection of array data. If literally read from JSON, it is static data, however that array data might also be the output of other programs and one would like to have a fully differentiable chain of operations throughout the JSON-like structure (@phinate is looking into this more closely)

This PR makes especially _create_nominal_and_modifiers differentiable. Perhaps for now this might not be merged in directly, but will just provide @phinate with a functioning fork until we figure out how to do this properly / more cleanly.

In this branch this works:

import pyhf
import jax
pyhf.set_backend(pyhf.tensor.jax_backend())


def model(signal,data,pars):
    m = pyhf.simplemodels.hepdata_like(
        signal_data=signal,
        bkg_data=pyhf.tensorlib.astensor([50.0, 52.0]),
        bkg_uncerts=pyhf.tensorlib.astensor([3.0, 7.0])
    )
    data = pyhf.tensorlib.concatenate([
        data,
        pyhf.tensorlib.astensor(m.config.auxdata)
    ])
    return m.logpdf(pars,data)[0]

model(
    pyhf.tensorlib.astensor([12.0, 11.0]),
    pyhf.tensorlib.astensor([50.0, 52.0]),
    pyhf.tensorlib.astensor([1.0,1.0,1.0]),
)

jax.grad(model)(
    pyhf.tensorlib.astensor([12.0, 11.0]),
    pyhf.tensorlib.astensor([50.0, 52.0]),
    pyhf.tensorlib.astensor([1.0,1.0,1.0]),
)

@kratsg
Copy link
Contributor

kratsg commented Jan 18, 2020

The JSON structure from which the model is constructed is essentially a collection of array data. If literally read from JSON, it is static data, however that array data might also be the output of other programs and one would like to have a fully differentiable chain of operations throughout the JSON-like structure (@phinate is looking into this more closely)

not sure we'd want to do it from JSON. The whole point of JSON is it implies static. If you want to build a differentiable model from python -- that's a different thing to explore. Likely, what you want to do is create a DifferentiableModel class and work on that instead, and call the existing ones StaticWorkspace / StaticModel classes. Because of duck-typing, you only need to keep the same API.

@kratsg
Copy link
Contributor

kratsg commented Jan 18, 2020

What benefit do you get from a differentiable model apart from being able to dynamically substitute in the array values?

@lukasheinrich
Copy link
Contributor Author

consider

signal = neural_network(weights)
model = pyhf.Model({"signal": signal, ... })
logpdf = model.logpdf(pars,data)

this will allow you compute the gradient dlogpdf / dweights

@lgtm-com
Copy link

lgtm-com bot commented Jul 28, 2020

This pull request introduces 1 alert when merging 5e57400 into 60488cd - view on LGTM.com

new alerts:

  • 1 for Unused import

@phinate
Copy link
Contributor

phinate commented Jul 29, 2020

@kratsg A mwe of this failing:

import pyhf
pyhf.set_backend('jax')

def from_spec(yields):
        
        s, b, bup, bdown = yields
        
        spec = {
            "channels": [
                {
                    "name": "nn",
                    "samples": [
                        {
                            "name": "signal",
                            "data": s,
                            "modifiers": [
                                {"name": "mu", "type": "normfactor", "data": None}
                            ],
                        },
                        {
                            "name": "bkg",
                            "data": b,
                            "modifiers": [
                                {
                                    "name": "nn_histosys",
                                    "type": "histosys",
                                    "data": {
                                        "lo_data": bdown,
                                        "hi_data": bup,
                                    },
                                }
                            ],
                        },      
                    ],
                },
            ],
        }

        return pyhf.Model(spec)

import jax.numpy as jnp
y =[jnp.array([ 5.,  9.,  4.]),
    jnp.array([23., 46.,  23.]),
    jnp.array([24., 46., 22.]),
    jnp.array([25., 46., 22.])]

from_spec(y) 

> ValueError                                Traceback (most recent call last)
 in 
      4     jnp.array([24., 46., 22.]),
      5     jnp.array([25., 46., 22.])]
----> 6 from_spec(y)

 in from_spec(yields)
     34         }
     35 
---> 36         return pyhf.Model(spec)

~/neos/pyhf/src/pyhf/pdf.py in __init__(self, spec, batch_size, **config_kwargs)
    590         self.config = _ModelConfig(self.spec, **config_kwargs)
    591 
--> 592         mega_mods, _nominal_rates = _nominal_and_modifiers_from_spec(
    593             self.config, self.spec
    594         )

~/neos/pyhf/src/pyhf/pdf.py in _nominal_and_modifiers_from_spec(config, spec)
    153                 else [0.0] * config.channel_nbins[c]
    154             )
--> 155             mega_nom += nom
    156             defined_mods = (
    157                 {

ValueError: operands could not be broadcast together with shapes (0,) (3,)

In a more complicated setting (taking gradients), it failed with a different error when trying to concatenate 1-d tensors in _create_nominal_and_modifiers, e.g. here:

final_mods[key][s]['data']['lo_data'] = tensorlib.concatenate(

List concatenation by + is something I think this PR was meant to factor out? (doesn't extend to tensors)

@kratsg
Copy link
Contributor

kratsg commented Jul 30, 2020

List concatenation by + is something I think this PR was meant to factor out? (doesn't extend to tensors)

should be fixed now. I also checked validity of results

(pyhf-dev) Lord Stark:~/pyhf (diffable_json)$ python run.py 
[ 3.19467733 -0.14385454]
(pyhf-dev) Lord Stark:~/pyhf (master)$ python run.py 
[ 3.19467733 -0.14385454]

so things are at least consistent (checked in numpy -- the scipy.optimize.minimize hates jax for some reason in the diffable branch when you pass in jax arrays instead for inits/bounds)

@lgtm-com
Copy link

lgtm-com bot commented Jul 30, 2020

This pull request introduces 1 alert when merging 3c8aec8 into 5fcf95b - view on LGTM.com

new alerts:

  • 1 for Unused import

@phinate
Copy link
Contributor

phinate commented Jul 30, 2020

Thanks for the fix, that works great!

Now, we're dealing with the same errors I encountered when I forked this originally: default_backend being numpy, even when overriden by pyhf.default_backend = pyhf.tensor.jax_backend(precision='64b'). (I was doing this directly in the relevant modules, before giving up and explicitly calling tensor_backend)

A MWE for this use case, similar to the above:

import pyhf
pyhf.set_backend('jax')
pyhf.default_backend = pyhf.tensor.jax_backend(precision='64b')

Define a model that depends on a data-altering parameter, and return an eval of logpdf:

import jax.numpy as jnp

def from_spec(param):

        yields =jnp.array([[ 5.,  9.,  4.],
                           [23., 46.,  23.],
                           [24., 46., 22.],
                           [25., 46., 22.]])
            
        s, b, bup, bdown = yields*param
        
        spec = {
            "channels": [
                {
                    "name": "nn",
                    "samples": [
                        {
                            "name": "signal",
                            "data": s,
                            "modifiers": [
                                {"name": "mu", "type": "normfactor", "data": None}
                            ],
                        },
                        {
                            "name": "bkg",
                            "data": b,
                            "modifiers": [
                                {
                                    "name": "nn_histosys",
                                    "type": "histosys",
                                    "data": {
                                        "lo_data": bdown,
                                        "hi_data": bup,
                                    },
                                }
                            ],
                        },      
                    ],
                },
            ],
        }

        pars = jnp.array([1,1])
        data = jnp.array([25.,50.,23.,2])
        return pyhf.Model(spec).logpdf(pars,data)[0]

We can then try to take the gradient of logpdf with respect to param:

from jax import value_and_grad

value_and_grad(from_spec)(4.)

This throws:


~/neos/pyhf/src/pyhf/interpolators/code0.py in __init__(self, histogramssets, subscribe)
     27         """Piecewise-linear Interpolation."""
     28         # nb: this should never be a tensor, store in default backend (e.g. numpy)
---> 29         self._histogramssets = default_backend.astensor(histogramssets)

~/neos/pyhf/src/pyhf/tensor/numpy_backend.py in astensor(self, tensor_in, dtype)
    169             raise
    170 
--> 171         return np.asarray(tensor_in, dtype=dtype)
    172 
    173     def sum(self, tensor_in, axis=None):

~/neos/env/lib/python3.8/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
     81 
     82     """
---> 83     return array(a, dtype, copy=False, order=order)
     84 
     85 

~/neos/env/lib/python3.8/site-packages/jax/core.py in __array__(self, *args, **kw)
    448            "JAX Tracer instance; in that case, you can instead write "
    449            "`jax.device_put(x)[idx]`.")
--> 450     raise Exception(msg)

Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Tracedwith
  with primal = DeviceArray([100., 184.,  88.], dtype=float64)
       tangent = Traced.

This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.

So theoretically, overriding the default backend should have worked, but numpy is still being used here. Any thoughts on why this behaviour is occuring? I never figured it out.

@kratsg
Copy link
Contributor

kratsg commented Jul 31, 2020

You're hitting this limitation in jax.

Screenshot 2020-07-30 22 58 19

@kratsg
Copy link
Contributor

kratsg commented Jul 31, 2020

The following code words for me now:

# gotta patch stuff
import sys
from unittest.mock import patch
# let's get started
import pyhf
jax_backend = pyhf.tensor.jax_backend(precision='64b')
pyhf.set_backend(jax_backend)

import jax.numpy as jnp

@patch('pyhf.default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.interpolators.code0'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.interpolators.code1'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.interpolators.code2'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.interpolators.code4'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.interpolators.code4p'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.modifiers.shapefactor'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.modifiers.shapesys'], 'default_backend', new=jax_backend)
@patch.object(sys.modules['pyhf.modifiers.staterror'], 'default_backend', new=jax_backend)
def from_spec(param):
...

@kratsg
Copy link
Contributor

kratsg commented Jul 31, 2020

For the jax issue that i temporarily patched, I filed the report here jax-ml/jax#3919

@lgtm-com
Copy link

lgtm-com bot commented Jul 31, 2020

This pull request introduces 1 alert when merging 81ee47a into 5fcf95b - view on LGTM.com

new alerts:

  • 1 for Unused import

@matthewfeickert
Copy link
Member

For the jax issue that i temporarily patched, I filed the report here google/jax#3919

...which has already been fixed by the blazing fast @jakevdp! :) This just missed being in JAX v0.1.75 (literally the commit in master after the release) but that means that this should be in v0.1.76 and JAX release on a pretty impressively fast basis. I'm watching JAX releases, so I'll make a note to update the minimum version of jax and jaxlib in setup.py

pyhf/setup.py

Line 10 in 5fcf95b

'jax': ['jax~=0.1,>0.1.51', 'jaxlib~=0.1,>0.1.33'],

the next time they release.

@matthewfeickert matthewfeickert added the feat/enhancement New feature or request label Jul 31, 2020
@phinate
Copy link
Contributor

phinate commented Jul 31, 2020

Awesome!! Thanks guys (and @jakevdp), that was the quickest turnaround for an external bug I’ve ever seen...

@phinate
Copy link
Contributor

phinate commented Jul 31, 2020

kde_working

Now neos works directly from this branch :)

@phinate phinate mentioned this pull request Oct 15, 2021
4 tasks
@matthewfeickert matthewfeickert changed the base branch from master to main September 21, 2022 20:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feat/enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants