Skip to content

Commit

Permalink
OverlapWindowPlugin support multiple outputs (#951)
Browse files Browse the repository at this point in the history
* `OverlapWindowPlugin` support multiple outputs

* Add more test

* Add few comments

* Extend windows
  • Loading branch information
dachengx authored Jan 12, 2025
1 parent 274fd08 commit 8149cab
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 42 deletions.
2 changes: 2 additions & 0 deletions strax/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,10 @@ def split_array(data, t, allow_early_split=False):
splittable_i = 0
i_first_beyond = -1
for i, d in enumerate(data):
# only non-overlapping data can be split
if d["time"] >= latest_end_seen:
splittable_i = i
# can not split beyond t
if d["time"] >= t:
i_first_beyond = i
break
Expand Down
2 changes: 1 addition & 1 deletion strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,7 +979,7 @@ def __add_lineage_to_plugin(
if issubclass(plugin.__class__, not_allowed_plugins):
raise ValueError(
f"Can not assign chunk_number for {plugin.__class__} "
f"because it is subclass of {not_allowed_plugins}!"
f"because it is subclass of one of {not_allowed_plugins}!"
)
configs.setdefault("chunk_number", {})
if d_depends in configs["chunk_number"]:
Expand Down
88 changes: 58 additions & 30 deletions strax/plugins/overlap_window_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ class OverlapWindowPlugin(Plugin):
"""

parallel = False
max_trials = 10

def __init__(self):
super().__init__()
self.cached_input = {}
self.cached_results = None
if self.multi_output:
self.cached_results = {}
else:
self.cached_results = None
self.sent_until = 0
if self.clean_chunk_after_compute:
raise ValueError(
Expand All @@ -37,8 +41,7 @@ def iter(self, iters, executor=None):
yield from super().iter(iters, executor=executor)

# Yield final results, kept at bay in fear of a new chunk
if self.cached_results is not None:
yield self.cached_results
yield self.cached_results

def do_compute(self, chunk_i=None, **kwargs):
if not len(kwargs):
Expand All @@ -51,12 +54,6 @@ def do_compute(self, chunk_i=None, **kwargs):
[self.cached_input[data_kind], chunk], self.allow_superrun
)

# Compute new results
result = super().do_compute(chunk_i=chunk_i, **kwargs)

# Throw away results we already sent out
_, result = result.split(t=self.sent_until, allow_early_split=False)

# When does this batch of inputs end?
ends = [c.end for c in kwargs.values()]
if not len(set(ends)) == 1:
Expand All @@ -66,43 +63,74 @@ def do_compute(self, chunk_i=None, **kwargs):
# When can we no longer trust our results?
# Take slightly larger windows for safety: it is very easy for me
# (or the user) to have made an off-by-one error
invalid_beyond = int(end - self.get_window_size() - 1)
invalid_beyond = int(end - 2 * self.get_window_size() - 1)

# Compute new results
result = super().do_compute(chunk_i=chunk_i, **kwargs)

# Throw away results we already sent out
# no error here though allow_early_split=False,
# because result.split(t=invalid_beyond, allow_early_split=True) tunes the
# sent_until to be not overlapping with result and
# sent_until <= invalid_beyond
if self.multi_output:
# when multi_output=True, the result is a dict
for data_type in result:
result[data_type] = result[data_type].split(
t=self.sent_until, allow_early_split=False
)[1]
else:
result = result.split(t=self.sent_until, allow_early_split=False)[1]

# Prepare to send out valid results, cache the rest
# Do not modify result anymore after this
# Note result.end <= invalid_beyond, with equality if there are
# no overlaps
result, self.cached_results = result.split(t=invalid_beyond, allow_early_split=True)
self.sent_until = result.end
# Do not modify result anymore after these lines
# Note result.end <= invalid_beyond, with equality if there are no overlaps
if self.multi_output:
prev_split = self.cache_beyond(result, invalid_beyond, self.cached_results)
for data_type in result:
result[data_type], self.cached_results[data_type] = result[data_type].split(
t=prev_split, allow_early_split=True
)
if len(set([c.start for c in self.cached_results.values()])) != 1:
raise ValueError("Output start time inconsistency has not been resolved?")
self.sent_until = prev_split
else:
result, self.cached_results = result.split(t=invalid_beyond, allow_early_split=True)
self.sent_until = self.cached_results.start

# Cache a necessary amount of input for next time
# Again, take a bit of overkill for good measure
# cache_inputs_beyond is smaller than sent_until
cache_inputs_beyond = int(self.sent_until - 2 * self.get_window_size() - 1)

# Cache inputs, make sure that the chunks start at the same time to
# prevent issues in input buffers later on
prev_split = cache_inputs_beyond
max_trials = 10
for try_counter in range(max_trials):
for data_kind, chunk in kwargs.items():
_, self.cached_input[data_kind] = chunk.split(t=prev_split, allow_early_split=True)
prev_split = self.cached_input[data_kind].start

unique_starts = set([c.start for c in self.cached_input.values()])
chunk_starts_are_equal = len(unique_starts) == 1
if chunk_starts_are_equal:
self.cache_beyond(kwargs, cache_inputs_beyond, self.cached_input)
return result

def cache_beyond(self, io, prev_split, cached):
original_prev_split = prev_split
for try_counter in range(self.max_trials):
for data, chunk in io.items():
# data here can not either data_kind or data_type
# do not temporarily modify result here because it will be used later
# keep its original value!
cached[data] = chunk.split(t=prev_split, allow_early_split=True)[1]
prev_split = cached[data].start
unique_starts = set([c.start for c in cached.values()])
if len(unique_starts) == 1:
self.log.debug(
f"Success after {try_counter}. "
f"Extra time = {cache_inputs_beyond - prev_split} ns"
f"Extra time is {original_prev_split - prev_split} ns"
)
break
else:
self.log.debug(
"Inconsistent start times of the cashed chunks after"
f" {try_counter}/{max_trials} passes.\nChunks {self.cached_input}"
"Inconsistent start times of the cashed chunks {io} after"
f" {try_counter}/{self.max_trials} passes."
)
else:
raise ValueError(
f"Buffer start time inconsistency cannot be resolved after {max_trials} tries"
f"Buffer start time inconsistency cannot be resolved after {self.max_trials} tries"
)
return result
return prev_split
2 changes: 1 addition & 1 deletion strax/processing/peak_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _merge_peaks(
max_data = np.array(max_data)

# Downsample the buffers into
# new_p['data'], new_p['data_top'], and new_p['data_bot']
# new_p['data'], new_p['data_top'], and new_p['data_start']
strax.store_downsampled_waveform(
new_p,
buffer,
Expand Down
4 changes: 1 addition & 3 deletions strax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ def merge_arrs(arrs, dtype=None, replacing=False):
replacing=True is usually used when you want to convert arrs into a new dtype
If you pass one array, it is returned without copying.
TODO: hmm... inconsistent
If you pass one array, it is returned without copying unless replacing=True.
Much faster than the similar function in numpy.lib.recfunctions.
Expand All @@ -205,7 +204,6 @@ def merge_arrs(arrs, dtype=None, replacing=False):

n = len(arrs[0])
if not all([len(x) == n for x in arrs]):
print([(len(x), x.dtype) for x in arrs])
raise ValueError(
"Arrays to merge must have the same length, got lengths "
+ ", ".join([str(len(x)) for x in arrs])
Expand Down
41 changes: 34 additions & 7 deletions tests/test_overlap_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,41 @@ def compute(self, peaks):
endtime=strax.endtime(peaks)[-1:],
)

class MultipleWithinWindow(WithinWindow):
provides = ("within_window", "multiple_within_window")
data_kind = dict(
within_window="within_window", multiple_within_window="multiple_within_window"
)
dtype = dict(
within_window=[("n_within_window", np.int16)] + strax.time_fields,
multiple_within_window=[("window_length", np.int16)] + strax.time_fields,
)

def compute(self, peaks):
within_window = dict(
n_within_window=count_in_window(strax.endtime(peaks)),
time=peaks["time"][:1],
endtime=strax.endtime(peaks)[-1:],
)
multiple_within_window = dict(
window_length=peaks["length"],
time=peaks["time"],
endtime=strax.endtime(peaks),
)
return dict(
within_window=within_window,
multiple_within_window=multiple_within_window,
)

st = strax.Context(storage=[])
st.register(Peaks)
st.register(WithinWindow)
for plugin in (WithinWindow, MultipleWithinWindow):
st.register(plugin)

result = st.get_array(run_id="some_run", targets="within_window")
expected = count_in_window(strax.endtime(input_peaks))
result = st.get_array(run_id="some_run", targets="within_window")
expected = count_in_window(strax.endtime(input_peaks))

assert len(expected) == len(input_peaks), "WTF??"
assert isinstance(result, np.ndarray), "Did not get an array"
assert len(result) == len(expected), "Result has wrong length"
np.testing.assert_equal(result["n_within_window"], expected, "Counting went wrong")
assert len(expected) == len(input_peaks), "WTF??"
assert isinstance(result, np.ndarray), "Did not get an array"
assert len(result) == len(expected), "Result has wrong length"
np.testing.assert_equal(result["n_within_window"], expected, "Counting went wrong")

0 comments on commit 8149cab

Please sign in to comment.