diff --git a/OTAnalytics/application/state.py b/OTAnalytics/application/state.py index 5cdff1bde..8d938895f 100644 --- a/OTAnalytics/application/state.py +++ b/OTAnalytics/application/state.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import Callable, Generic, Iterable, Optional +from typing import Callable, Generic, Optional from OTAnalytics.application.config import DEFAULT_TRACK_OFFSET from OTAnalytics.application.datastore import Datastore, VideoMetadata @@ -17,7 +17,7 @@ SectionRepositoryEvent, SectionType, ) -from OTAnalytics.domain.track import Detection, TrackId, TrackImage +from OTAnalytics.domain.track import TrackId, TrackImage from OTAnalytics.domain.track_repository import ( TrackListObserver, TrackObserver, @@ -436,9 +436,9 @@ def __init__(self, track_repository: TrackRepository) -> None: self._last_detection_occurrence: ObservableOptionalProperty[ datetime ] = ObservableOptionalProperty[datetime]() - self._classifications: ObservableProperty[set[str]] = ObservableProperty[set]( - set() - ) + self._classifications: ObservableProperty[frozenset[str]] = ObservableProperty[ + frozenset + ](frozenset()) self._detection_classifications: ObservableProperty[ frozenset[str] ] = ObservableProperty[frozenset](frozenset([])) @@ -464,7 +464,7 @@ def last_detection_occurrence(self) -> Optional[datetime]: return self._last_detection_occurrence.get() @property - def classifications(self) -> set[str]: + def classifications(self) -> frozenset[str]: """The current classifications in the track repository. Returns: @@ -484,37 +484,16 @@ def detection_classifications(self) -> frozenset[str]: def notify_tracks(self, track_event: TrackRepositoryEvent) -> None: """Update tracks metadata on track repository changes""" self._update_detection_occurrences() - self._update_classifications(track_event.added) + self._update_classifications() def _update_detection_occurrences(self) -> None: """Update the first and last detection occurrences.""" - sorted_detections = sorted( - self._get_all_track_detections(), key=lambda x: x.occurrence - ) - if sorted_detections: - self._first_detection_occurrence.set(sorted_detections[0].occurrence) - self._last_detection_occurrence.set(sorted_detections[-1].occurrence) + self._first_detection_occurrence.set(self._track_repository.first_occurrence) + self._last_detection_occurrence.set(self._track_repository.last_occurrence) - def _update_classifications(self, new_tracks: list[TrackId]) -> None: + def _update_classifications(self) -> None: """Update current classifications.""" - updated_classifications = self._classifications.get().copy() - for track_id in new_tracks: - if track := self._track_repository.get_for(track_id): - updated_classifications.add(track.classification) - self._classifications.set(updated_classifications) - - def _get_all_track_detections(self) -> Iterable[Detection]: - """Get all track detections in the track repository. - - Returns: - Iterable[Detection]: the track detections. - """ - detections: list[Detection] = [] - - for track in self._track_repository.get_all(): - detections.extend(track.detections) - - return detections + self._classifications.set(self._track_repository.classifications) def update_detection_classes(self, new_classes: frozenset[str]) -> None: """Update the classifications used by the detection model.""" diff --git a/OTAnalytics/application/track_filter.py b/OTAnalytics/application/track_filter.py index 446665a68..28454bb44 100644 --- a/OTAnalytics/application/track_filter.py +++ b/OTAnalytics/application/track_filter.py @@ -46,7 +46,7 @@ def __init__(self, start_date: datetime) -> None: self._start_date = start_date def test(self, to_test: Track) -> bool: - return self._start_date <= to_test.detections[0].occurrence + return self._start_date <= to_test.first_detection.occurrence class TrackEndsBeforeOrAtDate(TrackPredicate): @@ -60,7 +60,7 @@ def __init__(self, end_date: datetime) -> None: self._end_date = end_date def test(self, to_test: Track) -> bool: - return to_test.detections[0].occurrence <= self._end_date + return to_test.first_detection.occurrence <= self._end_date class TrackHasClassifications(TrackPredicate): @@ -83,7 +83,7 @@ def test(self, to_test: Track) -> bool: to_test (Track): the track under test Returns: - bool: `True` if track has classification. Otherwise `False`. + bool: `True` if track has classification. Otherwise, `False`. """ return to_test.classification in self._classifications @@ -92,7 +92,6 @@ class TrackFilter(Filter[Track, bool]): """A `Track` filter. Args: - Filter (Filter[Track, bool]): extends the `Filter` interface predicate (Predicate[Track, bool]): the predicate to test against during filtering """ diff --git a/OTAnalytics/domain/track_dataset.py b/OTAnalytics/domain/track_dataset.py index 56d2d86f1..40987c961 100644 --- a/OTAnalytics/domain/track_dataset.py +++ b/OTAnalytics/domain/track_dataset.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from datetime import datetime from typing import Callable, Iterable, Iterator, Optional, Sequence from OTAnalytics.domain.event import Event @@ -17,6 +18,21 @@ class TrackDataset(ABC): def __iter__(self) -> Iterator[Track]: yield from self.as_list() + @property + @abstractmethod + def first_occurrence(self) -> datetime | None: + raise NotImplementedError + + @property + @abstractmethod + def last_occurrence(self) -> datetime | None: + raise NotImplementedError + + @property + @abstractmethod + def classifications(self) -> frozenset[str]: + raise NotImplementedError + @abstractmethod def add_all(self, other: Iterable[Track]) -> "TrackDataset": raise NotImplementedError diff --git a/OTAnalytics/domain/track_repository.py b/OTAnalytics/domain/track_repository.py index 0e9b74a59..38b9091c5 100644 --- a/OTAnalytics/domain/track_repository.py +++ b/OTAnalytics/domain/track_repository.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from datetime import datetime from pathlib import Path from typing import Iterable, Optional @@ -99,6 +100,18 @@ def __init__(self, track_ids: list[TrackId], message: str): class TrackRepository: + @property + def first_occurrence(self) -> datetime | None: + return self._dataset.first_occurrence + + @property + def last_occurrence(self) -> datetime | None: + return self._dataset.last_occurrence + + @property + def classifications(self) -> frozenset[str]: + return self._dataset.classifications + def __init__(self, dataset: TrackDataset) -> None: self._dataset = dataset self.observers = Subject[TrackRepositoryEvent]() diff --git a/OTAnalytics/plugin_datastore/python_track_store.py b/OTAnalytics/plugin_datastore/python_track_store.py index 66b3b0ec8..e97a58b74 100644 --- a/OTAnalytics/plugin_datastore/python_track_store.py +++ b/OTAnalytics/plugin_datastore/python_track_store.py @@ -233,6 +233,24 @@ def calculate(self, detections: list[Detection]) -> str: class PythonTrackDataset(TrackDataset): """Pure Python implementation of a TrackDataset.""" + @property + def first_occurrence(self) -> datetime | None: + if not len(self): + return None + return min( + [track.first_detection.occurrence for track in self._tracks.values()] + ) + + @property + def last_occurrence(self) -> datetime | None: + if not len(self): + return None + return max([track.last_detection.occurrence for track in self._tracks.values()]) + + @property + def classifications(self) -> frozenset[str]: + return frozenset([track.classification for track in self._tracks.values()]) + def __init__( self, values: Optional[dict[TrackId, Track]] = None, diff --git a/OTAnalytics/plugin_datastore/track_store.py b/OTAnalytics/plugin_datastore/track_store.py index c3a156a52..5aff4dff8 100644 --- a/OTAnalytics/plugin_datastore/track_store.py +++ b/OTAnalytics/plugin_datastore/track_store.py @@ -201,6 +201,24 @@ def extract_hostname(name: str) -> str: class PandasTrackDataset(TrackDataset): + @property + def first_occurrence(self) -> datetime | None: + if not len(self): + return None + return self._dataset.index.get_level_values(LEVEL_OCCURRENCE).min() + + @property + def last_occurrence(self) -> datetime | None: + if not len(self): + return None + return self._dataset.index.get_level_values(LEVEL_OCCURRENCE).max() + + @property + def classifications(self) -> frozenset[str]: + if not len(self): + return frozenset() + return frozenset(self._dataset[track.TRACK_CLASSIFICATION].unique()) + def __init__( self, track_geometry_factory: TRACK_GEOMETRY_FACTORY, diff --git a/OTAnalytics/plugin_prototypes/track_visualization/track_viz.py b/OTAnalytics/plugin_prototypes/track_visualization/track_viz.py index 5b45b5dc7..4e9f82601 100644 --- a/OTAnalytics/plugin_prototypes/track_visualization/track_viz.py +++ b/OTAnalytics/plugin_prototypes/track_visualization/track_viz.py @@ -120,7 +120,7 @@ def __init__(self, default_palette: dict[str, str]) -> None: self._default_palette = default_palette self._palette: dict[str, str] = {} - def update(self, classifications: set[str]) -> None: + def update(self, classifications: frozenset[str]) -> None: for classification in classifications: if classification in self._default_palette.keys(): self._palette[classification] = self._default_palette[classification] diff --git a/tests/OTAnalytics/application/test_state.py b/tests/OTAnalytics/application/test_state.py index c3abfa4f0..7bf333c98 100644 --- a/tests/OTAnalytics/application/test_state.py +++ b/tests/OTAnalytics/application/test_state.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta, timezone from typing import Callable, Optional -from unittest.mock import Mock, call, patch +from unittest.mock import Mock, call import pytest @@ -342,66 +342,32 @@ def track( track.detections = [first_detection, second_detection, third_detection] return track - @patch("OTAnalytics.application.state.TracksMetadata._get_all_track_detections") - def test_update_detection_occurrences( - self, - mock_get_all_track_detections: Mock, - first_detection: Mock, - second_detection: Mock, - third_detection: Mock, - ) -> None: - mock_track_repository = Mock(spec=TrackRepository) + def test_update_detection_occurrences(self) -> None: + first_occurrence = datetime(2000, 1, 1, 12) + last_occurrence = datetime(2000, 1, 1, 15) + track_repository = Mock(spec=TrackRepository) + track_repository.first_occurrence = first_occurrence + track_repository.last_occurrence = last_occurrence - mock_get_all_track_detections.return_value = [ - first_detection, - third_detection, - second_detection, - ] - tracks_metadata = TracksMetadata(mock_track_repository) + tracks_metadata = TracksMetadata(track_repository) assert tracks_metadata.first_detection_occurrence is None assert tracks_metadata.last_detection_occurrence is None tracks_metadata._update_detection_occurrences() - assert tracks_metadata.first_detection_occurrence == first_detection.occurrence - assert tracks_metadata.last_detection_occurrence == third_detection.occurrence - - mock_get_all_track_detections.assert_called_once() + assert tracks_metadata.first_detection_occurrence == first_occurrence + assert tracks_metadata.last_detection_occurrence == last_occurrence - def test_get_all_track_detections( - self, first_detection: Mock, second_detection: Mock - ) -> None: - track = Mock(spec=Track).return_value - track.detections = [first_detection, second_detection] - track_repository = Mock(spec=TrackRepository) - track_repository.get_all.return_value = [track] - - tracks_metadata = TracksMetadata(track_repository) - detections = tracks_metadata._get_all_track_detections() - - assert detections == [first_detection, second_detection] - track_repository.get_all.assert_called_once() - - def test_update_classifications(self, track: Mock) -> None: + def test_update_classifications(self) -> None: + classifications = frozenset(["truck", "car", "pedestrian"]) mock_track_repository = Mock(spec=TrackRepository) - mock_track_repository.get_for.return_value = track + mock_track_repository.classifications = classifications tracks_metadata = TracksMetadata(mock_track_repository) - assert tracks_metadata.classifications == set() - tracks_metadata._update_classifications([track.id]) - - assert tracks_metadata.classifications == {"car"} - mock_track_repository.get_for.assert_any_call(track.id) - assert mock_track_repository.get_for.call_count == 1 - - track.detections[0].classification = "bicycle" - tracks_metadata._update_classifications([track.id]) - - assert tracks_metadata.classifications == {"car"} - mock_track_repository.get_for.assert_any_call(track.id) - assert mock_track_repository.get_for.call_count == 2 + tracks_metadata._update_classifications() + assert tracks_metadata.classifications == classifications def test_update_detection_classes(self) -> None: tracks_metadata = TracksMetadata(Mock()) diff --git a/tests/OTAnalytics/domain/test_track_repository.py b/tests/OTAnalytics/domain/test_track_repository.py index 4a2a3daaa..37b68008b 100644 --- a/tests/OTAnalytics/domain/test_track_repository.py +++ b/tests/OTAnalytics/domain/test_track_repository.py @@ -146,6 +146,27 @@ def test_len(self) -> None: assert result == expected_size dataset.__len__.assert_called_once() + def test_first_occurrence(self) -> None: + first_occurrence = Mock() + dataset = Mock() + dataset.first_occurrence = first_occurrence + repository = TrackRepository(dataset) + assert repository.first_occurrence == first_occurrence + + def test_last_occurrence(self) -> None: + last_occurrence = Mock() + dataset = Mock() + dataset.last_occurrence = last_occurrence + repository = TrackRepository(dataset) + assert repository.last_occurrence == last_occurrence + + def test_classifications(self) -> None: + classifications = Mock() + dataset = Mock() + dataset.classifications = classifications + repository = TrackRepository(dataset) + assert repository.classifications == classifications + class TestTrackFileRepository: @pytest.fixture diff --git a/tests/OTAnalytics/plugin_datastore/test_python_track_storage.py b/tests/OTAnalytics/plugin_datastore/test_python_track_storage.py index e4a93f017..954654e0d 100644 --- a/tests/OTAnalytics/plugin_datastore/test_python_track_storage.py +++ b/tests/OTAnalytics/plugin_datastore/test_python_track_storage.py @@ -481,3 +481,30 @@ def __create_leave_scene_event(self, track: Track) -> Event: ), video_name=track.last_detection.video_name, ) + + def test_first_occurrence(self, first_track: Track, second_track: Track) -> None: + dataset = PythonTrackDataset.from_list([second_track, first_track]) + assert dataset.first_occurrence == first_track.first_detection.occurrence + assert dataset.first_occurrence == second_track.first_detection.occurrence + + def test_last_occurrence(self, first_track: Track, second_track: Track) -> None: + dataset = PythonTrackDataset.from_list([second_track, first_track]) + assert dataset.last_occurrence == second_track.last_detection.occurrence + + def test_first_occurrence_on_empty_dataset(self) -> None: + dataset = PythonTrackDataset() + assert dataset.first_occurrence is None + + def test_last_occurrence_on_empty_dataset(self) -> None: + dataset = PythonTrackDataset() + assert dataset.last_occurrence is None + + def test_classifications(self, first_track: Track, second_track: Track) -> None: + dataset = PythonTrackDataset.from_list([first_track, second_track]) + assert dataset.classifications == frozenset( + [first_track.classification, second_track.classification] + ) + + def test_classifications_on_empty_dataset(self) -> None: + dataset = PythonTrackDataset() + assert dataset.classifications == frozenset() diff --git a/tests/OTAnalytics/plugin_datastore/test_track_store.py b/tests/OTAnalytics/plugin_datastore/test_track_store.py index 3b648057a..1cd08dfa5 100644 --- a/tests/OTAnalytics/plugin_datastore/test_track_store.py +++ b/tests/OTAnalytics/plugin_datastore/test_track_store.py @@ -474,3 +474,57 @@ def __create_leave_scene_event(self, track: Track) -> Event: ), video_name=track.last_detection.video_name, ) + + def test_first_occurrence( + self, + track_geometry_factory: TRACK_GEOMETRY_FACTORY, + first_track: Track, + second_track: Track, + ) -> None: + dataset = PandasTrackDataset.from_list( + [first_track, second_track], track_geometry_factory + ) + assert dataset.first_occurrence == first_track.first_detection.occurrence + assert dataset.first_occurrence == second_track.first_detection.occurrence + + def test_last_occurrence( + self, + track_geometry_factory: TRACK_GEOMETRY_FACTORY, + first_track: Track, + second_track: Track, + ) -> None: + dataset = PandasTrackDataset.from_list( + [first_track, second_track], track_geometry_factory + ) + assert dataset.last_occurrence == second_track.last_detection.occurrence + + def test_first_occurrence_on_empty_dataset( + self, track_geometry_factory: TRACK_GEOMETRY_FACTORY + ) -> None: + dataset = PandasTrackDataset(track_geometry_factory) + assert dataset.first_occurrence is None + + def test_last_occurrence_on_empty_dataset( + self, track_geometry_factory: TRACK_GEOMETRY_FACTORY + ) -> None: + dataset = PandasTrackDataset(track_geometry_factory) + assert dataset.last_occurrence is None + + def test_classifications( + self, + track_geometry_factory: TRACK_GEOMETRY_FACTORY, + first_track: Track, + second_track: Track, + ) -> None: + dataset = PandasTrackDataset.from_list( + [first_track, second_track], track_geometry_factory + ) + assert dataset.classifications == frozenset( + [first_track.classification, second_track.classification] + ) + + def test_classifications_on_empty_dataset( + self, track_geometry_factory: TRACK_GEOMETRY_FACTORY + ) -> None: + dataset = PandasTrackDataset(track_geometry_factory) + assert dataset.classifications == frozenset() diff --git a/tests/OTAnalytics/plugin_prototypes/track_visualization/test_track_viz.py b/tests/OTAnalytics/plugin_prototypes/track_visualization/test_track_viz.py index 605214efd..d179cdb55 100644 --- a/tests/OTAnalytics/plugin_prototypes/track_visualization/test_track_viz.py +++ b/tests/OTAnalytics/plugin_prototypes/track_visualization/test_track_viz.py @@ -306,7 +306,7 @@ class TestColorPaletteProvider: ) def test_update_with_filled_default( self, - new_classifications: set[str], + new_classifications: frozenset[str], default_palette: dict[str, str], expected: dict[str, str], ) -> None: