Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/3419-remove-only-events-of-changed-or-deleted-sections #383

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions OTAnalytics/application/use_cases/create_events.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from abc import ABC, abstractmethod
from typing import Callable

from OTAnalytics.application.analysis.intersect import RunIntersect
from OTAnalytics.application.eventlist import SceneActionDetector
from OTAnalytics.application.use_cases.event_repository import AddEvents, ClearAllEvents
from OTAnalytics.application.use_cases.track_repository import (
GetTracksWithoutSingleDetections,
)
from OTAnalytics.domain.section import SectionRepository
from OTAnalytics.domain.event import EventRepository
from OTAnalytics.domain.section import Section, SectionRepository


class CreateIntersectionEvents(ABC):
Expand All @@ -30,6 +32,32 @@ def __call__(self) -> None:
raise NotImplementedError


SectionProvider = Callable[[], list[Section]]


class MissingEventsSectionProvider:
"""
Calculates the section to be intersected with. All sections which have already
been intersected are retrieved from the event repository.

Args:
section_repository (SectionRepository): section repository to get all
sections from
event_repository (EventRepository): event repository to calculate
the sections to intersect
"""

def __init__(
self, section_repository: SectionRepository, event_repository: EventRepository
):
self._section_repository = section_repository
self._event_repository = event_repository

def __call__(self) -> list[Section]:
all = self._section_repository.get_all()
return self._event_repository.retain_missing(all)


class SimpleCreateIntersectionEvents(CreateIntersectionEvents):
"""Intersect tracks with sections to create intersection events and add them to the
event repository.
Expand All @@ -43,20 +71,21 @@ class SimpleCreateIntersectionEvents(CreateIntersectionEvents):
def __init__(
self,
run_intersect: RunIntersect,
section_repository: SectionRepository,
section_provider: SectionProvider,
add_events: AddEvents,
) -> None:
self._run_intersect = run_intersect
self._section_repository = section_repository
self._section_provider = section_provider
self._add_events = add_events

def __call__(self) -> None:
"""Runs the intersection of tracks with sections in the repository."""
sections = self._section_repository.get_all()
sections = self._section_provider()
if not sections:
return
events = self._run_intersect(sections)
self._add_events(events)
section_ids = [section.id for section in sections]
self._add_events(events, section_ids)


class SimpleCreateSceneEvents(CreateSceneEvents):
Expand Down Expand Up @@ -102,6 +131,5 @@ def __call__(self) -> None:
Intersect all tracks with all sections and write the events into the event
repository.
"""
self._clear_event_repository()
self._create_intersection_events()
self._create_scene_events()
12 changes: 8 additions & 4 deletions OTAnalytics/application/use_cases/event_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ class AddEvents:
def __init__(self, event_repository: EventRepository) -> None:
self._event_repository = event_repository

def __call__(self, events: Iterable[Event]) -> None:
def __call__(
self, events: Iterable[Event], sections: list[SectionId] | None = None
) -> None:
if sections is None:
sections = []
if events:
self._event_repository.add_all(events)
self._event_repository.add_all(events, sections)


class ClearAllEvents(SectionListObserver, TrackListObserver):
Expand All @@ -38,13 +42,13 @@ def clear(self) -> None:
self._event_repository.clear()

def notify_sections(self, section_event: SectionRepositoryEvent) -> None:
self.clear()
self._event_repository.remove(list(section_event.removed))

def notify_tracks(self, track_event: TrackRepositoryEvent) -> None:
self.clear()

def on_section_changed(self, section_id: SectionId) -> None:
self.clear()
self._event_repository.remove([section_id])

def on_tracks_cut(self, _: CutTracksDto) -> None:
self.clear()
60 changes: 51 additions & 9 deletions OTAnalytics/domain/event.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import itertools
import re
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from typing import Iterable, Optional

from OTAnalytics.domain.common import DataclassValidation
from OTAnalytics.domain.geometry import DirectionVector2D, ImageCoordinate
from OTAnalytics.domain.observer import OBSERVER, Subject
from OTAnalytics.domain.section import SectionId
from OTAnalytics.domain.section import Section, SectionId
from OTAnalytics.domain.track import Detection
from OTAnalytics.domain.types import EventType

