Skip to content

Commit

Permalink
fix: stitch_pars bug in stitching values other than fixed values (#991)
Browse files Browse the repository at this point in the history
* Fix bug in stitch_pars to use stitch_with argument
* Refactor stitch_pars to inside _make_stitch_pars function 
* Add test for stitch_pars
  • Loading branch information
kratsg authored Jul 27, 2020
1 parent 27f35e9 commit 60488cd
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
31 changes: 27 additions & 4 deletions src/pyhf/optimize/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,31 @@
from ..tensor.common import _TensorViewer


def _make_stitch_pars(tv=None, fixed_values=None):
"""
Construct a callable to stitch fixed paramter values into the unfixed parameters. See :func:`shim`.
This is extracted out to be unit-tested for proper behavior.
If ``tv`` or ``fixed_values`` are not provided, this returns the identity callable.
Args:
tv (~pyhf.tensor.common._TensorViewer): tensor viewer instance
fixed_values (`list`): default set of values to stitch parameters with
Returns:
callable (`func`): a callable that takes nuisance parameter values as input
"""
if tv is None or fixed_values is None:
return lambda pars, stitch_with=None: pars

def stitch_pars(pars, stitch_with=fixed_values):
tb, _ = get_backend()
return tv.stitch([tb.astensor(stitch_with, dtype='float'), pars])

return stitch_pars


def _get_tensor_shim():
"""
A shim-retriever to lazy-retrieve the necessary shims as needed.
Expand Down Expand Up @@ -96,15 +121,13 @@ def shim(

tv = _TensorViewer([fixed_idx, variable_idx])
# NB: this is a closure, tensorlib needs to be accessed at a different point in time
def stitch_pars(pars, stitch_with=fixed_values):
tb, _ = get_backend()
return tv.stitch([tb.astensor(fixed_values, dtype='float'), pars])
stitch_pars = _make_stitch_pars(tv, fixed_values)

else:
variable_init = init_pars
variable_bounds = par_bounds
minimizer_fixed_vals = fixed_vals
stitch_pars = lambda pars, stitch_with=None: pars
stitch_pars = _make_stitch_pars()

objective_and_grad = _get_tensor_shim()(
objective,
Expand Down
29 changes: 28 additions & 1 deletion tests/test_optim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pyhf
from pyhf.optimize.mixins import OptimizerMixin
from pyhf.optimize.common import _get_tensor_shim
from pyhf.optimize.common import _get_tensor_shim, _make_stitch_pars
from pyhf.tensor.common import _TensorViewer
import pytest
from scipy.optimize import minimize, OptimizeResult
import iminuit
Expand Down Expand Up @@ -409,3 +410,29 @@ def test_get_tensor_shim(monkeypatch):
_get_tensor_shim()

assert 'No optimizer shim for fake_backend.' == str(excinfo.value)


def test_stitch_pars(backend):
tb, _ = backend

passthrough = _make_stitch_pars()
pars = ['a', 'b', 1.0, 2.0, object()]
assert passthrough(pars) == pars

fixed_idx = [0, 3, 4]
variable_idx = [1, 2, 5]
fixed_vals = [10, 40, 50]
variable_vals = [20, 30, 60]
tv = _TensorViewer([fixed_idx, variable_idx])
stitch_pars = _make_stitch_pars(tv, fixed_vals)

pars = tb.astensor(variable_vals)
assert tb.tolist(stitch_pars(pars)) == [10, 20, 30, 40, 50, 60]
assert tb.tolist(stitch_pars(pars, stitch_with=tb.zeros(3))) == [
0,
20,
30,
0,
0,
60,
]

0 comments on commit 60488cd

Please sign in to comment.