From 47be0c5740a79bea3f32e9442792cf2f0f4023b1 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Wed, 10 Apr 2024 09:51:56 +0200 Subject: [PATCH 01/15] Allow method chaining in TrackBuilder --- tests/utils/builders/track_builder.py | 31 ++++++++++++++++++--------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/tests/utils/builders/track_builder.py b/tests/utils/builders/track_builder.py index 746cb4123..872a83b29 100644 --- a/tests/utils/builders/track_builder.py +++ b/tests/utils/builders/track_builder.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from datetime import datetime, timezone from itertools import repeat +from typing import Self from OTAnalytics.domain.track import Detection, Track, TrackId from OTAnalytics.plugin_datastore.python_track_store import PythonDetection, PythonTrack @@ -82,20 +83,25 @@ def create_detection(self) -> Detection: _video_name=self.video_name, ) - def add_track_id(self, id: str) -> None: + def add_track_id(self, id: str) -> Self: self.track_id = id + return self - def add_detection_class(self, classification: str) -> None: + def add_detection_class(self, classification: str) -> Self: self.detection_class = classification + return self - def add_confidence(self, confidence: float) -> None: + def add_confidence(self, confidence: float) -> Self: self.confidence = confidence + return self - def add_frame(self, frame: int) -> None: + def add_frame(self, frame: int) -> Self: self.frame = frame + return self - def add_track_class(self, classification: str) -> None: + def add_track_class(self, classification: str) -> Self: self.track_class = classification + return self def add_occurrence( self, @@ -106,7 +112,7 @@ def add_occurrence( minute: int, second: int, microsecond: int, - ) -> None: + ) -> Self: self.occurrence_year = year self.occurrence_month = month self.occurrence_day = day @@ -114,20 +120,25 @@ def add_occurrence( self.occurrence_minute = minute self.occurrence_second = second self.occurrence_microsecond = microsecond + return self - def add_second(self, second: int) -> None: + def add_second(self, second: int) -> Self: self.occurrence_second = second + return self - def add_microsecond(self, microsecond: int) -> None: + def add_microsecond(self, microsecond: int) -> Self: self.occurrence_microsecond = microsecond + return self - def add_xy_bbox(self, x: float, y: float) -> None: + def add_xy_bbox(self, x: float, y: float) -> Self: self.x = x self.y = y + return self - def add_wh_bbox(self, w: float, h: float) -> None: + def add_wh_bbox(self, w: float, h: float) -> Self: self.w = w self.h = h + return self def get_metadata(self) -> dict: return { From fd915ffc204c1f225a056779c76fb640b1b09587 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Wed, 10 Apr 2024 13:36:20 +0200 Subject: [PATCH 02/15] Add method to TrackDataset to query for max confidence of tracks --- OTAnalytics/domain/track_dataset.py | 22 +++++++++++++ .../plugin_datastore/python_track_store.py | 14 ++++++++ OTAnalytics/plugin_datastore/track_store.py | 14 ++++++++ .../track_dataset/test_track_dataset.py | 33 ++++++++++++++++++- .../test_python_track_storage.py | 17 ++++++++++ .../plugin_datastore/test_track_store.py | 18 ++++++++++ tests/conftest.py | 6 ++-- tests/utils/builders/track_builder.py | 30 ++++++++--------- 8 files changed, 136 insertions(+), 18 deletions(-) diff --git a/OTAnalytics/domain/track_dataset.py b/OTAnalytics/domain/track_dataset.py index 3c2a3825f..6f4e6f998 100644 --- a/OTAnalytics/domain/track_dataset.py +++ b/OTAnalytics/domain/track_dataset.py @@ -19,6 +19,10 @@ END_VIDEO_NAME: str = "end_video_name" +class TrackDoesNotExistError(Exception): + pass + + @dataclass(frozen=True, order=True) class IntersectionPoint: index: int @@ -201,6 +205,21 @@ def cut_with_section( """ raise NotImplementedError + @abstractmethod + def get_max_confidences_for(self, track_ids: list[str]) -> dict[str, float]: + """Get max confidences for given track ids. + + Args: + track_ids: the track ids to get the max confidences for. + + Returns: + dict[TrackId, float]: the max confidence values for the track ids. + + Raises: + TrackDoesNotExistError: if given track id does not exist within dataset. + """ + raise NotImplementedError + class FilteredTrackDataset(TrackDataset): @property @@ -273,6 +292,9 @@ def get_first_segments(self) -> TrackSegmentDataset: def get_last_segments(self) -> TrackSegmentDataset: return self._filter().get_last_segments() + def get_max_confidences_for(self, track_ids: list[str]) -> dict[str, float]: + return self._filter().get_max_confidences_for(track_ids) + class TrackGeometryDataset(ABC): """Dataset containing track geometries. diff --git a/OTAnalytics/plugin_datastore/python_track_store.py b/OTAnalytics/plugin_datastore/python_track_store.py index 088eb2e4f..d8e476b42 100644 --- a/OTAnalytics/plugin_datastore/python_track_store.py +++ b/OTAnalytics/plugin_datastore/python_track_store.py @@ -36,6 +36,7 @@ FilteredTrackDataset, IntersectionPoint, TrackDataset, + TrackDoesNotExistError, TrackGeometryDataset, TrackSegmentDataset, ) @@ -634,6 +635,19 @@ def _build_track( track_builder.add_detection(detection) return track_builder.build() + def get_max_confidences_for(self, track_ids: list[str]) -> dict[str, float]: + result: dict[str, float] = {} + for track_id in track_ids: + _track = self.get_for(TrackId(track_id)) + if not _track: + raise TrackDoesNotExistError(f"Track {track_id} not found.") + + max_confidence = max( + [detection.confidence for detection in _track.detections] + ) + result[track_id] = max_confidence + return result + class FilteredPythonTrackDataset(FilteredTrackDataset): @property diff --git a/OTAnalytics/plugin_datastore/track_store.py b/OTAnalytics/plugin_datastore/track_store.py index cdd61ed5d..d4773d718 100644 --- a/OTAnalytics/plugin_datastore/track_store.py +++ b/OTAnalytics/plugin_datastore/track_store.py @@ -30,6 +30,7 @@ FilteredTrackDataset, IntersectionPoint, TrackDataset, + TrackDoesNotExistError, TrackGeometryDataset, TrackSegmentDataset, ) @@ -587,6 +588,19 @@ def _create_cut_track_id( return f"{track_id}_{cut_segment_index}" return row[track.TRACK_ID] + def get_max_confidences_for(self, track_ids: list[str]) -> dict[str, float]: + try: + return ( + self._dataset.loc[track_ids][track.CONFIDENCE] + .groupby(level=[LEVEL_TRACK_ID]) + .max() + .to_dict() + ) + except KeyError as cause: + raise TrackDoesNotExistError( + "Some tracks do not exists in dataset with given id" + ) from cause + class FilteredPandasTrackDataset(FilteredTrackDataset, PandasDataFrameProvider): @property diff --git a/tests/OTAnalytics/domain/track_dataset/test_track_dataset.py b/tests/OTAnalytics/domain/track_dataset/test_track_dataset.py index a268d4c5f..a83d1af9d 100644 --- a/tests/OTAnalytics/domain/track_dataset/test_track_dataset.py +++ b/tests/OTAnalytics/domain/track_dataset/test_track_dataset.py @@ -5,7 +5,10 @@ from _pytest.fixtures import FixtureRequest from OTAnalytics.domain.track import Track, TrackId -from OTAnalytics.domain.track_dataset import FilteredTrackDataset +from OTAnalytics.domain.track_dataset import ( + FilteredTrackDataset, + TrackDoesNotExistError, +) from OTAnalytics.plugin_prototypes.track_visualization.track_viz import ( CLASS_BICYCLIST, CLASS_CAR, @@ -412,3 +415,31 @@ def test_cut_with_section(self) -> None: assert result_dataset.include_classes == frozenset() assert result_dataset.exclude_classes == frozenset() mock_other.cut_with_section.assert_called_once_with(section, offset) + + @pytest.mark.parametrize( + "include_classes,exclude_classes,expected", + [ + ([], [], {"1": 0.8, "2": 0.9}), + (["car"], [], {"1": 0.8}), + ([], ["car"], {"2": 0.9}), + ], + ) + def test_get_max_confidences_for( + self, + include_classes: list[str], + exclude_classes: list[str], + expected: dict[str, float], + car_track: Track, + pedestrian_track: Track, + ) -> None: + empty_datasets = self.get_datasets([], include_classes, exclude_classes) + + for empty_dataset in empty_datasets.values(): + with pytest.raises(TrackDoesNotExistError): + empty_dataset.get_max_confidences_for([car_track.id.id]) + + filled_dataset = empty_dataset.add_all([car_track, pedestrian_track]) + all_track_ids = [track_id.id for track_id in filled_dataset.track_ids] + + result = filled_dataset.get_max_confidences_for(all_track_ids) + assert result == expected diff --git a/tests/OTAnalytics/plugin_datastore/test_python_track_storage.py b/tests/OTAnalytics/plugin_datastore/test_python_track_storage.py index c8bd6c881..a23193b04 100644 --- a/tests/OTAnalytics/plugin_datastore/test_python_track_storage.py +++ b/tests/OTAnalytics/plugin_datastore/test_python_track_storage.py @@ -28,6 +28,7 @@ START_VIDEO_NAME, START_X, START_Y, + TrackDoesNotExistError, TrackGeometryDataset, ) from OTAnalytics.plugin_datastore.python_track_store import ( @@ -619,6 +620,22 @@ def test_empty(self, car_track: Track) -> None: filled_dataset = empty_dataset.add_all([car_track]) assert not filled_dataset.empty + def test_get_max_confidences_for( + self, + car_track: Track, + pedestrian_track: Track, + ) -> None: + empty_dataset = PythonTrackDataset() + with pytest.raises(TrackDoesNotExistError): + empty_dataset.get_max_confidences_for([car_track.id.id]) + filled_dataset = empty_dataset.add_all([car_track, pedestrian_track]) + + car_id = car_track.id.id + pedestrian_id = pedestrian_track.id.id + + result = filled_dataset.get_max_confidences_for([car_id, pedestrian_id]) + assert result == {car_id: 0.8, pedestrian_id: 0.9} + class TestSimpleCutTrackSegmentBuilder: def test_build(self) -> None: diff --git a/tests/OTAnalytics/plugin_datastore/test_track_store.py b/tests/OTAnalytics/plugin_datastore/test_track_store.py index 68c87f35d..b795f1a0b 100644 --- a/tests/OTAnalytics/plugin_datastore/test_track_store.py +++ b/tests/OTAnalytics/plugin_datastore/test_track_store.py @@ -12,6 +12,7 @@ from OTAnalytics.domain.track_dataset import ( TRACK_GEOMETRY_FACTORY, TrackDataset, + TrackDoesNotExistError, TrackGeometryDataset, ) from OTAnalytics.plugin_datastore.python_track_store import ( @@ -600,3 +601,20 @@ def test_empty( assert empty_dataset.empty filled_dataset = empty_dataset.add_all([car_track]) assert not filled_dataset.empty + + def test_get_max_confidences_for( + self, + track_geometry_factory: TRACK_GEOMETRY_FACTORY, + car_track: Track, + pedestrian_track: Track, + ) -> None: + empty_dataset = PandasTrackDataset(track_geometry_factory) + with pytest.raises(TrackDoesNotExistError): + empty_dataset.get_max_confidences_for([car_track.id.id]) + filled_dataset = empty_dataset.add_all([car_track, pedestrian_track]) + + car_id = car_track.id.id + pedestrian_id = pedestrian_track.id.id + + result = filled_dataset.get_max_confidences_for([car_id, pedestrian_id]) + assert result == {car_id: 0.8, pedestrian_id: 0.9} diff --git a/tests/conftest.py b/tests/conftest.py index a4addc456..bd49d1cc1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -215,7 +215,7 @@ def closed_track() -> Track: @pytest.fixture def car_track() -> Track: - return create_track("1", [(1, 1), (2, 2)], 1, CLASS_CAR) + return create_track("1", [(1, 1), (2, 2)], 1, CLASS_CAR, confidences=[0.6, 0.8]) @pytest.fixture @@ -225,7 +225,9 @@ def car_track_continuing() -> Track: @pytest.fixture def pedestrian_track() -> Track: - return create_track("2", [(1, 1), (2, 2), (3, 3)], 1, CLASS_PEDESTRIAN) + return create_track( + "2", [(1, 1), (2, 2), (3, 3)], 1, CLASS_PEDESTRIAN, confidences=[0.9, 0.8, 0.7] + ) @pytest.fixture diff --git a/tests/utils/builders/track_builder.py b/tests/utils/builders/track_builder.py index 872a83b29..0ab1d9e3d 100644 --- a/tests/utils/builders/track_builder.py +++ b/tests/utils/builders/track_builder.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from datetime import datetime, timezone -from itertools import repeat from typing import Self from OTAnalytics.domain.track import Detection, Track, TrackId @@ -276,30 +275,31 @@ def create_track( start_second: int, track_class: str = "car", detection_classes: list[str] | None = None, + confidences: list[float] | None = None, ) -> Track: if detection_classes: if len(detection_classes) != len(coord): raise ValueError( "Track coordinates must match length of detection classifications." ) + if confidences: + if len(confidences) != len(coord): + raise ValueError("Track coordinates must match length of confidences.") track_builder = TrackBuilder() track_builder.add_track_id(track_id) track_builder.add_track_class(track_class) - if detection_classes: - detections = [ - (x, y, detection_class) - for (x, y), detection_class in zip(coord, detection_classes) - ] - else: - detections = [ - (x, y, detection_class) - for (x, y), detection_class in zip(coord, repeat(track_class)) - ] - - for second, (x, y, detection_class) in enumerate(detections, start=start_second): - track_builder.add_second(second) + current_second = start_second + for current_index, (x, y) in enumerate(coord, start=0): + track_builder.add_second(current_second) track_builder.add_xy_bbox(x, y) - track_builder.add_detection_class(detection_class) + if detection_classes: + track_builder.add_detection_class(detection_classes[current_index]) + else: + track_builder.add_detection_class(track_class) + if confidences: + track_builder.add_confidence(confidences[current_index]) track_builder.append_detection() + current_second += 1 + return track_builder.build_track() From e2190ff43e3e81220cf637c73f1ab3ad9114d79e Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Wed, 10 Apr 2024 13:46:23 +0200 Subject: [PATCH 03/15] Fix typo --- .../eventlist_exporter/eventlist_exporter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py b/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py index a785bee19..2ed140117 100644 --- a/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py +++ b/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py @@ -38,16 +38,16 @@ def _convert_to_dataframe(self, events: Iterable[Event]) -> pd.DataFrame: return pd.DataFrame([event.to_dict() for event in events]) def build(self) -> pd.DataFrame: - self._convert_occurence_to_seconds_since_epoch() + self._convert_occurrence_to_seconds_since_epoch() self._split_columns_with_lists() self._add_section_names() return self._df - def _convert_occurence_to_seconds_since_epoch(self) -> None: + def _convert_occurrence_to_seconds_since_epoch(self) -> None: # TODO: Use OTAnalytics´ builtin timestamp methods epoch = pd.Timestamp("1970-01-01") - occurence = pd.to_datetime(self._df[OCCURRENCE]) - self._df[f"{OCCURRENCE}_sec"] = (occurence - epoch).dt.total_seconds() + occurrence = pd.to_datetime(self._df[OCCURRENCE]) + self._df[f"{OCCURRENCE}_sec"] = (occurrence - epoch).dt.total_seconds() def _split_columns_with_lists(self) -> None: self._df[["coordinate_px_x", "coordinate_px_y"]] = pd.DataFrame( From c4621c382cea20448f5d8c659a7447785edd336b Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Wed, 10 Apr 2024 13:57:27 +0200 Subject: [PATCH 04/15] Extract event list export format to own module --- .../application/export_formats/__init__.py | 0 .../application/export_formats/event_list.py | 23 +++++++++++++ .../eventlist_exporter/eventlist_exporter.py | 33 ++++++++++--------- 3 files changed, 41 insertions(+), 15 deletions(-) create mode 100644 OTAnalytics/application/export_formats/__init__.py create mode 100644 OTAnalytics/application/export_formats/event_list.py diff --git a/OTAnalytics/application/export_formats/__init__.py b/OTAnalytics/application/export_formats/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/OTAnalytics/application/export_formats/event_list.py b/OTAnalytics/application/export_formats/event_list.py new file mode 100644 index 000000000..e03d74603 --- /dev/null +++ b/OTAnalytics/application/export_formats/event_list.py @@ -0,0 +1,23 @@ +from OTAnalytics.domain import event + +ROAD_USER_ID = event.ROAD_USER_ID +ROAD_USER_TYPE = event.ROAD_USER_TYPE +HOSTNAME = event.HOSTNAME +OCCURRENCE = event.OCCURRENCE +OCCURRENCE_DATE = "occurrence_day" +OCCURRENCE_TIME = "occurrence_time" +FRAME_NUMBER = event.FRAME_NUMBER +SECTION_ID = event.SECTION_ID +SECTION_NAME = "section_name" +EVENT_COORDINATE = event.EVENT_COORDINATE +EVENT_COORDINATE_X = "coordinate_px_x" +EVENT_COORDINATE_Y = "coordinate_px_y" +EVENT_TYPE = event.EVENT_TYPE +DIRECTION_VECTOR = event.DIRECTION_VECTOR +DIRECTION_VECTOR_X = "vector_px_x" +DIRECTION_VECTOR_Y = "vector_px_y" +VIDEO_NAME = event.VIDEO_NAME + +DATE_FORMAT = "%Y-%m-%d" +TIME_FORMAT = "%H:%M:%S.%f" +DATE_TIME_FORMAT = f"{DATE_FORMAT} {TIME_FORMAT}" diff --git a/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py b/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py index 2ed140117..73ffa6a1f 100644 --- a/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py +++ b/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py @@ -5,18 +5,13 @@ from OTAnalytics.application.config import DEFAULT_EVENTLIST_FILE_TYPE from OTAnalytics.application.datastore import EventListParser +from OTAnalytics.application.export_formats import event_list from OTAnalytics.application.logger import logger from OTAnalytics.application.use_cases.export_events import ( EventListExporter, ExporterNotFoundError, ) -from OTAnalytics.domain.event import ( - DIRECTION_VECTOR, - EVENT_COORDINATE, - OCCURRENCE, - SECTION_ID, - Event, -) +from OTAnalytics.domain.event import Event from OTAnalytics.domain.section import Section from OTAnalytics.plugin_parser.otvision_parser import OtEventListParser @@ -46,24 +41,32 @@ def build(self) -> pd.DataFrame: def _convert_occurrence_to_seconds_since_epoch(self) -> None: # TODO: Use OTAnalytics´ builtin timestamp methods epoch = pd.Timestamp("1970-01-01") - occurrence = pd.to_datetime(self._df[OCCURRENCE]) - self._df[f"{OCCURRENCE}_sec"] = (occurrence - epoch).dt.total_seconds() + occurrence = pd.to_datetime(self._df[event_list.OCCURRENCE]) + self._df[f"{event_list.OCCURRENCE}_sec"] = ( + occurrence - epoch + ).dt.total_seconds() def _split_columns_with_lists(self) -> None: - self._df[["coordinate_px_x", "coordinate_px_y"]] = pd.DataFrame( - self._df[EVENT_COORDINATE].tolist(), index=self._df.index + self._df[[event_list.EVENT_COORDINATE_X, event_list.EVENT_COORDINATE_Y]] = ( + pd.DataFrame( + self._df[event_list.EVENT_COORDINATE].tolist(), index=self._df.index + ) + ) + self._df[[event_list.DIRECTION_VECTOR_X, event_list.DIRECTION_VECTOR_Y]] = ( + pd.DataFrame( + self._df[event_list.DIRECTION_VECTOR].tolist(), index=self._df.index + ) ) - self._df[["vector_px_x", "vector_px_y"]] = pd.DataFrame( - self._df[DIRECTION_VECTOR].tolist(), index=self._df.index + self._df = self._df.drop( + columns=[event_list.EVENT_COORDINATE, event_list.DIRECTION_VECTOR] ) - self._df = self._df.drop(columns=[EVENT_COORDINATE, DIRECTION_VECTOR]) def _add_section_names(self) -> None: sections_list_of_dicts = [section.to_dict() for section in self._sections] sections_dict = { section["id"]: section["name"] for section in sections_list_of_dicts } - self._df["section_name"] = self._df[SECTION_ID].map( + self._df["section_name"] = self._df[event_list.SECTION_ID].map( lambda x: sections_dict.get(x) ) From bd29e1678e6a8e5caefbb0656937d07e8782dc60 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Wed, 10 Apr 2024 22:54:46 +0200 Subject: [PATCH 05/15] Export road user assignments to csv --- .../application/analysis/traffic_counting.py | 15 +- .../export_formats/road_user_assignments.py | 46 ++++ .../use_cases/road_user_assignment_export.py | 213 ++++++++++++++++++ .../road_user_assignment_export.py | 66 ++++++ .../test_road_user_assignment_export.py | 117 ++++++++++ .../test_road_user_assignment_export.py | 79 +++++++ tests/conftest.py | 64 ++++++ tests/utils/builders/road_user_assignment.py | 43 ++++ 8 files changed, 640 insertions(+), 3 deletions(-) create mode 100644 OTAnalytics/application/export_formats/road_user_assignments.py create mode 100644 OTAnalytics/application/use_cases/road_user_assignment_export.py create mode 100644 OTAnalytics/plugin_parser/road_user_assignment_export.py create mode 100644 tests/OTAnalytics/application/use_cases/test_road_user_assignment_export.py create mode 100644 tests/OTAnalytics/plugin_parser/test_road_user_assignment_export.py create mode 100644 tests/utils/builders/road_user_assignment.py diff --git a/OTAnalytics/application/analysis/traffic_counting.py b/OTAnalytics/application/analysis/traffic_counting.py index c1cebe4c6..ecdba0436 100644 --- a/OTAnalytics/application/analysis/traffic_counting.py +++ b/OTAnalytics/application/analysis/traffic_counting.py @@ -489,6 +489,15 @@ class RoadUserAssignments: Represents a group of RoadUserAssignment objects. """ + @property + def road_user_ids(self) -> list[str]: + """Returns a sorted list of all road user ids within this group of assignments. + + Returns: + list[str]: the road user ids. + """ + return sorted([assignment.road_user for assignment in self._assignments]) + def __init__(self, assignments: list[RoadUserAssignment]) -> None: self._assignments = assignments.copy() @@ -610,9 +619,9 @@ def __group_flows_by_sections( dict[tuple[SectionId, SectionId], list[Flow]]: flows grouped by start and end section """ - flows_by_start_and_end: dict[ - tuple[SectionId, SectionId], list[Flow] - ] = defaultdict(list) + flows_by_start_and_end: dict[tuple[SectionId, SectionId], list[Flow]] = ( + defaultdict(list) + ) for current in flows: flows_by_start_and_end[(current.start, current.end)].append(current) return flows_by_start_and_end diff --git a/OTAnalytics/application/export_formats/road_user_assignments.py b/OTAnalytics/application/export_formats/road_user_assignments.py new file mode 100644 index 000000000..6460b9b6c --- /dev/null +++ b/OTAnalytics/application/export_formats/road_user_assignments.py @@ -0,0 +1,46 @@ +from OTAnalytics.application.export_formats import event_list + +START_PREFIX = "start" +END_PREFIX = "end" + + +def _prepend_start(key: str) -> str: + return f"{START_PREFIX}_{key}" + + +def _prepend_end(key: str) -> str: + return f"{END_PREFIX}_{key}" + + +FLOW_ID = "flow_id" +FLOW_NAME = "flow_name" +ROAD_USER_ID = event_list.ROAD_USER_ID +ROAD_USER_TYPE = event_list.ROAD_USER_TYPE +MAX_CONFIDENCE = "max_confidence" +START_OCCURRENCE = _prepend_start(event_list.OCCURRENCE) +END_OCCURRENCE = _prepend_end(event_list.OCCURRENCE) +START_OCCURRENCE_DATE = _prepend_start(event_list.OCCURRENCE_DATE) +END_OCCURRENCE_DATE = _prepend_end(event_list.OCCURRENCE_DATE) +START_OCCURRENCE_TIME = _prepend_start(event_list.OCCURRENCE_TIME) +END_OCCURRENCE_TIME = _prepend_end(event_list.OCCURRENCE_TIME) +START_FRAME = _prepend_start(event_list.FRAME_NUMBER) +END_FRAME = _prepend_end(event_list.FRAME_NUMBER) +START_VIDEO_NAME = _prepend_start(event_list.VIDEO_NAME) +END_VIDEO_NAME = _prepend_end(event_list.VIDEO_NAME) +START_SECTION_ID = _prepend_start(event_list.SECTION_ID) +END_SECTION_ID = _prepend_end(event_list.SECTION_ID) +START_SECTION_NAME = _prepend_start(event_list.SECTION_NAME) +END_SECTION_NAME = _prepend_end(event_list.SECTION_NAME) +START_EVENT_COORDINATE_X = _prepend_start(event_list.EVENT_COORDINATE_X) +START_EVENT_COORDINATE_Y = _prepend_start(event_list.EVENT_COORDINATE_Y) +END_EVENT_COORDINATE_X = _prepend_end(event_list.EVENT_COORDINATE_X) +END_EVENT_COORDINATE_Y = _prepend_end(event_list.EVENT_COORDINATE_Y) +START_DIRECTION_VECTOR_X = _prepend_start(event_list.DIRECTION_VECTOR_X) +START_DIRECTION_VECTOR_Y = _prepend_start(event_list.DIRECTION_VECTOR_Y) +END_DIRECTION_VECTOR_X = _prepend_end(event_list.DIRECTION_VECTOR_X) +END_DIRECTION_VECTOR_Y = _prepend_end(event_list.DIRECTION_VECTOR_Y) +HOSTNAME = event_list.HOSTNAME + +DATE_FORMAT = event_list.DATE_FORMAT +TIME_FORMAT = event_list.TIME_FORMAT +DATE_TIME_FORMAT = event_list.DATE_TIME_FORMAT diff --git a/OTAnalytics/application/use_cases/road_user_assignment_export.py b/OTAnalytics/application/use_cases/road_user_assignment_export.py new file mode 100644 index 000000000..808cb8a56 --- /dev/null +++ b/OTAnalytics/application/use_cases/road_user_assignment_export.py @@ -0,0 +1,213 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Callable, Iterable, Protocol, Self + +from OTAnalytics.application.analysis.traffic_counting import ( + RoadUserAssigner, + RoadUserAssignment, + RoadUserAssignments, +) +from OTAnalytics.application.analysis.traffic_counting_specification import ExportFormat +from OTAnalytics.application.export_formats import road_user_assignments as ras +from OTAnalytics.application.use_cases.create_events import CreateEvents +from OTAnalytics.application.use_cases.track_repository import GetAllTracks +from OTAnalytics.domain.event import EventRepository +from OTAnalytics.domain.flow import FlowRepository +from OTAnalytics.domain.section import Section, SectionId, SectionRepository + +MaxConfidenceLookupTable = dict[str, float] +MaxConfidenceProvider = Callable[[list[str]], MaxConfidenceLookupTable] + + +class RoadUserAssignmentBuildError(Exception): + pass + + +class RoadUserAssignmentBuilder: + def __init__(self) -> None: + self._start_section: Section | None = None + self._end_section: Section | None = None + self._max_confidence: float | None = None + + def add_start_section(self, start: Section) -> Self: + self._start_section = start + return self + + def add_end_section(self, end: Section) -> Self: + self._end_section = end + return self + + def add_max_confidence(self, max_confidence: float) -> Self: + self._max_confidence = max_confidence + return self + + def build(self, assignment: RoadUserAssignment) -> dict: + result = self.__create(assignment) + self.reset() + return result + + def __create(self, assignment: RoadUserAssignment) -> dict: + if self._start_section is None: + raise RoadUserAssignmentBuildError("Start section not set") + if self._end_section is None: + raise RoadUserAssignmentBuildError("End section not set") + if self._max_confidence is None: + raise RoadUserAssignmentBuildError("Max confidence not set") + assigned_flow = assignment.assignment + start = assignment.events.start + end = assignment.events.end + return { + ras.FLOW_ID: assigned_flow.id.id, + ras.FLOW_NAME: assigned_flow.name, + ras.ROAD_USER_ID: assignment.road_user, + ras.MAX_CONFIDENCE: self._max_confidence, + ras.START_OCCURRENCE: start.occurrence.strftime(ras.DATE_TIME_FORMAT), + ras.START_OCCURRENCE_DATE: start.occurrence.strftime(ras.DATE_FORMAT), + ras.START_OCCURRENCE_TIME: start.occurrence.strftime(ras.TIME_FORMAT), + ras.END_OCCURRENCE: end.occurrence.strftime(ras.DATE_TIME_FORMAT), + ras.END_OCCURRENCE_DATE: end.occurrence.strftime(ras.DATE_FORMAT), + ras.END_OCCURRENCE_TIME: end.occurrence.strftime(ras.TIME_FORMAT), + ras.START_FRAME: start.frame_number, + ras.END_FRAME: end.frame_number, + ras.START_VIDEO_NAME: start.video_name, + ras.END_VIDEO_NAME: end.video_name, + ras.START_SECTION_ID: self._start_section.id.id, + ras.END_SECTION_ID: self._end_section.id.id, + ras.START_SECTION_NAME: self._start_section.name, + ras.END_SECTION_NAME: self._end_section.name, + ras.START_EVENT_COORDINATE_X: start.event_coordinate.x, + ras.START_EVENT_COORDINATE_Y: start.event_coordinate.y, + ras.END_EVENT_COORDINATE_X: end.event_coordinate.x, + ras.END_EVENT_COORDINATE_Y: end.event_coordinate.y, + ras.START_DIRECTION_VECTOR_X: start.direction_vector.x1, + ras.START_DIRECTION_VECTOR_Y: start.direction_vector.x2, + ras.END_DIRECTION_VECTOR_X: end.direction_vector.x1, + ras.END_DIRECTION_VECTOR_Y: end.direction_vector.x2, + ras.HOSTNAME: start.hostname, + } + + def reset(self) -> None: + self._start_section = None + self._end_section = None + self._max_confidence = None + + +class RoadUserAssignmentExportError(Exception): + pass + + +class RoadUserAssignmentExporter(ABC): + @property + @abstractmethod + def format(self) -> ExportFormat: + raise NotImplementedError + + def __init__( + self, + section_repository: SectionRepository, + get_all_tracks: GetAllTracks, + builder: RoadUserAssignmentBuilder, + output_file: Path, + ) -> None: + self._section_repository = section_repository + self._get_all_tracks = get_all_tracks + self._builder = builder + self._outputfile = output_file + + def export(self, assignments: RoadUserAssignments) -> None: + dtos = self._convert(assignments) + self._serialize(dtos) + + @abstractmethod + def _serialize(self, dtos: list[dict]) -> None: + """Hook for implementations to serialize in their respective save format. + + Args: + dtos (list[dict]): the vehicle flow assignments as dtos. + """ + raise NotImplementedError + + def _convert(self, assignments: RoadUserAssignments) -> list[dict]: + vehicle_flow_assignments = [] + look_up_table = self._get_max_conf_lookup_table_for(assignments) + for assignment in assignments.as_list(): + start_section = self._get_section_by_id(assignment.assignment.start) + end_section = self._get_section_by_id(assignment.assignment.end) + max_confidence = look_up_table[assignment.road_user] + vehicle_flow_assignments.append( + self._builder.add_start_section(start_section) + .add_end_section(end_section) + .add_max_confidence(max_confidence) + .build(assignment) + ) + return vehicle_flow_assignments + + def _get_max_conf_lookup_table_for( + self, assignments: RoadUserAssignments + ) -> MaxConfidenceLookupTable: + return self._get_all_tracks.as_dataset().get_max_confidences_for( + assignments.road_user_ids + ) + + def _get_section_by_id(self, section_id: SectionId) -> Section: + result = self._section_repository.get(section_id) + if not result: + raise RoadUserAssignmentExportError( + f"No section found with id '{section_id.id}'" + ) + return result + + +class ExportSpecification(Protocol): + save_path: Path + format: str + + +class RoadUserAssignmentExporterFactory(Protocol): + def get_supported_formats(self) -> Iterable[ExportFormat]: + """ + Returns an iterable of the supported export formats. + + Returns: + Iterable[ExportFormat]: supported export formats. + """ + ... + + def create(self, specification: ExportSpecification) -> RoadUserAssignmentExporter: + """ + Create the exporter for the given road user assignment export specification. + + Args: + specification (ExportSpecification): specification of the Exporter. + + Returns: + RoadUserAssignmentExporter: Exporter to export road user assignments. + """ + ... + + +class ExportRoadUserAssignments: + """Use case to export_formats vehicle flow assignments.""" + + def __init__( + self, + event_repository: EventRepository, + flow_repository: FlowRepository, + create_events: CreateEvents, + assigner: RoadUserAssigner, + exporter_factory: RoadUserAssignmentExporterFactory, + ) -> None: + self._event_repository = event_repository + self._flow_repository = flow_repository + self._create_events = create_events + self._assigner = assigner + self._exporter_factory = exporter_factory + + def export(self, specification: ExportSpecification) -> None: + if self._event_repository.is_empty(): + self._create_events() + events = self._event_repository.get_all() + flows = self._flow_repository.get_all() + road_user_assignments = self._assigner.assign(events, flows) + exporter = self._exporter_factory.create(specification) + exporter.export(road_user_assignments) diff --git a/OTAnalytics/plugin_parser/road_user_assignment_export.py b/OTAnalytics/plugin_parser/road_user_assignment_export.py new file mode 100644 index 000000000..9af8fe3c2 --- /dev/null +++ b/OTAnalytics/plugin_parser/road_user_assignment_export.py @@ -0,0 +1,66 @@ +from typing import Iterable + +from pandas import DataFrame + +from OTAnalytics.application.analysis.traffic_counting_specification import ExportFormat +from OTAnalytics.application.use_cases.road_user_assignment_export import ( + ExportSpecification, + RoadUserAssignmentBuilder, + RoadUserAssignmentExporter, +) +from OTAnalytics.application.use_cases.track_repository import GetAllTracks +from OTAnalytics.domain.section import SectionRepository + + +class RoadUserAssignmentCsvExporter(RoadUserAssignmentExporter): + + @property + def format(self) -> ExportFormat: + return ExportFormat("csv", ".csv") + + def _serialize(self, dtos: list[dict]) -> None: + DataFrame(dtos).to_csv(self._outputfile, index=False) + + +class SimpleRoadUserAssignmentExporterFactory: + def __init__( + self, + section_repository: SectionRepository, + get_all_tracks: GetAllTracks, + ) -> None: + self._section_repository = section_repository + self._get_all_tracks = get_all_tracks + self._formats = { + ExportFormat( + "CSV", ".csv" + ): lambda builder, output_file: RoadUserAssignmentCsvExporter( + section_repository, get_all_tracks, builder, output_file + ) + } + self._factories = { + export_format.name: factory + for export_format, factory in self._formats.items() + } + + def get_supported_formats(self) -> Iterable[ExportFormat]: + """ + Returns an iterable of the supported export formats. + + Returns: + Iterable[ExportFormat]: supported export formats. + """ + return self._formats.keys() + + def create(self, specification: ExportSpecification) -> RoadUserAssignmentExporter: + """ + Create the exporter for the given road user assignment export specification. + + Args: + specification (ExportSpecification): specification of the Exporter. + + Returns: + RoadUserAssignmentExporter: Exporter to export road user assignments. + """ + return self._factories[specification.format]( + RoadUserAssignmentBuilder(), specification.save_path + ) diff --git a/tests/OTAnalytics/application/use_cases/test_road_user_assignment_export.py b/tests/OTAnalytics/application/use_cases/test_road_user_assignment_export.py new file mode 100644 index 000000000..a6681af67 --- /dev/null +++ b/tests/OTAnalytics/application/use_cases/test_road_user_assignment_export.py @@ -0,0 +1,117 @@ +from unittest.mock import Mock + +import pytest + +from OTAnalytics.application.analysis.traffic_counting import RoadUserAssignment +from OTAnalytics.application.use_cases.road_user_assignment_export import ( + ExportRoadUserAssignments, + RoadUserAssignmentBuilder, + RoadUserAssignmentBuildError, +) +from OTAnalytics.domain.section import Section +from tests.utils.builders.road_user_assignment import create_road_user_assignment + + +@pytest.fixture +def _builder() -> RoadUserAssignmentBuilder: + return RoadUserAssignmentBuilder() + + +class TestRoadUserAssignmentBuilder: + def test_add_start_section(self, _builder: RoadUserAssignmentBuilder) -> None: + section = Mock() + _builder.add_start_section(section) + assert _builder._start_section == section + + def test_add_end_section(self, _builder: RoadUserAssignmentBuilder) -> None: + section = Mock() + _builder.add_end_section(section) + assert _builder._end_section == section + + def test_add_max_confidence(self, _builder: RoadUserAssignmentBuilder) -> None: + confidence = 0.8 + _builder.add_max_confidence(confidence) + assert _builder._max_confidence == confidence + + def test_build( + self, + _builder: RoadUserAssignmentBuilder, + first_line_section: Section, + second_line_section: Section, + first_road_user_assignment: RoadUserAssignment, + ) -> None: + _builder.add_start_section(first_line_section) + _builder.add_end_section(second_line_section) + _builder.add_max_confidence(0.9) + result = _builder.build(first_road_user_assignment) + assert result == create_road_user_assignment( + first_road_user_assignment, first_line_section, second_line_section + ) + + def test_build_with_start_section_missing( + self, _builder: RoadUserAssignmentBuilder + ) -> None: + _builder.add_end_section(Mock()) + _builder.add_max_confidence(0.9) + with pytest.raises(RoadUserAssignmentBuildError, match="Start section not set"): + _builder.build(Mock()) + + def test_build_with_end_section_missing( + self, _builder: RoadUserAssignmentBuilder + ) -> None: + _builder.add_start_section(Mock()) + _builder.add_max_confidence(0.9) + with pytest.raises(RoadUserAssignmentBuildError, match="End section not set"): + _builder.build(Mock()) + + def test_build_with_max_confidence_missing( + self, _builder: RoadUserAssignmentBuilder + ) -> None: + _builder.add_start_section(Mock()) + _builder.add_end_section(Mock()) + with pytest.raises( + RoadUserAssignmentBuildError, match="Max confidence not set" + ): + _builder.build(Mock()) + + +class TestExportRoadUserAssignments: + def test_export(self) -> None: + event_repository = Mock() + flow_repository = Mock() + create_events = Mock() + road_user_assigner = Mock() + exporter_factory = Mock() + + events = Mock() + event_repository.is_empty.return_value = False + event_repository.get_all.return_value = events + + flows = Mock() + flow_repository.get_all.return_value = flows + + assignments = Mock() + road_user_assigner.assign.return_value = assignments + + exporter = Mock() + exporter_factory.create.return_value = exporter + + export_road_user_assignments = ExportRoadUserAssignments( + event_repository, + flow_repository, + create_events, + road_user_assigner, + exporter_factory, + ) + specification = Mock() + specification.save_path = Mock() + specification.format = "csv" + + export_road_user_assignments.export(specification) + + event_repository.is_empty.assert_called_once() + event_repository.get_all.assert_called_once() + flow_repository.get_all.assert_called_once() + road_user_assigner.assign.assert_called_once_with(events, flows) + exporter_factory.create.assert_called_once_with(specification) + exporter.export.assert_called_once_with(assignments) diff --git a/tests/OTAnalytics/plugin_parser/test_road_user_assignment_export.py b/tests/OTAnalytics/plugin_parser/test_road_user_assignment_export.py new file mode 100644 index 000000000..730ae3b18 --- /dev/null +++ b/tests/OTAnalytics/plugin_parser/test_road_user_assignment_export.py @@ -0,0 +1,79 @@ +from pathlib import Path +from unittest.mock import Mock + +from pandas import DataFrame, read_csv + +from OTAnalytics.application.analysis.traffic_counting import ( + RoadUserAssignment, + RoadUserAssignments, +) +from OTAnalytics.application.export_formats import road_user_assignments as ras +from OTAnalytics.application.use_cases.road_user_assignment_export import ( + RoadUserAssignmentBuilder, +) +from OTAnalytics.domain.section import Section +from OTAnalytics.plugin_parser.road_user_assignment_export import ( + RoadUserAssignmentCsvExporter, +) +from tests.utils.builders.road_user_assignment import create_road_user_assignment + + +class TestRoadUserAssignmentCsvExporter: + def test_export( + self, + test_data_tmp_dir: Path, + first_line_section: Section, + second_line_section: Section, + first_road_user_assignment: RoadUserAssignment, + second_road_user_assignment: RoadUserAssignment, + ) -> None: + save_path = test_data_tmp_dir / "road_user_assignments.csv" + + section_repository = Mock() + get_all_tracks = Mock() + builder = RoadUserAssignmentBuilder() + track_dataset = Mock() + + track_dataset.get_max_confidences_for.return_value = { + first_road_user_assignment.road_user: 0.9, + second_road_user_assignment.road_user: 0.7, + } + get_all_tracks.as_dataset.return_value = track_dataset + section_repository.get.side_effect = [ + first_line_section, + second_line_section, + first_line_section, + second_line_section, + ] + + exporter = RoadUserAssignmentCsvExporter( + section_repository, get_all_tracks, builder, save_path + ) + exporter.export( + RoadUserAssignments( + [first_road_user_assignment, second_road_user_assignment] + ) + ) + expected = DataFrame( + [ + create_road_user_assignment( + first_road_user_assignment, + first_line_section, + second_line_section, + 0.9, + ), + create_road_user_assignment( + second_road_user_assignment, + first_line_section, + second_line_section, + 0.7, + ), + ] + ) + actual = read_csv(save_path) + actual[ras.START_SECTION_ID] = actual[ras.START_SECTION_ID].astype(str) + actual[ras.END_SECTION_ID] = actual[ras.END_SECTION_ID].astype(str) + actual[ras.START_SECTION_NAME] = actual[ras.START_SECTION_NAME].astype(str) + actual[ras.END_SECTION_NAME] = actual[ras.END_SECTION_NAME].astype(str) + + assert actual.equals(expected) diff --git a/tests/conftest.py b/tests/conftest.py index bd49d1cc1..64e48d564 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,12 @@ import pytest +from OTAnalytics.application.analysis.traffic_counting import ( + EventPair, + RoadUserAssignment, +) +from OTAnalytics.domain.event import Event +from OTAnalytics.domain.flow import Flow, FlowId from OTAnalytics.domain.geometry import Coordinate from OTAnalytics.domain.section import LineSection, Section, SectionId from OTAnalytics.domain.track import Track, TrackId @@ -319,3 +325,61 @@ def pandas_track_segment_dataset_builder( track_segment_dataset_builder_provider: TrackSegmentDatasetBuilderProvider, ) -> TrackSegmentDatasetBuilder: return track_segment_dataset_builder_provider.provide(PANDAS) + + +@pytest.fixture +def first_line_section() -> Section: + return LineSection( + SectionId("1"), "First Section", {}, {}, [Coordinate(0, 0), Coordinate(1, 0)] + ) + + +@pytest.fixture +def second_line_section() -> Section: + return LineSection( + SectionId("2"), "Second Section", {}, {}, [Coordinate(0, 0), Coordinate(1, 0)] + ) + + +@pytest.fixture +def first_flow(first_line_section: Section, second_line_section: Section) -> Flow: + _id = FlowId("First Flow") + return Flow(_id, _id.id, first_line_section.id, second_line_section.id) + + +@pytest.fixture +def first_section_event(first_line_section: Section) -> Event: + builder = EventBuilder() + builder.add_road_user_id("Road User 1") + builder.add_section_id(first_line_section.id.id) + return builder.build_section_event() + + +@pytest.fixture +def second_section_event(second_line_section: Section) -> Event: + builder = EventBuilder() + builder.add_road_user_id("Road User 1") + builder.add_section_id(second_line_section.id.id) + return builder.build_section_event() + + +@pytest.fixture +def first_road_user_assignment( + first_flow: Flow, first_section_event: Event, second_section_event: Event +) -> RoadUserAssignment: + return RoadUserAssignment( + "Road User 1", + first_flow, + EventPair(first_section_event, second_section_event), + ) + + +@pytest.fixture +def second_road_user_assignment( + first_flow: Flow, first_section_event: Event, second_section_event: Event +) -> RoadUserAssignment: + return RoadUserAssignment( + "Road User 2", + first_flow, + EventPair(first_section_event, second_section_event), + ) diff --git a/tests/utils/builders/road_user_assignment.py b/tests/utils/builders/road_user_assignment.py new file mode 100644 index 000000000..b3d688595 --- /dev/null +++ b/tests/utils/builders/road_user_assignment.py @@ -0,0 +1,43 @@ +from OTAnalytics.application.analysis.traffic_counting import RoadUserAssignment +from OTAnalytics.application.export_formats import road_user_assignments as ras +from OTAnalytics.domain.section import Section + + +def create_road_user_assignment( + assignment: RoadUserAssignment, + start_section: Section, + end_section: Section, + max_confidence: float = 0.9, +) -> dict: + start_event = assignment.events.start + end_event = assignment.events.end + + return { + ras.FLOW_ID: assignment.assignment.id.id, + ras.FLOW_NAME: assignment.assignment.name, + ras.ROAD_USER_ID: assignment.road_user, + ras.MAX_CONFIDENCE: max_confidence, + ras.START_OCCURRENCE: start_event.occurrence.strftime(ras.DATE_TIME_FORMAT), + ras.START_OCCURRENCE_DATE: start_event.occurrence.strftime(ras.DATE_FORMAT), + ras.START_OCCURRENCE_TIME: start_event.occurrence.strftime(ras.TIME_FORMAT), + ras.END_OCCURRENCE: end_event.occurrence.strftime(ras.DATE_TIME_FORMAT), + ras.END_OCCURRENCE_DATE: end_event.occurrence.strftime(ras.DATE_FORMAT), + ras.END_OCCURRENCE_TIME: end_event.occurrence.strftime(ras.TIME_FORMAT), + ras.START_FRAME: start_event.frame_number, + ras.END_FRAME: end_event.frame_number, + ras.START_VIDEO_NAME: start_event.video_name, + ras.END_VIDEO_NAME: end_event.video_name, + ras.START_SECTION_ID: start_section.id.id, + ras.END_SECTION_ID: end_section.id.id, + ras.START_SECTION_NAME: start_section.name, + ras.END_SECTION_NAME: end_section.name, + ras.START_EVENT_COORDINATE_X: start_event.event_coordinate.x, + ras.START_EVENT_COORDINATE_Y: start_event.event_coordinate.y, + ras.END_EVENT_COORDINATE_X: end_event.event_coordinate.x, + ras.END_EVENT_COORDINATE_Y: end_event.event_coordinate.y, + ras.START_DIRECTION_VECTOR_X: start_event.event_coordinate.x, + ras.START_DIRECTION_VECTOR_Y: start_event.event_coordinate.y, + ras.END_DIRECTION_VECTOR_X: end_event.event_coordinate.x, + ras.END_DIRECTION_VECTOR_Y: end_event.event_coordinate.y, + ras.HOSTNAME: start_event.hostname, + } From 1b6b6adb142404507c0de14094af424f8efe6556 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Wed, 10 Apr 2024 23:09:13 +0200 Subject: [PATCH 06/15] Wire use case to export road user assignments to OTAnalyticsApplication --- OTAnalytics/application/application.py | 9 +++++++ OTAnalytics/plugin_ui/main_application.py | 30 +++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/OTAnalytics/application/application.py b/OTAnalytics/application/application.py index b1a746a7e..401c72125 100644 --- a/OTAnalytics/application/application.py +++ b/OTAnalytics/application/application.py @@ -42,6 +42,10 @@ from OTAnalytics.application.use_cases.quick_save_configuration import ( QuickSaveConfiguration, ) +from OTAnalytics.application.use_cases.road_user_assignment_export import ( + ExportRoadUserAssignments, + ExportSpecification, +) from OTAnalytics.application.use_cases.save_otflow import SaveOtflow from OTAnalytics.application.use_cases.section_repository import ( AddSection, @@ -123,6 +127,7 @@ def __init__( quick_save_configuration: QuickSaveConfiguration, load_otconfig: LoadOtconfig, config_has_changed: ConfigHasChanged, + export_road_user_assignments: ExportRoadUserAssignments, ) -> None: self._datastore: Datastore = datastore self.track_state: TrackState = track_state @@ -161,6 +166,7 @@ def __init__( self._quick_save_configuration = quick_save_configuration self._load_otconfig = load_otconfig self._config_has_changed = config_has_changed + self._export_road_user_assignments = export_road_user_assignments def connect_observers(self) -> None: """ @@ -622,6 +628,9 @@ def quick_save_configuration(self) -> None: def config_has_changed(self) -> bool: return self._config_has_changed.has_changed() + def export_road_user_assignments(self, specification: ExportSpecification) -> None: + self._export_road_user_assignments.export(specification) + class MissingTracksError(Exception): pass diff --git a/OTAnalytics/plugin_ui/main_application.py b/OTAnalytics/plugin_ui/main_application.py index 6a37c5fe6..5eca92b59 100644 --- a/OTAnalytics/plugin_ui/main_application.py +++ b/OTAnalytics/plugin_ui/main_application.py @@ -110,6 +110,9 @@ QuickSaveConfiguration, ) from OTAnalytics.application.use_cases.reset_project_config import ResetProjectConfig +from OTAnalytics.application.use_cases.road_user_assignment_export import ( + ExportRoadUserAssignments, +) from OTAnalytics.application.use_cases.save_otflow import SaveOtflow from OTAnalytics.application.use_cases.section_repository import ( AddAllSections, @@ -187,6 +190,9 @@ SimpleVideoParser, ) from OTAnalytics.plugin_parser.pandas_parser import PandasDetectionParser +from OTAnalytics.plugin_parser.road_user_assignment_export import ( + SimpleRoadUserAssignmentExporterFactory, +) from OTAnalytics.plugin_parser.track_export import CsvTrackExport from OTAnalytics.plugin_progress.tqdm_progressbar import TqdmBuilder from OTAnalytics.plugin_prototypes.eventlist_exporter.eventlist_exporter import ( @@ -476,6 +482,13 @@ def start_gui(self, run_config: RunConfiguration) -> None: OtflowHasChanged(flow_parser, get_sections, get_flows), file_state, ) + export_road_user_assignments = self.create_export_road_user_assignments( + get_all_tracks, + section_repository, + event_repository, + flow_repository, + create_events, + ) application = OTAnalyticsApplication( datastore, track_state, @@ -508,6 +521,7 @@ def start_gui(self, run_config: RunConfiguration) -> None: quick_save_configuration, load_otconfig, config_has_changed, + export_road_user_assignments, ) section_repository.register_sections_observer(cut_tracks_intersecting_section) section_repository.register_section_changed_observer( @@ -1005,3 +1019,19 @@ def create_config_parser( flow_parser=flow_parser, format_fixer=format_fixer, ) + + def create_export_road_user_assignments( + self, + get_all_tracks: GetAllTracks, + section_repository: SectionRepository, + event_repository: EventRepository, + flow_repository: FlowRepository, + create_events: CreateEvents, + ) -> ExportRoadUserAssignments: + return ExportRoadUserAssignments( + event_repository, + flow_repository, + create_events, + FilterBySectionEnterEvent(SimpleRoadUserAssigner()), + SimpleRoadUserAssignmentExporterFactory(section_repository, get_all_tracks), + ) From c218f115579c59ddd2c02269b0f871fd9432acfd Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:10:25 +0200 Subject: [PATCH 07/15] Add and wire button to export road user assignments --- OTAnalytics/adapter_ui/view_model.py | 4 ++ OTAnalytics/application/application.py | 5 +++ .../use_cases/road_user_assignment_export.py | 13 +++++- .../eventlist_exporter/eventlist_exporter.py | 7 +-- .../customtkinter_gui/dummy_viewmodel.py | 43 +++++++++++++++++++ .../customtkinter_gui/frame_analysis.py | 14 +++++- .../toplevel_export_events.py | 6 ++- 7 files changed, 85 insertions(+), 7 deletions(-) diff --git a/OTAnalytics/adapter_ui/view_model.py b/OTAnalytics/adapter_ui/view_model.py index c68870484..c319d80c0 100644 --- a/OTAnalytics/adapter_ui/view_model.py +++ b/OTAnalytics/adapter_ui/view_model.py @@ -389,3 +389,7 @@ def get_skip_frames(self) -> int: @abstractmethod def set_video_control_frame(self, frame: AbstractFrame) -> None: raise NotImplementedError + + @abstractmethod + def export_road_user_assignments(self) -> None: + raise NotImplementedError diff --git a/OTAnalytics/application/application.py b/OTAnalytics/application/application.py index 401c72125..19382c370 100644 --- a/OTAnalytics/application/application.py +++ b/OTAnalytics/application/application.py @@ -631,6 +631,11 @@ def config_has_changed(self) -> bool: def export_road_user_assignments(self, specification: ExportSpecification) -> None: self._export_road_user_assignments.export(specification) + def get_road_user_export_formats( + self, + ) -> Iterable[ExportFormat]: + return self._export_road_user_assignments.get_supported_formats() + class MissingTracksError(Exception): pass diff --git a/OTAnalytics/application/use_cases/road_user_assignment_export.py b/OTAnalytics/application/use_cases/road_user_assignment_export.py index 808cb8a56..f51818e5b 100644 --- a/OTAnalytics/application/use_cases/road_user_assignment_export.py +++ b/OTAnalytics/application/use_cases/road_user_assignment_export.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from dataclasses import dataclass from pathlib import Path from typing import Callable, Iterable, Protocol, Self @@ -158,7 +159,8 @@ def _get_section_by_id(self, section_id: SectionId) -> Section: return result -class ExportSpecification(Protocol): +@dataclass(frozen=True) +class ExportSpecification: save_path: Path format: str @@ -211,3 +213,12 @@ def export(self, specification: ExportSpecification) -> None: road_user_assignments = self._assigner.assign(events, flows) exporter = self._exporter_factory.create(specification) exporter.export(road_user_assignments) + + def get_supported_formats(self) -> Iterable[ExportFormat]: + """ + Returns an iterable of the supported export formats. + + Returns: + Iterable[ExportFormat]: supported export formats + """ + return self._exporter_factory.get_supported_formats() diff --git a/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py b/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py index 73ffa6a1f..fa14d3cd6 100644 --- a/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py +++ b/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py @@ -15,8 +15,9 @@ from OTAnalytics.domain.section import Section from OTAnalytics.plugin_parser.otvision_parser import OtEventListParser -EXTENSION_CSV = "csv" -EXTENSION_EXCEL = "xlsx" +EXTENSION_CSV = ".csv" +EXTENSION_EXCEL = ".xlsx" +EXTENSION_OTEVENTS = f".{DEFAULT_EVENTLIST_FILE_TYPE}" OTC_EXCEL_FORMAT_NAME = "Excel (OpenTrafficCam)" OTC_CSV_FORMAT_NAME = "CSV (OpenTrafficCam)" @@ -134,7 +135,7 @@ def export( self._event_list_parser.serialize(events, sections, file) def get_extension(self) -> str: - return DEFAULT_EVENTLIST_FILE_TYPE + return EXTENSION_OTEVENTS def get_name(self) -> str: return OTC_OTEVENTS_FORMAT_NAME diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py b/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py index 0f929c36b..78e6b3c3b 100644 --- a/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py +++ b/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py @@ -64,6 +64,9 @@ from OTAnalytics.application.use_cases.quick_save_configuration import ( NoExistingFileToSave, ) +from OTAnalytics.application.use_cases.road_user_assignment_export import ( + ExportSpecification, +) from OTAnalytics.application.use_cases.save_otflow import NoSectionsToSave from OTAnalytics.domain import geometry from OTAnalytics.domain.date import ( @@ -1694,3 +1697,43 @@ def set_button_quick_save_config( self, button_quick_save_config: AbstractButtonQuickSaveConfig ) -> None: self._button_quick_save_config = button_quick_save_config + + def export_road_user_assignments(self) -> None: + if len(self._application.get_all_flows()) == 0: + InfoBox( + message=( + "Counting needs at least one flow.\n" + "There is no flow configured.\n" + "Please create a flow." + ), + initial_position=( + self._window.get_position() if self._window else (0, 0) + ), + ) + return + export_formats: dict = { + export_format.name: export_format.file_extension + for export_format in self._application.get_road_user_export_formats() + } + default_format = next(iter(export_formats.keys())) + default_values: dict = { + EXPORT_FORMAT: default_format, + } + + try: + export_values = ToplevelExportEvents( + title="Export road user assignments", + initial_position=(50, 50), + input_values=default_values, + export_format_extensions=export_formats, + initial_file_stem="road_user_assignments", + ).get_data() + logger().debug(export_values) + save_path = export_values[toplevel_export_events.EXPORT_FILE] + export_format = export_values[toplevel_export_events.EXPORT_FORMAT] + + export_specification = ExportSpecification(save_path, export_format) + self._application.export_road_user_assignments(export_specification) + logger().info(f"Exporting road user assignments to {save_path}") + except CancelExportEvents: + logger().info("User canceled configuration of export") diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/frame_analysis.py b/OTAnalytics/plugin_ui/customtkinter_gui/frame_analysis.py index 951e2a129..265fcd0f7 100644 --- a/OTAnalytics/plugin_ui/customtkinter_gui/frame_analysis.py +++ b/OTAnalytics/plugin_ui/customtkinter_gui/frame_analysis.py @@ -53,6 +53,11 @@ def _get_widgets(self) -> None: self.button_export_counts = CTkButton( master=self, text="Export counts ...", command=self._viewmodel.export_counts ) + self.button_export_road_user_assignments = CTkButton( + master=self, + text="Export road user assignments ...", + command=self._viewmodel.export_road_user_assignments, + ) def _place_widgets(self) -> None: self.button_export_eventlist.grid( @@ -61,6 +66,13 @@ def _place_widgets(self) -> None: self.button_export_counts.grid( row=1, column=0, padx=PADX, pady=PADY, sticky=STICKY ) + self.button_export_road_user_assignments.grid( + row=2, column=0, padx=PADX, pady=PADY, sticky=STICKY + ) def get_general_buttons(self) -> list[CTkButton]: - return [self.button_export_counts, self.button_export_eventlist] + return [ + self.button_export_counts, + self.button_export_eventlist, + self.button_export_road_user_assignments, + ] diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_events.py b/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_events.py index 9a3c6fab3..c6f58295e 100644 --- a/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_events.py +++ b/OTAnalytics/plugin_ui/customtkinter_gui/toplevel_export_events.py @@ -72,10 +72,12 @@ def __init__( self, export_format_extensions: dict[str, str], input_values: dict, + initial_file_stem: str = INITIAL_FILE_STEM, **kwargs: Any, ) -> None: self._input_values = input_values self._export_format_extensions = export_format_extensions + self._initial_file_stem = initial_file_stem super().__init__(**kwargs) def _create_frame_content(self, master: Any) -> FrameContent: @@ -87,12 +89,12 @@ def _create_frame_content(self, master: Any) -> FrameContent: def _choose_file(self) -> None: export_format = self._input_values[EXPORT_FORMAT] # - export_extension = f"*.{self._export_format_extensions[export_format]}" + export_extension = f"*{self._export_format_extensions[export_format]}" export_file = ask_for_save_file_name( title="Save counts as", filetypes=[(export_format, export_extension)], defaultextension=export_extension, - initialfile=INITIAL_FILE_STEM, + initialfile=self._initial_file_stem, ) self._input_values[EXPORT_FILE] = export_file if export_file == "": From bf1946396e95eee99d4dae44cbd34549b3416a81 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Thu, 11 Apr 2024 01:44:17 +0200 Subject: [PATCH 08/15] Streamline file extension format for exports File extensions for exports always start with a dot. --- .../eventlist_exporter/eventlist_exporter.py | 20 +++++++++---------- OTAnalytics/plugin_ui/cli.py | 2 +- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py b/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py index fa14d3cd6..fd44ecc8a 100644 --- a/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py +++ b/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py @@ -15,9 +15,9 @@ from OTAnalytics.domain.section import Section from OTAnalytics.plugin_parser.otvision_parser import OtEventListParser -EXTENSION_CSV = ".csv" -EXTENSION_EXCEL = ".xlsx" -EXTENSION_OTEVENTS = f".{DEFAULT_EVENTLIST_FILE_TYPE}" +EXTENSION_CSV = "csv" +EXTENSION_EXCEL = "xlsx" +EXTENSION_OTEVENTS = DEFAULT_EVENTLIST_FILE_TYPE OTC_EXCEL_FORMAT_NAME = "Excel (OpenTrafficCam)" OTC_CSV_FORMAT_NAME = "CSV (OpenTrafficCam)" @@ -102,7 +102,7 @@ def _write_to_excel( writer.close() def get_extension(self) -> str: - return EXTENSION_EXCEL + return f".{EXTENSION_EXCEL}" def get_name(self) -> str: return OTC_EXCEL_FORMAT_NAME @@ -119,7 +119,7 @@ def _write_to_excel(self, file: Path, df_events: pd.DataFrame) -> None: df_events.to_csv(file, index=False) def get_extension(self) -> str: - return EXTENSION_CSV + return f".{EXTENSION_CSV}" def get_name(self) -> str: return OTC_CSV_FORMAT_NAME @@ -135,7 +135,7 @@ def export( self._event_list_parser.serialize(events, sections, file) def get_extension(self) -> str: - return EXTENSION_OTEVENTS + return f".{EXTENSION_OTEVENTS}" def get_name(self) -> str: return OTC_OTEVENTS_FORMAT_NAME @@ -190,15 +190,15 @@ def get_name(self) -> str: def provide_available_eventlist_exporter(event_format: str) -> EventListExporter: _format = event_format.lower() - if _format == EXTENSION_CSV: + if _format == EXTENSION_CSV or _format == f".{EXTENSION_CSV}": return AVAILABLE_EVENTLIST_EXPORTERS[OTC_CSV_FORMAT_NAME] - elif _format == EXTENSION_EXCEL: + elif _format == EXTENSION_EXCEL or _format == f".{EXTENSION_EXCEL}": return AVAILABLE_EVENTLIST_EXPORTERS[OTC_EXCEL_FORMAT_NAME] - elif _format == DEFAULT_EVENTLIST_FILE_TYPE: + elif _format == EXTENSION_OTEVENTS or _format == f".{EXTENSION_OTEVENTS}": return AVAILABLE_EVENTLIST_EXPORTERS[OTC_OTEVENTS_FORMAT_NAME] else: raise ExporterNotFoundError( f"{event_format} is a not supported eventlist format. " f"Supported formats are: [{EXTENSION_CSV}, " - f"{EXTENSION_EXCEL}, {DEFAULT_EVENTLIST_FILE_TYPE}]" + f"{EXTENSION_EXCEL}, {EXTENSION_OTEVENTS}]" ) diff --git a/OTAnalytics/plugin_ui/cli.py b/OTAnalytics/plugin_ui/cli.py index ebfe19c55..c8fad1d91 100644 --- a/OTAnalytics/plugin_ui/cli.py +++ b/OTAnalytics/plugin_ui/cli.py @@ -231,7 +231,7 @@ def _export_events(self, sections: Iterable[Section], save_path: Path) -> None: for event_format in self._run_config.event_formats: event_list_exporter = self._provide_eventlist_exporter(event_format) actual_save_path = save_path.with_suffix( - f".events.{event_list_exporter.get_extension()}" + f".events{event_list_exporter.get_extension()}" ) event_list_exporter.export(events, sections, actual_save_path) logger().info(f"Event list saved at '{actual_save_path}'") From f67cff10d5905da3aa2968154dcce2d0c3f6d4ef Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Thu, 11 Apr 2024 02:27:37 +0200 Subject: [PATCH 09/15] Add time and date as separate columns to events and counts --- OTAnalytics/plugin_parser/export.py | 25 ++++++++++++++++++- .../eventlist_exporter/eventlist_exporter.py | 10 ++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/OTAnalytics/plugin_parser/export.py b/OTAnalytics/plugin_parser/export.py index 3f6f6c0df..a40a1e912 100644 --- a/OTAnalytics/plugin_parser/export.py +++ b/OTAnalytics/plugin_parser/export.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Iterable -from pandas import DataFrame +from pandas import DataFrame, to_datetime from OTAnalytics.application.analysis.traffic_counting import ( LEVEL_CLASSIFICATION, @@ -27,6 +27,14 @@ ) from OTAnalytics.application.logger import logger +START_DATE = "start occurrence date" +START_TIME = "start occurrence time" +END_DATE = "end occurrence date" +END_TIME = "end occurrence time" + +DATE_FORMAT = "%Y-%m-%d" +TIME_FORMAT = "%H:%M:%S" + class CsvExport(Exporter): def __init__(self, output_file: str) -> None: @@ -35,6 +43,7 @@ def __init__(self, output_file: str) -> None: def export(self, counts: Count) -> None: logger().info(f"Exporting counts to {self._output_file}") dataframe = self.__create_data_frame(counts) + dataframe = self._add_detailed_date_time_columns(dataframe) if dataframe.empty: logger().info("Nothing to count.") return @@ -45,11 +54,25 @@ def export(self, counts: Count) -> None: dataframe.to_csv(self.__create_path(), index=False) logger().info(f"Counts saved at {self._output_file}") + def _add_detailed_date_time_columns(self, df: DataFrame) -> DataFrame: + start_occurrence = to_datetime(df[LEVEL_START_TIME]) + end_occurrence = to_datetime(df[LEVEL_END_TIME]) + + df[START_DATE] = start_occurrence.dt.strftime(DATE_FORMAT) + df[START_TIME] = start_occurrence.dt.strftime(TIME_FORMAT) + df[END_DATE] = end_occurrence.dt.strftime(DATE_FORMAT) + df[END_TIME] = end_occurrence.dt.strftime(TIME_FORMAT) + return df + @staticmethod def _set_column_order(dataframe: DataFrame) -> DataFrame: desired_columns_order = [ LEVEL_START_TIME, + START_DATE, + START_TIME, LEVEL_END_TIME, + END_DATE, + END_TIME, LEVEL_CLASSIFICATION, LEVEL_FLOW, LEVEL_FROM_SECTION, diff --git a/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py b/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py index fd44ecc8a..6731c2e66 100644 --- a/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py +++ b/OTAnalytics/plugin_prototypes/eventlist_exporter/eventlist_exporter.py @@ -37,8 +37,18 @@ def build(self) -> pd.DataFrame: self._convert_occurrence_to_seconds_since_epoch() self._split_columns_with_lists() self._add_section_names() + self._add_detailed_date_time_columns() return self._df + def _add_detailed_date_time_columns(self) -> None: + occurrence_column = pd.to_datetime(self._df[event_list.OCCURRENCE]) + self._df[event_list.OCCURRENCE_DATE] = occurrence_column.dt.strftime( + event_list.DATE_FORMAT + ) + self._df[event_list.OCCURRENCE_TIME] = occurrence_column.dt.strftime( + event_list.TIME_FORMAT + ) + def _convert_occurrence_to_seconds_since_epoch(self) -> None: # TODO: Use OTAnalytics´ builtin timestamp methods epoch = pd.Timestamp("1970-01-01") From bf7b181a9d2ac7277194bf4d6746b4e2ef2155ae Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Thu, 11 Apr 2024 02:36:48 +0200 Subject: [PATCH 10/15] Fix unit tests --- OTAnalytics/plugin_parser/export.py | 2 +- tests/OTAnalytics/plugin_parser/test_export.py | 14 +++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/OTAnalytics/plugin_parser/export.py b/OTAnalytics/plugin_parser/export.py index a40a1e912..9754133dd 100644 --- a/OTAnalytics/plugin_parser/export.py +++ b/OTAnalytics/plugin_parser/export.py @@ -43,10 +43,10 @@ def __init__(self, output_file: str) -> None: def export(self, counts: Count) -> None: logger().info(f"Exporting counts to {self._output_file}") dataframe = self.__create_data_frame(counts) - dataframe = self._add_detailed_date_time_columns(dataframe) if dataframe.empty: logger().info("Nothing to count.") return + dataframe = self._add_detailed_date_time_columns(dataframe) dataframe = self._set_column_order(dataframe) dataframe = dataframe.sort_values( by=[LEVEL_START_TIME, LEVEL_END_TIME, LEVEL_CLASSIFICATION] diff --git a/tests/OTAnalytics/plugin_parser/test_export.py b/tests/OTAnalytics/plugin_parser/test_export.py index 55de89cb5..d2eaab0f4 100644 --- a/tests/OTAnalytics/plugin_parser/test_export.py +++ b/tests/OTAnalytics/plugin_parser/test_export.py @@ -26,7 +26,15 @@ ExportSpecificationDto, FlowNameDto, ) -from OTAnalytics.plugin_parser.export import CsvExport, FillZerosExporter, TagExploder +from OTAnalytics.plugin_parser.export import ( + END_DATE, + END_TIME, + START_DATE, + START_TIME, + CsvExport, + FillZerosExporter, + TagExploder, +) class TestCsvExport: @@ -53,7 +61,11 @@ def test_export(self, test_data_tmp_dir: Path) -> None: counts.to_dict.return_value = {tag: 1} expected = { LEVEL_START_TIME: {0: "2023-01-02 08:00:00"}, + START_DATE: {0: "2023-01-02"}, + START_TIME: {0: "08:00:00"}, LEVEL_END_TIME: {0: "2023-01-02 08:15:00"}, + END_DATE: {0: "2023-01-02"}, + END_TIME: {0: "08:15:00"}, LEVEL_CLASSIFICATION: {0: "car"}, LEVEL_FLOW: {0: "West --> Ost"}, LEVEL_FROM_SECTION: {0: "West"}, From 3704163b94d016ae309f996bab20bebbceed80f6 Mon Sep 17 00:00:00 2001 From: Lars Briem Date: Fri, 12 Apr 2024 09:22:29 +0200 Subject: [PATCH 11/15] Add methods to recreate regression test results --- tests/regression_otanalytics.py | 115 +++++++++++++++++++++++++++++--- 1 file changed, 104 insertions(+), 11 deletions(-) diff --git a/tests/regression_otanalytics.py b/tests/regression_otanalytics.py index f84df584b..455504c51 100644 --- a/tests/regression_otanalytics.py +++ b/tests/regression_otanalytics.py @@ -1,6 +1,8 @@ from pathlib import Path import pytest +from more_itertools import chunked +from tqdm import tqdm from OTAnalytics.application.parser.flow_parser import FlowParser from OTAnalytics.plugin_parser.otvision_parser import OtFlowParser @@ -32,6 +34,12 @@ def track_files_2hours(test_data_dir: Path) -> list[str]: ] +@pytest.fixture(scope="module") +def all_track_files_test_dataset() -> list[Path]: + data_folder = Path("../../platomo/OpenTrafficCam-testdata/tests/data") + return list(data_folder.glob("*.ottrk")) + + @pytest.fixture(scope="module") def otflow_file(test_data_dir: Path) -> str: return to_cli_path(test_data_dir, "OTCamera19_FR20_2023-05-24.otflow") @@ -43,6 +51,72 @@ def otflow_parser() -> FlowParser: class TestRegressionCompleteApplication: + + @pytest.mark.skip + def test_15_min_recreate_test_data( + self, + otflow_file: str, + all_track_files_test_dataset: list[Path], + otflow_parser: FlowParser, + ) -> None: + for test_file in tqdm(all_track_files_test_dataset, desc="test data file"): + test_data = test_file + test_interval = "15min" + save_dir = test_data.parent + self._run_otanalytics( + count_interval=15, + otflow_file=otflow_file, + test_data=[str(test_data)], + save_dir=save_dir, + test_interval=test_interval, + otflow_parser=otflow_parser, + event_formats=("csv", "otevents"), + ) + + @pytest.mark.skip + def test_2_h_single_recreate_test_data( + self, + otflow_file: str, + all_track_files_test_dataset: list[Path], + otflow_parser: FlowParser, + ) -> None: + batches = list(chunked(sorted(all_track_files_test_dataset), n=8)) + for test_file in tqdm(batches, desc="test data file"): + test_data = test_file + test_interval = "2h" + save_dir = test_data[0].parent + self._run_otanalytics( + count_interval=15, + otflow_file=otflow_file, + test_data=[str(file) for file in test_data], + save_dir=save_dir, + test_interval=test_interval, + otflow_parser=otflow_parser, + event_formats=("csv", "otevents"), + ) + + @pytest.mark.skip + def test_2_h_recreate_test_data( + self, + otflow_file: str, + all_track_files_test_dataset: list[Path], + otflow_parser: FlowParser, + ) -> None: + batches = list(chunked(sorted(all_track_files_test_dataset), n=8)) + for test_file in tqdm(batches, desc="test data file"): + test_data = test_file + test_interval = "2h" + save_dir = test_data[0].parent + self._run_otanalytics( + count_interval=120, + otflow_file=otflow_file, + test_data=[str(file) for file in test_data], + save_dir=save_dir, + test_interval=test_interval, + otflow_parser=otflow_parser, + event_formats=("csv", "otevents"), + ) + def test_15_min( self, otflow_file: str, @@ -114,18 +188,14 @@ def _execute_test( otflow_parser: FlowParser, count_interval: int, ) -> None: - save_name = f"{Path(test_data[0]).stem}_{test_interval}" - - run_config = create_run_config( - track_files=[str(_file) for _file in test_data], - otflow_file=str(otflow_file), - save_dir=str(test_data_tmp_dir), - save_name=save_name, - event_formats=["csv"], - count_intervals=[count_interval], - flow_parser=otflow_parser, + save_name = self._run_otanalytics( + count_interval, + otflow_file, + otflow_parser, + test_data, + test_data_tmp_dir, + test_interval, ) - ApplicationStarter().start_cli(run_config) actual_events_file = Path(test_data_tmp_dir / save_name).with_suffix( ".events.csv" @@ -146,3 +216,26 @@ def _execute_test( .absolute() ) assert_two_files_equal_sorted(actual_counts_file, expected_counts_file) + + def _run_otanalytics( + self, + count_interval: int, + otflow_file: str, + otflow_parser: FlowParser, + test_data: list[str], + save_dir: Path, + test_interval: str, + event_formats: tuple[str, ...] = ("csv",), + ) -> str: + save_name = f"{Path(test_data[0]).stem}_{test_interval}" + run_config = create_run_config( + track_files=[str(_file) for _file in test_data], + otflow_file=str(otflow_file), + save_dir=str(save_dir), + save_name=save_name, + event_formats=list(event_formats), + count_intervals=[count_interval], + flow_parser=otflow_parser, + ) + ApplicationStarter().start_cli(run_config) + return save_name From eca999339bf7880be0bf27ff62db7ff2e7ff21cb Mon Sep 17 00:00:00 2001 From: Lars Briem Date: Fri, 12 Apr 2024 11:33:16 +0200 Subject: [PATCH 12/15] Add methods to recreate regression test results for whole day --- tests/regression_otanalytics.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/regression_otanalytics.py b/tests/regression_otanalytics.py index 455504c51..66364d09b 100644 --- a/tests/regression_otanalytics.py +++ b/tests/regression_otanalytics.py @@ -117,6 +117,28 @@ def test_2_h_recreate_test_data( event_formats=("csv", "otevents"), ) + # @pytest.mark.skip + def test_whole_day_recreate_test_data( + self, + otflow_file: str, + all_track_files_test_dataset: list[Path], + otflow_parser: FlowParser, + ) -> None: + batches = list(chunked(sorted(all_track_files_test_dataset), n=8)) + for test_file in tqdm(batches, desc="test data file"): + test_data = test_file + test_interval = "24h" + save_dir = test_data[0].parent + self._run_otanalytics( + count_interval=15, + otflow_file=otflow_file, + test_data=[str(file) for file in test_data], + save_dir=save_dir, + test_interval=test_interval, + otflow_parser=otflow_parser, + event_formats=("csv", "otevents"), + ) + def test_15_min( self, otflow_file: str, From 8f798c00a2a91976be97893e4cb65682ca3ac4c5 Mon Sep 17 00:00:00 2001 From: Lars Briem Date: Fri, 12 Apr 2024 12:15:12 +0200 Subject: [PATCH 13/15] Add methods to recreate regression test results for whole day --- tests/regression_otanalytics.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/tests/regression_otanalytics.py b/tests/regression_otanalytics.py index 66364d09b..784919eb1 100644 --- a/tests/regression_otanalytics.py +++ b/tests/regression_otanalytics.py @@ -124,20 +124,18 @@ def test_whole_day_recreate_test_data( all_track_files_test_dataset: list[Path], otflow_parser: FlowParser, ) -> None: - batches = list(chunked(sorted(all_track_files_test_dataset), n=8)) - for test_file in tqdm(batches, desc="test data file"): - test_data = test_file - test_interval = "24h" - save_dir = test_data[0].parent - self._run_otanalytics( - count_interval=15, - otflow_file=otflow_file, - test_data=[str(file) for file in test_data], - save_dir=save_dir, - test_interval=test_interval, - otflow_parser=otflow_parser, - event_formats=("csv", "otevents"), - ) + test_data = all_track_files_test_dataset + test_interval = "24h" + save_dir = test_data[0].parent + self._run_otanalytics( + count_interval=15, + otflow_file=otflow_file, + test_data=[str(file) for file in test_data], + save_dir=save_dir, + test_interval=test_interval, + otflow_parser=otflow_parser, + event_formats=("csv", "otevents"), + ) def test_15_min( self, From cc7ff29b787e3c25545042ad73de87e09fda9da8 Mon Sep 17 00:00:00 2001 From: Lars Briem Date: Mon, 15 Apr 2024 08:51:52 +0200 Subject: [PATCH 14/15] Revert unintended config changes --- ...o_Cars-Cyclist_FR20_2020-01-01_00-00-00.otconfig | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/tests/data/Testvideo_Cars-Cyclist_FR20_2020-01-01_00-00-00.otconfig b/tests/data/Testvideo_Cars-Cyclist_FR20_2020-01-01_00-00-00.otconfig index dfbf123e2..d8be65f63 100644 --- a/tests/data/Testvideo_Cars-Cyclist_FR20_2020-01-01_00-00-00.otconfig +++ b/tests/data/Testvideo_Cars-Cyclist_FR20_2020-01-01_00-00-00.otconfig @@ -1,16 +1,7 @@ { "project": { - "name": "Test- Knotenpunk", - "start_date": 1704092400.0, - "metadata": { - "tk_number": "1234", - "counting_location_number": "6789", - "direction": "2", - "weather": "2", - "remark": "Nichts spannendes", - "coordinate_x": "8.1235", - "coordinate_y": "49.2468" - } + "name": "My Test Project", + "start_date": 1577877077.0 }, "videos": [ { From 8e512d6bb964da01c40ee7b313be471b6ab26110 Mon Sep 17 00:00:00 2001 From: Lars Briem Date: Mon, 15 Apr 2024 08:54:39 +0200 Subject: [PATCH 15/15] Skip creation of whole test data --- tests/regression_otanalytics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/regression_otanalytics.py b/tests/regression_otanalytics.py index 784919eb1..ec4fc09ad 100644 --- a/tests/regression_otanalytics.py +++ b/tests/regression_otanalytics.py @@ -117,7 +117,7 @@ def test_2_h_recreate_test_data( event_formats=("csv", "otevents"), ) - # @pytest.mark.skip + @pytest.mark.skip def test_whole_day_recreate_test_data( self, otflow_file: str,