Expand Down Expand Up @@ -334,7 +336,8 @@ def __init__(
self, subject: Subject[EventRepositoryEvent] = Subject[EventRepositoryEvent]()
) -> None:
self._subject = subject
self._events: list[Event] = []
self._events: dict[SectionId, list[Event]] = defaultdict(list)
self._non_section_events = list[Event]()

def register_observer(self, observer: OBSERVER[EventRepositoryEvent]) -> None:
"""Register observer to listen to repository changes.
Expand All @@ -350,16 +353,36 @@ def add(self, event: Event) -> None:
Args:
event (Event): the event to add
"""
self._events.append(event)
self.__do_add(event)
self._subject.notify(EventRepositoryEvent([event], []))

def add_all(self, events: Iterable[Event]) -> None:
"""Add multiple events at once to the repository.
def __do_add(self, event: Event) -> None:
"""
Internal add method that does not notify observers.
"""
if event.section_id:
self._events[event.section_id].append(event)
else:
self._non_section_events.append(event)

def add_all(
self, events: Iterable[Event], sections: list[SectionId] | None = None
) -> None:
"""
Add multiple events at once to the repository. Preserve the sections used
to generate the events for later usage.

Args:
events (Iterable[Event]): the events
sections (list[SectionId]): the sections which have been used to generate
the events
"""
self._events.extend(events)
if sections is None:
sections = []
for event in events:
self.__do_add(event)
for section in sections:
self._events.setdefault(section, [])
self._subject.notify(EventRepositoryEvent(events, []))

def get_all(self) -> Iterable[Event]:
Expand All @@ -368,17 +391,36 @@ def get_all(self) -> Iterable[Event]:
Returns:
Iterable[Event]: the events
"""
return self._events
return itertools.chain.from_iterable(
[self._non_section_events, *self._events.values()]
)

def clear(self) -> None:
"""
Clear the repository and notify observers only if repository was filled.
"""
if self._events:
removed = self._events
self._events = []
removed = list(self.get_all())
self._events = defaultdict(list)
self._non_section_events = list[Event]()
self._subject.notify(EventRepositoryEvent([], removed))

def remove(self, sections: list[SectionId]) -> None:
if self._events:
removed = [
event for event in self.get_all() if event.section_id in sections
]
for section in sections:
del self._events[section]
self._subject.notify((EventRepositoryEvent([], removed)))

def is_empty(self) -> bool:
"""Whether repository is empty."""
return not self._events

def retain_missing(self, all: list[Section]) -> list[Section]:
"""
Returns a new list of sections. The list contains all Sections from the input
except the ones event have been generated for.
"""
return [section for section in all if section.id not in self._events.keys()]
19 changes: 12 additions & 7 deletions OTAnalytics/plugin_ui/main_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from OTAnalytics.application.use_cases.create_events import (
CreateEvents,
CreateIntersectionEvents,
MissingEventsSectionProvider,
SectionProvider,
SimpleCreateIntersectionEvents,
SimpleCreateSceneEvents,
)
Expand Down Expand Up @@ -288,16 +290,19 @@ def start_gui(self) -> None:
datastore._track_to_video_repository
)

