Skip to content

Commit

Permalink
Add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
aphearin committed Feb 25, 2025
1 parent 0fe5e2f commit 584219c
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
59 changes: 57 additions & 2 deletions diffsky/cosmos_utils/cosmos_mstar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,29 @@
def load_cosmos20_tdata(
phot_keys=COSMOS_PHOT_KEYS, zlo=0.5, zhi=2.5, sm_key="lp_mass_best", **kwargs
):
"""Load the COSMOS-20 dataset used to define training data for the scaling relation.
Parameters
----------
phot_keys : list of strings
zlo, zhi : floats
sm_key : string
Returns
-------
photdata : namedtuple
Fields defined by PhotData at module top
Fields contain only the minimum information to train the approximate model
Notes
-----
This function has the following dependency:
https://github.com/LSSTDESC/cosmos20_colors
"""
try:
from cosmos20_colors import load_cosmos20
except (ImportError, ModuleNotFoundError):
Expand Down Expand Up @@ -67,6 +90,21 @@ def load_cosmos20_tdata(

@jjit
def predict_logsm(params, photdata):
"""Approximate model of LePhare stellar mass from COSMOS griz photometry
Parameters
----------
params : namedtuple
Fields defined by DEFAULT_PARAMS at top of module
photdata: namedtuple
Fields defined by PhotData at top of module
Returns
-------
logsm : ndarray, shape (n, )
"""
logsm = (
params.b0
+ params.i * photdata.i
Expand All @@ -91,16 +129,33 @@ def _mae(x, y):


@jjit
def loss_kern(params, photdata):
def _loss_kern(params, photdata):
pred = predict_logsm(params, photdata)
target = photdata.logsm
return _mae(pred, target)


loss_and_grad_func = jjit(value_and_grad(loss_kern))
loss_and_grad_func = jjit(value_and_grad(_loss_kern))


def fit_model(n_steps, photdata, params_init=DEFAULT_PARAMS, step_size=0.001):
"""Find best-fitting parameters for the approximate stellar mass model
Parameters
----------
n_steps : int
photdata: namedtuple
Fields defined by PhotData at top of module
Returns
-------
best_fit_params : namedtuple
Fields defined by DEFAULT_PARAMS at top of module
loss_arr : ndarray, shape (n_steps, )
"""
loss_collector = []
opt_init, opt_update, get_params = jax_opt.adam(step_size)
opt_state = opt_init(params_init)
Expand Down
4 changes: 2 additions & 2 deletions diffsky/cosmos_utils/tests/test_cosmos_mstar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ def test_fit_model():
photdata = get_fake_cosmos_data(ran_key)

logsm = cmm.predict_logsm(cmm.DEFAULT_PARAMS, photdata)
loss_init = cmm.loss_kern(cmm.DEFAULT_PARAMS, photdata)
loss_init = cmm._loss_kern(cmm.DEFAULT_PARAMS, photdata)
assert np.all(np.isfinite(logsm))
assert np.all(np.isfinite(loss_init))

p_best, loss_arr = cmm.fit_model(100, photdata, step_size=0.01)

loss_best = cmm.loss_kern(p_best, photdata)
loss_best = cmm._loss_kern(p_best, photdata)
assert loss_best < loss_init

0 comments on commit 584219c

Please sign in to comment.