Skip to content

Commit

Permalink
Merge pull request #389 from Aske-Rosted/working_base
Browse files Browse the repository at this point in the history
Working base
  • Loading branch information
Aske-Rosted authored Feb 21, 2023
2 parents 405a1cd + 695bd90 commit 983818a
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 21 deletions.
21 changes: 19 additions & 2 deletions src/graphnet/data/dataconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,15 @@ def _extract_data(self, fileset: FileSet) -> List[OrderedDict]:
while i3_file_io.more():
try:
frame = i3_file_io.pop_physics()
except: # noqa: E722
except Exception as e:
if "I3" in str(e):
continue
if self._skip_frame(frame):
continue

# Extract data from I3Frame
# Try to extract data from I3Frame
results = self._extractors(frame)

data_dict = OrderedDict(zip(self._table_names, results))

# If an I3GenericExtractor is used, we want each automatically
Expand Down Expand Up @@ -546,3 +550,16 @@ def _get_output_file(self, input_file: str) -> str:
re.sub(r"\.i3\..*", "", basename) + "." + self.file_suffix,
)
return output_file

def _skip_frame(self, frame: "icetray.I3Frame") -> bool:
"""Check if frame should be skipped.
Args:
frame: I3Frame to check.
Returns:
True if frame is a null split frame, else False.
"""
if frame["I3EventHeader"].sub_event_stream == "NullSplit":
return True
return False
1 change: 1 addition & 0 deletions src/graphnet/data/extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .i3truthextractor import I3TruthExtractor
from .i3retroextractor import I3RetroExtractor
from .i3splinempeextractor import I3SplineMPEICExtractor
from .i3particleextractor import I3ParticleExtractor
from .i3tumextractor import I3TUMExtractor
from .i3hybridrecoextractor import I3GalacticPlaneHybridRecoExtractor
from .i3genericextractor import I3GenericExtractor
Expand Down
43 changes: 43 additions & 0 deletions src/graphnet/data/extractors/i3particleextractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""I3Extractor class(es) for extracting I3Particle properties."""

from typing import TYPE_CHECKING, Dict

from graphnet.data.extractors.i3extractor import I3Extractor

if TYPE_CHECKING:
from icecube import icetray # pyright: reportMissingImports=false


class I3ParticleExtractor(I3Extractor):
"""Class for extracting I3Particle properties.
Can be used to extract predictions from other algorithms for comparisons
with GraphNeT.
"""

def __init__(self, name: str):
"""Construct I3ParticleExtractor."""
# Base class constructor
super().__init__(name)

def __call__(self, frame: "icetray.I3Frame") -> Dict[str, float]:
"""Extract I3Particle properties from I3Particle in frame."""
output = {}
if self._name in frame:
output.update(
{
"zenith_" + self._name: frame[self._name].dir.zenith,
"azimuth_" + self._name: frame[self._name].dir.azimuth,
"dir_x_" + self._name: frame[self._name].dir.x,
"dir_y_" + self._name: frame[self._name].dir.y,
"dir_z_" + self._name: frame[self._name].dir.z,
"pos_x_" + self._name: frame[self._name].pos.x,
"pos_y_" + self._name: frame[self._name].pos.y,
"pos_z_" + self._name: frame[self._name].pos.z,
"time_" + self._name: frame[self._name].time,
"speed_" + self._name: frame[self._name].speed,
"energy_" + self._name: frame[self._name].energy,
}
)

return output
19 changes: 12 additions & 7 deletions src/graphnet/data/extractors/i3truthextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ class I3TruthExtractor(I3Extractor):
"""Class for extracting truth-level information."""

def __init__(
self, name: str = "truth", borders: Optional[List[np.ndarray]] = None
self,
name: str = "truth",
borders: Optional[List[np.ndarray]] = None,
mctree: Optional[str] = "I3MCTree",
):
"""Construct I3TruthExtractor.
Expand All @@ -33,6 +36,7 @@ def __init__(
coordinates, for identifying, e.g., particles starting and
stopping within the detector. Defaults to hard-coded boundary
coordinates.
mctree: Str of which MCTree to use for truth values.
"""
# Base class constructor
super().__init__(name)
Expand Down Expand Up @@ -74,13 +78,14 @@ def __init__(
self._borders = [border_xy, border_z]
else:
self._borders = borders
self._mctree = mctree

def __call__(
self, frame: "icetray.I3Frame", padding_value: Any = -1
) -> Dict[str, Any]:
"""Extract truth-level information."""
is_mc = frame_is_montecarlo(frame)
is_noise = frame_is_noise(frame)
is_mc = frame_is_montecarlo(frame, self._mctree)
is_noise = frame_is_noise(frame, self._mctree)
sim_type = self._find_data_type(is_mc, self._i3_file)

output = {
Expand Down Expand Up @@ -217,7 +222,7 @@ def __call__(
def _extract_dbang_decay_length(
self, frame: "icetray.I3Frame", padding_value: float = -1
) -> float:
mctree = frame["I3MCTree"]
mctree = frame[self._mctree]
try:
p_true = mctree.primaries[0]
p_daughters = mctree.get_daughters(p_true)
Expand Down Expand Up @@ -346,11 +351,11 @@ def _get_primary_particle_interaction_type_and_elasticity(
try:
MCInIcePrimary = frame["MCInIcePrimary"]
except KeyError:
MCInIcePrimary = frame["I3MCTree"][0]
MCInIcePrimary = frame[self._mctree][0]
if (
MCInIcePrimary.energy != MCInIcePrimary.energy
): # This is a nan check. Only happens for some muons where second item in MCTree is primary. Weird!
MCInIcePrimary = frame["I3MCTree"][
MCInIcePrimary = frame[self._mctree][
1
] # For some strange reason the second entry is identical in all variables and has no nans (always muon)
else:
Expand Down Expand Up @@ -380,7 +385,7 @@ def _get_primary_track_energy_and_inelasticity(
Tuple containing the energy of tracks from primary, and the
corresponding inelasticity.
"""
mc_tree = frame["I3MCTree"]
mc_tree = frame[self._mctree]
primary = mc_tree.primaries[0]
daughters = mc_tree.get_daughters(primary)
tracks = []
Expand Down
12 changes: 8 additions & 4 deletions src/graphnet/data/extractors/utilities/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,19 @@
) # pyright: reportMissingImports=false