section_provider = MissingEventsSectionProvider(
section_repository, event_repository
)
create_events = self._create_use_case_create_events(
section_repository,
section_provider,
clear_all_events,
get_tracks_without_single_detections,
add_events,
DEFAULT_NUM_PROCESSES,
)
intersect_tracks_with_sections = (
self._create_use_case_create_intersection_events(
section_repository,
section_provider,
get_tracks_without_single_detections,
add_events,
DEFAULT_NUM_PROCESSES,
Expand Down Expand Up @@ -457,7 +462,7 @@ def start_cli(self, cli_args: CliArguments) -> None:
get_all_track_ids = GetAllTrackIds(track_repository)
clear_all_events = ClearAllEvents(event_repository)
create_events = self._create_use_case_create_events(
section_repository,
section_repository.get_all,
clear_all_events,
get_tracks_without_single_detections,
add_events,
Expand Down Expand Up @@ -640,13 +645,13 @@ def _create_flow_generator(

def _create_use_case_create_intersection_events(
self,
section_repository: SectionRepository,
section_provider: SectionProvider,
get_tracks: GetTracksWithoutSingleDetections,
add_events: AddEvents,
num_processes: int,
) -> CreateIntersectionEvents:
intersect = self._create_intersect(get_tracks, num_processes)
return SimpleCreateIntersectionEvents(intersect, section_repository, add_events)
return SimpleCreateIntersectionEvents(intersect, section_provider, add_events)

@staticmethod
def _create_intersect(
Expand Down Expand Up @@ -692,15 +697,15 @@ def _create_export_counts(

def _create_use_case_create_events(
self,
section_repository: SectionRepository,
section_provider: SectionProvider,
clear_events: ClearAllEvents,
get_tracks: GetTracksWithoutSingleDetections,
add_events: AddEvents,
num_processes: int,
) -> CreateEvents:
run_intersect = self._create_intersect(get_tracks, num_processes)
create_intersection_events = SimpleCreateIntersectionEvents(
run_intersect, section_repository, add_events
run_intersect, section_provider, add_events
)
scene_action_detector = SceneActionDetector(SceneEventBuilder())
create_scene_events = SimpleCreateSceneEvents(
Expand Down
37 changes: 17 additions & 20 deletions tests/OTAnalytics/application/use_cases/test_create_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
CreateEvents,
CreateIntersectionEvents,
CreateSceneEvents,
SectionProvider,
SimpleCreateIntersectionEvents,
SimpleCreateSceneEvents,
)
from OTAnalytics.application.use_cases.event_repository import AddEvents, ClearAllEvents
from OTAnalytics.application.use_cases.track_repository import GetAllTracks
from OTAnalytics.domain.event import Event
from OTAnalytics.domain.section import Section, SectionRepository
from OTAnalytics.domain.section import Section, SectionId
from OTAnalytics.domain.track import Track


Expand All @@ -25,7 +26,9 @@ def track() -> Mock:

@pytest.fixture
def section() -> Mock:
return Mock(spec=Section)
section = Mock(spec=Section)
section.id = SectionId("section")
return section


@pytest.fixture
Expand All @@ -35,36 +38,38 @@ def event() -> Mock:

class TestSimpleCreateIntersectionEvents:
def test_intersection_event_creation(self, section: Mock, event: Mock) -> None:
section_repository = Mock(spec=SectionRepository)
section_repository.get_all.return_value = [section]
section_provider = Mock(spec=SectionProvider)
provided_sections = [section]
section_provider.return_value = provided_sections

run_intersect = Mock(spec=RunIntersect)
run_intersect.return_value = [event]

add_events = Mock(spec=AddEvents)

create_intersections_events = SimpleCreateIntersectionEvents(
run_intersect, section_repository, add_events
run_intersect, section_provider, add_events
)
create_intersections_events()

section_repository.get_all.assert_called_once()
run_intersect.assert_called_once_with([section])
add_events.assert_called_once_with([event])
section_provider.assert_called_once()
run_intersect.assert_called_once_with(provided_sections)
add_events.assert_called_once()
assert add_events.call_args == call([event], [section.id])

def test_empty_section_repository_should_not_run_intersection(self) -> None:
section_repository = Mock(spec=SectionRepository)
section_repository.get_all.return_value = []
section_provider = Mock(spec=SectionProvider)
section_provider.return_value = []

run_intersect = Mock(spec=RunIntersect)
add_events = Mock(spec=AddEvents)

create_intersections_events = SimpleCreateIntersectionEvents(
run_intersect, section_repository, add_events
run_intersect, section_provider, add_events
)
create_intersections_events()

section_repository.get_all.assert_called_once()
section_provider.assert_called_once()
run_intersect.assert_not_called()
add_events.assert_not_called()

Expand Down Expand Up @@ -107,13 +112,5 @@ def test_create_events(self) -> None:

create_events()

clear_all_events.assert_called_once()
create_intersection_events.assert_called_once()
create_scene_events.assert_called_once()

# Check that clearing event repository is called first
method_execution_order_observer.mock_calls == [
call.clear_event_repository(),
call.create_intersection_events(),
call.create_scene_events(),
]
Loading