Skip to content

Commit

Permalink
Add more test
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx committed Jan 12, 2025
1 parent c7b8c00 commit 464385f
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 13 deletions.
7 changes: 1 addition & 6 deletions strax/plugins/overlap_window_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,7 @@ def iter(self, iters, executor=None):
yield from super().iter(iters, executor=executor)

# Yield final results, kept at bay in fear of a new chunk
if self.multi_output:
if any(len(v) > 0 for v in self.cached_results.values()):
yield self.cached_results
else:
if self.cached_results is not None:
yield self.cached_results
yield self.cached_results

def do_compute(self, chunk_i=None, **kwargs):
if not len(kwargs):
Expand Down
41 changes: 34 additions & 7 deletions tests/test_overlap_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,41 @@ def compute(self, peaks):
endtime=strax.endtime(peaks)[-1:],
)

class MultipleWithinWindow(WithinWindow):
provides = ("within_window", "multiple_within_window")
data_kind = dict(
within_window="within_window", multiple_within_window="multiple_within_window"
)
dtype = dict(
within_window=[("n_within_window", np.int16)] + strax.time_fields,
multiple_within_window=[("window_length", np.int16)] + strax.time_fields,
)

def compute(self, peaks):
within_window = dict(
n_within_window=count_in_window(strax.endtime(peaks)),
time=peaks["time"][:1],
endtime=strax.endtime(peaks)[-1:],
)
multiple_within_window = dict(
window_length=peaks["length"],
time=peaks["time"],
endtime=strax.endtime(peaks),
)
return dict(
within_window=within_window,
multiple_within_window=multiple_within_window,
)

st = strax.Context(storage=[])
st.register(Peaks)
st.register(WithinWindow)
for plugin in (WithinWindow, MultipleWithinWindow):
st.register(plugin)

result = st.get_array(run_id="some_run", targets="within_window")
expected = count_in_window(strax.endtime(input_peaks))
result = st.get_array(run_id="some_run", targets="within_window")
expected = count_in_window(strax.endtime(input_peaks))

assert len(expected) == len(input_peaks), "WTF??"
assert isinstance(result, np.ndarray), "Did not get an array"
assert len(result) == len(expected), "Result has wrong length"
np.testing.assert_equal(result["n_within_window"], expected, "Counting went wrong")
assert len(expected) == len(input_peaks), "WTF??"
assert isinstance(result, np.ndarray), "Did not get an array"
assert len(result) == len(expected), "Result has wrong length"
np.testing.assert_equal(result["n_within_window"], expected, "Counting went wrong")

0 comments on commit 464385f

Please sign in to comment.