From d9b77d5876537cd38041649edbd3bfe8ce4a97f9 Mon Sep 17 00:00:00 2001 From: Tom Aldcroft Date: Tue, 1 Aug 2023 08:13:45 -0400 Subject: [PATCH] Use cached property where possible --- kadi/commands/validate.py | 160 +++++++++++++++++--------------------- 1 file changed, 71 insertions(+), 89 deletions(-) diff --git a/kadi/commands/validate.py b/kadi/commands/validate.py index 0656445a..f4617780 100644 --- a/kadi/commands/validate.py +++ b/kadi/commands/validate.py @@ -10,7 +10,7 @@ """ import functools import logging -from abc import ABC, abstractmethod +from abc import ABC from dataclasses import dataclass from pathlib import Path from typing import List, Optional @@ -133,31 +133,30 @@ def __init__(self, stop=None, days: float = 14, no_exclude: bool = False): self.start: CxoTime = self.stop - days * u.day self.no_exclude = no_exclude + # Get the exclude intervals from the google sheet along with any auto-generated + # ones. This creates self.exclude_intervals which is a Table. By virtue of `tlm` + # and `states` properties that get used, this also creates self.tlm and + # self.states. + self.add_exclude_intervals() + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if cls.state_name is not None: cls.subclasses.append(cls) - @property + @functools.cached_property def tlm(self): - if not hasattr(self, "_tlm"): - logger.info( - f"Fetching telemetry for {self.msids} between {self.start.date} and" + logger.info( + f"Fetching telemetry for {self.msids} between {self.start.date} and" + f" {self.stop.date}" + ) + _tlm = get_telem_table(self.msids, self.start, self.stop) + if len(_tlm) == 0: + raise NoTelemetryError( + f"No telemetry for {self.msids} between {self.start.date} and" f" {self.stop.date}" ) - self._tlm = get_telem_table(self.msids, self.start, self.stop) - if len(self._tlm) == 0: - raise NoTelemetryError( - f"No telemetry for {self.msids} between {self.start.date} and" - f" {self.stop.date}" - ) - self.update_tlm() - self.add_exclude_intervals() - return self._tlm - - @abstractmethod - def update_tlm(self): - """Update the telemetry values with any subclass-specific processing""" + return _tlm @property def times(self): @@ -168,18 +167,17 @@ def msid(self): """Validate classes have first MSID as primary telemetry. Override as needed.""" return self.msids[0] - @property + @functools.cached_property def states(self): - if not hasattr(self, "_states"): - state_keys = [self.state_name] + (self.state_keys_extra or []) - self._states = get_states( - start=self.tlm["time"][0], - stop=self.tlm["time"][-1], - state_keys=state_keys, - ) - return self._states + state_keys = [self.state_name] + (self.state_keys_extra or []) + _states = get_states( + start=self.tlm["time"][0], + stop=self.tlm["time"][-1], + state_keys=state_keys, + ) + return _states - @property + @functools.cached_property def exclude_intervals(self): """Intervals that are excluded from state validation. @@ -187,12 +185,11 @@ def exclude_intervals(self): (e.g. within a few minutes of an IU-reset), or auto-generated state-specific intervals like not validating pitch when in NMM. """ - if not hasattr(self, "_exclude_intervals"): - self._exclude_intervals = Table( - names=["start", "stop", "states", "comment", "source"], - dtype=[str, str, str, str, str], - ) - return self._exclude_intervals + _exclude_intervals = Table( + names=["start", "stop", "states", "comment", "source"], + dtype=[str, str, str, str, str], + ) + return _exclude_intervals def add_exclude_intervals(self): """Base method to exclude intervals, starting with intervals defined in the @@ -305,24 +302,18 @@ def tlm_vals(self): """ raise NotImplementedError - @property + @functools.cached_property def states_at_times(self): """Get the states that correspond to the telemetry times""" - if not hasattr(self, "_states_at_times"): - self._states_at_times = interpolate_states(self.states, self.times) - return self._states_at_times + return interpolate_states(self.states, self.times) - @property + @functools.cached_property def state_vals(self): - if not hasattr(self, "_state_vals"): - self._state_vals = self.states_at_times[self.state_name].copy() - return self._state_vals + return self.states_at_times[self.state_name].copy() - @property + @functools.cached_property def violations(self) -> Table: - if not hasattr(self, "_violations"): - self._violations = self.get_violations() - return self._violations + return self.get_violations() def get_violations_mask(self) -> np.ndarray: """Get the violations mask for this validation class @@ -482,42 +473,35 @@ def get_html( class ValidateSingleMsid(Validate): - @property + @functools.cached_property def tlm_vals(self): - if not hasattr(self, "_tlm_vals"): - self._tlm_vals = self.tlm[self.msid].copy() - return self._tlm_vals + return self.tlm[self.msid].copy() class ValidateStateCode(Validate): """Base class for validation of state with state codes like PCAD_MODE""" - @property + @functools.cached_property def state_codes(self) -> Table: tsc = Ska.tdb.msids.find(self.msid)[0].Tsc - state_codes = Table( + _state_codes = Table( [tsc.data["LOW_RAW_COUNT"], tsc.data["STATE_CODE"]], names=["raw_val", "state_code"], ) - state_codes.sort("raw_val") - return state_codes + _state_codes.sort("raw_val") + return _state_codes - @property + @functools.cached_property def tlm_vals(self): - if not hasattr(self, "_tlm_vals"): - vals = convert_state_code_to_raw_val(self.tlm[self.msid], self.state_codes) - self._tlm_vals = vals - return self._tlm_vals + return convert_state_code_to_raw_val(self.tlm[self.msid], self.state_codes) - @property + @functools.cached_property def state_vals(self): - if not hasattr(self, "_state_vals"): - states_interp = interpolate_states(self.states, self.tlm["time"]) - state_vals = convert_state_code_to_raw_val( - states_interp[self.state_name], self.state_codes - ) - self._state_vals = state_vals - return self._state_vals + states_interp = interpolate_states(self.states, self.tlm["time"]) + _state_vals = convert_state_code_to_raw_val( + states_interp[self.state_name], self.state_codes + ) + return _state_vals def get_plot_figure(self) -> pgo.Figure: fig = super().get_plot_figure() @@ -657,31 +641,29 @@ def add_exclude_intervals(self): super().add_exclude_intervals() self.exclude_ofp_intervals_except(["NRML"]) - @property + @functools.cached_property def state_codes(self) -> Table: - if not hasattr(self, "_state_codes"): - rows = [ - [0, "INSE"], - [1, "INSE_MOVE"], - [2, "RETR_MOVE"], - [3, "RETR"], - ] - self._state_codes = Table(rows=rows, names=["raw_val", "state_code"]) - return self._state_codes + rows = [ + [0, "INSE"], + [1, "INSE_MOVE"], + [2, "RETR_MOVE"], + [3, "RETR"], + ] + _state_codes = Table(rows=rows, names=["raw_val", "state_code"]) + return _state_codes - @property + @functools.cached_property def tlm_vals(self): - if not hasattr(self, "_tlm_vals"): - vals = np.repeat("RETR", len(self.tlm)) - # use a combination of the select telemetry and the insertion telem to - # approximate the appropriate telemetry values - # fmt: off - ok = ((self.tlm["4ootgsel"] == self.state_name.upper()) - & (self.tlm["4ootgmtn"] == "INSE")) - # fmt: on - vals[ok] = "INSE" - self._tlm_vals = convert_state_code_to_raw_val(vals, self.state_codes) - return self._tlm_vals + vals = np.repeat("RETR", len(self.tlm)) + # use a combination of the select telemetry and the insertion telem to + # approximate the appropriate telemetry values + # fmt: off + ok = ((self.tlm["4ootgsel"] == self.state_name.upper()) + & (self.tlm["4ootgmtn"] == "INSE")) + # fmt: on + vals[ok] = "INSE" + _tlm_vals = convert_state_code_to_raw_val(vals, self.state_codes) + return _tlm_vals class ValidateLETG(ValidateGrating):