diff --git a/.run/pytest-in-tests.run.xml b/.run/pytest-in-tests.run.xml
index 5f90722c2..edc3ff2b2 100644
--- a/.run/pytest-in-tests.run.xml
+++ b/.run/pytest-in-tests.run.xml
@@ -4,7 +4,7 @@
-
+
diff --git a/OTAnalytics/adapter_ui/view_model.py b/OTAnalytics/adapter_ui/view_model.py
index ffa2c9f75..f0d03d318 100644
--- a/OTAnalytics/adapter_ui/view_model.py
+++ b/OTAnalytics/adapter_ui/view_model.py
@@ -335,3 +335,27 @@ def set_frame_track_plotting(
@abstractmethod
def set_analysis_frame(self, frame: AbstractFrame) -> None:
raise NotImplementedError
+
+ @abstractmethod
+ def next_frame(self) -> None:
+ pass
+
+ @abstractmethod
+ def previous_frame(self) -> None:
+ pass
+
+ @abstractmethod
+ def update_skip_time(self, seconds: int, frames: int) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_skip_seconds(self) -> int:
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_skip_frames(self) -> int:
+ raise NotImplementedError
+
+ @abstractmethod
+ def set_video_control_frame(self, frame: AbstractFrame) -> None:
+ raise NotImplementedError
diff --git a/OTAnalytics/application/application.py b/OTAnalytics/application/application.py
index 700ec1adc..572c03db8 100644
--- a/OTAnalytics/application/application.py
+++ b/OTAnalytics/application/application.py
@@ -17,6 +17,10 @@
TrackViewState,
VideosMetadata,
)
+from OTAnalytics.application.ui.frame_control import (
+ SwitchToNextFrame,
+ SwitchToPreviousFrame,
+)
from OTAnalytics.application.use_cases.config import SaveOtconfig
from OTAnalytics.application.use_cases.create_events import (
CreateEvents,
@@ -98,6 +102,8 @@ def __init__(
start_new_project: StartNewProject,
project_updater: ProjectUpdater,
load_track_files: LoadTrackFiles,
+ previous_frame: SwitchToPreviousFrame,
+ next_frame: SwitchToNextFrame,
) -> None:
self._datastore: Datastore = datastore
self.track_state: TrackState = track_state
@@ -129,6 +135,8 @@ def __init__(
self._track_repository_size = TrackRepositorySize(
self._datastore._track_repository
)
+ self._previous_frame = previous_frame
+ self._next_frame = next_frame
def connect_observers(self) -> None:
"""
@@ -449,6 +457,12 @@ def get_current_track_offset(self) -> Optional[RelativeOffsetCoordinate]:
"""
return self.track_view_state.track_offset.get()
+ def next_frame(self) -> None:
+ self._next_frame.set_next_frame()
+
+ def previous_frame(self) -> None:
+ self._previous_frame.set_previous_frame()
+
def update_date_range_tracks_filter(self, date_range: DateRange) -> None:
"""Update the date range of the track filter.
diff --git a/OTAnalytics/application/datastore.py b/OTAnalytics/application/datastore.py
index 5f8ce0e65..464c5ba91 100644
--- a/OTAnalytics/application/datastore.py
+++ b/OTAnalytics/application/datastore.py
@@ -61,6 +61,12 @@ def duration(self) -> timedelta:
return self.expected_duration
return timedelta(seconds=self.number_of_frames / self.recorded_fps)
+ @property
+ def fps(self) -> float:
+ if self.actual_fps:
+ return self.actual_fps
+ return self.recorded_fps
+
@dataclass(frozen=True)
class TrackParseResult:
@@ -120,7 +126,7 @@ def serialize(
class VideoParser(ABC):
@abstractmethod
- def parse(self, file: Path) -> Video:
+ def parse(self, file: Path, start_date: Optional[datetime]) -> Video:
pass
@abstractmethod
@@ -294,7 +300,7 @@ def load_video_files(self, files: list[Path]) -> None:
videos = []
for file in files:
try:
- videos.append(self._video_parser.parse(file))
+ videos.append(self._video_parser.parse(file, None))
except Exception as cause:
raised_exceptions.append(cause)
if raised_exceptions:
@@ -514,7 +520,11 @@ def get_video_for(self, track_id: TrackId) -> Optional[Video]:
def get_all_videos(self) -> list[Video]:
return self._video_repository.get_all()
- def get_image_of_track(self, track_id: TrackId) -> Optional[TrackImage]:
+ def get_image_of_track(
+ self,
+ track_id: TrackId,
+ frame: int = 0,
+ ) -> Optional[TrackImage]:
"""
Retrieve an image for the given track.
@@ -525,4 +535,6 @@ def get_image_of_track(self, track_id: TrackId) -> Optional[TrackImage]:
Optional[TrackImage]: an image of the track if the track is available and
the image can be loaded
"""
- return video.get_frame(0) if (video := self.get_video_for(track_id)) else None
+ return (
+ video.get_frame(frame) if (video := self.get_video_for(track_id)) else None
+ )
diff --git a/OTAnalytics/application/playback.py b/OTAnalytics/application/playback.py
new file mode 100644
index 000000000..d84bb4c96
--- /dev/null
+++ b/OTAnalytics/application/playback.py
@@ -0,0 +1,7 @@
+from dataclasses import dataclass
+
+
+@dataclass(frozen=True)
+class SkipTime:
+ seconds: int
+ frames: int
diff --git a/OTAnalytics/application/plotting.py b/OTAnalytics/application/plotting.py
index 8fe4d66e8..0fd80f73d 100644
--- a/OTAnalytics/application/plotting.py
+++ b/OTAnalytics/application/plotting.py
@@ -1,4 +1,6 @@
-from abc import abstractmethod
+from abc import ABC, abstractmethod
+from datetime import datetime, timedelta
+from math import floor
from typing import Any, Callable, Generic, Iterable, Optional, Sequence, TypeVar
from OTAnalytics.application.state import (
@@ -6,8 +8,10 @@
ObservableProperty,
Plotter,
TrackViewState,
+ VideosMetadata,
)
from OTAnalytics.domain.track import TrackImage
+from OTAnalytics.domain.video import Video
class Layer:
@@ -98,16 +102,31 @@ def __add(self, image: TrackImage) -> None:
self._current_image = image
+class VisualizationTimeProvider(ABC):
+ @abstractmethod
+ def get_time(self) -> datetime:
+ raise NotImplementedError
+
+
+VideoProvider = Callable[[], list[Video]]
+
+
class TrackBackgroundPlotter(Plotter):
"""Plot video frame as background."""
- def __init__(self, track_view_state: TrackViewState) -> None:
- self._track_view_state = track_view_state
+ def __init__(
+ self,
+ video_provider: VideoProvider,
+ visualization_time_provider: VisualizationTimeProvider,
+ ) -> None:
+ self._video_provider = video_provider
+ self._visualization_time_provider = visualization_time_provider
def plot(self) -> Optional[TrackImage]:
- if videos := self._track_view_state.selected_videos.get():
- if len(videos) > 0:
- return videos[0].get_frame(0)
+ if videos := self._video_provider():
+ visualization_time = self._visualization_time_provider.get_time()
+ frame_number = videos[0].get_frame_number_for(visualization_time)
+ return videos[0].get_frame(frame_number)
return None
@@ -238,3 +257,50 @@ def _handle_remove(self, entities: Iterable[ENTITY]) -> None:
for entity in entities:
del self._plotter_mapping[entity]
del self._layer_mapping[entity]
+
+
+class GetCurrentVideoPath:
+ """
+ This use case provides the currently visible video path. It uses the current filters
+ end date to retrieve the corresponding file path.
+ """
+
+ def __init__(
+ self,
+ state: TrackViewState,
+ videos_metadata: VideosMetadata,
+ ) -> None:
+ self._state = state
+ self._videos_metadata = videos_metadata
+
+ def get_video(self) -> Optional[str]:
+ if end_date := self._state.filter_element.get().date_range.end_date:
+ if metadata := self._videos_metadata.get_metadata_for(end_date):
+ return metadata.path
+ return None
+
+
+class GetCurrentFrame:
+ """
+ This use case provides the currently visible frame. It uses the current filters
+ end date to retrieve the corresponding frame.
+ """
+
+ def __init__(
+ self,
+ state: TrackViewState,
+ videos_metadata: VideosMetadata,
+ ) -> None:
+ self._state = state
+ self._videos_metadata = videos_metadata
+
+ def get_frame_number(self) -> int:
+ if end_date := self._state.filter_element.get().date_range.end_date:
+ if metadata := self._videos_metadata.get_metadata_for(end_date):
+ time_in_video = end_date - metadata.start
+ if time_in_video < timedelta(0):
+ return 0
+ if time_in_video > metadata.duration:
+ return metadata.number_of_frames
+ return floor(metadata.fps * time_in_video.total_seconds())
+ return 0
diff --git a/OTAnalytics/application/state.py b/OTAnalytics/application/state.py
index 9d64d4948..947f58ba6 100644
--- a/OTAnalytics/application/state.py
+++ b/OTAnalytics/application/state.py
@@ -1,9 +1,11 @@
+import bisect
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Callable, Generic, Optional
from OTAnalytics.application.config import DEFAULT_TRACK_OFFSET
from OTAnalytics.application.datastore import Datastore, VideoMetadata
+from OTAnalytics.application.playback import SkipTime
from OTAnalytics.application.use_cases.section_repository import GetSectionsById
from OTAnalytics.domain.date import DateRange
from OTAnalytics.domain.event import EventRepositoryEvent
@@ -185,6 +187,7 @@ def __init__(self) -> None:
self.selected_videos: ObservableProperty[list[Video]] = ObservableProperty[
list[Video]
](default=[])
+ self.skip_time = ObservableProperty[SkipTime](SkipTime(1, 0))
def reset(self) -> None:
"""Reset to default settings."""
@@ -404,19 +407,50 @@ def _update_image(self) -> None:
class VideosMetadata:
def __init__(self) -> None:
- self._metadata: list[VideoMetadata] = []
+ self._metadata: dict[datetime, VideoMetadata] = {}
+ self._first_video_start: Optional[datetime] = None
+ self._last_video_end: Optional[datetime] = None
def update(self, metadata: VideoMetadata) -> None:
- self._metadata.append(metadata)
- self._metadata.sort(key=lambda current: current.start)
+ """
+ Update the stored metadata.
+ """
+ if metadata.start in self._metadata.keys():
+ raise ValueError(
+ f"metadata with start date {metadata.start} already exists."
+ )
+ self._metadata[metadata.start] = metadata
+ self._metadata = dict(sorted(self._metadata.items()))
+ self._update_start_end_by(metadata)
+
+ def _update_start_end_by(self, metadata: VideoMetadata) -> None:
+ if (not self._first_video_start) or metadata.start < self._first_video_start:
+ self._first_video_start = metadata.start
+ if (not self._last_video_end) or metadata.end > self._last_video_end:
+ self._last_video_end = metadata.end
+
+ def get_metadata_for(self, current: datetime) -> Optional[VideoMetadata]:
+ """
+ Find the metadata for the given datetime. If the datetime matches exactly a
+ start time of a video, the corresponding VideoMetadata is returned. Otherwise,
+ the metadata of the video containing the datetime will be returned.
+ """
+ if current in self._metadata:
+ return self._metadata[current]
+ keys = list(self._metadata.keys())
+ key = bisect.bisect_left(keys, current) - 1
+ metadata = self._metadata[keys[key]]
+ if metadata.start <= current <= metadata.end:
+ return metadata
+ return None
@property
def first_video_start(self) -> Optional[datetime]:
- return self._metadata[0].recorded_start_date if self._metadata else None
+ return self._first_video_start
@property
def last_video_end(self) -> Optional[datetime]:
- return self._metadata[-1].end if self._metadata else None
+ return self._last_video_end
class TracksMetadata(TrackListObserver):
diff --git a/OTAnalytics/application/ui/__init__.py b/OTAnalytics/application/ui/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/OTAnalytics/application/ui/frame_control.py b/OTAnalytics/application/ui/frame_control.py
new file mode 100644
index 000000000..3d613026a
--- /dev/null
+++ b/OTAnalytics/application/ui/frame_control.py
@@ -0,0 +1,57 @@
+from datetime import timedelta
+
+from OTAnalytics.application.state import TrackViewState, VideosMetadata
+from OTAnalytics.domain.date import DateRange
+
+
+class SwitchToNextFrame:
+ def __init__(self, state: TrackViewState, videos_metadata: VideosMetadata) -> None:
+ self._state = state
+ self._videos_metadata = videos_metadata
+
+ def set_next_frame(self) -> None:
+ if filter_element := self._state.filter_element.get():
+ current_date_range = filter_element.date_range
+ if current_date_range.start_date and current_date_range.end_date:
+ if metadata := self._videos_metadata.get_metadata_for(
+ current_date_range.end_date
+ ):
+ fps = metadata.fps
+ skip_time = self._state.skip_time.get()
+ subseconds = min(skip_time.frames, fps) / fps
+ milliseconds = subseconds * 1000
+ current_skip = timedelta(
+ seconds=skip_time.seconds, milliseconds=milliseconds
+ )
+ next_start = current_date_range.start_date + current_skip
+ next_end = current_date_range.end_date + current_skip
+ next_date_range = DateRange(next_start, next_end)
+ self._state.filter_element.set(
+ filter_element.derive_date(next_date_range)
+ )
+
+
+class SwitchToPreviousFrame:
+ def __init__(self, state: TrackViewState, videos_metadata: VideosMetadata) -> None:
+ self._state = state
+ self._videos_metadata = videos_metadata
+
+ def set_previous_frame(self) -> None:
+ if filter_element := self._state.filter_element.get():
+ current_date_range = filter_element.date_range
+ if current_date_range.start_date and current_date_range.end_date:
+ if metadata := self._videos_metadata.get_metadata_for(
+ current_date_range.end_date
+ ):
+ fps = metadata.fps
+ skip_time = self._state.skip_time.get()
+ subseconds = min(skip_time.frames, fps) / fps
+ current_skip = timedelta(seconds=skip_time.seconds) + timedelta(
+ seconds=subseconds
+ )
+ next_start = current_date_range.start_date - current_skip
+ next_end = current_date_range.end_date - current_skip
+ next_date_range = DateRange(next_start, next_end)
+ self._state.filter_element.set(
+ filter_element.derive_date(next_date_range)
+ )
diff --git a/OTAnalytics/application/use_cases/__init__.py b/OTAnalytics/application/use_cases/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/OTAnalytics/domain/track.py b/OTAnalytics/domain/track.py
index 1bced9fc1..7393a1413 100644
--- a/OTAnalytics/domain/track.py
+++ b/OTAnalytics/domain/track.py
@@ -241,6 +241,9 @@ def height(self) -> int:
"""
pass
+ def save(self, name: str) -> None:
+ self.as_image().save(name)
+
@dataclass(frozen=True)
class PilImage(TrackImage):
diff --git a/OTAnalytics/domain/video.py b/OTAnalytics/domain/video.py
index b6d92e189..72bf030d6 100644
--- a/OTAnalytics/domain/video.py
+++ b/OTAnalytics/domain/video.py
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
+from datetime import datetime, timedelta
+from math import floor
from os import path
from os.path import normcase, splitdrive
from pathlib import Path
@@ -12,6 +14,10 @@
class VideoReader(ABC):
+ @abstractmethod
+ def get_fps(self, video: Path) -> float:
+ raise NotImplementedError
+
@abstractmethod
def get_frame(self, video: Path, index: int) -> TrackImage:
"""Get frame of `video` at `index`.
@@ -23,8 +29,22 @@ def get_frame(self, video: Path, index: int) -> TrackImage:
"""
pass
+ @abstractmethod
+ def get_frame_number_for(self, video_path: Path, date: timedelta) -> int:
+ raise NotImplementedError
+
class Video(ABC):
+ @property
+ @abstractmethod
+ def start_date(self) -> Optional[datetime]:
+ raise NotImplementedError
+
+ @property
+ @abstractmethod
+ def fps(self) -> float:
+ raise NotImplementedError
+
@abstractmethod
def get_path(self) -> Path:
pass
@@ -33,6 +53,10 @@ def get_path(self) -> Path:
def get_frame(self, index: int) -> TrackImage:
pass
+ @abstractmethod
+ def get_frame_number_for(self, date: datetime) -> int:
+ raise NotImplementedError
+
@abstractmethod
def to_dict(
self,
@@ -73,6 +97,16 @@ class SimpleVideo(Video):
video_reader: VideoReader
path: Path
+ _start_date: Optional[datetime]
+ _fps: Optional[int] = None
+
+ @property
+ def start_date(self) -> Optional[datetime]:
+ return self._start_date
+
+ @property
+ def fps(self) -> float:
+ return self._fps if self._fps else self.video_reader.get_fps(self.path)
def __post_init__(self) -> None:
self.check_path_exists()
@@ -95,6 +129,15 @@ def get_frame(self, index: int) -> TrackImage:
"""
return self.video_reader.get_frame(self.path, index)
+ def get_frame_number_for(self, date: datetime) -> int:
+ if not self.start_date:
+ return 0
+ time_in_video = date - self.start_date
+ if time_in_video < timedelta(0):
+ return 0
+
+ return floor(self.fps * time_in_video.total_seconds())
+
def to_dict(
self,
relative_to: Path,
diff --git a/OTAnalytics/plugin_parser/otvision_parser.py b/OTAnalytics/plugin_parser/otvision_parser.py
index d63b60519..3155d3aa0 100644
--- a/OTAnalytics/plugin_parser/otvision_parser.py
+++ b/OTAnalytics/plugin_parser/otvision_parser.py
@@ -798,8 +798,8 @@ class SimpleVideoParser(VideoParser):
def __init__(self, video_reader: VideoReader) -> None:
self._video_reader = video_reader
- def parse(self, file: Path) -> Video:
- return SimpleVideo(self._video_reader, file)
+ def parse(self, file: Path, start_date: Optional[datetime]) -> Video:
+ return SimpleVideo(self._video_reader, file, start_date)
def parse_list(
self,
@@ -816,7 +816,7 @@ def __create_video(
if PATH not in entry:
raise MissingPath(entry)
video_path = Path(base_folder, entry[PATH])
- return self.parse(video_path)
+ return self.parse(video_path, None)
def convert(
self,
@@ -833,6 +833,14 @@ class CachedVideo(Video):
other: Video
cache: dict[int, TrackImage] = field(default_factory=dict)
+ @property
+ def start_date(self) -> Optional[datetime]:
+ return self.other.start_date
+
+ @property
+ def fps(self) -> float:
+ return self.other.fps
+
def get_path(self) -> Path:
return self.other.get_path()
@@ -843,6 +851,9 @@ def get_frame(self, index: int) -> TrackImage:
self.cache[index] = new_frame
return new_frame
+ def get_frame_number_for(self, date: datetime) -> int:
+ return self.other.get_frame_number_for(date)
+
def to_dict(self, relative_to: Path) -> dict:
return self.other.to_dict(relative_to)
@@ -851,8 +862,8 @@ class CachedVideoParser(VideoParser):
def __init__(self, other: VideoParser) -> None:
self._other = other
- def parse(self, file: Path) -> Video:
- other_video = self._other.parse(file)
+ def parse(self, file: Path, start_date: Optional[datetime]) -> Video:
+ other_video = self._other.parse(file, start_date)
return self.__create_cached_video(other_video)
def __create_cached_video(self, other_video: Video) -> Video:
@@ -882,9 +893,14 @@ def parse(
content = parse_json_bz2(file)
metadata = content[ottrk_format.METADATA][ottrk_format.VIDEO]
video_file = metadata[ottrk_format.FILENAME] + metadata[ottrk_format.FILETYPE]
- video = self._video_parser.parse(file.parent / video_file)
+ start_date = self.__parse_recorded_start_date(metadata)
+ video = self._video_parser.parse(file.parent / video_file, start_date)
return track_ids, [video] * len(track_ids)
+ def __parse_recorded_start_date(self, metadata: dict) -> datetime:
+ start_date = metadata[ottrk_format.RECORDED_START_DATE]
+ return datetime.fromtimestamp(start_date, tz=timezone.utc)
+
class OtEventListParser(EventListParser):
def serialize(
diff --git a/OTAnalytics/plugin_prototypes/track_visualization/track_viz.py b/OTAnalytics/plugin_prototypes/track_visualization/track_viz.py
index bd827e810..8fede242d 100644
--- a/OTAnalytics/plugin_prototypes/track_visualization/track_viz.py
+++ b/OTAnalytics/plugin_prototypes/track_visualization/track_viz.py
@@ -8,11 +8,17 @@
from matplotlib.axes import Axes
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
+from matplotlib.patches import Rectangle
from mpl_toolkits.axes_grid1 import Divider, Size
from pandas import DataFrame
from PIL import Image
-from OTAnalytics.application.plotting import DynamicLayersPlotter, EntityPlotterFactory
+from OTAnalytics.application.plotting import (
+ DynamicLayersPlotter,
+ EntityPlotterFactory,
+ GetCurrentFrame,
+ GetCurrentVideoPath,
+)
from OTAnalytics.application.state import (
FlowState,
Plotter,
@@ -32,11 +38,15 @@
SectionRepositoryEvent,
)
from OTAnalytics.domain.track import (
+ H,
PilImage,
Track,
TrackId,
TrackIdProvider,
TrackImage,
+ W,
+ X,
+ Y,
)
from OTAnalytics.domain.track_repository import (
TrackListObserver,
@@ -46,6 +56,9 @@
from OTAnalytics.plugin_datastore.track_store import PandasTrackDataset
from OTAnalytics.plugin_filter.dataframe_filter import DataFrameFilterBuilder
+"""Frames start with 1 in OTVision but frames of videos are loaded zero based."""
+FRAME_OFFSET = 1
+
ENCODING = "UTF-8"
DPI = 100
@@ -371,12 +384,10 @@ class PandasTrackProvider(PandasDataFrameProvider):
def __init__(
self,
track_repository: TrackRepository,
- track_view_state: TrackViewState,
filter_builder: DataFrameFilterBuilder,
progressbar: ProgressbarBuilder,
) -> None:
self._track_repository = track_repository
- self._track_view_state = track_view_state
self._filter_builder = filter_builder
self._progressbar = progressbar
@@ -463,13 +474,10 @@ class CachedPandasTrackProvider(PandasTrackProvider, TrackListObserver):
def __init__(
self,
track_repository: TrackRepository,
- track_view_state: TrackViewState,
filter_builder: DataFrameFilterBuilder,
progressbar: ProgressbarBuilder,
) -> None:
- super().__init__(
- track_repository, track_view_state, filter_builder, progressbar
- )
+ super().__init__(track_repository, filter_builder, progressbar)
track_repository.register_tracks_observer(self)
self._cache_df: DataFrame = DataFrame()
@@ -679,6 +687,136 @@ def _plot_dataframe(self, track_df: DataFrame, axes: Axes) -> None:
)
+class FilterByVideo(PandasDataFrameProvider):
+ """
+ Filter the data of the other data provider using the video name / path of the
+ currently displayed video.
+ """
+
+ def __init__(
+ self, data_provider: PandasDataFrameProvider, current_video: GetCurrentVideoPath
+ ) -> None:
+ self._data_provider = data_provider
+ self._current_video = current_video
+
+ def get_data(self) -> DataFrame:
+ track_df = self._data_provider.get_data()
+ if track_df.empty:
+ return track_df
+ current_video = self._current_video.get_video()
+ return track_df[track_df[track.VIDEO_NAME] == current_video]
+
+
+class FilterByFrame(PandasDataFrameProvider):
+ """
+ Filter the data of the other data provider using the frame number of the
+ currently displayed frame. If multiple videos are loaded, the filter will return all
+ detections for the given frame number. If only the frame of the currently displayed
+ video should be shown, combine this filter with FilterByVideo.
+ """
+
+ def __init__(
+ self,
+ data_provider: PandasDataFrameProvider,
+ current_frame: GetCurrentFrame,
+ ) -> None:
+ self._data_provider = data_provider
+ self._current_frame = current_frame
+
+ def get_data(self) -> DataFrame:
+ track_df = self._data_provider.get_data()
+ if track_df.empty:
+ return track_df
+ current_frame = self._current_frame.get_frame_number() + FRAME_OFFSET
+ return track_df[track_df[track.FRAME] == current_frame]
+
+
+class TrackBoundingBoxPlotter(MatplotlibPlotterImplementation):
+ """Plot bounding boxes of detections."""
+
+ def __init__(
+ self,
+ data_provider: PandasDataFrameProvider,
+ color_palette_provider: ColorPaletteProvider,
+ track_view_state: TrackViewState,
+ alpha: float = 0.5,
+ ) -> None:
+ self._data_provider = data_provider
+ self._color_palette_provider = color_palette_provider
+ self._track_view_state = track_view_state
+ self._alpha = alpha
+
+ def plot(self, axes: Axes) -> None:
+ data = self._data_provider.get_data()
+ if not data.empty:
+ self._plot_dataframe(data, axes)
+
+ def _plot_dataframe(self, track_df: DataFrame, axes: Axes) -> None:
+ """
+ Plot given tracks on the given axes with the given transparency (alpha)
+
+ Args:
+ track_df (DataFrame): tracks to plot
+ alpha (float): transparency of the lines
+ axes (Axes): axes to plot on
+ """
+ for index, row in track_df.iterrows():
+ x = row[X]
+ y = row[Y]
+ width = row[W]
+ height = row[H]
+ classification = row[track.TRACK_CLASSIFICATION]
+ color = self._color_palette_provider.get()[classification]
+ axes.add_patch(
+ Rectangle(
+ xy=(x, y),
+ width=width,
+ height=height,
+ fc="none",
+ linewidth=0.5,
+ color=color,
+ alpha=0.5,
+ )
+ )
+
+
+class TrackPointPlotter(MatplotlibPlotterImplementation):
+ """Plot point of bounding boxes of detections."""
+
+ def __init__(
+ self,
+ data_provider: PandasDataFrameProvider,
+ color_palette_provider: ColorPaletteProvider,
+ alpha: float = 0.5,
+ ) -> None:
+ self._data_provider = data_provider
+ self._color_palette_provider = color_palette_provider
+ self._alpha = alpha
+
+ def plot(self, axes: Axes) -> None:
+ data = self._data_provider.get_data()
+ if not data.empty:
+ self._plot_dataframe(data, axes)
+
+ def _plot_dataframe(self, track_df: DataFrame, axes: Axes) -> None:
+ """
+ Plot given tracks on the given axes with the given transparency (alpha)
+
+ Args:
+ track_df (DataFrame): tracks to plot
+ axes (Axes): axes to plot on
+ """
+ for index, row in track_df.iterrows():
+ classification = row[track.TRACK_CLASSIFICATION]
+ color = self._color_palette_provider.get()[classification]
+ axes.plot(
+ row[X],
+ row[Y],
+ marker="o",
+ color=color,
+ )
+
+
class MatplotlibTrackPlotter(TrackPlotter):
"""
Implementation of the TrackPlotter interface using matplotlib.
diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py b/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py
index ce123fd63..92d7928d5 100644
--- a/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py
+++ b/OTAnalytics/plugin_ui/customtkinter_gui/dummy_viewmodel.py
@@ -48,6 +48,7 @@
)
from OTAnalytics.application.datastore import FlowParser, NoSectionsToSave
from OTAnalytics.application.logger import logger
+from OTAnalytics.application.playback import SkipTime
from OTAnalytics.application.use_cases.config import MissingDate
from OTAnalytics.application.use_cases.cut_tracks_with_sections import CutTracksDto
from OTAnalytics.application.use_cases.export_events import (
@@ -191,6 +192,7 @@ def __init__(
self._frame_tracks: Optional[AbstractFrameTracks] = None
self._frame_videos: Optional[AbstractFrame] = None
self._frame_canvas: Optional[AbstractFrameCanvas] = None
+ self._frame_video_control: Optional[AbstractFrame] = None
self._frame_sections: Optional[AbstractFrame] = None
self._frame_flows: Optional[AbstractFrame] = None
self._frame_filter: Optional[AbstractFrameFilter] = None
@@ -1354,6 +1356,12 @@ def change_track_offset_to_section_offset(self) -> None:
raise MissingInjectedInstanceError(type(self._frame_tracks).__name__)
self.update_section_offset_button_state()
+ def next_frame(self) -> None:
+ self._application.next_frame()
+
+ def previous_frame(self) -> None:
+ self._application.previous_frame()
+
def validate_date(self, date: str) -> bool:
return any(
[validate_date(date, date_format) for date_format in SUPPORTED_FORMATS]
@@ -1591,3 +1599,27 @@ def on_tracks_cut(self, cut_tracks_dto: CutTracksDto) -> None:
def set_analysis_frame(self, frame: AbstractFrame) -> None:
self._frame_analysis = frame
+
+ def update_skip_time(self, seconds: int, frames: int) -> None:
+ self._application.track_view_state.skip_time.set(SkipTime(seconds, frames))
+
+ def get_skip_seconds(self) -> int:
+ return self._application.track_view_state.skip_time.get().seconds
+
+ def get_skip_frames(self) -> int:
+ return self._application.track_view_state.skip_time.get().frames
+
+ def set_video_control_frame(self, frame: AbstractFrame) -> None:
+ self._frame_video_control = frame
+ self.notify_filter_element_change(
+ self._application.track_view_state.filter_element.get()
+ )
+
+ def notify_filter_element_change(self, filter_element: FilterElement) -> None:
+ if not self._frame_video_control:
+ raise MissingInjectedInstanceError("Frame video control missing")
+ filter_element_is_set = (
+ filter_element.date_range.start_date is not None
+ and filter_element.date_range.end_date is not None
+ )
+ self._frame_video_control.set_enabled_general_buttons(filter_element_is_set)
diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/frame_video_control.py b/OTAnalytics/plugin_ui/customtkinter_gui/frame_video_control.py
new file mode 100644
index 000000000..1ab330987
--- /dev/null
+++ b/OTAnalytics/plugin_ui/customtkinter_gui/frame_video_control.py
@@ -0,0 +1,66 @@
+import tkinter
+from typing import Any
+
+from customtkinter import CTkButton, CTkEntry, CTkLabel
+
+from OTAnalytics.adapter_ui.view_model import ViewModel
+from OTAnalytics.plugin_ui.customtkinter_gui.abstract_ctk_frame import AbstractCTkFrame
+from OTAnalytics.plugin_ui.customtkinter_gui.constants import PADX, STICKY
+
+
+class FrameVideoControl(AbstractCTkFrame):
+ def __init__(self, viewmodel: ViewModel, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+ self._viewmodel = viewmodel
+ self._seconds = tkinter.IntVar(value=viewmodel.get_skip_seconds())
+ self._frames = tkinter.IntVar(value=viewmodel.get_skip_frames())
+ self._get_widgets()
+ self._place_widgets()
+ self._wire_widgets()
+ self.introduce_to_viewmodel()
+
+ def introduce_to_viewmodel(self) -> None:
+ self._viewmodel.set_video_control_frame(self)
+
+ def _get_widgets(self) -> None:
+ self._button_next_frame = CTkButton(
+ master=self,
+ text=">",
+ command=self._viewmodel.next_frame,
+ )
+ self._button_previous_frame = CTkButton(
+ master=self,
+ text="<",
+ command=self._viewmodel.previous_frame,
+ )
+ self._label_seconds = CTkLabel(
+ master=self, text="Seconds", anchor="e", justify="right"
+ )
+ self._label_frames = CTkLabel(
+ master=self, text="Frames", anchor="e", justify="right"
+ )
+ self._entry_seconds = CTkEntry(master=self, textvariable=self._seconds)
+ self._entry_frames = CTkEntry(master=self, textvariable=self._frames)
+
+ def _place_widgets(self) -> None:
+ PADY = 10
+ self._button_previous_frame.grid(
+ row=0, column=1, rowspan=2, padx=PADX, pady=PADY, sticky=STICKY
+ )
+ self._label_seconds.grid(row=0, column=2, padx=PADX, pady=PADY, sticky=STICKY)
+ self._entry_seconds.grid(row=0, column=3, padx=PADX, pady=PADY, sticky=STICKY)
+ self._label_frames.grid(row=1, column=2, padx=PADX, pady=PADY, sticky=STICKY)
+ self._entry_frames.grid(row=1, column=3, padx=PADX, pady=PADY, sticky=STICKY)
+ self._button_next_frame.grid(
+ row=0, column=4, rowspan=2, padx=PADX, pady=PADY, sticky=STICKY
+ )
+
+ def _wire_widgets(self) -> None:
+ self._seconds.trace_add("write", callback=self._update_skip_time)
+ self._frames.trace_add("write", callback=self._update_skip_time)
+
+ def get_general_buttons(self) -> list[CTkButton]:
+ return [self._button_previous_frame, self._button_next_frame]
+
+ def _update_skip_time(self, name: str, other: str, mode: str) -> None:
+ self._viewmodel.update_skip_time(self._seconds.get(), self._frames.get())
diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/gui.py b/OTAnalytics/plugin_ui/customtkinter_gui/gui.py
index 61eec38ae..fdbd7b29c 100644
--- a/OTAnalytics/plugin_ui/customtkinter_gui/gui.py
+++ b/OTAnalytics/plugin_ui/customtkinter_gui/gui.py
@@ -25,6 +25,9 @@
FrameTrackPlotting,
)
from OTAnalytics.plugin_ui.customtkinter_gui.frame_tracks import TracksFrame
+from OTAnalytics.plugin_ui.customtkinter_gui.frame_video_control import (
+ FrameVideoControl,
+)
from OTAnalytics.plugin_ui.customtkinter_gui.frame_videos import FrameVideos
from OTAnalytics.plugin_ui.customtkinter_gui.helpers import get_widget_position
from OTAnalytics.plugin_ui.customtkinter_gui.messagebox import InfoBox
@@ -117,13 +120,20 @@ def __init__(
master=self,
viewmodel=self._viewmodel,
)
+ self._frame_video_control = FrameVideoControl(
+ master=self, viewmodel=self._viewmodel
+ )
self.grid_rowconfigure(0, weight=0)
self.grid_rowconfigure(1, weight=1)
+ self.grid_rowconfigure(2, weight=0)
self.grid_columnconfigure(0, weight=0)
self.grid_columnconfigure(1, weight=1)
self._frame_canvas.grid(row=0, column=0, pady=PADY, sticky=STICKY)
self._frame_track_plotting.grid(row=0, column=1, pady=PADY, sticky=STICKY)
self._frame_filter.grid(row=1, column=0, pady=PADY, sticky=STICKY)
+ self._frame_video_control.grid(
+ row=2, column=0, columnspan=2, pady=PADY, sticky=STICKY
+ )
class FrameNavigation(EmbeddedCTkScrollableFrame):
diff --git a/OTAnalytics/plugin_ui/main_application.py b/OTAnalytics/plugin_ui/main_application.py
index 43c7756ea..b8bf674c8 100644
--- a/OTAnalytics/plugin_ui/main_application.py
+++ b/OTAnalytics/plugin_ui/main_application.py
@@ -46,6 +46,10 @@
TrackViewState,
VideosMetadata,
)
+from OTAnalytics.application.ui.frame_control import (
+ SwitchToNextFrame,
+ SwitchToPreviousFrame,
+)
from OTAnalytics.application.use_cases.clear_repositories import ClearRepositories
from OTAnalytics.application.use_cases.create_events import (
CreateEvents,
@@ -264,10 +268,12 @@ def start_gui(self, run_config: RunConfiguration) -> None:
section_repository.register_section_changed_observer(
clear_all_intersections.on_section_changed
)
+ videos_metadata = VideosMetadata()
layers = self._create_layers(
datastore,
intersection_repository,
track_view_state,
+ videos_metadata,
flow_state,
section_state,
pulling_progressbar_builder,
@@ -289,7 +295,6 @@ def start_gui(self, run_config: RunConfiguration) -> None:
tracks_metadata._classifications.register(
observer=color_palette_provider.update
)
- videos_metadata = VideosMetadata()
action_state = self._create_action_state()
filter_element_settings_restorer = (
self._create_filter_element_setting_restorer()
@@ -388,6 +393,8 @@ def start_gui(self, run_config: RunConfiguration) -> None:
remove_tracks,
remove_section,
)
+ previous_frame = SwitchToPreviousFrame(track_view_state, videos_metadata)
+ next_frame = SwitchToNextFrame(track_view_state, videos_metadata)
application = OTAnalyticsApplication(
datastore,
track_state,
@@ -410,6 +417,8 @@ def start_gui(self, run_config: RunConfiguration) -> None:
start_new_project,
project_updater,
load_track_files,
+ previous_frame,
+ next_frame,
)
section_repository.register_sections_observer(cut_tracks_intersecting_section)
section_repository.register_section_changed_observer(
@@ -450,6 +459,9 @@ def start_gui(self, run_config: RunConfiguration) -> None:
application.action_state.action_running.register(
dummy_viewmodel._notify_action_running_state
)
+ application.track_view_state.filter_element.register(
+ dummy_viewmodel.notify_filter_element_change
+ )
# TODO: Refactor observers - move registering to subjects happening in
# constructor dummy_viewmodel
# cut_tracks_intersecting_section.register(
@@ -630,6 +642,7 @@ def _create_layers(
datastore: Datastore,
intersection_repository: IntersectionRepository,
track_view_state: TrackViewState,
+ videos_metadata: VideosMetadata,
flow_state: FlowState,
section_state: SectionState,
pulling_progressbar_builder: ProgressbarBuilder,
@@ -640,6 +653,7 @@ def _create_layers(
datastore,
intersection_repository,
track_view_state,
+ videos_metadata,
section_state,
color_palette_provider,
pulling_progressbar_builder,
diff --git a/OTAnalytics/plugin_ui/visualization/visualization.py b/OTAnalytics/plugin_ui/visualization/visualization.py
index 29b49ffa1..8a7a302e5 100644
--- a/OTAnalytics/plugin_ui/visualization/visualization.py
+++ b/OTAnalytics/plugin_ui/visualization/visualization.py
@@ -1,3 +1,4 @@
+from datetime import datetime, timezone
from typing import Callable, Optional, Sequence
from pandas import DataFrame
@@ -7,14 +8,18 @@
from OTAnalytics.application.datastore import Datastore
from OTAnalytics.application.plotting import (
CachedPlotter,
+ GetCurrentFrame,
+ GetCurrentVideoPath,
PlottingLayer,
TrackBackgroundPlotter,
+ VisualizationTimeProvider,
)
from OTAnalytics.application.state import (
FlowState,
Plotter,
SectionState,
TrackViewState,
+ VideosMetadata,
)
from OTAnalytics.application.use_cases.highlight_intersections import (
IntersectionRepository,
@@ -40,8 +45,10 @@
ColorPaletteProvider,
EventToFlowResolver,
FilterByClassification,
+ FilterByFrame,
FilterById,
FilterByOccurrence,
+ FilterByVideo,
FlowLayerPlotter,
MatplotlibTrackPlotter,
PandasDataFrameProvider,
@@ -49,10 +56,34 @@
PandasTracksOffsetProvider,
PlotterPrototype,
SectionLayerPlotter,
+ TrackBoundingBoxPlotter,
TrackGeometryPlotter,
+ TrackPointPlotter,
TrackStartEndPointPlotter,
)
+LONG_IN_THE_PAST = datetime(
+ year=1970,
+ month=1,
+ day=1,
+ hour=0,
+ minute=0,
+ second=0,
+ tzinfo=timezone.utc,
+)
+ALPHA_BOUNDING_BOX = 0.5
+
+
+class FilterEndDateProvider(VisualizationTimeProvider):
+ def __init__(self, state: TrackViewState) -> None:
+ self._state = state
+
+ def get_time(self) -> datetime:
+ if end_date := self._state.filter_element.get().date_range.end_date:
+ return end_date
+ return LONG_IN_THE_PAST
+
+
ALPHA_ALL_TRACKS_PLOTTER = 0.5
ALPHA_HIGHLIGHT_TRACKS = 1
ALPHA_HIGHLIGHT_TRACKS_NOT_ASSIGNED_TO_FLOWS = ALPHA_HIGHLIGHT_TRACKS
@@ -99,6 +130,7 @@ def __init__(
datastore: Datastore,
intersection_repository: IntersectionRepository,
track_view_state: TrackViewState,
+ videos_metadata: VideosMetadata,
section_state: SectionState,
color_palette_provider: ColorPaletteProvider,
pulling_progressbar_builder: ProgressbarBuilder,
@@ -113,9 +145,21 @@ def __init__(
self._flow_repository = datastore._flow_repository
self._intersection_repository = intersection_repository
self._event_repository = datastore._event_repository
+ self._get_current_frame = GetCurrentFrame(track_view_state, videos_metadata)
+ self._get_current_video = GetCurrentVideoPath(track_view_state, videos_metadata)
+ self._visualization_time_provider: VisualizationTimeProvider = (
+ FilterEndDateProvider(track_view_state)
+ )
self._pandas_data_provider: Optional[PandasDataFrameProvider] = None
+ self._pandas_data_provider_with_offset: Optional[PandasDataFrameProvider] = None
self._data_provider_all_filters: Optional[PandasDataFrameProvider] = None
+ self._data_provider_all_filters_with_offset: Optional[
+ PandasDataFrameProvider
+ ] = None
self._data_provider_class_filter: Optional[PandasDataFrameProvider] = None
+ self._data_provider_class_filter_with_offset: Optional[
+ PandasDataFrameProvider
+ ] = None
self._tracks_intersection_selected_sections: Optional[
TracksIntersectingSelectedSections
] = None
@@ -128,7 +172,10 @@ def build(
flow_state: FlowState,
road_user_assigner: RoadUserAssigner,
) -> Sequence[PlottingLayer]:
- background_image_plotter = TrackBackgroundPlotter(self._track_view_state)
+ background_image_plotter = TrackBackgroundPlotter(
+ self._track_view_state.selected_videos.get,
+ self._visualization_time_provider,
+ )
all_tracks_plotter = self._create_all_tracks_plotter()
highlight_tracks_assigned_to_flows_plotter = (
self._create_highlight_tracks_assigned_to_flows_plotter(
@@ -140,6 +187,10 @@ def build(
road_user_assigner, flow_state
)
)
+
+ track_bounding_box_plotter = self._create_track_bounding_box_plotter()
+ track_point_plotter = self._create_track_point_plotter()
+
layer_definitions = [
("Background", background_image_plotter, True),
("Show all tracks", all_tracks_plotter, False),
@@ -178,6 +229,16 @@ def build(
highlight_tracks_not_assigned_to_flows_plotter,
False,
),
+ (
+ "Show bounding boxes of current frame",
+ track_bounding_box_plotter,
+ False,
+ ),
+ (
+ "Show track point of bounding boxes of current frame",
+ track_point_plotter,
+ False,
+ ),
]
return [
@@ -187,7 +248,7 @@ def build(
def _create_all_tracks_plotter(self) -> Plotter:
track_geometry_plotter = self._create_track_geometry_plotter(
- self._get_data_provider_all_filters(),
+ self._get_data_provider_all_filters_with_offset(),
self._color_palette_provider,
alpha=ALPHA_ALL_TRACKS_PLOTTER,
enable_legend=True,
@@ -221,7 +282,7 @@ def _create_start_end_point_intersecting_sections_plotter(self) -> Plotter:
start_end_points_intersecting = self._create_cached_section_layer_plotter(
self._create_start_end_point_intersecting_section_factory(
self._create_tracks_start_end_point_intersecting_given_sections_filter(
- self._get_data_provider_class_filter(),
+ self._get_data_provider_class_filter_with_offset(),
self._create_tracks_intersecting_sections(),
self._create_get_sections_by_id(),
),
@@ -235,7 +296,7 @@ def _create_start_end_point_intersecting_sections_plotter(self) -> Plotter:
def _create_start_end_point_not_intersection_sections_plotter(self) -> Plotter:
section_filter = (
self._create_tracks_start_end_point_not_intersecting_given_sections_filter(
- self._get_data_provider_class_filter(),
+ self._get_data_provider_class_filter_with_offset(),
self._create_tracks_intersecting_sections(),
self._create_get_sections_by_id(),
)
@@ -252,7 +313,7 @@ def _create_start_end_point_not_intersection_sections_plotter(self) -> Plotter:
def _create_start_end_point_plotter(self) -> Plotter:
track_start_end_point_plotter = self._create_track_start_end_point_plotter(
self._create_track_start_end_point_data_provider(
- self._get_data_provider_class_filter()
+ self._get_data_provider_class_filter_with_offset()
),
self._color_palette_provider,
enable_legend=False,
@@ -268,7 +329,8 @@ def _create_highlight_tracks_assigned_to_flows_plotter(
return self._create_highlight_tracks_assigned_to_flow(
self._create_highlight_tracks_assigned_to_flows_factory(
self._create_tracks_assigned_to_flows_filter(
- self._get_data_provider_all_filters(), road_user_assigner
+ self._get_data_provider_all_filters_with_offset(),
+ road_user_assigner,
),
self._color_palette_provider,
alpha=ALPHA_HIGHLIGHT_TRACKS_ASSIGNED_TO_FLOWS,
@@ -281,7 +343,7 @@ def _create_highlight_tracks_not_assigned_to_flows_plotter(
self, road_user_assigner: RoadUserAssigner, flow_state: FlowState
) -> Plotter:
flows_filter = TracksNotAssignedToSelection(
- self._get_data_provider_all_filters(),
+ self._get_data_provider_all_filters_with_offset(),
road_user_assigner,
self._event_repository,
self._flow_repository,
@@ -302,16 +364,41 @@ def _get_data_provider_class_filter(self) -> PandasDataFrameProvider:
)
return self._data_provider_class_filter
+ def _get_data_provider_class_filter_with_offset(self) -> PandasDataFrameProvider:
+ if not self._data_provider_class_filter_with_offset:
+ self._data_provider_class_filter_with_offset = (
+ self._build_filter_by_classification(
+ self._get_pandas_data_provider_with_offset()
+ )
+ )
+ return self._data_provider_class_filter_with_offset
+
def _get_data_provider_all_filters(self) -> PandasDataFrameProvider:
if not self._data_provider_all_filters:
- self._data_provider_all_filters = self._build_filter_by_classification(
- self._create_filter_by_occurrence()
+ self._data_provider_all_filters = self._create_all_filters(
+ self._get_pandas_data_provider()
)
return self._data_provider_all_filters
- def _create_filter_by_occurrence(self) -> PandasDataFrameProvider:
+ def _get_data_provider_all_filters_with_offset(self) -> PandasDataFrameProvider:
+ if not self._data_provider_all_filters_with_offset:
+ self._data_provider_all_filters_with_offset = self._create_all_filters(
+ self._get_pandas_data_provider_with_offset()
+ )
+ return self._data_provider_all_filters_with_offset
+
+ def _create_all_filters(
+ self, data_provider: PandasDataFrameProvider
+ ) -> PandasDataFrameProvider:
+ return self._build_filter_by_classification(
+ self._create_filter_by_occurrence(data_provider)
+ )
+
+ def _create_filter_by_occurrence(
+ self, data_provider: PandasDataFrameProvider
+ ) -> PandasDataFrameProvider:
return FilterByOccurrence(
- self._get_pandas_data_provider(),
+ data_provider,
self._track_view_state,
self._create_dataframe_filter_builder(),
)
@@ -323,7 +410,7 @@ def _get_tracks_not_intersecting_selected_sections_filter(
self,
) -> Callable[[SectionId], PandasDataFrameProvider]:
return lambda section: FilterById(
- self._get_data_provider_all_filters(),
+ self._get_data_provider_all_filters_with_offset(),
TracksNotIntersectingSelection(
TracksIntersectingGivenSections(
{section},
@@ -344,15 +431,14 @@ def _build_filter_by_classification(
self._create_dataframe_filter_builder(),
)
- def _get_pandas_data_provider(self) -> PandasDataFrameProvider:
- if not self._pandas_data_provider:
- cached_pandas_track_provider = self._create_pandas_track_provider(
- self._pulling_progressbar_builder
+ def _get_pandas_data_provider_with_offset(self) -> PandasDataFrameProvider:
+ if not self._pandas_data_provider_with_offset:
+ self._pandas_data_provider_with_offset = (
+ self._wrap_pandas_track_offset_provider(
+ self._get_pandas_data_provider()
+ )
)
- self._pandas_data_provider = self._wrap_pandas_track_offset_provider(
- cached_pandas_track_provider
- )
- return self._pandas_data_provider
+ return self._pandas_data_provider_with_offset
def _wrap_plotter_with_cache(self, other: Plotter) -> Plotter:
"""
@@ -365,22 +451,15 @@ def _wrap_plotter_with_cache(self, other: Plotter) -> Plotter:
self._track_view_state.track_offset.register(invalidate)
return cached_plotter
- def _create_pandas_track_provider(
- self, progressbar: ProgressbarBuilder
- ) -> PandasTrackProvider:
+ def _get_pandas_data_provider(self) -> PandasDataFrameProvider:
dataframe_filter_builder = self._create_dataframe_filter_builder()
- return PandasTrackProvider(
- self._track_repository,
- self._track_view_state,
- dataframe_filter_builder,
- progressbar,
- )
- # return CachedPandasTrackProvider(
- # self._track_repository,
- # self._track_view_state,
- # dataframe_filter_builder,
- # progressbar,
- # )
+ if not self._pandas_data_provider:
+ self._pandas_data_provider = PandasTrackProvider(
+ self._track_repository,
+ dataframe_filter_builder,
+ self._pulling_progressbar_builder,
+ )
+ return self._pandas_data_provider
def _wrap_pandas_track_offset_provider(
self, other: PandasDataFrameProvider
@@ -437,7 +516,7 @@ def _get_tracks_intersecting_sections_filter(
self,
) -> Callable[[SectionId], PandasDataFrameProvider]:
return lambda section: FilterById(
- self._get_data_provider_all_filters(),
+ self._get_data_provider_all_filters_with_offset(),
TracksIntersectingGivenSections(
{section},
self._create_tracks_intersecting_sections(),
@@ -612,3 +691,38 @@ def _create_tracks_intersecting_sections(self) -> TracksIntersectingSections:
return SimpleTracksIntersectingSections(
GetAllTracks(self._track_repository),
)
+
+ def _create_track_bounding_box_plotter(
+ self,
+ ) -> Plotter:
+ track_plotter = MatplotlibTrackPlotter(
+ TrackBoundingBoxPlotter(
+ FilterByFrame(
+ FilterByVideo(
+ self._get_data_provider_class_filter(),
+ self._get_current_video,
+ ),
+ self._get_current_frame,
+ ),
+ self._color_palette_provider,
+ self._track_view_state,
+ alpha=ALPHA_BOUNDING_BOX,
+ ),
+ )
+ return PlotterPrototype(self._track_view_state, track_plotter)
+
+ def _create_track_point_plotter(self) -> Plotter:
+ track_plotter = MatplotlibTrackPlotter(
+ TrackPointPlotter(
+ FilterByFrame(
+ FilterByVideo(
+ self._get_data_provider_class_filter_with_offset(),
+ self._get_current_video,
+ ),
+ self._get_current_frame,
+ ),
+ self._color_palette_provider,
+ alpha=ALPHA_BOUNDING_BOX,
+ ),
+ )
+ return PlotterPrototype(self._track_view_state, track_plotter)
diff --git a/OTAnalytics/plugin_video_processing/video_reader.py b/OTAnalytics/plugin_video_processing/video_reader.py
index d7d22aadc..02c3dfc2f 100644
--- a/OTAnalytics/plugin_video_processing/video_reader.py
+++ b/OTAnalytics/plugin_video_processing/video_reader.py
@@ -1,3 +1,5 @@
+from datetime import timedelta
+from math import floor
from pathlib import Path
import cv2
@@ -19,6 +21,9 @@ class FrameDoesNotExistError(Exception):
class OpenCvVideoReader(VideoReader):
+ def get_fps(self, video_path: Path) -> float:
+ return self.__get_clip(video_path).get(cv2.CAP_PROP_FPS)
+
def get_frame(self, video_path: Path, index: int) -> TrackImage:
"""Get image of video at `frame`.
Args:
@@ -44,3 +49,6 @@ def __get_clip(video_path: Path) -> VideoCapture:
return VideoCapture(str(video_path.absolute()))
except IOError as e:
raise InvalidVideoError(f"{str(video_path)} is not a valid video") from e
+
+ def get_frame_number_for(self, video_path: Path, delta: timedelta) -> int:
+ return floor(self.get_fps(video_path) * delta.total_seconds())
diff --git a/tests/OTAnalytics/application/test_datastore.py b/tests/OTAnalytics/application/test_datastore.py
index 546754e12..83fccd6c5 100644
--- a/tests/OTAnalytics/application/test_datastore.py
+++ b/tests/OTAnalytics/application/test_datastore.py
@@ -46,6 +46,9 @@
class MockVideoReader(VideoReader):
+ def get_fps(self, video: Path) -> float:
+ return 20
+
def get_frame(self, video: Path, index: int) -> TrackImage:
del video
del index
@@ -62,32 +65,40 @@ def height(self) -> int:
return MockImage()
+ def get_frame_number_for(self, video_path: Path, date: timedelta) -> int:
+ return 0
+
class TestVideoMetadata:
def test_fully_specified_metadata(self) -> None:
+ recorded_fps = 20.0
+ actual_fps = 20.0
metadata = VideoMetadata(
path="video_path_1.mp4",
recorded_start_date=FIRST_START_DATE,
expected_duration=timedelta(seconds=3),
- recorded_fps=20.0,
- actual_fps=20.0,
+ recorded_fps=recorded_fps,
+ actual_fps=actual_fps,
number_of_frames=60,
)
assert metadata.start == FIRST_START_DATE
assert metadata.end == FIRST_START_DATE + timedelta(seconds=3)
+ assert metadata.fps == actual_fps
def test_partially_specified_metadata(self) -> None:
+ recorded_fps = 20.0
metadata = VideoMetadata(
path="video_path_1.mp4",
recorded_start_date=FIRST_START_DATE,
expected_duration=None,
- recorded_fps=20.0,
+ recorded_fps=recorded_fps,
actual_fps=None,
number_of_frames=60,
)
assert metadata.start == FIRST_START_DATE
expected_video_end = FIRST_START_DATE + timedelta(seconds=3)
assert metadata.end == expected_video_end
+ assert metadata.fps == recorded_fps
class TestSimpleVideo:
@@ -95,15 +106,15 @@ class TestSimpleVideo:
def test_raise_error_if_file_not_exists(self) -> None:
with pytest.raises(ValueError):
- SimpleVideo(video_reader=self.video_reader, path=Path("foo/bar.mp4"))
+ SimpleVideo(self.video_reader, Path("foo/bar.mp4"), FIRST_START_DATE)
def test_init_with_valid_args(self, cyclist_video: Path) -> None:
- video = SimpleVideo(video_reader=self.video_reader, path=cyclist_video)
+ video = SimpleVideo(self.video_reader, cyclist_video, FIRST_START_DATE)
assert video.path == cyclist_video
assert video.video_reader == self.video_reader
def test_get_frame_return_correct_image(self, cyclist_video: Path) -> None:
- video = SimpleVideo(video_reader=self.video_reader, path=cyclist_video)
+ video = SimpleVideo(self.video_reader, cyclist_video, FIRST_START_DATE)
assert video.get_frame(0).as_image() == Image.fromarray(
array([[1, 0], [0, 1]], int32)
)
diff --git a/tests/OTAnalytics/application/test_plotting.py b/tests/OTAnalytics/application/test_plotting.py
index b2289c141..27927b0fc 100644
--- a/tests/OTAnalytics/application/test_plotting.py
+++ b/tests/OTAnalytics/application/test_plotting.py
@@ -1,13 +1,21 @@
+from datetime import datetime, timedelta
from unittest.mock import Mock, call
import pytest
+from OTAnalytics.application.datastore import VideoMetadata
from OTAnalytics.application.plotting import (
+ GetCurrentFrame,
+ GetCurrentVideoPath,
LayeredPlotter,
PlottingLayer,
TrackBackgroundPlotter,
+ VideoProvider,
+ VisualizationTimeProvider,
)
-from OTAnalytics.application.state import Plotter, TrackViewState
+from OTAnalytics.application.state import Plotter, TrackViewState, VideosMetadata
+from OTAnalytics.domain.date import DateRange
+from OTAnalytics.domain.filter import FilterElement
from OTAnalytics.domain.track import TrackImage
from OTAnalytics.domain.video import Video
@@ -75,27 +83,99 @@ def test_plot(self, plotter: Mock) -> None:
class TestBackgroundPlotter:
def test_plot(self) -> None:
- expected_image = Mock(spec=TrackImage)
- video = Mock(spec=Video)
- video.get_frame.return_value = expected_image
- track_view_state = Mock(spec=TrackViewState)
- track_view_state.selected_videos = Mock()
- track_view_state.selected_videos.get.return_value = [video]
-
- background_plotter = TrackBackgroundPlotter(track_view_state)
+ expected_image = Mock()
+ single_video = Mock(spec=Video)
+ frame_number = 0
+ single_video.get_frame_number_for.return_value = frame_number
+ single_video.get_frame.return_value = expected_image
+ videos: list[Video] = [single_video]
+ video_provider = Mock(spec=VideoProvider)
+ video_provider.return_value = videos
+ some_time = datetime(2023, 1, 1, 0, 0)
+ visualization_time_provider = Mock(spec=VisualizationTimeProvider)
+ visualization_time_provider.get_time.return_value = some_time
+
+ background_plotter = TrackBackgroundPlotter(
+ video_provider=video_provider,
+ visualization_time_provider=visualization_time_provider,
+ )
result = background_plotter.plot()
- track_view_state.selected_videos.get.assert_called_once()
- video.get_frame.assert_called_once()
+ video_provider.assert_called_once()
+ visualization_time_provider.get_time.assert_called_once()
+ single_video.get_frame_number_for.assert_called_with(some_time)
+ single_video.get_frame.assert_called_once_with(frame_number)
assert result is not None
assert result == expected_image
def test_plot_empty_track_repository_returns_none(self) -> None:
- track_view_state = Mock(spec=TrackViewState)
- track_view_state.selected_videos = Mock()
- track_view_state.selected_videos.get.return_value = []
- background_plotter = TrackBackgroundPlotter(track_view_state)
+ videos: list[Video] = []
+ video_provider = Mock(spec=VideoProvider)
+ video_provider.return_value = videos
+ visualization_time_provider = Mock(spec=VisualizationTimeProvider)
+ background_plotter = TrackBackgroundPlotter(
+ video_provider=video_provider,
+ visualization_time_provider=visualization_time_provider,
+ )
result = background_plotter.plot()
- track_view_state.selected_videos.get.assert_called_once()
+ video_provider.assert_called_once()
+ visualization_time_provider.get_time.assert_not_called()
assert result is None
+
+
+class TestGetCurrentVideoPath:
+ def test_get_video(self) -> None:
+ filter_end_date = datetime(2023, 1, 1, 0, 1)
+ mocked_filter_element = FilterElement(
+ DateRange(start_date=None, end_date=filter_end_date), classifications=set()
+ )
+ video_path = "some/path"
+ state = TrackViewState()
+ state.filter_element.set(mocked_filter_element)
+ metadata = Mock(spec=VideoMetadata)
+ metadata.path = video_path
+ videos_metadata = Mock(spec=VideosMetadata)
+ videos_metadata.get_metadata_for.return_value = metadata
+ use_case = GetCurrentVideoPath(state, videos_metadata)
+
+ actual = use_case.get_video()
+
+ assert actual == video_path
+
+
+class TestGetCurrentFrame:
+ @pytest.mark.parametrize(
+ "filter_end_date, expected_frame_number",
+ [
+ (datetime(2023, 1, 1, 0, 1), 0),
+ (datetime(2023, 1, 1, 0, 1, 1), 20),
+ (datetime(2023, 1, 1, 0, 1, 3), 60),
+ (datetime(2023, 1, 1, 0, 1, 4), 60),
+ ],
+ )
+ def test_get_frame_number(
+ self,
+ filter_end_date: datetime,
+ expected_frame_number: int,
+ ) -> None:
+ video_start_date = datetime(2023, 1, 1, 0, 1)
+ mocked_filter_element = FilterElement(
+ DateRange(start_date=None, end_date=filter_end_date), classifications=set()
+ )
+ state = TrackViewState()
+ state.filter_element.set(mocked_filter_element)
+ metadata = Mock(spec=VideoMetadata)
+ metadata.start = video_start_date
+ metadata.duration = timedelta(seconds=3)
+ metadata.fps = 20
+ metadata.number_of_frames = 60
+ videos_metadata = Mock(spec=VideosMetadata)
+ videos_metadata.get_metadata_for.return_value = metadata
+ use_case = GetCurrentFrame(state, videos_metadata)
+
+ frame_number = use_case.get_frame_number()
+
+ assert frame_number == expected_frame_number
+
+ videos_metadata.get_metadata_for.assert_called_with(filter_end_date)
diff --git a/tests/OTAnalytics/application/test_state.py b/tests/OTAnalytics/application/test_state.py
index 2b0332287..a8620d41f 100644
--- a/tests/OTAnalytics/application/test_state.py
+++ b/tests/OTAnalytics/application/test_state.py
@@ -215,40 +215,43 @@ def test_update_image(self) -> None:
plotter.plot.assert_called_once()
-class TestVideosMetadata:
- @pytest.fixture
- def first_full_metadata(self) -> VideoMetadata:
- return VideoMetadata(
- path="video_path_1.mp4",
- recorded_start_date=FIRST_START_DATE,
- expected_duration=timedelta(seconds=3),
- recorded_fps=20.0,
- actual_fps=20.0,
- number_of_frames=60,
- )
-
- @pytest.fixture
- def second_full_metadata(self) -> VideoMetadata:
- return VideoMetadata(
- path="video_path_2.mp4",
- recorded_start_date=SECOND_START_DATE,
- expected_duration=timedelta(seconds=3),
- recorded_fps=20.0,
- actual_fps=20.0,
- number_of_frames=60,
- )
+@pytest.fixture
+def first_full_metadata() -> VideoMetadata:
+ return VideoMetadata(
+ path="video_path_1.mp4",
+ recorded_start_date=FIRST_START_DATE,
+ expected_duration=timedelta(seconds=3),
+ recorded_fps=20.0,
+ actual_fps=20.0,
+ number_of_frames=60,
+ )
+
+
+@pytest.fixture
+def second_full_metadata() -> VideoMetadata:
+ return VideoMetadata(
+ path="video_path_2.mp4",
+ recorded_start_date=SECOND_START_DATE,
+ expected_duration=timedelta(seconds=3),
+ recorded_fps=20.0,
+ actual_fps=20.0,
+ number_of_frames=60,
+ )
+
+
+@pytest.fixture
+def first_partial_metadata() -> VideoMetadata:
+ return VideoMetadata(
+ path="video_path_1.mp4",
+ recorded_start_date=FIRST_START_DATE,
+ expected_duration=None,
+ recorded_fps=20.0,
+ actual_fps=None,
+ number_of_frames=60,
+ )
- @pytest.fixture
- def first_partial_metadata(self) -> VideoMetadata:
- return VideoMetadata(
- path="video_path_1.mp4",
- recorded_start_date=FIRST_START_DATE,
- expected_duration=None,
- recorded_fps=20.0,
- actual_fps=None,
- number_of_frames=60,
- )
+class TestVideosMetadata:
def test_nothing_updated(self) -> None:
videos_metadata = VideosMetadata()
@@ -311,6 +314,51 @@ def test_ensure_order(
seconds=3
)
+ def test_add_metadata_with_same_start_date_fails(
+ self, first_full_metadata: VideoMetadata, first_partial_metadata: VideoMetadata
+ ) -> None:
+ videos_metadata = VideosMetadata()
+
+ videos_metadata.update(first_full_metadata)
+ with pytest.raises(ValueError):
+ videos_metadata.update(first_partial_metadata)
+
+ def test_get_metadata_for_date(
+ self,
+ first_full_metadata: VideoMetadata,
+ second_full_metadata: VideoMetadata,
+ ) -> None:
+ metadata = VideosMetadata()
+
+ metadata.update(first_full_metadata)
+ metadata.update(second_full_metadata)
+
+ exact_result = metadata.get_metadata_for(first_full_metadata.start)
+ floored_result = metadata.get_metadata_for(
+ first_full_metadata.start + timedelta(seconds=1)
+ )
+ floored_second_result = metadata.get_metadata_for(
+ second_full_metadata.start + timedelta(seconds=1)
+ )
+
+ assert exact_result == first_full_metadata
+ assert floored_result == first_full_metadata
+ assert floored_second_result == second_full_metadata
+
+ def test_get_metadata_for_too_late_date(
+ self,
+ first_full_metadata: VideoMetadata,
+ ) -> None:
+ metadata = VideosMetadata()
+
+ metadata.update(first_full_metadata)
+
+ result = metadata.get_metadata_for(
+ first_full_metadata.start + timedelta(seconds=4)
+ )
+
+ assert result is None
+
class TestTracksMetadata:
@pytest.fixture
diff --git a/tests/OTAnalytics/application/ui/__init__.py b/tests/OTAnalytics/application/ui/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/OTAnalytics/application/ui/test_frame_control.py b/tests/OTAnalytics/application/ui/test_frame_control.py
new file mode 100644
index 000000000..eda3f31c6
--- /dev/null
+++ b/tests/OTAnalytics/application/ui/test_frame_control.py
@@ -0,0 +1,105 @@
+from datetime import datetime, timedelta
+from unittest.mock import Mock
+
+import pytest
+
+from OTAnalytics.application.datastore import VideoMetadata
+from OTAnalytics.application.playback import SkipTime
+from OTAnalytics.application.state import (
+ ObservableProperty,
+ TrackViewState,
+ VideosMetadata,
+)
+from OTAnalytics.application.ui.frame_control import (
+ SwitchToNextFrame,
+ SwitchToPreviousFrame,
+)
+from OTAnalytics.domain.date import DateRange
+from OTAnalytics.domain.filter import FilterElement
+
+FPS = 1
+TIME_OF_A_FRAME = timedelta(seconds=1) / FPS
+START_DATE = datetime(2023, 1, 1, 0, 0, 0)
+END_DATE = datetime(2023, 1, 1, 0, 0, 1)
+
+
+def observable(value: Mock) -> Mock:
+ observable_property = Mock(spec=ObservableProperty)
+ observable_property.get.return_value = value
+ return observable_property
+
+
+@pytest.fixture
+def filter_element() -> Mock:
+ filter_element = Mock(spec=FilterElement)
+ filter_element.date_range = DateRange(START_DATE, END_DATE)
+ return filter_element
+
+
+@pytest.fixture
+def skip_time() -> Mock:
+ skip_time = Mock(spec=SkipTime)
+ skip_time.seconds = 0
+ skip_time.frames = 1
+ return skip_time
+
+
+@pytest.fixture
+def track_view_state(filter_element: Mock, skip_time: Mock) -> Mock:
+ track_view_state = Mock(spec=TrackViewState)
+ track_view_state.filter_element = observable(filter_element)
+ track_view_state.skip_time = observable(skip_time)
+ return track_view_state
+
+
+@pytest.fixture
+def videos_metadata() -> VideosMetadata:
+ metadata = Mock(spec=VideoMetadata)
+ metadata.fps = FPS
+ videos_metadata = Mock(spec=VideosMetadata)
+ videos_metadata.get_metadata_for.return_value = metadata
+ return videos_metadata
+
+
+class TestSwitchToNextFrame:
+ def test_set_next_frame(
+ self,
+ track_view_state: Mock,
+ videos_metadata: Mock,
+ filter_element: Mock,
+ ) -> None:
+ derived_filter_element = Mock(spec=FilterElement)
+ filter_element.derive_date.return_value = derived_filter_element
+
+ new_date_range = DateRange(
+ START_DATE + TIME_OF_A_FRAME, END_DATE + TIME_OF_A_FRAME
+ )
+ use_case = SwitchToNextFrame(track_view_state, videos_metadata)
+
+ use_case.set_next_frame()
+
+ filter_element.derive_date.assert_called_with(new_date_range)
+ track_view_state.filter_element.set.assert_called_with(derived_filter_element)
+ videos_metadata.get_metadata_for.assert_called_with(END_DATE)
+
+
+class TestSwitchToPreviousFrame:
+ def test_set_next_frame(
+ self,
+ track_view_state: Mock,
+ videos_metadata: Mock,
+ filter_element: Mock,
+ ) -> None:
+ derived_filter_element = Mock(spec=FilterElement)
+ filter_element.derive_date.return_value = derived_filter_element
+
+ new_date_range = DateRange(
+ START_DATE - TIME_OF_A_FRAME, END_DATE - TIME_OF_A_FRAME
+ )
+ use_case = SwitchToPreviousFrame(track_view_state, videos_metadata)
+
+ use_case.set_previous_frame()
+
+ filter_element.derive_date.assert_called_with(new_date_range)
+ track_view_state.filter_element.set.assert_called_with(derived_filter_element)
+ videos_metadata.get_metadata_for.assert_called_with(END_DATE)
diff --git a/tests/OTAnalytics/application/use_cases/test_load_track_files.py b/tests/OTAnalytics/application/use_cases/test_load_track_files.py
index 2cefb4cb8..a15808c7c 100644
--- a/tests/OTAnalytics/application/use_cases/test_load_track_files.py
+++ b/tests/OTAnalytics/application/use_cases/test_load_track_files.py
@@ -1,3 +1,4 @@
+from datetime import datetime
from pathlib import Path
from unittest.mock import MagicMock, Mock, call, patch
@@ -5,6 +6,8 @@
from OTAnalytics.domain.track import TrackId
from OTAnalytics.domain.video import SimpleVideo
+START_DATE = datetime(2023, 1, 1)
+
class TestLoadTrackFile:
@patch("OTAnalytics.application.use_cases.load_track_files.LoadTrackFiles.load")
@@ -34,7 +37,7 @@ def test_load(self) -> None:
some_track = Mock()
some_track_id = TrackId("1")
some_track.id = some_track_id
- some_video = SimpleVideo(video_reader=Mock(), path=Path(""))
+ some_video = SimpleVideo(Mock(), Path(""), START_DATE)
detection_metadata = Mock()
parse_result = Mock()
parse_result.tracks = [some_track]
diff --git a/tests/OTAnalytics/domain/test_video.py b/tests/OTAnalytics/domain/test_video.py
index f806c9967..68cfb0b72 100644
--- a/tests/OTAnalytics/domain/test_video.py
+++ b/tests/OTAnalytics/domain/test_video.py
@@ -1,3 +1,4 @@
+from datetime import datetime
from pathlib import Path
from unittest.mock import Mock, call, patch
@@ -13,6 +14,8 @@
VideoRepository,
)
+START_DATE = datetime(2023, 1, 1)
+
@pytest.fixture
def video_reader() -> Mock:
@@ -30,7 +33,7 @@ def test_resolve_relative_paths(
config_path.parent.mkdir(parents=True)
video_path.touch()
config_path.touch()
- video = SimpleVideo(path=video_path, video_reader=video_reader)
+ video = SimpleVideo(video_reader, video_path, START_DATE)
result = video.to_dict(config_path)
@@ -41,7 +44,7 @@ def test_resolve_relative_paths_on_different_drives(
) -> None:
video_path = Mock(spec=Path)
config_path = Mock(spec=Path)
- video = SimpleVideo(path=video_path, video_reader=video_reader)
+ video = SimpleVideo(video_reader, video_path, START_DATE)
with patch(
"OTAnalytics.domain.video.splitdrive",
@@ -70,7 +73,7 @@ def test_remove(self, video_reader: VideoReader, test_data_tmp_dir: Path) -> Non
observer = Mock(spec=VideoListObserver)
path = test_data_tmp_dir / "dummy.mp4"
path.touch()
- video = SimpleVideo(video_reader, path)
+ video = SimpleVideo(video_reader, path, START_DATE)
repository = VideoRepository()
repository.register_videos_observer(observer)
diff --git a/tests/OTAnalytics/plugin_parser/test_otvision_parser.py b/tests/OTAnalytics/plugin_parser/test_otvision_parser.py
index 7dde7eee6..ed3018171 100644
--- a/tests/OTAnalytics/plugin_parser/test_otvision_parser.py
+++ b/tests/OTAnalytics/plugin_parser/test_otvision_parser.py
@@ -758,7 +758,7 @@ def test_parse_to_cached_video(self, test_data_tmp_dir: Path) -> None:
cached_parser = CachedVideoParser(video_parser)
- parsed_video = cached_parser.parse(video_file)
+ parsed_video = cached_parser.parse(video_file, start_date=None)
assert isinstance(parsed_video, CachedVideo)
assert parsed_video.other == video
diff --git a/tests/OTAnalytics/plugin_prototypes/track_visualization/test_track_viz.py b/tests/OTAnalytics/plugin_prototypes/track_visualization/test_track_viz.py
index f2b793af7..189997206 100644
--- a/tests/OTAnalytics/plugin_prototypes/track_visualization/test_track_viz.py
+++ b/tests/OTAnalytics/plugin_prototypes/track_visualization/test_track_viz.py
@@ -1,4 +1,4 @@
-from datetime import datetime
+from datetime import datetime, timedelta, timezone
from unittest.mock import Mock, patch
import pytest
@@ -96,12 +96,10 @@ class TestPandasTrackProvider:
def test_get_data_empty_track_repository(self) -> None:
track_repository = Mock(spec=TrackRepository)
track_repository.get_all.return_value = PythonTrackDataset.from_list([])
- track_view_state = Mock(spec=TrackViewState).return_value
- track_view_state.track_offset.get.return_value = RelativeOffsetCoordinate(0, 0)
filter_builder = Mock(FilterBuilder)
provider = PandasTrackProvider(
- track_repository, track_view_state, filter_builder, NoProgressbarBuilder()
+ track_repository, filter_builder, NoProgressbarBuilder()
)
result = provider.get_data()
@@ -122,6 +120,11 @@ def track_2(self) -> Track:
def set_up_track(self, id: str) -> Track:
"""Create a dummy track with the given id and 5 car detections."""
+ first_detection_occurrence = datetime(2020, 1, 1, 0, 0, tzinfo=timezone.utc)
+ second_occurrence = first_detection_occurrence + timedelta(seconds=1)
+ third_occurrence = second_occurrence + timedelta(seconds=1)
+ fourth_occurrence = third_occurrence + timedelta(seconds=1)
+ fives_occurrence = fourth_occurrence + timedelta(seconds=1)
t_id = TrackId(id)
detections: list[Detection] = [
PythonDetection(
@@ -132,7 +135,7 @@ def set_up_track(self, id: str) -> Track:
2,
7,
1,
- datetime.min,
+ first_detection_occurrence,
False,
t_id,
"video_name",
@@ -145,7 +148,7 @@ def set_up_track(self, id: str) -> Track:
2,
7,
2,
- datetime.min,
+ second_occurrence,
False,
t_id,
"video_name",
@@ -158,7 +161,7 @@ def set_up_track(self, id: str) -> Track:
2,
7,
3,
- datetime.min,
+ third_occurrence,
False,
t_id,
"video_name",
@@ -171,7 +174,7 @@ def set_up_track(self, id: str) -> Track:
2,
7,
4,
- datetime.min,
+ fourth_occurrence,
False,
t_id,
"video_name",
@@ -184,7 +187,7 @@ def set_up_track(self, id: str) -> Track:
2,
7,
5,
- datetime.min,
+ fives_occurrence,
False,
t_id,
"video_name",
@@ -203,11 +206,9 @@ def set_up_provider(
track_repository = Mock(spec=TrackRepository)
track_repository.get_for.side_effect = query_tracks
- track_view_state = Mock(spec=TrackViewState).return_value
- track_view_state.track_offset.get.return_value = RelativeOffsetCoordinate(0, 0)
filter_builder = Mock(spec=FilterBuilder)
provider = CachedPandasTrackProvider(
- track_repository, track_view_state, filter_builder, NoProgressbarBuilder()
+ track_repository, filter_builder, NoProgressbarBuilder()
)
assert provider._cache_df.empty
diff --git a/tests/OTAnalytics/plugin_video_processing/test_video_reader.py b/tests/OTAnalytics/plugin_video_processing/test_video_reader.py
index 51b1c1f7a..2507833f6 100644
--- a/tests/OTAnalytics/plugin_video_processing/test_video_reader.py
+++ b/tests/OTAnalytics/plugin_video_processing/test_video_reader.py
@@ -1,3 +1,4 @@
+from datetime import timedelta
from pathlib import Path
from OTAnalytics.domain.track import PilImage
@@ -14,3 +15,10 @@ def test_get_image_possible(self, cyclist_video: Path) -> None:
def test_get_frame_out_of_bounds(self, cyclist_video: Path) -> None:
image = self.video_reader.get_frame(cyclist_video, 100)
assert isinstance(image, PilImage)
+
+ def test_get_frame_number_for(self, cyclist_video: Path) -> None:
+ delta = timedelta(seconds=1)
+
+ frame_number = self.video_reader.get_frame_number_for(cyclist_video, delta)
+
+ assert frame_number == 20