diff --git a/diffsky/diagnostics/plot_delta_mag_burstiness.py b/diffsky/diagnostics/plot_delta_mag_burstiness.py index 18492b4..c7c6b27 100644 --- a/diffsky/diagnostics/plot_delta_mag_burstiness.py +++ b/diffsky/diagnostics/plot_delta_mag_burstiness.py @@ -12,7 +12,10 @@ from dsps.cosmology import DEFAULT_COSMOLOGY from dsps.cosmology.flat_wcdm import _age_at_z_kern, age_at_z0 from dsps.data_loaders import load_filter_data, load_ssp_templates -from dsps.data_loaders.retrieve_fake_fsps_data import load_fake_ssp_data +from dsps.data_loaders.retrieve_fake_fsps_data import ( + load_fake_filter_transmission_curves, + load_fake_ssp_data, +) from dsps.dust.utils import get_filter_effective_wavelength from dsps.photometry import photpop from dsps.utils import cumulative_mstar_formed @@ -46,13 +49,35 @@ mpurple, mgreen, morange = ("#9467bd", "#2ca02c", "#ff7f0e") -def get_interpolated_lsst_tcurves(ssp_wave): - tcurve_u = load_filter_data.load_transmission_curve(bn_pat="lsst_u*") - tcurve_g = load_filter_data.load_transmission_curve(bn_pat="lsst_g*") - tcurve_r = load_filter_data.load_transmission_curve(bn_pat="lsst_r*") - tcurve_i = load_filter_data.load_transmission_curve(bn_pat="lsst_i*") - tcurve_z = load_filter_data.load_transmission_curve(bn_pat="lsst_z*") - tcurve_y = load_filter_data.load_transmission_curve(bn_pat="lsst_y*") +def get_interpolated_lsst_tcurves(ssp_wave, drn_ssp_data=DEFAULT_DSPS_DRN): + try: + tcurve_u = load_filter_data.load_transmission_curve( + bn_pat="lsst_u*", drn=drn_ssp_data + ) + tcurve_g = load_filter_data.load_transmission_curve( + bn_pat="lsst_g*", drn=drn_ssp_data + ) + tcurve_r = load_filter_data.load_transmission_curve( + bn_pat="lsst_r*", drn=drn_ssp_data + ) + tcurve_i = load_filter_data.load_transmission_curve( + bn_pat="lsst_i*", drn=drn_ssp_data + ) + tcurve_z = load_filter_data.load_transmission_curve( + bn_pat="lsst_z*", drn=drn_ssp_data + ) + tcurve_y = load_filter_data.load_transmission_curve( + bn_pat="lsst_y*", drn=drn_ssp_data + ) + except (ImportError, OSError, ValueError): + _res = load_fake_filter_transmission_curves() + wave, u, g, r, i, z, y = _res + tcurve_u = load_filter_data.TransmissionCurve((wave, u)) + tcurve_g = load_filter_data.TransmissionCurve((wave, g)) + tcurve_r = load_filter_data.TransmissionCurve((wave, r)) + tcurve_i = load_filter_data.TransmissionCurve((wave, i)) + tcurve_z = load_filter_data.TransmissionCurve((wave, z)) + tcurve_y = load_filter_data.TransmissionCurve((wave, y)) tcurve_u = tcurve_u._replace( transmission=np.interp(ssp_wave, tcurve_u.wave, tcurve_u.transmission) @@ -105,7 +130,9 @@ def get_burstiness_delta_mag_quantities( ssp_data = load_fake_ssp_data() print(f"{drn_ssp_data} directory not found. Using fake SSP SEDs") - lsst_tcurves = get_interpolated_lsst_tcurves(ssp_data.ssp_wave) + lsst_tcurves = get_interpolated_lsst_tcurves( + ssp_data.ssp_wave, drn_ssp_data=drn_ssp_data + ) wave_eff_arr = get_wave_eff(lsst_tcurves, z_obs) X = np.array([ssp_data.ssp_wave] * 6) @@ -204,11 +231,13 @@ def get_burstiness_delta_mag_quantities( return mags, alt_mags, halopop, sfh_galpop, smh_galpop, mc_is_q -def plot_delta_mag_lsst_vs_logsm(z_obs, n_halos=2_000): +def plot_delta_mag_lsst_vs_logsm(z_obs, n_halos=2_000, drn_ssp_data=DEFAULT_DSPS_DRN): if not HAS_MATPLOTLIB: raise ImportError("Must have matplotlib installed to use this function") - _res = get_burstiness_delta_mag_quantities(z_obs, n_halos=n_halos) + _res = get_burstiness_delta_mag_quantities( + z_obs, n_halos=n_halos, drn_ssp_data=drn_ssp_data + ) mags, alt_mags, halopop, sfh_galpop, smh_galpop, mc_is_q = _res fig, ax = plt.subplots(1, 1) @@ -252,12 +281,17 @@ def plot_delta_mag_lsst_vs_logsm(z_obs, n_halos=2_000): def plot_delta_mag_lsst_vs_ssfr( - z_obs, figname="delta_mag_burstiness_ssfr.png", n_halos=2_000 + z_obs, + figname="delta_mag_burstiness_ssfr.png", + n_halos=2_000, + drn_ssp_data=DEFAULT_DSPS_DRN, ): if not HAS_MATPLOTLIB: raise ImportError("Must have matplotlib installed to use this function") - _res = get_burstiness_delta_mag_quantities(z_obs, n_halos=n_halos) + _res = get_burstiness_delta_mag_quantities( + z_obs, n_halos=n_halos, drn_ssp_data=drn_ssp_data + ) mags, alt_mags, halopop, sfh_galpop, smh_galpop, mc_is_q = _res ssfr_z0 = np.log10(sfh_galpop[:, -1]) - np.log10(smh_galpop[:, -1])