Skip to content

Commit

Permalink
Rewrite test to use fake filter curves if dsps data directory is not …
Browse files Browse the repository at this point in the history
…detected
  • Loading branch information
aphearin committed Feb 13, 2025
1 parent 4e798d5 commit dcca869
Showing 1 changed file with 47 additions and 13 deletions.
60 changes: 47 additions & 13 deletions diffsky/diagnostics/plot_delta_mag_burstiness.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from dsps.cosmology import DEFAULT_COSMOLOGY
from dsps.cosmology.flat_wcdm import _age_at_z_kern, age_at_z0
from dsps.data_loaders import load_filter_data, load_ssp_templates
from dsps.data_loaders.retrieve_fake_fsps_data import load_fake_ssp_data
from dsps.data_loaders.retrieve_fake_fsps_data import (
load_fake_filter_transmission_curves,
load_fake_ssp_data,
)
from dsps.dust.utils import get_filter_effective_wavelength
from dsps.photometry import photpop
from dsps.utils import cumulative_mstar_formed
Expand Down Expand Up @@ -46,13 +49,35 @@
mpurple, mgreen, morange = ("#9467bd", "#2ca02c", "#ff7f0e")


def get_interpolated_lsst_tcurves(ssp_wave):
tcurve_u = load_filter_data.load_transmission_curve(bn_pat="lsst_u*")
tcurve_g = load_filter_data.load_transmission_curve(bn_pat="lsst_g*")
tcurve_r = load_filter_data.load_transmission_curve(bn_pat="lsst_r*")
tcurve_i = load_filter_data.load_transmission_curve(bn_pat="lsst_i*")
tcurve_z = load_filter_data.load_transmission_curve(bn_pat="lsst_z*")
tcurve_y = load_filter_data.load_transmission_curve(bn_pat="lsst_y*")
def get_interpolated_lsst_tcurves(ssp_wave, drn_ssp_data=DEFAULT_DSPS_DRN):
try:
tcurve_u = load_filter_data.load_transmission_curve(
bn_pat="lsst_u*", drn=drn_ssp_data
)
tcurve_g = load_filter_data.load_transmission_curve(
bn_pat="lsst_g*", drn=drn_ssp_data
)
tcurve_r = load_filter_data.load_transmission_curve(
bn_pat="lsst_r*", drn=drn_ssp_data
)
tcurve_i = load_filter_data.load_transmission_curve(
bn_pat="lsst_i*", drn=drn_ssp_data
)
tcurve_z = load_filter_data.load_transmission_curve(
bn_pat="lsst_z*", drn=drn_ssp_data
)
tcurve_y = load_filter_data.load_transmission_curve(
bn_pat="lsst_y*", drn=drn_ssp_data
)
except (ImportError, OSError, ValueError):
_res = load_fake_filter_transmission_curves()
wave, u, g, r, i, z, y = _res
tcurve_u = load_filter_data.TransmissionCurve((wave, u))
tcurve_g = load_filter_data.TransmissionCurve((wave, g))
tcurve_r = load_filter_data.TransmissionCurve((wave, r))
tcurve_i = load_filter_data.TransmissionCurve((wave, i))
tcurve_z = load_filter_data.TransmissionCurve((wave, z))
tcurve_y = load_filter_data.TransmissionCurve((wave, y))

tcurve_u = tcurve_u._replace(
transmission=np.interp(ssp_wave, tcurve_u.wave, tcurve_u.transmission)
Expand Down Expand Up @@ -105,7 +130,9 @@ def get_burstiness_delta_mag_quantities(
ssp_data = load_fake_ssp_data()
print(f"{drn_ssp_data} directory not found. Using fake SSP SEDs")

lsst_tcurves = get_interpolated_lsst_tcurves(ssp_data.ssp_wave)
lsst_tcurves = get_interpolated_lsst_tcurves(
ssp_data.ssp_wave, drn_ssp_data=drn_ssp_data
)
wave_eff_arr = get_wave_eff(lsst_tcurves, z_obs)

X = np.array([ssp_data.ssp_wave] * 6)
Expand Down Expand Up @@ -204,11 +231,13 @@ def get_burstiness_delta_mag_quantities(
return mags, alt_mags, halopop, sfh_galpop, smh_galpop, mc_is_q


def plot_delta_mag_lsst_vs_logsm(z_obs, n_halos=2_000):
def plot_delta_mag_lsst_vs_logsm(z_obs, n_halos=2_000, drn_ssp_data=DEFAULT_DSPS_DRN):
if not HAS_MATPLOTLIB:
raise ImportError("Must have matplotlib installed to use this function")

_res = get_burstiness_delta_mag_quantities(z_obs, n_halos=n_halos)
_res = get_burstiness_delta_mag_quantities(
z_obs, n_halos=n_halos, drn_ssp_data=drn_ssp_data
)
mags, alt_mags, halopop, sfh_galpop, smh_galpop, mc_is_q = _res

fig, ax = plt.subplots(1, 1)
Expand Down Expand Up @@ -252,12 +281,17 @@ def plot_delta_mag_lsst_vs_logsm(z_obs, n_halos=2_000):


def plot_delta_mag_lsst_vs_ssfr(
z_obs, figname="delta_mag_burstiness_ssfr.png", n_halos=2_000
z_obs,
figname="delta_mag_burstiness_ssfr.png",
n_halos=2_000,
drn_ssp_data=DEFAULT_DSPS_DRN,
):
if not HAS_MATPLOTLIB:
raise ImportError("Must have matplotlib installed to use this function")

_res = get_burstiness_delta_mag_quantities(z_obs, n_halos=n_halos)
_res = get_burstiness_delta_mag_quantities(
z_obs, n_halos=n_halos, drn_ssp_data=drn_ssp_data
)
mags, alt_mags, halopop, sfh_galpop, smh_galpop, mc_is_q = _res

ssfr_z0 = np.log10(sfh_galpop[:, -1]) - np.log10(smh_galpop[:, -1])
Expand Down

0 comments on commit dcca869

Please sign in to comment.