Skip to content

Commit

Permalink
Merge pull request #123 from sot/faster-bad-times
Browse files Browse the repository at this point in the history
Improve performance of bad_times and mask_times
  • Loading branch information
taldcroft authored Sep 8, 2022
2 parents 063a764 + 19aed41 commit cc1883f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 10 deletions.
52 changes: 42 additions & 10 deletions xija/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
from collections import OrderedDict
from io import StringIO
from pathlib import Path
from typing import List, Tuple, Optional

import numpy as np

from . import component
from . import tmal

# Optional packages for model fitting or use on HEAD LAN
from Chandra.Time import DateTime, date2secs
from Chandra.Time import DateTime
from cxotime import date2secs
from astropy.io import ascii
import Ska.Numpy
import Ska.DBI
Expand Down Expand Up @@ -51,6 +53,41 @@ def convert_type_star_star(array, ctype_type):
return (f4ptr * len(array))(*[row.ctypes.data_as(f4ptr) for row in array])


def _get_bad_times(
times: np.ndarray,
datestart: str,
datestop: str,
bad_times_in: Optional[List[Tuple[str, str]]] = None,
) -> Tuple[List[Tuple[str, str]], List[Tuple[int, int]]]:
"""Return bad_times, bad_times_indices into ``times`` for elements in the
``bad_times_in`` list that overlap with the ``datestart`` to ``datestop``.
NOTE: bad_times_in is a list of [datestart, datestop] lists. The "time"
name is unfortunate since it has string dates, not float CXCsec times.
:returns: bad_times: List[List[str, str]]], bad_times_indices: List[List[int, int]]
"""
if bad_times_in is None or len(bad_times_in) == 0:
return [], []

bad_times: np.ndarray = np.array(bad_times_in)

# Get inclusive overlap of bad_times with datestart to datestop
ok = (bad_times[:, 1] > datestart) & (bad_times[:, 0] < datestop)
if np.any(ok):
bad_times = bad_times[ok]
bad_times_secs = date2secs(bad_times)
idxs = np.searchsorted(times, bad_times_secs)
ok = idxs[:, 0] < idxs[:, 1]
bad_times_out = bad_times[ok].tolist()
bad_times_indices = idxs[ok].tolist()
else:
bad_times_out = []
bad_times_indices = []

return bad_times_out, bad_times_indices


class FetchError(Exception):
pass

Expand Down Expand Up @@ -144,15 +181,10 @@ def __init__(self, name=None, start=None, stop=None, dt=None,
self.rk4 = rk4
self.limits = limits

self.bad_times_indices = []
self.bad_times = []
if model_spec is not None and 'bad_times' in model_spec:
self.bad_times = model_spec['bad_times']
for d0, d1 in self.bad_times:
t0, t1 = DateTime([d0, d1]).secs
i0, i1 = np.searchsorted(self.times, [t0, t1])
if i1 > i0:
self.bad_times_indices.append((i0, i1))
bad_times = None if (model_spec is None) else model_spec.get('bad_times')
self.bad_times, self.bad_times_indices = _get_bad_times(
self.times, self.datestart, self.datestop, bad_times
)
# This is really setting the mask times for the first
# time in this case
self.reset_mask_times()
Expand Down
50 changes: 50 additions & 0 deletions xija/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

from xija import ThermalModel, Node, HeatSink, SolarHeat, Pitch, Eclipse, __version__
from xija.get_model_spec import get_xija_model_spec
import xija
from numpy import sin, cos, abs

Expand Down Expand Up @@ -350,3 +351,52 @@ def test_fewer_dP_pitches(solar_class):
model2.calc()

assert np.allclose(model1.comp[msid].mvals, model2.comp[msid].mvals)


def test_bad_times():
"""
Test bad times handling for PM2THV1T model which has a large number of bad
times defined, including in particular:
['2022:083:21:43:01.949', '2022:083:21:52:52.349'],
['2022:083:22:02:09.949', '2022:083:22:40:58.749'],
['2022:083:23:03:56.349', '2022:084:03:45:28.350'],
['2022:084:07:13:45.151', '2022:084:14:54:02.753'],
"""
spec1 = get_xija_model_spec('pm2thv1t', version='3.40.2')[0]
assert len(spec1['bad_times']) == 5876

# Straddling two bad time intervals (model starts within first and ends
# after second).
mdl = xija.XijaModel(
'test', '2022:083:22:30:00', '2022:084:04:00:00', model_spec=spec1,
)
exp = [['2022:083:22:02:09.949', '2022:083:22:40:58.749'],
['2022:083:23:03:56.349', '2022:084:03:45:28.350']]
assert mdl.bad_times == exp
assert mdl.bad_times_indices == [[0, 2], [6, 58]]
assert 58 < mdl.n_times

# Within one bad time interval
mdl = xija.XijaModel(
'test', '2022:084:00:00:00', '2022:084:01:00:00', model_spec=spec1,
)
exp = [['2022:083:23:03:56.349', '2022:084:03:45:28.350']]
assert mdl.bad_times == exp
assert mdl.bad_times_indices == [[0, len(mdl.times)]]

# Within one bad time interval
mdl = xija.XijaModel(
'test', '2022:084:00:00:00', '2022:084:01:00:00', model_spec=spec1,
)
exp = [['2022:083:23:03:56.349', '2022:084:03:45:28.350']]
assert mdl.bad_times == exp
assert mdl.bad_times_indices == [[0, len(mdl.times)]]

# Within no bad time interval
mdl = xija.XijaModel(
'test', '2022:084:04:00:00', '2022:084:07:00:00', model_spec=spec1,
)
exp = []
assert mdl.bad_times == exp
assert mdl.bad_times_indices == []

0 comments on commit cc1883f

Please sign in to comment.