Skip to content

Commit

Permalink
Merge pull request gammapy#4852 from fabiopintore/tagevtid
Browse files Browse the repository at this point in the history
Add flag to switch off MC identifiers in MapDatasetEventSampler
  • Loading branch information
registerrier authored Oct 26, 2023
2 parents b7de5d9 + 9f6dab0 commit a0053ed
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 13 deletions.
43 changes: 30 additions & 13 deletions gammapy/datasets/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,22 @@ class MapDatasetEventSampler:
an energy-dependent time-varying source
t_delta : `~astropy.units.Quantity`
Time interval used to sample the time-dependent source
keep_mc_id : bool
Flag to tag sampled events from a given model with a Montecarlo identifier.
Default is True. If set to False, no identifier will be assigned
"""

def __init__(
self, random_state="random-seed", oversample_energy_factor=10, t_delta=0.5 * u.s
self,
random_state="random-seed",
oversample_energy_factor=10,
t_delta=0.5 * u.s,
keep_mc_id=True,
):
self.random_state = get_random_state(random_state)
self.oversample_energy_factor = oversample_energy_factor
self.t_delta = t_delta
self.keep_mc_id = keep_mc_id

def _repr_html_(self):
try:
Expand Down Expand Up @@ -264,13 +272,14 @@ def sample_sources(self, dataset):
npred = evaluator.apply_exposure(flux)
table = self._sample_coord_time(npred, temporal_model, dataset.gti)

if len(table) == 0:
mcid = table.Column(name="MC_ID", length=0, dtype=int)
table.add_column(mcid)
if self.keep_mc_id:
if len(table) == 0:
mcid = table.Column(name="MC_ID", length=0, dtype=int)
table.add_column(mcid)

table["MC_ID"] = idx + 1
table.meta["MID{:05d}".format(idx + 1)] = idx + 1
table.meta["MMN{:05d}".format(idx + 1)] = evaluator.model.name
table["MC_ID"] = idx + 1
table.meta["MID{:05d}".format(idx + 1)] = idx + 1
table.meta["MMN{:05d}".format(idx + 1)] = evaluator.model.name

events_all.append(EventList(table))

Expand All @@ -295,13 +304,14 @@ def sample_background(self, dataset):

table = self._sample_coord_time(background, temporal_model, dataset.gti)

table["MC_ID"] = 0
table["ENERGY"] = table["ENERGY_TRUE"]
table["RA"] = table["RA_TRUE"]
table["DEC"] = table["DEC_TRUE"]

table.meta["MID{:05d}".format(0)] = 0
table.meta["MMN{:05d}".format(0)] = dataset.background_model.name
if self.keep_mc_id:
table["MC_ID"] = 0
table.meta["MID{:05d}".format(0)] = 0
table.meta["MMN{:05d}".format(0)] = dataset.background_model.name

return EventList(table)

Expand Down Expand Up @@ -387,15 +397,19 @@ def event_det_coords(observation, events):
return events

@staticmethod
def event_list_meta(dataset, observation):
def event_list_meta(dataset, observation, keep_mc_id=True):
"""Event list meta info.
Please, note that this function will be updated in the future.
Parameters
----------
dataset : `~gammapy.datasets.MapDataset`
Map dataset
observation : `~gammapy.data.Observation`
In memory observation
keep_mc_id : bool
Flag to tag sampled events from a given model with a Montecarlo identifier.
Default is True. If set to False, no identifier will be assigned
Returns
-------
Expand Down Expand Up @@ -476,7 +490,8 @@ def event_list_meta(dataset, observation):
meta["CONV_RA"] = 0
meta["CONV_DEC"] = 0

meta["NMCIDS"] = len(dataset.models)
if keep_mc_id:
meta["NMCIDS"] = len(dataset.models)

# Necessary for DataStore, but they should be ALT and AZ instead!
telescope = observation.aeff.meta["TELESCOP"]
Expand Down Expand Up @@ -555,7 +570,9 @@ def run(self, dataset, observation=None):

events = self.event_det_coords(observation, events)
events.table["EVENT_ID"] = np.arange(len(events.table))
events.table.meta.update(self.event_list_meta(dataset, observation))
events.table.meta.update(
self.event_list_meta(dataset, observation, self.keep_mc_id)
)

geom = dataset._geom
selection = geom.contains(events.map_coord(geom))
Expand Down
57 changes: 57 additions & 0 deletions gammapy/datasets/tests/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,3 +780,60 @@ def test_MC_ID_NMCID(model_alternative):
assert meta["MID00002"] == 2
assert meta["MID00003"] == 3
assert meta["NMCIDS"] == 4


@requires_data()
def test_MC_ID_flag(model_alternative):
irfs = load_irf_dict_from_file(
"$GAMMAPY_DATA/cta-1dc/caldb/data/cta/1dc/bcf/South_z20_50h/irf_file.fits"
)
livetime = 0.1 * u.hr
skydir = SkyCoord(0, 0, unit="deg", frame="galactic")
pointing = FixedPointingInfo(fixed_icrs=skydir.icrs)
obs = Observation.create(
obs_id=1001,
pointing=pointing,
livetime=livetime,
irfs=irfs,
location=LOCATION,
)

energy_axis = MapAxis.from_energy_bounds(
"1.0 TeV", "10 TeV", nbin=10, per_decade=True
)
energy_axis_true = MapAxis.from_energy_bounds(
"0.5 TeV", "20 TeV", nbin=20, per_decade=True, name="energy_true"
)
migra_axis = MapAxis.from_bounds(0.5, 2, nbin=150, node_type="edges", name="migra")

geom = WcsGeom.create(
skydir=skydir,
width=(2, 2),
binsz=0.06,
frame="icrs",
axes=[energy_axis],
)

empty = MapDataset.create(
geom,
energy_axis_true=energy_axis_true,
migra_axis=migra_axis,
name="test",
)
maker = MapDatasetMaker(selection=["exposure", "background", "psf", "edisp"])
dataset = maker.run(empty, obs)

model_alternative[0].spectral_model.parameters["amplitude"].value = 1e-16
dataset.models = model_alternative
sampler = MapDatasetEventSampler(random_state=0, keep_mc_id=False)
events = sampler.run(dataset=dataset, observation=obs)

meta = events.table.meta
assert len(events.table) == 47
assert "MC_ID" not in events.table.colnames
assert "MID00000" not in meta.keys()
assert "MMN00000" not in meta.keys()
assert "MID00001" not in meta.keys()
assert "MID00002" not in meta.keys()
assert "MID00003" not in meta.keys()
assert "NMCIDS" not in meta.keys()

0 comments on commit a0053ed

Please sign in to comment.