diff --git a/strax/plugins/overlap_window_plugin.py b/strax/plugins/overlap_window_plugin.py index 6d513006..4f267a93 100644 --- a/strax/plugins/overlap_window_plugin.py +++ b/strax/plugins/overlap_window_plugin.py @@ -41,12 +41,7 @@ def iter(self, iters, executor=None): yield from super().iter(iters, executor=executor) # Yield final results, kept at bay in fear of a new chunk - if self.multi_output: - if any(len(v) > 0 for v in self.cached_results.values()): - yield self.cached_results - else: - if self.cached_results is not None: - yield self.cached_results + yield self.cached_results def do_compute(self, chunk_i=None, **kwargs): if not len(kwargs): diff --git a/tests/test_overlap_plugin.py b/tests/test_overlap_plugin.py index b2d75c48..48fe69b0 100644 --- a/tests/test_overlap_plugin.py +++ b/tests/test_overlap_plugin.py @@ -78,14 +78,41 @@ def compute(self, peaks): endtime=strax.endtime(peaks)[-1:], ) + class MultipleWithinWindow(WithinWindow): + provides = ("within_window", "multiple_within_window") + data_kind = dict( + within_window="within_window", multiple_within_window="multiple_within_window" + ) + dtype = dict( + within_window=[("n_within_window", np.int16)] + strax.time_fields, + multiple_within_window=[("window_length", np.int16)] + strax.time_fields, + ) + + def compute(self, peaks): + within_window = dict( + n_within_window=count_in_window(strax.endtime(peaks)), + time=peaks["time"][:1], + endtime=strax.endtime(peaks)[-1:], + ) + multiple_within_window = dict( + window_length=peaks["length"], + time=peaks["time"], + endtime=strax.endtime(peaks), + ) + return dict( + within_window=within_window, + multiple_within_window=multiple_within_window, + ) + st = strax.Context(storage=[]) st.register(Peaks) - st.register(WithinWindow) + for plugin in (WithinWindow, MultipleWithinWindow): + st.register(plugin) - result = st.get_array(run_id="some_run", targets="within_window") - expected = count_in_window(strax.endtime(input_peaks)) + result = st.get_array(run_id="some_run", targets="within_window") + expected = count_in_window(strax.endtime(input_peaks)) - assert len(expected) == len(input_peaks), "WTF??" - assert isinstance(result, np.ndarray), "Did not get an array" - assert len(result) == len(expected), "Result has wrong length" - np.testing.assert_equal(result["n_within_window"], expected, "Counting went wrong") + assert len(expected) == len(input_peaks), "WTF??" + assert isinstance(result, np.ndarray), "Did not get an array" + assert len(result) == len(expected), "Result has wrong length" + np.testing.assert_equal(result["n_within_window"], expected, "Counting went wrong")