Skip to content

Commit

Permalink
Merge pull request #91 from ArgonneCPAC/zero_effect_params
Browse files Browse the repository at this point in the history
Introduce ZERODUST_AVPOP_PARAMS for `avpop_mono.py`
  • Loading branch information
aphearin authored Feb 14, 2025
2 parents dc09d21 + 2eb64ee commit 6737086
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 3 deletions.
10 changes: 10 additions & 0 deletions diffsky/burstpop/fburstpop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""

from collections import OrderedDict, namedtuple
from copy import deepcopy

from jax import jit as jjit
from jax import numpy as jnp
Expand Down Expand Up @@ -47,6 +48,15 @@
DEFAULT_FBURSTPOP_PARAMS = FburstPopParams(**DEFAULT_FBURSTPOP_PDICT)
FBURSTPOP_PBOUNDS = FburstPopParams(**FBURSTPOP_BOUNDS_PDICT)

_EPS = 0.1
ZEROBURST_FBURSTPOP_PARAMS = deepcopy(DEFAULT_FBURSTPOP_PARAMS)
ZEROBURST_FBURSTPOP_PARAMS = ZEROBURST_FBURSTPOP_PARAMS._replace(
lgfburst_logsm_ylo_q=_LGFBURST_BOUNDS[0] + _EPS,
lgfburst_logsm_ylo_ms=_LGFBURST_BOUNDS[0] + _EPS,
lgfburst_logsm_yhi_q=_LGFBURST_BOUNDS[0] + _EPS,
lgfburst_logsm_yhi_ms=_LGFBURST_BOUNDS[0] + _EPS,
)


@jjit
def get_lgfburst_from_fburstpop_u_params(fburstpop_u_params, logsm, logssfr):
Expand Down
10 changes: 10 additions & 0 deletions diffsky/burstpop/freqburst_mono.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""

from collections import OrderedDict, namedtuple
from copy import deepcopy

from dsps.utils import _inverse_sigmoid, _sigmoid
from jax import jit as jjit
Expand Down Expand Up @@ -46,6 +47,15 @@
DEFAULT_FREQBURST_PARAMS = FreqburstParams(**DEFAULT_FREQBURST_PDICT)
FREQBURST_PBOUNDS = FreqburstParams(**FREQBURST_PBOUNDS_PDICT)

_EPS = 0.2
ZEROBURST_FREQBURST_PARAMS = deepcopy(DEFAULT_FREQBURST_PARAMS)
ZEROBURST_FREQBURST_PARAMS = ZEROBURST_FREQBURST_PARAMS._replace(
sufqb_logsm_ylo_q=U_BOUNDS[0] + _EPS,
sufqb_logsm_ylo_ms=U_BOUNDS[0] + _EPS,
sufqb_logsm_yhi_q=U_BOUNDS[0] + _EPS,
sufqb_logsm_yhi_ms=U_BOUNDS[0] + _EPS,
)


