Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass filter to GRB model #2

Merged
merged 2 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions nmma/em/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,16 @@ def __init__(
self.svd_lbol_model = None
elif self.interpolation_type == "tensorflow":
import tensorflow as tf
tf.get_logger().setLevel("ERROR")
from tensorflow.keras.models import load_model

# TODO: remove below 3 lines once <model>_tf.pkl files on Zenodo are updated to <model>.pkl
if not os.path.exists(modelfile):
warnings.warn(
f"Attempting to load {core_model_name}_tf.pkl. In the future, all model files will have the format <model>.pkl, regardless of --interpolation-type."
)
modelfile = os.path.join(self.svd_path, f"{core_model_name}_tf.pkl")

tf.get_logger().setLevel("ERROR")
from tensorflow.keras.models import load_model


if not local_only:
_, model_filters = get_model(
self.svd_path, f"{self.model}_tf", filters=filters
Expand Down Expand Up @@ -369,7 +368,7 @@ def observation_angle_conversion(self, parameters):
def generate_lightcurve(self, sample_times, parameters):
if self.parameter_conversion:
new_parameters = parameters.copy()
new_parameters, _ = self.parameter_conversion(new_parameters, [])
new_parameters, _ = self.parameter_conversion(new_parameters)
else:
new_parameters = parameters.copy()

Expand Down Expand Up @@ -481,10 +480,9 @@ def __repr__(self):
return self.__class__.__name__ + "(model={0})".format(self.model)

def generate_lightcurve(self, sample_times, parameters):

if self.parameter_conversion:
new_parameters = parameters.copy()
new_parameters, _ = self.parameter_conversion(new_parameters, [])
new_parameters, _ = self.parameter_conversion(new_parameters)
else:
new_parameters = parameters.copy()

Expand Down Expand Up @@ -571,7 +569,6 @@ def observation_angle_conversion(self, parameters):
return parameters

def generate_lightcurve(self, sample_times, parameters):

total_lbol = np.zeros(len(sample_times))
total_mag = {}

Expand Down
5 changes: 3 additions & 2 deletions nmma/em/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,10 @@ def fluxDensity(t, nu, **params):
def grb_lc(t_day, Ebv, param_dict, filters=None):
day = 86400.0 # in seconds
tStart = (np.amin(t_day)) * day
tStart = max(10**(-5), tStart)
tStart = max(10**(-5)*day, tStart)
tEnd = (np.amax(t_day) + 1) * day
tnode = min(len(t_day), 201)
default_time = np.logspace(np.log10(tStart), np.log10(tEnd), base=10.0, num=tnode-1)
default_time = np.logspace(np.log10(tStart), np.log10(tEnd), base=10.0, num=tnode)
filts, lambdas = get_default_filts_lambdas(filters=filters)

nu_0s = scipy.constants.c / lambdas
Expand All @@ -571,6 +571,7 @@ def grb_lc(t_day, Ebv, param_dict, filters=None):
# output flux density is in milliJansky
try:
mJys = fluxDensity(times, nus, **param_dict)

except TimeoutError:
return t_day, np.zeros(t_day.shape), {}

Expand Down
18 changes: 10 additions & 8 deletions nmma/joint/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from ..em.model import SVDLightCurveModel, KilonovaGRBLightCurveModel
from ..em.model import SVDLightCurveModel, KilonovaGRBLightCurveModel, GRBLightCurveModel, GenericCombineLightCurveModel
from ..em.likelihood import OpticalLightCurve
from .conversion import MultimessengerConversion, MultimessengerConversionWithLambdas

Expand Down Expand Up @@ -124,7 +124,6 @@ def __init__(self, interferometers, waveform_generator,
time_marginalization=False, distance_marginalization=False,
phase_marginalization=False, distance_marginalization_lookup_table=None,
jitter_time=True, reference_frame="sky", time_reference="geocenter"):

# construct the eos prior
if with_eos:
xx = np.arange(0, Neos + 1)
Expand Down Expand Up @@ -166,20 +165,23 @@ def __init__(self, interferometers, waveform_generator,
GWLikelihood = ROQGravitationalWaveTransient(**gw_likelihood_kwargs)

# initialize the EM likelihood
if not filters:
filters = list(light_curve_data.keys())
sample_times = np.arange(tmin, tmax, 0.1)
light_curve_model_kwargs = dict(model=light_curve_model_name, sample_times=sample_times,
svd_path=light_curve_SVD_path,
parameter_conversion=parameter_conversion,
mag_ncoeff=mag_ncoeff, lbol_ncoeff=lbol_ncoeff,
interpolation_type=light_curve_interpolation_type)
interpolation_type=light_curve_interpolation_type, filters=filters)

if with_grb:
light_curve_model = KilonovaGRBLightCurveModel(sample_times=sample_times,
kilonova_kwargs=light_curve_model_kwargs,
GRB_resolution=grb_resolution)
models = []
models.append(SVDLightCurveModel(**light_curve_model_kwargs))
models.append(GRBLightCurveModel(sample_times = sample_times, resolution = grb_resolution, filters = filters, parameter_conversion = parameter_conversion))
light_curve_model = GenericCombineLightCurveModel(models = models, sample_times=sample_times)
else:
light_curve_model = SVDLightCurveModel(**light_curve_model_kwargs)
if not filters:
filters = list(light_curve_data.keys())

em_likelihood_kwargs = dict(light_curve_model=light_curve_model, filters=filters,
light_curve_data=light_curve_data, trigger_time=em_trigger_time,
error_budget=error_budget, tmin=tmin, tmax=tmax)
Expand Down