def frame_is_montecarlo(frame: "icetray.I3Frame") -> bool:
def frame_is_montecarlo(
frame: "icetray.I3Frame", mctree: Optional[str] = "I3MCTree"
) -> bool:
"""Check whether `frame` is from Monte Carlo simulation."""
return ("MCInIcePrimary" in frame) or ("I3MCTree" in frame)
return ("MCInIcePrimary" in frame) or (mctree in frame)


def frame_is_noise(frame: "icetray.I3Frame") -> bool:
def frame_is_noise(
frame: "icetray.I3Frame", mctree: Optional[str] = "I3MCTree"
) -> bool:
"""Check whether `frame` is from noise."""
try:
frame["I3MCTree"][0].energy
frame[mctree][0].energy
return False
except: # noqa: E722
try:
Expand Down
2 changes: 1 addition & 1 deletion src/graphnet/data/sqlite/sqlite_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import pandas as pd

import random
from graphnet.utilities.logging import get_logger

logger = get_logger()
Expand Down
39 changes: 32 additions & 7 deletions src/graphnet/models/coarsening.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,28 @@ def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor:
class DOMCoarsening(Coarsening):
"""Coarsen pulses to DOM-level."""

def __init__(
self,
reduce: str = "avg",
transfer_attributes: bool = True,
keys: Optional[List[str]] = None,
):
"""Cluster pulses on the same DOM."""
super().__init__(reduce, transfer_attributes)
if keys is None:
self._keys = [
"dom_x",
"dom_y",
"dom_z",
"rde",
"pmt_area",
]
else:
self._keys = keys

def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor:
"""Cluster nodes in `data` by assigning a cluster index to each."""
dom_index = group_by(
data, ["dom_x", "dom_y", "dom_z", "rde", "pmt_area"]
)
dom_index = group_by(data, self._keys)
return dom_index


Expand Down Expand Up @@ -271,23 +288,31 @@ def __init__(
time_window: float,
reduce: str = "avg",
transfer_attributes: bool = True,
keys: List[str] = [
"dom_x",
"dom_y",
"dom_z",
"rde",
"pmt_area",
],
time_key: str = "dom_time",
):
"""Cluster pulses on the same DOM within `time_window`."""
super().__init__(reduce, transfer_attributes)
self._time_window = time_window
self._cluster_method = DBSCAN(self._time_window, min_samples=1)
self._keys = keys
self._time_key = time_key

def _perform_clustering(self, data: Union[Data, Batch]) -> LongTensor:
"""Cluster nodes in `data` by assigning a cluster index to each."""
dom_index = group_by(
data, ["dom_x", "dom_y", "dom_z", "rde", "pmt_area"]
)
dom_index = group_by(data, self._keys)
if data.batch is not None:
features = data.features[0]
else:
features = data.features

ix_time = features.index("dom_time")
ix_time = features.index(self._time_key)
hit_times = data.x[:, ix_time]

# Scale up dom_index to make sure clusters are well separated
Expand Down

0 comments on commit 983818a

Please sign in to comment.