diff --git a/OTAnalytics/application/analysis/traffic_counting.py b/OTAnalytics/application/analysis/traffic_counting.py index ecdba0436..4213eebb3 100644 --- a/OTAnalytics/application/analysis/traffic_counting.py +++ b/OTAnalytics/application/analysis/traffic_counting.py @@ -184,7 +184,11 @@ def create_mode_tag(tag: str) -> Tag: return SingleTag(level=LEVEL_CLASSIFICATION, id=tag) -def create_timeslot_tag(start_of_time_slot: datetime, interval: timedelta) -> Tag: +def create_timeslot_tag(start: datetime, interval: timedelta) -> Tag: + interval_seconds = interval.total_seconds() + original_time = int(start.timestamp()) + result = int(original_time / interval_seconds) * interval_seconds + start_of_time_slot = datetime.fromtimestamp(result, tz=timezone.utc) end_of_time_slot = start_of_time_slot + interval serialized_start = start_of_time_slot.strftime(r"%Y-%m-%d %H:%M:%S") serialized_end = end_of_time_slot.strftime(r"%Y-%m-%d %H:%M:%S") @@ -375,11 +379,7 @@ def __init__(self, interval: timedelta) -> None: self._interval = interval def create_tag(self, assignment: RoadUserAssignment) -> Tag: - original_time = int(assignment.events.start.occurrence.timestamp()) - interval_seconds = self._interval.total_seconds() - result = int(original_time / interval_seconds) * interval_seconds - start_of_time_slot = datetime.fromtimestamp(result, timezone.utc) - return create_timeslot_tag(start_of_time_slot, self._interval) + return create_timeslot_tag(assignment.events.start.occurrence, self._interval) class CountableAssignments: @@ -918,7 +918,10 @@ def export(self, specification: CountingSpecificationDto) -> None: """ if self._event_repository.is_empty(): self._create_events() - events = self._event_repository.get_all() + events = self._event_repository.get( + start_date=specification.start, + end_date=specification.end, + ) flows = self._flow_repository.get_all() assigned_flows = self._assigner.assign(events, flows) tagger = self._tagger_factory.create_tagger(specification) diff --git a/OTAnalytics/domain/event.py b/OTAnalytics/domain/event.py index 512cb9d7e..bb1d3d834 100644 --- a/OTAnalytics/domain/event.py +++ b/OTAnalytics/domain/event.py @@ -507,6 +507,8 @@ def get_previous_before( def get( self, + start_date: datetime | None = None, + end_date: datetime | None = None, sections: Sequence[SectionId] | None = None, event_types: Sequence[EventType] | None = None, ) -> Iterable[Event]: @@ -514,9 +516,13 @@ def get( event_types = [] if sections is None: sections = [] - filter_function = self.__create_filter(event_types) + type_filter = self.__create_type_filter(event_types) + start_filter = self.__create_start_filter(start_date) + end_filter = self.__create_end_filter(end_date) events = self.__create_event_list(sections) - return list(filter(filter_function, events)) + return list( + filter(start_filter, filter(end_filter, filter(type_filter, events))) + ) def __create_event_list(self, sections: Sequence[SectionId]) -> Iterable[Event]: if sections: @@ -525,7 +531,29 @@ def __create_event_list(self, sections: Sequence[SectionId]) -> Iterable[Event]: return self.get_all() @staticmethod - def __create_filter(event_types: Sequence[EventType]) -> Callable[[Event], bool]: + def __create_type_filter( + event_types: Sequence[EventType], + ) -> Callable[[Event], bool]: if event_types: return lambda actual: actual.event_type in event_types return lambda event: True + + @staticmethod + def __create_start_filter(start_date: datetime | None) -> Callable[[Event], bool]: + if start_date: + return after_filter(start_date) + return lambda event: True + + @staticmethod + def __create_end_filter(end_date: datetime | None) -> Callable[[Event], bool]: + if end_date: + return before_filter(end_date) + return lambda event: True + + +def after_filter(date: datetime) -> Callable[[Event], bool]: + return lambda actual: actual.occurrence >= date + + +def before_filter(date: datetime) -> Callable[[Event], bool]: + return lambda actual: actual.occurrence <= date diff --git a/tests/OTAnalytics/application/analysis/test_traffic_counting.py b/tests/OTAnalytics/application/analysis/test_traffic_counting.py index bd39af776..08f703a1c 100644 --- a/tests/OTAnalytics/application/analysis/test_traffic_counting.py +++ b/tests/OTAnalytics/application/analysis/test_traffic_counting.py @@ -36,6 +36,7 @@ TaggerFactory, TimeslotTagger, create_export_specification, + create_timeslot_tag, ) from OTAnalytics.application.analysis.traffic_counting_specification import ( CountingSpecificationDto, @@ -64,6 +65,39 @@ def track(track_builder: TrackBuilder) -> Track: return track_builder.build_track() +@pytest.mark.parametrize( + "start_time,expected_start_time,expected_end_time", + [ + ("00:00:00", "00:00:00", "00:15:00"), + ("00:03:00", "00:00:00", "00:15:00"), + ], +) +def test_create_timeslot_tag( + start_time: str, + expected_start_time: str, + expected_end_time: str, +) -> None: + start_date = f"2024-01-01 {start_time}" + expected_start_date = f"2024-01-01 {expected_start_time}" + expected_end_date = f"2024-01-01 {expected_end_time}" + current = datetime.strptime(start_date, "%Y-%m-%d %H:%M:%S").replace( + tzinfo=timezone.utc + ) + interval = timedelta(minutes=15) + tag = create_timeslot_tag(current, interval) + + expected_tag = MultiTag( + frozenset( + [ + SingleTag(level=LEVEL_START_TIME, id=expected_start_date), + SingleTag(level=LEVEL_END_TIME, id=expected_end_date), + ] + ) + ) + + assert tag == expected_tag + + class TestCountByFlow: def test_to_dict(self) -> None: value = 2 @@ -758,7 +792,7 @@ def test_count_traffic(self) -> None: use_case.export(counting_specification) - event_repository.get_all.assert_called_once() + event_repository.get.assert_called_once_with(start_date=start, end_date=end) flow_repository.get_all.assert_called_once() create_events.assert_called_once() road_user_assigner.assign.assert_called_once() diff --git a/tests/OTAnalytics/domain/test_event.py b/tests/OTAnalytics/domain/test_event.py index 8a3480dc8..89c1578e5 100644 --- a/tests/OTAnalytics/domain/test_event.py +++ b/tests/OTAnalytics/domain/test_event.py @@ -574,19 +574,89 @@ def test_get_previous_before( assert actual_event == expected_event @pytest.mark.parametrize( - "sections,event_type,expected_events", + "start_date,end_date,sections,event_type,expected_events", [ - ([], [], all_events()), - ([SECTION_ID_1], [], [event_1_section_1(), event_2_section_1()]), - ([SECTION_ID_2], [], [event_1_section_2(), event_2_section_2()]), - ([SECTION_ID_1, SECTION_ID_2], [], all_events()), - ([SECTION_ID_1, SECTION_ID_2], DEFAULT_EVENT_TYPES, all_events()), ( + None, + None, + [], + [], + all_events(), + ), + ( + event_2_section_2().occurrence, + None, + [], + [], + [event_2_section_2()], + ), + ( + event_2_section_1().occurrence, + None, + [], + [], + [event_2_section_1(), event_2_section_2()], + ), + ( + None, + event_1_section_1().occurrence, + [], + [], + [event_1_section_1()], + ), + ( + None, + event_1_section_2().occurrence, + [], + [], + [event_1_section_1(), event_1_section_2()], + ), + ( + None, + event_2_section_1().occurrence, + [], + [], + [event_1_section_1(), event_2_section_1(), event_1_section_2()], + ), + ( + event_1_section_2().occurrence, + event_2_section_1().occurrence, + [], + [], + [event_2_section_1(), event_1_section_2()], + ), + ( + None, + None, + [SECTION_ID_1], + [], + [event_1_section_1(), event_2_section_1()], + ), + ( + None, + None, + [SECTION_ID_2], + [], + [event_1_section_2(), event_2_section_2()], + ), + (None, None, [SECTION_ID_1, SECTION_ID_2], [], all_events()), + ( + None, + None, + [SECTION_ID_1, SECTION_ID_2], + DEFAULT_EVENT_TYPES, + all_events(), + ), + ( + None, + None, [SECTION_ID_1], [EventType.SECTION_ENTER], [event_1_section_1()], ), ( + None, + None, [SECTION_ID_1], [EventType.SECTION_LEAVE], [event_2_section_1()], @@ -595,6 +665,8 @@ def test_get_previous_before( ) def test_get( self, + start_date: datetime, + end_date: datetime, sections: list[SectionId], event_type: list[EventType], expected_events: list[Event], @@ -602,6 +674,11 @@ def test_get( repository = EventRepository() repository.add_all(all_events()) - actual_events = repository.get(sections=sections, event_types=event_type) + actual_events = repository.get( + start_date=start_date, + end_date=end_date, + sections=sections, + event_types=event_type, + ) assert actual_events == expected_events