@jjit
def double_sigmoid_monotonic(u_params, x, y, x0, y0, xk, yk, z_bounds):
Expand Down
39 changes: 36 additions & 3 deletions diffsky/burstpop/tests/test_fburstpop.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""
"""

import numpy as np
from jax import random as jran

from ..fburstpop import (
DEFAULT_FBURSTPOP_PARAMS,
DEFAULT_FBURSTPOP_U_PARAMS,
FBURSTPOP_PBOUNDS,
ZEROBURST_FBURSTPOP_PARAMS,
get_bounded_fburstpop_params,
get_lgfburst_from_fburstpop_params,
get_lgfburst_from_fburstpop_u_params,
Expand All @@ -15,6 +18,14 @@
TOL = 1e-2


def test_default_params_are_in_bounds():

gen = zip(DEFAULT_FBURSTPOP_PARAMS, DEFAULT_FBURSTPOP_PARAMS._fields)
for val, key in gen:
bound = getattr(FBURSTPOP_PBOUNDS, key)
assert bound[0] < val < bound[1]


def test_param_u_param_names_propagate_properly():
gen = zip(DEFAULT_FBURSTPOP_U_PARAMS._fields, DEFAULT_FBURSTPOP_PARAMS._fields)
for u_key, key in gen:
Expand Down Expand Up @@ -61,9 +72,7 @@ def test_get_lgfburst_from_fburstpop_u_params_fails_when_passing_params():

try:
get_lgfburst_from_fburstpop_u_params(DEFAULT_FBURSTPOP_PARAMS, logsm, logssfr)
raise NameError(
"get_lgfburst_from_fburstpop_u_params should not accept params"
)
raise NameError("get_lgfburst_from_fburstpop_u_params should not accept params")
except AttributeError:
pass

Expand Down Expand Up @@ -112,3 +121,27 @@ def test_get_bursty_age_weights_pop_u_param_inversion():
assert np.all(np.isfinite(gal_lgfburst_u))

assert np.allclose(gal_lgfburst, gal_lgfburst_u, rtol=1e-4)


def test_zeroburst_params_are_in_bounds():

gen = zip(ZEROBURST_FBURSTPOP_PARAMS, ZEROBURST_FBURSTPOP_PARAMS._fields)
for val, key in gen:
bound = getattr(FBURSTPOP_PBOUNDS, key)
assert bound[0] < val < bound[1]


def test_zeroburst_params_produce_zero_burstiness():
ran_key = jran.key(0)
sm_key, ssfr_key = jran.split(ran_key, 2)
n_gals = 2_000
logsmarr = jran.uniform(sm_key, minval=5, maxval=13, shape=(n_gals,))
logssfr = jran.uniform(ssfr_key, minval=-14, maxval=-5, shape=(n_gals,))

fb = 10 ** get_lgfburst_from_fburstpop_params(
ZEROBURST_FBURSTPOP_PARAMS, logsmarr, logssfr
)
assert fb.shape == (n_gals,)
assert np.all(np.isfinite(fb))
assert np.all(fb > 0.0)
assert np.all(fb < 1e-4)
46 changes: 46 additions & 0 deletions diffsky/burstpop/tests/test_freqburst_mono.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from ..freqburst_mono import (
DEFAULT_FREQBURST_PARAMS,
DEFAULT_FREQBURST_U_PARAMS,
FREQBURST_PBOUNDS,
SUFQB_BOUNDS,
ZEROBURST_FREQBURST_PARAMS,
FreqburstUParams,
get_bounded_freqburst_params,
get_freqburst_from_freqburst_params,
Expand All @@ -20,6 +22,14 @@
EPSILON = 1e-5


def test_default_params_are_in_bounds():

gen = zip(DEFAULT_FREQBURST_PARAMS, DEFAULT_FREQBURST_PARAMS._fields)
for val, key in gen:
bound = getattr(FREQBURST_PBOUNDS, key)
assert bound[0] < val < bound[1]


def test_param_u_param_names_propagate_properly():
gen = zip(DEFAULT_FREQBURST_U_PARAMS._fields, DEFAULT_FREQBURST_PARAMS._fields)
for u_key, key in gen:
Expand Down Expand Up @@ -160,3 +170,39 @@ def test_get_freqburst_from_freqburst_params_is_monotonic_with_logsm_and_logssfr
assert np.all(fqb >= 0.0)
assert np.all(fqb <= fqb_max)
assert np.all(np.diff(fqb) >= -EPSILON)


def test_zeroburst_params_are_in_bounds():

gen = zip(ZEROBURST_FREQBURST_PARAMS, ZEROBURST_FREQBURST_PARAMS._fields)
for val, key in gen:
bound = getattr(FREQBURST_PBOUNDS, key)
assert bound[0] < val < bound[1]


def test_zeroburst_params_are_invertible():
u_params = get_unbounded_freqburst_params(ZEROBURST_FREQBURST_PARAMS)
params = get_bounded_freqburst_params(u_params)
for p, p_orig in zip(ZEROBURST_FREQBURST_PARAMS, params):
assert np.all(np.isfinite(p))
assert np.allclose(p, p_orig, rtol=1e-4)


def test_zeroburst_params_produce_zero_burstiness():
ran_key = jran.key(0)
sm_key, ssfr_key = jran.split(ran_key, 2)
n_gals = 2_000
logsmarr = jran.uniform(sm_key, minval=5, maxval=13, shape=(n_gals,))
logssfr = jran.uniform(ssfr_key, minval=-14, maxval=-5, shape=(n_gals,))

fqb = get_freqburst_from_freqburst_params(
ZEROBURST_FREQBURST_PARAMS, logsmarr, logssfr
)
assert fqb.shape == (n_gals,)
assert np.all(np.isfinite(fqb))
assert np.all(fqb >= 0.0)
assert np.all(fqb < 0.01)

sufq_max = SUFQB_BOUNDS[1]
fqb_max = nn.softplus(sufq_max)
assert np.all(fqb <= fqb_max)
16 changes: 16 additions & 0 deletions diffsky/dustpop/avpop_mono.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""

from collections import OrderedDict, namedtuple
from copy import deepcopy

from jax import jit as jjit
from jax import nn
Expand Down Expand Up @@ -66,6 +67,21 @@
DEFAULT_AVPOP_PARAMS = AvPopParams(**DEFAULT_AVPOP_PDICT)
AVPOP_PBOUNDS = AvPopParams(**AVPOP_PBOUNDS_PDICT)

_EPS = 0.2
_EPS2 = 0.05
ZERODUST_AVPOP_PARAMS = deepcopy(DEFAULT_AVPOP_PARAMS)
ZERODUST_AVPOP_PARAMS = ZERODUST_AVPOP_PARAMS._replace(
suav_logsm_ylo_q_z_ylo=U_BOUNDS[0] + _EPS,
suav_logsm_ylo_ms_z_ylo=U_BOUNDS[0] + _EPS,
suav_logsm_yhi_q_z_ylo=U_BOUNDS[0] + _EPS,
suav_logsm_yhi_ms_z_ylo=U_BOUNDS[0] + _EPS,
suav_logsm_ylo_q_z_yhi=U_BOUNDS[0] + _EPS,
suav_logsm_ylo_ms_z_yhi=U_BOUNDS[0] + _EPS,
suav_logsm_yhi_q_z_yhi=U_BOUNDS[0] + _EPS,
suav_logsm_yhi_ms_z_yhi=U_BOUNDS[0] + _EPS,
delta_suav_age=DELTA_SUAV_AGE_BOUNDS[0] + _EPS2,
)


@jjit
def double_sigmoid_monotonic(u_params, x, y, x0, y0, xk, yk, z_bounds):
Expand Down
33 changes: 33 additions & 0 deletions diffsky/dustpop/tests/test_avpop_mono.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
DEFAULT_AVPOP_U_PARAMS,
DELTA_SUAV_AGE_BOUNDS,
SUAV_BOUNDS,
ZERODUST_AVPOP_PARAMS,
AvPopUParams,
get_av_from_avpop_params_galpop,
get_av_from_avpop_params_singlegal,
Expand Down Expand Up @@ -244,3 +245,35 @@ def test_get_av_is_always_within_bounds_for_all_u_params():

assert np.all(av >= av_min), (av.min(), av_min)
assert np.all(av <= av_max), (av.max(), av_max)


def test_zerodust_params_are_in_bounds():

gen = zip(ZERODUST_AVPOP_PARAMS, ZERODUST_AVPOP_PARAMS._fields)
for val, key in gen:
bound = getattr(AVPOP_PBOUNDS, key)
assert bound[0] < val < bound[1]


def test_zerodust_params_are_invertible():
u_params = get_unbounded_avpop_params(ZERODUST_AVPOP_PARAMS)
params = get_bounded_avpop_params(u_params)
for p, p_orig in zip(ZERODUST_AVPOP_PARAMS, params):
assert np.all(np.isfinite(p))
assert np.allclose(p, p_orig, rtol=1e-4)


def test_av_is_finite_and_tiny_for_zerodust_params():
ran_key = jran.PRNGKey(0)
logsm_key, logssfr_key, z_key = jran.split(ran_key, 3)
n_gals = 500
logsm = jran.uniform(logsm_key, minval=5, maxval=13, shape=(n_gals,))
logssfr = jran.uniform(logssfr_key, minval=-14, maxval=-6, shape=(n_gals,))
redshift = jran.uniform(z_key, minval=0, maxval=10, shape=(n_gals,))

av = get_av_from_avpop_params_galpop(
ZERODUST_AVPOP_PARAMS, logsm, logssfr, redshift, LGAGE_GYR
)
assert av.shape == (n_gals, N_AGE)
assert np.all(np.isfinite(av))
assert np.all(av < 0.05)

0 comments on commit 6737086

Please sign in to comment.