Skip to content

Commit

Permalink
Use cached property where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
taldcroft committed Nov 21, 2023
1 parent ff07cad commit d9b77d5
Showing 1 changed file with 71 additions and 89 deletions.
160 changes: 71 additions & 89 deletions kadi/commands/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""
import functools
import logging
from abc import ABC, abstractmethod
from abc import ABC
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
Expand Down Expand Up @@ -133,31 +133,30 @@ def __init__(self, stop=None, days: float = 14, no_exclude: bool = False):
self.start: CxoTime = self.stop - days * u.day
self.no_exclude = no_exclude

# Get the exclude intervals from the google sheet along with any auto-generated
# ones. This creates self.exclude_intervals which is a Table. By virtue of `tlm`
# and `states` properties that get used, this also creates self.tlm and
# self.states.
self.add_exclude_intervals()

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if cls.state_name is not None:
cls.subclasses.append(cls)

@property
@functools.cached_property
def tlm(self):
if not hasattr(self, "_tlm"):
logger.info(
f"Fetching telemetry for {self.msids} between {self.start.date} and"
logger.info(
f"Fetching telemetry for {self.msids} between {self.start.date} and"
f" {self.stop.date}"
)
_tlm = get_telem_table(self.msids, self.start, self.stop)
if len(_tlm) == 0:
raise NoTelemetryError(
f"No telemetry for {self.msids} between {self.start.date} and"
f" {self.stop.date}"
)
self._tlm = get_telem_table(self.msids, self.start, self.stop)
if len(self._tlm) == 0:
raise NoTelemetryError(
f"No telemetry for {self.msids} between {self.start.date} and"
f" {self.stop.date}"
)
self.update_tlm()
self.add_exclude_intervals()
return self._tlm

@abstractmethod
def update_tlm(self):
"""Update the telemetry values with any subclass-specific processing"""
return _tlm

@property
def times(self):
Expand All @@ -168,31 +167,29 @@ def msid(self):
"""Validate classes have first MSID as primary telemetry. Override as needed."""
return self.msids[0]

@property
@functools.cached_property
def states(self):
if not hasattr(self, "_states"):
state_keys = [self.state_name] + (self.state_keys_extra or [])
self._states = get_states(
start=self.tlm["time"][0],
stop=self.tlm["time"][-1],
state_keys=state_keys,
)
return self._states
state_keys = [self.state_name] + (self.state_keys_extra or [])
_states = get_states(
start=self.tlm["time"][0],
stop=self.tlm["time"][-1],
state_keys=state_keys,
)
return _states

@property
@functools.cached_property
def exclude_intervals(self):
"""Intervals that are excluded from state validation.
This includes manually excluded times from the Command Events sheet
(e.g. within a few minutes of an IU-reset), or auto-generated
state-specific intervals like not validating pitch when in NMM.
"""
if not hasattr(self, "_exclude_intervals"):
self._exclude_intervals = Table(
names=["start", "stop", "states", "comment", "source"],
dtype=[str, str, str, str, str],
)
return self._exclude_intervals
_exclude_intervals = Table(
names=["start", "stop", "states", "comment", "source"],
dtype=[str, str, str, str, str],
)
return _exclude_intervals

def add_exclude_intervals(self):
"""Base method to exclude intervals, starting with intervals defined in the
Expand Down Expand Up @@ -305,24 +302,18 @@ def tlm_vals(self):
"""
raise NotImplementedError

@property
@functools.cached_property
def states_at_times(self):
"""Get the states that correspond to the telemetry times"""
if not hasattr(self, "_states_at_times"):
self._states_at_times = interpolate_states(self.states, self.times)
return self._states_at_times
return interpolate_states(self.states, self.times)

@property
@functools.cached_property
def state_vals(self):
if not hasattr(self, "_state_vals"):
self._state_vals = self.states_at_times[self.state_name].copy()
return self._state_vals
return self.states_at_times[self.state_name].copy()

@property
@functools.cached_property
def violations(self) -> Table:
if not hasattr(self, "_violations"):
self._violations = self.get_violations()
return self._violations
return self.get_violations()

def get_violations_mask(self) -> np.ndarray:
"""Get the violations mask for this validation class
Expand Down Expand Up @@ -482,42 +473,35 @@ def get_html(


class ValidateSingleMsid(Validate):
@property
@functools.cached_property
def tlm_vals(self):
if not hasattr(self, "_tlm_vals"):
self._tlm_vals = self.tlm[self.msid].copy()
return self._tlm_vals
return self.tlm[self.msid].copy()


class ValidateStateCode(Validate):
"""Base class for validation of state with state codes like PCAD_MODE"""

@property
@functools.cached_property
def state_codes(self) -> Table:
tsc = Ska.tdb.msids.find(self.msid)[0].Tsc
state_codes = Table(
_state_codes = Table(
[tsc.data["LOW_RAW_COUNT"], tsc.data["STATE_CODE"]],
names=["raw_val", "state_code"],
)
state_codes.sort("raw_val")
return state_codes
_state_codes.sort("raw_val")
return _state_codes

@property
@functools.cached_property
def tlm_vals(self):
if not hasattr(self, "_tlm_vals"):
vals = convert_state_code_to_raw_val(self.tlm[self.msid], self.state_codes)
self._tlm_vals = vals
return self._tlm_vals
return convert_state_code_to_raw_val(self.tlm[self.msid], self.state_codes)

@property
@functools.cached_property
def state_vals(self):
if not hasattr(self, "_state_vals"):
states_interp = interpolate_states(self.states, self.tlm["time"])
state_vals = convert_state_code_to_raw_val(
states_interp[self.state_name], self.state_codes
)
self._state_vals = state_vals
return self._state_vals
states_interp = interpolate_states(self.states, self.tlm["time"])
_state_vals = convert_state_code_to_raw_val(
states_interp[self.state_name], self.state_codes
)
return _state_vals

def get_plot_figure(self) -> pgo.Figure:
fig = super().get_plot_figure()
Expand Down Expand Up @@ -657,31 +641,29 @@ def add_exclude_intervals(self):
super().add_exclude_intervals()
self.exclude_ofp_intervals_except(["NRML"])

@property
@functools.cached_property
def state_codes(self) -> Table:
if not hasattr(self, "_state_codes"):
rows = [
[0, "INSE"],
[1, "INSE_MOVE"],
[2, "RETR_MOVE"],
[3, "RETR"],
]
self._state_codes = Table(rows=rows, names=["raw_val", "state_code"])
return self._state_codes
rows = [
[0, "INSE"],
[1, "INSE_MOVE"],
[2, "RETR_MOVE"],
[3, "RETR"],
]
_state_codes = Table(rows=rows, names=["raw_val", "state_code"])
return _state_codes

@property
@functools.cached_property
def tlm_vals(self):
if not hasattr(self, "_tlm_vals"):
vals = np.repeat("RETR", len(self.tlm))
# use a combination of the select telemetry and the insertion telem to
# approximate the appropriate telemetry values
# fmt: off
ok = ((self.tlm["4ootgsel"] == self.state_name.upper())
& (self.tlm["4ootgmtn"] == "INSE"))
# fmt: on
vals[ok] = "INSE"
self._tlm_vals = convert_state_code_to_raw_val(vals, self.state_codes)
return self._tlm_vals
vals = np.repeat("RETR", len(self.tlm))
# use a combination of the select telemetry and the insertion telem to
# approximate the appropriate telemetry values
# fmt: off
ok = ((self.tlm["4ootgsel"] == self.state_name.upper())
& (self.tlm["4ootgmtn"] == "INSE"))
# fmt: on
vals[ok] = "INSE"
_tlm_vals = convert_state_code_to_raw_val(vals, self.state_codes)
return _tlm_vals


class ValidateLETG(ValidateGrating):
Expand Down

0 comments on commit d9b77d5

Please sign in to comment.