diff --git a/src/eventio/base.py b/src/eventio/base.py index 7cb36f8..c5a5045 100644 --- a/src/eventio/base.py +++ b/src/eventio/base.py @@ -62,6 +62,7 @@ def __init__(self, path, zcat=True): self.read_process = None self.zstd = False self.next = None + self.peek_error = None if not is_eventio(path): raise ValueError('File {} is not an eventio file'.format(path)) @@ -125,7 +126,12 @@ def __next__(self): def peek(self): if self.next is None: - self.next = next(self) + try: + self.next = next(self) + except (StopIteration, EOFError, IOError) as e: + self.peek_error = e + self.next = None + return self.next def seek(self, position, whence=0): @@ -157,7 +163,7 @@ def check_size_or_raise(data, expected_length, zero_ok=True): else: raise EOFError('File seems to be truncated') - if length < expected_length: + elif length < expected_length: raise EOFError('File seems to be truncated') @@ -182,12 +188,6 @@ def read_header(byte_stream, offset, toplevel=False): ''' header_bytes = byte_stream.read(constants.OBJECT_HEADER_SIZE) - check_size_or_raise( - header_bytes, - constants.OBJECT_HEADER_SIZE, - zero_ok=False, - ) - header = parse_header_bytes(header_bytes, toplevel=toplevel) if header.extended: diff --git a/src/eventio/header.pyx b/src/eventio/header.pyx index 4132159..bcd309f 100644 --- a/src/eventio/header.pyx +++ b/src/eventio/header.pyx @@ -122,6 +122,11 @@ cpdef ObjectHeader parse_header_bytes(const uint8_t[:] header_bytes, bint toplev cdef bint only_subobjects cdef uint64_t length + if len(header_bytes) != OBJECT_HEADER_SIZE: + # for backwards compatibility, we raise the same error as before. + # more appropriate for this free function would be something like + # raise ValueError("header_bytes size must be 12") + raise EOFError('File seems to be truncated') type_int = unpack_uint32(header_bytes[0:4]) type_, version, user, extended = parse_type_field(type_int) diff --git a/src/eventio/simtel/objects.py b/src/eventio/simtel/objects.py index 851269a..d26ee62 100644 --- a/src/eventio/simtel/objects.py +++ b/src/eventio/simtel/objects.py @@ -1461,10 +1461,8 @@ def __str__(self): ) def parse(self): - ''' ''' - assert_version_in(self, (1, 2)) + '''''' self.seek(0) - d = MCEvent.parse_mc_event(self.read(), self.header.version) d['event_id'] = self.header.id return d diff --git a/src/eventio/simtel/parsing.pyx b/src/eventio/simtel/parsing.pyx index 459f9c7..3b7723a 100644 --- a/src/eventio/simtel/parsing.pyx +++ b/src/eventio/simtel/parsing.pyx @@ -54,6 +54,12 @@ cpdef dict parse_mc_event( const uint8_t[:] data, uint32_t version ): + if version > 2: + raise NotImplementedError( + 'Unsupported version of MCEvent:' + ' only versions up to 2 supported,' + f' got: {version} ' + ) cdef uint64_t pos = 0 cdef float xcore, ycore, aweight diff --git a/src/eventio/simtel/simtelfile.py b/src/eventio/simtel/simtelfile.py index c67a234..8f01cc7 100644 --- a/src/eventio/simtel/simtelfile.py +++ b/src/eventio/simtel/simtelfile.py @@ -2,11 +2,13 @@ Implementation of an EventIOFile that loops through SimTel Array events. ''' +from functools import lru_cache import re from copy import copy from collections import defaultdict import warnings import logging +from typing import Dict, Any from ..base import EventIOFile from ..exceptions import check_type @@ -68,6 +70,13 @@ class UnknownObjectWarning(UserWarning): camel_re2 = re.compile('([a-z0-9])([A-Z])') +# these objects mark the end of the current event +NEXT_EVENT_MARKERS = ( + MCEvent, MCShower, CalibrationEvent, CalibrationPhotoelectrons, type(None) +) + + +@lru_cache() def camel_to_snake(name): s1 = camel_re1.sub(r'\1_\2', name) return camel_re2.sub(r'\1_\2', s1).lower() @@ -77,17 +86,48 @@ class NoTrackingPositions(Exception): pass -class SimTelFile(EventIOFile): - def __init__(self, path, allowed_telescopes=None, skip_calibration=False, zcat=True): - super().__init__(path, zcat=zcat) +class SimTelFile: + ''' + This assumes the following top-level structure once events are seen: + + Either: + MCShower[2020] + MCEvent[2021] + # stuff belonging to this MCEvent + optional TelescopeData[1204] + optional PixelMonitoring[2033] for each telescope + optional (CameraMonitoring[2022], LaserCalibration[2023]) for each telescope + optional MCPhotoelectronSum[2026] + optional ArrayEvent[2010] + + optional MCEvent for same shower (reuse) + + Or: + CalibrationEvent[2028] + with possibly more CameraMonitoring / LaserCalibration in between + calibration events + ''' + def __init__( + self, + path, + skip_non_triggered=True, + skip_calibration=False, + allowed_telescopes=None, + zcat=True, + ): + self._file = EventIOFile(path, zcat=zcat) self.path = path + + self.skip_calibration = skip_calibration + self.skip_non_triggered = skip_non_triggered + self.allowed_telescopes = None if allowed_telescopes: self.allowed_telescopes = set(allowed_telescopes) + # object storage self.histograms = None - self.history = [] self.mc_run_headers = [] self.corsika_inputcards = [] @@ -100,21 +140,13 @@ def __init__(self, path, allowed_telescopes=None, skip_calibration=False, zcat=T self.pixel_monitorings = defaultdict(dict) self.camera_monitorings = defaultdict(dict) self.laser_calibrations = defaultdict(dict) + + # wee need to keep the mc_shower separate from the event, + # as it is valid for more than one (CORSIKA re-use) self.current_mc_shower = None self.current_mc_shower_id = None - self.current_mc_event = None - self.current_mc_event_id = None - self.current_telescope_data_event_id = None - self.current_photoelectron_sum = None - self.current_photoelectron_sum_event_id = None - self.current_photoelectrons = {} - self.current_photons = {} - self.current_emitter = {} - self.current_array_event = None - self.current_calibration_event = None - self.current_calibration_event_id = None - self.current_calibration_pe = {} - self.skip_calibration = skip_calibration + self.current_event_id = None + self.current_event = {"type": "data"} # read the header: # assumption: the header is done when @@ -122,13 +154,12 @@ def __init__(self, path, allowed_telescopes=None, skip_calibration=False, zcat=T # and we found the telescope_descriptions of all telescopes check = [] found_all_telescopes = False - while not (any(o for o in check) and found_all_telescopes): - self.next_low_level() + while not (any(o is not None for o in check) and found_all_telescopes): + self._parse_next_object() check = [ self.current_mc_shower, - self.current_array_event, - self.current_calibration_event, + self.current_event_id, self.laser_calibrations, self.camera_monitorings, ] @@ -142,39 +173,85 @@ def __init__(self, path, allowed_telescopes=None, skip_calibration=False, zcat=T found_all_telescopes = found == self.n_telescopes def __iter__(self): - return self.iter_array_events() + ''' + Iterate over all events in the file. + ''' + return self + + def __next__(self): + event = self._read_next_event() + + while self._check_skip(event): + event = self._read_next_event() + + return event + + def _read_next_event(self): + if self._file.peek() is None: + raise StopIteration() + + while isinstance(self._file.peek(), (PixelMonitoring, CameraMonitoring, LaserCalibration)): + self._parse_next_object() + + if isinstance(self._file.peek(), CalibrationPhotoelectrons): + self._parse_next_object() + + if isinstance(self._file.peek(), MCShower): + self._parse_next_object() + + if isinstance(self._file.peek(), (MCEvent, CalibrationEvent)): + self._parse_next_object() + self._read_until_next_event() + return self._build_event() + + # extracted calibration events have "naked" ArrayEvents without + # a preceding MCEvent or CalibrationEvent wrapper + if isinstance(self._file.peek(), ArrayEvent): + self._parse_next_object() + return self._build_event() + + raise ValueError(f"Unexpected obj type: {self._file.peek()}") - def next_low_level(self): - o = next(self) + def _check_skip(self, event): + if event['type'] == 'data': + return self.skip_non_triggered and not event.get('telescope_events') + + if event['type'] == 'calibration': + return self.skip_calibration + + raise ValueError(f'Unexpected event type {event["type"]}') + + def _read_until_next_event(self): + while not isinstance(self._file.peek(), NEXT_EVENT_MARKERS): + self._parse_next_object() + + def _parse_next_object(self): + o = next(self._file) # order of if statements is roughly sorted # by the number of occurences in a simtel file # this should minimize the number of if statements evaluated if isinstance(o, MCEvent): - self.current_mc_event = o.parse() - self.current_mc_event_id = o.header.id + self.current_event["event_id"] = o.header.id + self.current_event["mc_event"] = o.parse() elif isinstance(o, MCShower): self.current_mc_shower = o.parse() self.current_mc_shower_id = o.header.id elif isinstance(o, ArrayEvent): - self.current_array_event = parse_array_event( - o, - self.allowed_telescopes + self.current_event_id = o.header.id + self.current_event["event_id"] = o.header.id + self.current_event.update( + parse_array_event(o, self.allowed_telescopes) ) elif isinstance(o, iact.TelescopeData): - event_id, photons, emitter, photoelectrons = parse_telescope_data(o) - self.current_telescope_data_event_id = event_id - self.current_photons = photons - self.current_emitter = emitter - self.current_photoelectrons = photoelectrons + self.current_event.update(parse_telescope_data(o)) elif isinstance(o, MCPhotoelectronSum): - self.current_photoelectron_sum_event_id = o.header.id - self.current_photoelectron_sum = o.parse() + self.current_event["photoelectron_sums"] = o.parse() elif isinstance(o, CameraMonitoring): self.camera_monitorings[o.telescope_id].update(o.parse()) @@ -200,16 +277,17 @@ def next_low_level(self): self.corsika_inputcards.append(o.parse()) elif isinstance(o, CalibrationEvent): - if not self.skip_calibration: - array_event = next(o) - self.current_calibration_event = parse_array_event( - array_event, - self.allowed_telescopes, - ) - # assign negative event_ids to calibration events to avoid - # duplicated event_ids - self.current_calibration_event_id = -array_event.header.id - self.current_calibration_event['calibration_type'] = o.type + array_event = next(o) + # make event_id negative for calibration events to not overlap with + # later air shower events + self.current_event["event_id"] = -array_event.header.id + self.current_event_id = self.current_event["event_id"] + self.current_event.update( + parse_array_event(array_event, self.allowed_telescopes) + ) + self.current_event['type'] = 'calibration' + self.current_event['calibration_type'] = o.type + elif isinstance(o, CalibrationPhotoelectrons): telescope_data = next(o) if not isinstance(telescope_data, iact.TelescopeData): @@ -219,7 +297,7 @@ def next_low_level(self): ) return - self.current_calibration_pe = {} + self.current_event["photoelectrons"] = {} for photoelectrons in telescope_data: if not isinstance(photoelectrons, iact.PhotoElectrons): warnings.warn( @@ -228,8 +306,7 @@ def next_low_level(self): ) tel_id = photoelectrons.telescope_id - self.current_calibration_pe[tel_id] = photoelectrons.parse() - + self.current_event["photoelectrons"][tel_id] = photoelectrons.parse() elif isinstance(o, History): for sub in o: @@ -252,135 +329,49 @@ def next_low_level(self): UnknownObjectWarning, ) - def iter_mc_events(self): - while True: - try: - next_event = self.try_build_mc_event() - except StopIteration: - break - if next_event is not None: - yield next_event - - def try_build_mc_event(self): - if self.current_mc_event: - - event_data = { - 'event_id': self.current_mc_event_id, - 'mc_shower': self.current_mc_shower, - 'mc_event': self.current_mc_event, - } - # if next object is TelescopeData, it belongs to this event - if isinstance(self.peek(), iact.TelescopeData): - self.next_low_level() - event_data['photons'] = self.current_photons - event_data['emitter'] = self.current_emitter - event_data['photoelectrons'] = self.current_photoelectrons - - self.current_mc_event = None - return event_data - self.next_low_level() - - def iter_array_events(self): - while True: - - next_event = self.try_build_event() - if next_event is not None: - yield next_event - - try: - self.next_low_level() - except StopIteration: - break - - def try_build_event(self): + def _build_event(self): '''check if all necessary info for an event was found, then make an event and invalidate old data ''' - if self.current_array_event: - if ( - self.allowed_telescopes - and not self.current_array_event['telescope_events'] - ): - self.current_array_event = None - return None - - event_id = self.current_array_event['event_id'] - - event_data = { - 'type': 'data', - 'event_id': event_id, - 'mc_shower': None, - 'mc_event': None, - 'telescope_events': self.current_array_event['telescope_events'], - 'tracking_positions': self.current_array_event['tracking_positions'], - 'trigger_information': self.current_array_event['trigger_information'], - 'photons': {}, - 'emitter': {}, - 'photoelectrons': {}, - 'photoelectron_sums': None, - } - - if self.current_mc_event_id == event_id: - event_data['mc_shower'] = self.current_mc_shower - event_data['mc_event'] = self.current_mc_event + event = self.current_event + self.current_event: Dict[str, Any] = {"type": "data"} - if self.current_telescope_data_event_id == event_id: - event_data['photons'] = self.current_photons - event_data['emitter'] = self.current_emitter - event_data['photoelectrons'] = self.current_photoelectrons + if self.current_mc_shower is not None and event["type"] == "data": + event["mc_shower"] = self.current_mc_shower - if self.current_photoelectron_sum_event_id == event_id: - event_data['photoelectron_sums'] = self.current_photoelectron_sum - - event_data['camera_monitorings'] = { + # fill monitoring info if we have telescope events + if 'telescope_events' in event: + tel_ids = event["telescope_events"].keys() + event['camera_monitorings'] = { telescope_id: copy(self.camera_monitorings[telescope_id]) - for telescope_id in self.current_array_event['telescope_events'].keys() + for telescope_id in tel_ids } - event_data['laser_calibrations'] = { + event['laser_calibrations'] = { telescope_id: copy(self.laser_calibrations[telescope_id]) - for telescope_id in self.current_array_event['telescope_events'].keys() + for telescope_id in tel_ids } - event_data['pixel_monitorings'] = { + event['pixel_monitorings'] = { telescope_id: copy(self.pixel_monitorings[telescope_id]) - for telescope_id in self.current_array_event['telescope_events'].keys() + for telescope_id in tel_ids } - self.current_array_event = None - - return event_data - - elif self.current_calibration_event: - event = self.current_calibration_event - if ( - self.allowed_telescopes - and not self.current_array_event['telescope_events'] - ): - self.current_calibration_event = None - return None - - event_data = { - 'type': 'calibration', - 'event_id': self.current_calibration_event_id, - 'telescope_events': event['telescope_events'], - 'tracking_positions': event['tracking_positions'], - 'trigger_information': event['trigger_information'], - 'calibration_type': event['calibration_type'], - 'photoelectrons': self.current_calibration_pe, - } + return event - event_data['camera_monitorings'] = { - telescope_id: copy(self.camera_monitorings[telescope_id]) - for telescope_id in event['telescope_events'].keys() - } - event_data['laser_calibrations'] = { - telescope_id: copy(self.laser_calibrations[telescope_id]) - for telescope_id in event['telescope_events'].keys() - } + def __enter__(self): + return self - self.current_calibration_event = None + def __exit__(self, exc_type, exc_value, traceback): + self.close() - return event_data + def close(self): + self._file.close() + + def tell(self): + return self._file.tell() + + def seek(self, *args, **kwargs): + return self._file.seek(*args, **kwargs) def parse_array_event(array_event, allowed_telescopes=None): @@ -416,7 +407,10 @@ def parse_array_event(array_event, allowed_telescopes=None): # require first element to be a TriggerInformation if i == 0: check_type(o, TriggerInformation) - event_id = o.header.id + # extracted calibration events seem to have a valid event id in the array event + # but not in the trigger + if o.header.id != 0: + event_id = o.header.id trigger_information = o.parse() telescopes = set(trigger_information['telescopes_with_data']) @@ -451,19 +445,16 @@ def parse_telescope_data(telescope_data): ''' Parse the TelescopeData block with Cherenkov Photon information''' check_type(telescope_data, iact.TelescopeData) - photons = {} - emitter = {} - photo_electrons = {} + data = defaultdict(dict) for o in telescope_data: if isinstance(o, iact.PhotoElectrons): - photo_electrons[o.telescope_id] = o.parse() + data["photoelectrons"][o.telescope_id] = o.parse() elif isinstance(o, iact.Photons): p, e = o.parse() - photons[o.telescope_id] = p + data["photons"][o.telescope_id] = p if e is not None: - emitter[o.telescope_id] = e - - return telescope_data.header.id, photons, emitter, photo_electrons + data["emitter"][o.telescope_id] = e + return data def parse_telescope_event(telescope_event): diff --git a/tests/resources/extracted_pedestals.simtel.zst b/tests/resources/extracted_pedestals.simtel.zst new file mode 100644 index 0000000..e782a7a Binary files /dev/null and b/tests/resources/extracted_pedestals.simtel.zst differ diff --git a/tests/simtel/test_simtelfile.py b/tests/simtel/test_simtelfile.py index 7bbaeb2..8985799 100644 --- a/tests/simtel/test_simtelfile.py +++ b/tests/simtel/test_simtelfile.py @@ -1,3 +1,4 @@ +import pytest from pytest import importorskip from eventio.simtel import SimTelFile import numpy as np @@ -19,102 +20,100 @@ test_paths = [prod2_path, prod3_path, prod4_path] -def test_can_open(): - for path in test_paths: - assert SimTelFile(path) - - -def test_at_least_one_event_found(): - for path in test_paths: - one_found = False - for event in SimTelFile(path): - one_found = True - break - assert one_found, path - - -def test_show_we_get_a_dict_with_hower_and_event(): - for path in test_paths: - for event in SimTelFile(path): - assert 'mc_shower' in event - assert 'telescope_events' in event - assert 'mc_event' in event - break - - -def test_show_event_is_not_empty_and_has_some_members_for_sure(): - for path in test_paths: - for event in SimTelFile(path): - assert event['mc_shower'].keys() == { - 'shower', - 'primary_id', - 'energy', - 'azimuth', - 'altitude', - 'depth_start', - 'h_first_int', - 'xmax', - 'hmax', - 'emax', - 'cmax', - 'n_profiles', - 'profiles' +@pytest.mark.parametrize("path", test_paths) +def test_can_open(path): + assert SimTelFile(path) + + +@pytest.mark.parametrize("path", test_paths) +def test_at_least_one_event_found(path): + one_found = False + for event in SimTelFile(path): + one_found = True + break + assert one_found, path + + +@pytest.mark.parametrize("path", test_paths) +def test_show_we_get_a_dict_with_hower_and_event(path): + for event in SimTelFile(path): + assert 'mc_shower' in event + assert 'telescope_events' in event + assert 'mc_event' in event + break + + +@pytest.mark.parametrize("path", test_paths) +def test_show_event_is_not_empty_and_has_some_members_for_sure(path): + for event in SimTelFile(path): + assert event['mc_shower'].keys() == { + 'shower', + 'primary_id', + 'energy', + 'azimuth', + 'altitude', + 'depth_start', + 'h_first_int', + 'xmax', + 'hmax', + 'emax', + 'cmax', + 'n_profiles', + 'profiles' + } + + required = { + 'type', + 'event_id', + 'mc_shower', + 'mc_event', + 'telescope_events', + 'trigger_information', + 'tracking_positions', + 'photoelectron_sums', + 'camera_monitorings', + 'laser_calibrations', + 'pixel_monitorings', + } + assert required.issubset(event.keys()) + + telescope_events = event['telescope_events'] + + assert telescope_events # never empty! + + for telescope_event in telescope_events.values(): + expected_keys = { + 'header', + 'pixel_timing', + 'pixel_lists', } - - assert event.keys() == { - 'type', - 'event_id', - 'mc_shower', - 'mc_event', - 'telescope_events', - 'trigger_information', - 'tracking_positions', - 'photoelectron_sums', - 'photoelectrons', - 'photons', - 'emitter', - 'camera_monitorings', - 'laser_calibrations', - 'pixel_monitorings', + allowed_keys = { + 'image_parameters', + 'adc_sums', + 'adc_samples' } - telescope_events = event['telescope_events'] - - assert telescope_events # never empty! + found_keys = set(telescope_event.keys()) + assert expected_keys.issubset(found_keys) - for telescope_event in telescope_events.values(): - expected_keys = { - 'header', - 'pixel_timing', - 'pixel_lists', - } - allowed_keys = { - 'image_parameters', - 'adc_sums', - 'adc_samples' - } + extra_keys = found_keys.difference(expected_keys) + assert extra_keys.issubset(allowed_keys) + assert 'adc_sums' in found_keys or 'adc_samples' in found_keys - found_keys = set(telescope_event.keys()) - assert expected_keys.issubset(found_keys) - - extra_keys = found_keys.difference(expected_keys) - assert extra_keys.issubset(allowed_keys) - assert 'adc_sums' in found_keys or 'adc_samples' in found_keys - - break + break def test_iterate_complete_file(): expected_counter_values = { + prod4_path: 30, prod2_path: 8, prod3_path: 5, - prod4_path: 30, } for path, expected in expected_counter_values.items(): try: for counter, event in enumerate(SimTelFile(path)): pass - except (EOFError, IndexError): # truncated files might raise these... + except EOFError: # truncated files might raise these... pass assert counter == expected @@ -132,21 +131,22 @@ def test_iterate_complete_file_zst(): def test_iterate_mc_events(): expected = 200 - with SimTelFile(prod4_path) as f: - for counter, event in enumerate(f.iter_mc_events(), start=1): + with SimTelFile(prod4_path, skip_non_triggered=False) as f: + for counter, event in enumerate(f, start=1): assert 'event_id' in event assert 'mc_shower' in event assert 'mc_event' in event assert counter == expected - with SimTelFile('tests/resources/lst_with_photons.simtel.zst') as f: - for counter, event in enumerate(f.iter_mc_events(), start=1): - assert event.keys() == { + path = 'tests/resources/lst_with_photons.simtel.zst' + with SimTelFile(path, skip_non_triggered=False) as f: + for counter, event in enumerate(f, start=1): + assert set(event.keys()).issuperset({ 'event_id', 'mc_shower', 'mc_event', - 'photons', 'photoelectrons', 'emitter' - } + 'photons', 'photoelectrons', + }) def test_allowed_tels(): @@ -214,9 +214,16 @@ def test_new_prod4(): def test_correct_event_ids_iter_mc_events(): with SimTelFile('tests/resources/lst_with_photons.simtel.zst') as f: + n_use = f.mc_run_headers[-1]["n_use"] + n_showers = f.mc_run_headers[-1]["n_showers"] + i = 0 for e in f: - assert f.current_mc_event_id == f.current_telescope_data_event_id - assert f.current_mc_shower_id == f.current_mc_event_id // 100 + i += 1 + expected = i // n_use * 100 + i % n_use + assert e["event_id"] == expected + assert f.current_mc_shower_id == e["event_id"] // 100 + + assert i == n_showers * n_use def test_photons(): @@ -230,17 +237,16 @@ def test_photons(): assert photons.dtype == Photons.long_dtype # no emitter info in file - print(e['emitter']) - assert len(e['emitter']) == 0 + assert 'emitter' not in e def test_missing_photons(): with SimTelFile('tests/resources/gamma_test.simtel.gz') as f: e = next(iter(f)) - assert e['photons'] == {} - assert e['photoelectrons'] == {} - assert e['emitter'] == {} + assert 'photons' not in e + assert 'photoelectrons' not in e + assert 'emitter' not in e def test_calibration_photoelectrons(): @@ -251,6 +257,7 @@ def test_calibration_photoelectrons(): true_image = e['photoelectrons'][0]['photoelectrons'] assert np.isclose(np.mean(true_image), expected_pe, rtol=0.05) + def test_history_meta(): with SimTelFile(history_meta_path) as f: assert isinstance(f.global_meta, dict) @@ -275,3 +282,13 @@ def test_type_2033(): 'flags', 'n_pixels', 'n_gains', 'nsb_rate', 'qe_rel', 'gain_rel', 'hv_rel', 'current', 'fadc_amp_hg', 'disabled', } + + +def test_extracted_pedestals(): + with SimTelFile("tests/resources/extracted_pedestals.simtel.zst") as f: + expected_event_id = 0 + for e in f: + expected_event_id += 1 + assert e["event_id"] == expected_event_id + + assert expected_event_id == 5 diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_iterator.py b/tests/test_iterator.py index 8758437..370547e 100644 --- a/tests/test_iterator.py +++ b/tests/test_iterator.py @@ -1,4 +1,5 @@ from eventio import EventIOFile +import pytest testfile = 'tests/resources/one_shower.dat' @@ -26,3 +27,27 @@ def test_peek(): assert o is next(f) # make sure peek gives us the next object assert o is not f.peek() # assure we get the next assert f.peek() is next(f) + + # make sure peek returns None at end of file + for o in f: + pass + + assert f.peek() is None + + # make sure peek returns None at end of file also for truncated files + truncated = 'tests/resources/gamma_test_large_truncated.simtel.gz' + + # make sure file was really truncated and we reached end of file + with pytest.raises(EOFError): + f = EventIOFile(truncated) + for o in f: + pass + + f = EventIOFile(truncated) + + while f.peek() is not None: + o = next(f) + + # test we can peak multiple times for a truncated file + assert f.peek() is None + assert f.peek() is None