diff --git a/diffsky/cosmos_utils/cosmos_mstar_model.py b/diffsky/cosmos_utils/cosmos_mstar_model.py index fd2490f..f24e055 100644 --- a/diffsky/cosmos_utils/cosmos_mstar_model.py +++ b/diffsky/cosmos_utils/cosmos_mstar_model.py @@ -36,6 +36,29 @@ def load_cosmos20_tdata( phot_keys=COSMOS_PHOT_KEYS, zlo=0.5, zhi=2.5, sm_key="lp_mass_best", **kwargs ): + """Load the COSMOS-20 dataset used to define training data for the scaling relation. + + Parameters + ---------- + phot_keys : list of strings + + zlo, zhi : floats + + sm_key : string + + Returns + ------- + photdata : namedtuple + Fields defined by PhotData at module top + Fields contain only the minimum information to train the approximate model + + Notes + ----- + This function has the following dependency: + + https://github.com/LSSTDESC/cosmos20_colors + + """ try: from cosmos20_colors import load_cosmos20 except (ImportError, ModuleNotFoundError): @@ -67,6 +90,21 @@ def load_cosmos20_tdata( @jjit def predict_logsm(params, photdata): + """Approximate model of LePhare stellar mass from COSMOS griz photometry + + Parameters + ---------- + params : namedtuple + Fields defined by DEFAULT_PARAMS at top of module + + photdata: namedtuple + Fields defined by PhotData at top of module + + Returns + ------- + logsm : ndarray, shape (n, ) + + """ logsm = ( params.b0 + params.i * photdata.i @@ -91,16 +129,33 @@ def _mae(x, y): @jjit -def loss_kern(params, photdata): +def _loss_kern(params, photdata): pred = predict_logsm(params, photdata) target = photdata.logsm return _mae(pred, target) -loss_and_grad_func = jjit(value_and_grad(loss_kern)) +loss_and_grad_func = jjit(value_and_grad(_loss_kern)) def fit_model(n_steps, photdata, params_init=DEFAULT_PARAMS, step_size=0.001): + """Find best-fitting parameters for the approximate stellar mass model + + Parameters + ---------- + n_steps : int + + photdata: namedtuple + Fields defined by PhotData at top of module + + Returns + ------- + best_fit_params : namedtuple + Fields defined by DEFAULT_PARAMS at top of module + + loss_arr : ndarray, shape (n_steps, ) + + """ loss_collector = [] opt_init, opt_update, get_params = jax_opt.adam(step_size) opt_state = opt_init(params_init) diff --git a/diffsky/cosmos_utils/tests/test_cosmos_mstar_model.py b/diffsky/cosmos_utils/tests/test_cosmos_mstar_model.py index dfa3591..eb32ac2 100644 --- a/diffsky/cosmos_utils/tests/test_cosmos_mstar_model.py +++ b/diffsky/cosmos_utils/tests/test_cosmos_mstar_model.py @@ -26,11 +26,11 @@ def test_fit_model(): photdata = get_fake_cosmos_data(ran_key) logsm = cmm.predict_logsm(cmm.DEFAULT_PARAMS, photdata) - loss_init = cmm.loss_kern(cmm.DEFAULT_PARAMS, photdata) + loss_init = cmm._loss_kern(cmm.DEFAULT_PARAMS, photdata) assert np.all(np.isfinite(logsm)) assert np.all(np.isfinite(loss_init)) p_best, loss_arr = cmm.fit_model(100, photdata, step_size=0.01) - loss_best = cmm.loss_kern(p_best, photdata) + loss_best = cmm._loss_kern(p_best, photdata) assert loss_best < loss_init