Skip to content

Commit

Permalink
Merge pull request #102 from ArgonneCPAC/mc_centrals
Browse files Browse the repository at this point in the history
Add Monte Carlo generator of central galaxies
  • Loading branch information
aphearin authored Mar 4, 2025
2 parents a79c73f + e1c08bc commit 71d137c
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 28 deletions.
82 changes: 73 additions & 9 deletions diffsky/mass_functions/mc_diffmah_tpeak.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from diffmah.diffmahpop_kernels.mc_bimod_cens import mc_cenpop
from diffmah.diffmahpop_kernels.mc_bimod_sats import mc_satpop
from dsps.cosmology.defaults import DEFAULT_COSMOLOGY
from dsps.cosmology.flat_wcdm import _age_at_z_kern
from dsps.cosmology.flat_wcdm import _age_at_z_kern, age_at_z0
from jax import random as jran

from .mc_hosts import mc_host_halos_singlez
Expand Down Expand Up @@ -37,11 +37,11 @@

def mc_subhalos(
ran_key,
z_obs,
lgmp_min,
redshift,
volume_com=None,
hosts_logmh_at_z=None,
cosmo=DEFAULT_COSMOLOGY,
cosmo_params=DEFAULT_COSMOLOGY,
diffmahpop_params=DEFAULT_DIFFMAHPOP_PARAMS,
):
"""Monte Carlo realization of a subhalo catalog at a single redshift
Expand All @@ -55,7 +55,7 @@ def mc_subhalos(
Smaller values of lgmp_min produce more halos in the returned sample
A small fraction of halos will have slightly smaller masses than lgmp_min
redshift : float
z_obs : float
Redshift of the halo population
volume_com : float, optional
Expand Down Expand Up @@ -125,17 +125,15 @@ def mc_subhalos(
if hosts_logmh_at_z is None:
msg = "Must pass volume_com argument if not passing hosts_logmh_at_z"
assert volume_com is not None, msg
hosts_logmh_at_z = mc_host_halos_singlez(
host_key1, lgmp_min, redshift, volume_com
)
hosts_logmh_at_z = mc_host_halos_singlez(host_key1, lgmp_min, z_obs, volume_com)

subhalo_info = generate_subhalopop(sub_key1, hosts_logmh_at_z, lgmp_min)
subs_lgmu, subs_lgmhost, subs_host_halo_indx = subhalo_info

subs_logmh_at_z = subs_lgmu + subs_lgmhost

t_obs = _age_at_z_kern(redshift, *cosmo)
t_0 = _age_at_z_kern(0.0, *cosmo)
t_obs = _age_at_z_kern(z_obs, *cosmo_params)
t_0 = _age_at_z_kern(0.0, *cosmo_params)
lgt0 = np.log10(t_0)

n_cens = hosts_logmh_at_z.size
Expand Down Expand Up @@ -206,3 +204,69 @@ def mc_subhalos(
ult_host_indx,
)
return subcat


def mc_host_halos(
ran_key,
z_obs,
lgmp_min=None,
volume_com=None,
hosts_logmh_at_z=None,
cosmo_params=DEFAULT_COSMOLOGY,
diffmahpop_params=DEFAULT_DIFFMAHPOP_PARAMS,
):
"""Monte Carlo realization of a subhalo catalog at a single redshift"""
host_key1, host_key2 = jran.split(ran_key, 2)
if hosts_logmh_at_z is None:
msg = "Must pass volume_com argument if not passing hosts_logmh_at_z"
assert volume_com is not None, msg
msg = "Must pass lgmp_min argument if not passing hosts_logmh_at_z"
assert lgmp_min is not None, msg

hosts_logmh_at_z = mc_host_halos_singlez(host_key1, lgmp_min, z_obs, volume_com)

t_obs = _age_at_z_kern(z_obs, *cosmo_params)
t_0 = age_at_z0(*cosmo_params)
lgt0 = np.log10(t_0)

n_cens = hosts_logmh_at_z.size
halo_ids = np.arange(n_cens).astype(int)
_ZH = np.zeros(n_cens)

tarr = np.zeros(1) + 10**lgt0
mah_params = mc_cenpop(
diffmahpop_params, tarr, hosts_logmh_at_z, t_obs + _ZH, host_key2, lgt0
)[0]

host_mah_params = mah_params
lgmhost_pen_inf = np.copy(hosts_logmh_at_z)
lgmhost_ult_inf = np.copy(hosts_logmh_at_z)
t_pen_inf = mah_params.t_peak
t_ult_inf = mah_params.t_peak
upids = _ZH - 1
pen_host_indx = np.arange(n_cens)
ult_host_indx = np.arange(n_cens)

logmp0 = _log_mah_kern(mah_params, 10**lgt0, lgt0)
lgmp_t_obs = _log_mah_kern(mah_params, t_obs, lgt0)
lgmp_pen_inf = _log_mah_kern(mah_params, mah_params.t_peak, lgt0)
lgmp_ult_inf = _log_mah_kern(mah_params, mah_params.t_peak, lgt0)

subcat = SubhaloCatalog(
halo_ids,
mah_params,
host_mah_params,
logmp0,
lgmp_t_obs,
lgmp_pen_inf,
lgmp_ult_inf,
lgmhost_pen_inf,
lgmhost_ult_inf,
t_obs,
t_pen_inf,
t_ult_inf,
upids,
pen_host_indx,
ult_host_indx,
)
return subcat
26 changes: 24 additions & 2 deletions diffsky/mass_functions/tests/test_mc_diffmah_tpeak.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_mc_subhalo_catalog_singlez():
redshift = 0.5
Lbox = 25.0
volume_com = Lbox**3
args = ran_key, lgmp_min, redshift, volume_com
args = ran_key, redshift, lgmp_min, volume_com

subcat = mcd.mc_subhalos(*args)
for x in subcat:
Expand Down Expand Up @@ -67,7 +67,29 @@ def test_mc_subhalo_catalog_input_logmh_grid():

n_hosts = 250
hosts_logmh_at_z = np.linspace(lgmp_min, 15, n_hosts)
args = ran_key, lgmp_min, redshift
args = ran_key, redshift, lgmp_min
subcat = mcd.mc_subhalos(*args, hosts_logmh_at_z=hosts_logmh_at_z)
for x in subcat:
assert np.all(np.isfinite(x))


def test_mc_host_halos():

ran_key = jran.PRNGKey(0)
redshift = 0.5
Lbox = 25.0
args = ran_key, redshift

subcat = mcd.mc_host_halos(*args, lgmp_min=11, volume_com=Lbox**3)
for x in subcat:
assert np.all(np.isfinite(x))

n_gals = subcat.logmp_pen_inf.size
assert subcat.logmp_pen_inf.shape == (n_gals,)
for mah_p in subcat.mah_params:
assert mah_p.shape == (n_gals,)

n_cens = 200
hosts_logmh_at_z = np.linspace(10, 15, n_cens)
subcat = mcd.mc_host_halos(*args, hosts_logmh_at_z=hosts_logmh_at_z)
assert subcat.logmp0.size == n_cens
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def load_mc_halo_cat(seed=0):
randkey = jax.random.key(seed)

# Perform initial MC generation slightly below LGMP_MIN
raw_cat = mc_subhalos(randkey, LGMP_MIN - 0.2, Z_OBS, VOLUME)
raw_cat = mc_subhalos(randkey, Z_OBS, LGMP_MIN - 0.2, VOLUME)

cut = raw_cat.logmp_t_obs[raw_cat.ult_host_indx] >= LGMP_MIN
return recursive_namedtuple_cut_and_reindex(raw_cat, cut)
Expand Down
109 changes: 102 additions & 7 deletions diffsky/mc_diffsky.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
from jax import random as jran
from jax import vmap

from .mass_functions.mc_diffmah_tpeak import mc_subhalos
from .mass_functions.mc_diffmah_tpeak import mc_host_halos, mc_subhalos

N_T = 100

_interp_vmap_single_t_obs = jjit(vmap(jnp.interp, in_axes=(None, None, 0)))


def mc_diffstar_galhalo_pop(
def mc_diffstar_galpop(
ran_key,
lgmp_min,
z_obs,
lgmp_min,
volume_com=None,
hosts_logmh_at_z=None,
cosmo_params=DEFAULT_COSMOLOGY,
Expand All @@ -36,14 +36,14 @@ def mc_diffstar_galhalo_pop(
----------
ran_key : jran.PRNGKey
z_obs : float
Redshift of the halo population
lgmp_min : float
Base-10 log of the halo mass competeness limit of the generated population
Smaller values of lgmp_min produce more halos in the returned sample
A small fraction of halos will have slightly smaller masses than lgmp_min
redshift : float
Redshift of the halo population
volume_com : float, optional
volume_com = Lbox**3 where Lbox is in comoving in units of Mpc/h
Default is None, in which case argument hosts_logmh_at_z must be passed
Expand All @@ -68,11 +68,106 @@ def mc_diffstar_galhalo_pop(

subcat = mc_subhalos(
mah_key,
z_obs,
lgmp_min,
volume_com=volume_com,
hosts_logmh_at_z=hosts_logmh_at_z,
cosmo_params=cosmo_params,
diffmahpop_params=DEFAULT_DIFFMAHPOP_PARAMS,
)

logmu_infall = subcat.logmp_ult_inf - subcat.logmhost_ult_inf
args = (
diffstarpop_params,
subcat.mah_params,
subcat.logmp0,
logmu_infall,
subcat.logmhost_ult_inf,
subcat.t_ult_inf,
sfh_key,
t_table,
)

_res = mcdsp.mc_diffstar_sfh_galpop(*args)
sfh_ms, sfh_q, frac_q, mc_is_q = _res[2:]
sfh_table = jnp.where(mc_is_q.reshape((-1, 1)), sfh_q, sfh_ms)
smh_table = cumulative_mstar_formed_galpop(t_table, sfh_table)

t_obs = flat_wcdm._age_at_z_kern(z_obs, *cosmo_params)

diffstar_data = dict()
diffstar_data["subcat"] = subcat
diffstar_data["t_table"] = t_table
diffstar_data["t_obs"] = t_obs
diffstar_data["sfh"] = sfh_table
diffstar_data["smh"] = smh_table
diffstar_data["mc_quenched"] = mc_is_q

diffstar_data["logsm_obs"] = _interp_vmap_single_t_obs(
t_obs, t_table, jnp.log10(diffstar_data["smh"])
)
logsfh_obs = _interp_vmap_single_t_obs(
t_obs, t_table, jnp.log10(diffstar_data["sfh"])
)
diffstar_data["logssfr_obs"] = logsfh_obs - diffstar_data["logsm_obs"]

return diffstar_data


def mc_diffstar_cenpop(
ran_key,
z_obs,
lgmp_min=None,
volume_com=None,
hosts_logmh_at_z=None,
cosmo_params=DEFAULT_COSMOLOGY,
diffstarpop_params=DEFAULT_DIFFSTARPOP_PARAMS,
n_t=N_T,
):
"""Generate a population of central galaxies with diffmah MAH and diffstar SFH
Parameters
----------
ran_key : jran.PRNGKey
z_obs : float
Redshift of the halo population
lgmp_min : float
Base-10 log of the halo mass competeness limit of the generated population
Smaller values of lgmp_min produce more halos in the returned sample
A small fraction of halos will have slightly smaller masses than lgmp_min
volume_com : float, optional
volume_com = Lbox**3 where Lbox is in comoving in units of Mpc/h
Default is None, in which case argument hosts_logmh_at_z must be passed
Larger values of volume_com produce more halos in the returned sample
hosts_logmh_at_z : ndarray, optional
Grid of host halo masses at the input redshift.
Default is None, in which case volume_com argument must be passed
and the host halo mass function will be randomly sampled.
Returns
-------
diffsky_data : dict
Diffstar galaxy population
"""

mah_key, sfh_key = jran.split(ran_key, 2)

t0 = flat_wcdm.age_at_z0(*cosmo_params)
t_table = jnp.linspace(T_TABLE_MIN, t0, n_t)

subcat = mc_host_halos(
mah_key,
z_obs,
lgmp_min=lgmp_min,
volume_com=volume_com,
hosts_logmh_at_z=hosts_logmh_at_z,
cosmo=DEFAULT_COSMOLOGY,
cosmo_params=cosmo_params,
diffmahpop_params=DEFAULT_DIFFMAHPOP_PARAMS,
)

Expand Down
31 changes: 27 additions & 4 deletions diffsky/tests/test_mc_diffsky.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,37 @@
from .. import mc_diffsky as mcd


def test_mc_diffstar_galhalo_pop():
def test_mc_diffstar_galpop():
ran_key = jran.key(0)
hosts_logmh_at_z = np.linspace(10, 15, 200)
n_cens_input = 200
hosts_logmh_at_z = np.linspace(10, 15, n_cens_input)
lgmp_min = 11.0
z_obs = 0.01
args = (ran_key, z_obs, lgmp_min)
diffsky_data = mcd.mc_diffstar_galpop(*args, hosts_logmh_at_z=hosts_logmh_at_z)

for p in diffsky_data["subcat"].mah_params:
assert np.all(np.isfinite(p))
assert np.all(np.isfinite(diffsky_data["smh"]))

n_cens_subcat = np.sum(diffsky_data["subcat"].upids == -1)
assert n_cens_subcat == n_cens_input
n_sats_subcat = np.sum(diffsky_data["subcat"].upids != -1)
assert n_sats_subcat > 0

assert diffsky_data["t_obs"] > 13.5


def test_mc_diffstar_cenpop():
ran_key = jran.key(0)
n_cens = 200
hosts_logmh_at_z = np.linspace(10, 15, n_cens)
z_obs = 0.1
args = (ran_key, lgmp_min, z_obs)
diffsky_data = mcd.mc_diffstar_galhalo_pop(*args, hosts_logmh_at_z=hosts_logmh_at_z)
args = (ran_key, z_obs)
diffsky_data = mcd.mc_diffstar_cenpop(*args, hosts_logmh_at_z=hosts_logmh_at_z)

for p in diffsky_data["subcat"].mah_params:
assert np.all(np.isfinite(p))
assert np.all(np.isfinite(diffsky_data["smh"]))

assert diffsky_data["subcat"].logmp0.size == n_cens
8 changes: 3 additions & 5 deletions docs/source/demo_diffmahpop_t_peak.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@
"source": [
"from diffsky import mc_subhalos\n",
"\n",
"lgmp_min = 10.5 # minimum halo mass\n",
"lgmp_min = 11.0 # minimum halo mass\n",
"z_obs = 0.5\n",
"Lbox_com = 200.0 # Mpc/h\n",
"Lbox_com = 100.0 # Mpc/h\n",
"volume_com = Lbox_com**3 \n",
"\n",
"subcat = mc_subhalos(ran_key, lgmp_min, z_obs, volume_com)\n",
"subcat = mc_subhalos(ran_key, z_obs, lgmp_min=lgmp_min, volume_com=volume_com)\n",
"subcat._fields"
]
},
Expand Down Expand Up @@ -109,7 +109,6 @@
"metadata": {},
"outputs": [],
"source": [
"mskm105 = np.abs(subcat.logmp_t_obs - 10.5) < 0.2\n",
"mskm115 = np.abs(subcat.logmp_t_obs - 11.5) < 0.2\n",
"mskm125 = np.abs(subcat.logmp_t_obs - 12.5) < 0.2\n",
"mskm135 = np.abs(subcat.logmp_t_obs - 13.5) < 0.2\n",
Expand All @@ -125,7 +124,6 @@
"mblue = u'#1f77b4' \n",
"mpurple = u'#9467bd' \n",
"for i in range(10):\n",
" __=ax.plot(tarr, 10**log_mah[mskm105][i], lw=0.5, color=mpurple)\n",
" __=ax.plot(tarr, 10**log_mah[mskm115][i], lw=0.5, color=mblue)\n",
" __=ax.plot(tarr, 10**log_mah[mskm125][i], lw=0.5, color=mgreen)\n",
" __=ax.plot(tarr, 10**log_mah[mskm135][i], lw=0.5, color=morange)\n",
Expand Down

0 comments on commit 71d137c

Please sign in to comment.