From d55fbc1866cf4b0585875e8ee764c046faef27f5 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Fri, 16 Aug 2024 15:39:43 +0200 Subject: [PATCH 01/17] Add script to update pre-commit type stubs --- update_precommit.py | 133 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 update_precommit.py diff --git a/update_precommit.py b/update_precommit.py new file mode 100644 index 000000000..d3d1eaa85 --- /dev/null +++ b/update_precommit.py @@ -0,0 +1,133 @@ +import re +from pathlib import Path +from typing import Iterable + +import requests +import yaml + +REPOSITORIES = "repos" +REPOSITORY = "repo" +MYPY_REPOSITORY = "https://github.com/pre-commit/mirrors-mypy" +HOOKS = "hooks" +ADDITIONAL_DEPENDENCIES = "additional_dependencies" + + +class CustomDumper(yaml.SafeDumper): + def increase_indent(self, flow: bool = False, indentless: bool = False) -> None: + return super(CustomDumper, self).increase_indent(flow, False) + + +def check_type_stub_exists(package_name: str) -> bool: + """Check if a type stub exists for a given package name.""" + types_package_name = f"types-{package_name}" + response = requests.get(f"https://pypi.org/pypi/{types_package_name}/json") + return response.status_code == 200 + + +def extract_package_name(requirement_line: str) -> str | None: + """Extract package name from a requirement line using regex.""" + # Regex pattern to capture the package name, ignoring version specifiers + pattern = re.compile(r"^([a-zA-Z0-9_\-\.]+)(?:[<>=~!]+\S*)?") + match = pattern.match(requirement_line) + if match: + return match.group(1).strip() + return None + + +def parse_requirements(requirements_file: Path) -> set[str]: + """Parse requirements.txt and extract package names using regex.""" + with open(requirements_file, "r") as file: + lines = file.readlines() + + packages = set() + for line in lines: + line = line.strip() + if ( + line and not line.startswith("#") and line != "-r requirements.txt" + ): # Ignore empty lines, comments '-r requirements.txt' + package_name = extract_package_name(line) + if package_name: + packages.add(package_name) + + return packages + + +def parse_multiple_requirements(files: Iterable[Path]) -> set[str]: + packages: set[str] = set() + for _file in files: + packages.update(parse_requirements(_file)) + return packages + + +def retrieve_type_stubs(packages: Iterable[str]) -> list[str]: + """Generate a list of type stubs for the given list of packages.""" + type_stubs = [] + for package in packages: + if check_type_stub_exists(package): + type_stubs.append(f"types-{package}") + else: + print(f"No type stub found for package: {package}") + return type_stubs + + +def read_precommit_file(precommit_file: Path) -> dict: + with open(precommit_file, "r") as stream: + yaml_config = yaml.safe_load(stream) + return yaml_config + + +def update_precommit_config(config: dict, type_stubs: list[str]) -> dict: + updated_config = config.copy() + for repo in updated_config[REPOSITORIES]: + if repo[REPOSITORY] == MYPY_REPOSITORY: + repo[HOOKS][0][ADDITIONAL_DEPENDENCIES] = type_stubs + break + return updated_config + + +def save_precommit_config(config: dict, save_path: Path) -> None: + with open(save_path, "w") as yaml_file: + yaml.dump( + data=config, + stream=yaml_file, + Dumper=CustomDumper, + explicit_start=True, + default_flow_style=False, + sort_keys=False, + ) + + +def display_available_type_stubs(type_stubs: list[str]) -> None: + if type_stubs: + print("\nType stubs that can be added to your pre-commit configuration:") + for stub in type_stubs: + print(f"- {stub}") + else: + print("\n No type stubs to be added to your pre-commit configuration.") + + +def main() -> None: + requirements_file = Path("requirements.txt") + requirements_dev_file = Path("requirements-dev.txt") + precommit_file = Path(".pre-commit-config.yaml") + + print("Parsing requirements.txt and requirements-dev.txt...") + packages = parse_multiple_requirements([requirements_file, requirements_dev_file]) + + print("Checking for type stubs...") + type_stubs = retrieve_type_stubs(packages) + + display_available_type_stubs(type_stubs) + + print("Read pre-commit config...") + precommit_config = read_precommit_file(precommit_file) + + print("Update pre-commit config...") + updated_precommit_config = update_precommit_config(precommit_config, type_stubs) + + print("Save updated pre-commit config...") + save_precommit_config(updated_precommit_config, precommit_file) + + +if __name__ == "__main__": + main() From 427675567bbda2e8d1f332317ef92909806d5ab9 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:28:28 +0200 Subject: [PATCH 02/17] Update all dependencies --- .pre-commit-config.yaml | 34 +++++++++++++++++++++++----------- requirements-dev.txt | 10 ++++++---- requirements.txt | 6 +++--- 3 files changed, 32 insertions(+), 18 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d7deba46d..c7e92856f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,13 @@ --- repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.6.0 hooks: - id: check-yaml - id: check-json - id: end-of-file-fixer - exclude_types: [json] + exclude_types: + - json - id: trailing-whitespace - id: no-commit-to-branch - id: debug-statements @@ -14,32 +15,43 @@ repos: - id: check-executables-have-shebangs - id: detect-private-key - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 + rev: 7.1.1 hooks: - id: flake8 - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort - args: ["--profile", "black"] + args: + - --profile + - black - repo: https://github.com/psf/black - rev: 24.2.0 + rev: 24.4.2 hooks: - id: black - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.6.1 + rev: v1.11.1 hooks: - id: mypy - entry: mypy . - additional_dependencies: [types-all, pydantic] + entry: mypy OTAnalytics tests + additional_dependencies: + - types-seaborn + - types-tqdm + - types-openpyxl + - types-PyYAML + - types-flake8 + - types-ujson + - types-pillow + - types-shapely always_run: true pass_filenames: false - repo: https://github.com/adrienverge/yamllint.git rev: v1.35.1 hooks: - id: yamllint - args: [-c=./.yamllint.yaml] + args: + - -c=./.yamllint.yaml - repo: https://github.com/koalaman/shellcheck-precommit - rev: v0.9.0 + rev: v0.10.0 hooks: - id: shellcheck diff --git a/requirements-dev.txt b/requirements-dev.txt index 0225a45ed..3a27a2880 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,13 +1,15 @@ -r requirements.txt black==24.4.2 -flake8==7.1.0 +flake8==7.1.1 hatch-requirements-txt==0.4.1 interrogate==1.7.0 isort==5.13.2 -mypy==1.10.0 -pre-commit==3.7.1 -pytest==8.2.2 +mypy==1.11.1 +pandas-stubs==2.2.2.240807 +pre-commit==3.8.0 +pytest==8.3.2 pytest-benchmark==4.0.0 pytest-cov==5.0.0 +PyYAML==6.0.1 twine==5.0.0 yamllint==1.35.1 diff --git a/requirements.txt b/requirements.txt index 1ed6ee3c9..41dbce70e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,10 @@ customtkinter==5.2.2 matplotlib==3.9.0 more-itertools==10.3.0 numpy==1.26.4 -opencv-python==4.10.0.82 -openpyxl==3.1.4 +opencv-python==4.10.0.84 +openpyxl==3.1.5 pandas==2.2.2 -pillow==10.3.0 +pillow==10.4.0 pygeos==0.14 seaborn==0.13.2 shapely==2.0.4 From 0c6f0000e24d3010be78a92966753904eb539091 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:35:13 +0200 Subject: [PATCH 03/17] Remove unused method --- .../shapely/create_intersection_events.py | 70 ------ .../plugin_intersect/shapely/intersect.py | 216 ------------------ .../test_shapely_intersect.py | 123 ---------- 3 files changed, 409 deletions(-) delete mode 100644 OTAnalytics/plugin_intersect/shapely/create_intersection_events.py delete mode 100644 OTAnalytics/plugin_intersect/shapely/intersect.py delete mode 100644 tests/OTAnalytics/plugin_intersect/test_shapely_intersect.py diff --git a/OTAnalytics/plugin_intersect/shapely/create_intersection_events.py b/OTAnalytics/plugin_intersect/shapely/create_intersection_events.py deleted file mode 100644 index 7f25aa9f8..000000000 --- a/OTAnalytics/plugin_intersect/shapely/create_intersection_events.py +++ /dev/null @@ -1,70 +0,0 @@ -from functools import singledispatchmethod -from typing import Callable, Iterable - -from shapely import LineString, Polygon - -from OTAnalytics.application.geometry import GeometryBuilder -from OTAnalytics.domain.geometry import RelativeOffsetCoordinate, apply_offset -from OTAnalytics.domain.section import Area, LineSection -from OTAnalytics.domain.track import Track, TrackId - - -class ShapelyGeometryBuilder(GeometryBuilder[LineString, Polygon]): - def __init__( - self, - apply_offset_: Callable[ - [float, float, float, float, RelativeOffsetCoordinate], tuple[float, float] - ] = apply_offset, - ): - self._apply_offset = apply_offset_ - - @singledispatchmethod - def create_section(self) -> LineString | Polygon: - raise NotImplementedError - - @create_section.register - def _(self, section: LineSection) -> LineString: - return LineString([(coord.x, coord.y) for coord in section.get_coordinates()]) - - @create_section.register - def _(self, section: Area) -> Polygon: - return Polygon([(coord.x, coord.y) for coord in section.get_coordinates()]) - - def create_track( - self, track: Track, offset: RelativeOffsetCoordinate - ) -> LineString: - return LineString( - [ - self._apply_offset( - detection.x, detection.y, detection.w, detection.h, offset - ) - for detection in track.detections - ] - ) - - def create_line_segments(self, geometry: LineString) -> Iterable[LineString]: - line_segments: list[LineString] = [] - - for _current, _next in zip(geometry.coords[0:-1], geometry.coords[1:]): - line_segments.append(LineString([_current, _next])) - return line_segments - - -class ShapelyTrackLookupTable: - def __init__( - self, - lookup_table: dict[TrackId, LineString], - geometry_builder: GeometryBuilder[LineString, Polygon], - offset: RelativeOffsetCoordinate, - ): - self._table = lookup_table - self._geometry_builder = geometry_builder - self._offset = offset - - def look_up(self, track: Track) -> LineString: - if line := self._table.get(track.id): - return line - new_line = self._geometry_builder.create_track(track, self._offset) - - self._table[track.id] = new_line - return new_line diff --git a/OTAnalytics/plugin_intersect/shapely/intersect.py b/OTAnalytics/plugin_intersect/shapely/intersect.py deleted file mode 100644 index 4c9247988..000000000 --- a/OTAnalytics/plugin_intersect/shapely/intersect.py +++ /dev/null @@ -1,216 +0,0 @@ -from functools import lru_cache - -from numpy import ndarray -from shapely import GeometryCollection, LineString -from shapely import Polygon as ShapelyPolygon -from shapely import contains_xy, prepare -from shapely.ops import snap, split - -from OTAnalytics.domain.geometry import Coordinate, Line, Polygon -from OTAnalytics.domain.intersect import IntersectImplementation -from OTAnalytics.plugin_intersect.shapely.mapping import ShapelyMapper - - -@lru_cache(maxsize=100000) -def cached_intersects(line_1: LineString, line_2: LineString) -> bool: - return line_1.intersects(line_2) - - -class ShapelyIntersector(IntersectImplementation): - """Provides shapely geometry operations.""" - - def __init__(self, mapper: ShapelyMapper = ShapelyMapper()) -> None: - self._mapper = mapper - - def line_intersects_line(self, line_1: Line, line_2: Line) -> bool: - """Checks if a line intersects with another line. - - Args: - line_1 (Line): the first line. - line_2 (Line): the second line. - - Returns: - bool: `True` if they intersect. Otherwise `False`. - """ - shapely_line_1 = self._mapper.map_to_shapely_line_string(line_1) - shapely_line_2 = self._mapper.map_to_shapely_line_string(line_2) - - return cached_intersects(shapely_line_1, shapely_line_2) - - def line_intersects_polygon(self, line: Line, polygon: Polygon) -> bool: - """Checks if a line intersects with a polygon. - - Args: - line (Line): the line. - polygon (Polygon): the polygon. - - Returns: - bool: `True` if they intersect. Otherwise `False`. - """ - shapely_line = self._mapper.map_to_shapely_line_string(line) - shapely_polygon = self._mapper.map_to_shapely_polygon(polygon) - return shapely_line.intersects(shapely_polygon) - - def intersection_line_with_line( - self, line_1: Line, line_2: Line - ) -> list[Coordinate]: - """Calculates the intersection points of to lines if they exist. - - Args: - line_1 (Line): the first line to intersect with. - line_2 (Line): the second line to intersect with. - - Returns: - list[Coordinate]: the intersection points if they intersect. - Otherwise, `None`. - """ - shapely_line_1 = self._mapper.map_to_shapely_line_string(line_1) - shapely_line_2 = self._mapper.map_to_shapely_line_string(line_2) - intersection = shapely_line_1.intersection(shapely_line_2) - if intersection.is_empty: - return [] - else: - try: - intersection_points: list[Coordinate] = [] - - for intersection_point in intersection.geoms: - intersection_points.append( - self._mapper.map_to_domain_coordinate(intersection_point) - ) - return intersection_points - except AttributeError: - return [self._mapper.map_to_domain_coordinate(intersection)] - - def split_line_with_line(self, line: Line, splitter: Line) -> list[Line]: - """Use a LineString to split another LineString. - - If `line` intersects `splitter` then line_1 will be splitted at the - intersection points. - I.e. Let line_1 = [p_1, p_2, ..., p_n], n a natural number and p_x - the intersection point. - - Then `line` will be splitted as follows: - [[p_1, p_2, ..., p_x], [p_x, p_(x+1), ..., p_n]. - - Args: - line (Line): the line to be splitted. - splitter (Line): the line used for splitting. - - Returns: - list[Line]: the splitted lines. - """ - shapely_line = self._mapper.map_to_shapely_line_string(line) - shapely_splitter = self._mapper.map_to_shapely_line_string(splitter) - intersection_points = self._complex_split(shapely_line, shapely_splitter) - if len(intersection_points.geoms) == 1: - # If there are no splits, intersection_points holds a single element - # that is the original `line`. - return [] - - return [ - self._mapper.map_to_domain_line(line_string) - for line_string in intersection_points.geoms - ] - - def distance_between(self, point_1: Coordinate, point_2: Coordinate) -> float: - """Calculates the distance between two points. - - Args: - point_1 (Coordinate): the first coordinate to calculate the distance for. - point_2 (Coordinate): the second coordinate to calculate the distance for. - - Returns: - float: the unitless distance between p1 and p2. - """ - shapely_p1 = self._mapper.map_to_shapely_point(point_1) - shapely_p2 = self._mapper.map_to_shapely_point(point_2) - - return shapely_p1.distance(shapely_p2) - - def are_coordinates_within_polygon( - self, coordinates: list[Coordinate], polygon: Polygon - ) -> list[bool]: - """Checks if the points are within the polygon. - - A point is within a polygon if it is enclosed by it. Meaning that a point - sitting on the boundary of a polygon is treated as not being within it. - - Args: - coordinates (list[Coordinate]): the coordinates. - polygon (Polygon): the polygon. - - Returns: - list[bool]: the boolean mask holding the information whether a coordinate is - within the polygon or not. - """ - shapely_points = self._mapper.map_to_tuple_coordinates(coordinates) - shapely_polygon = self._mapper.map_to_shapely_polygon(polygon) - prepare(shapely_polygon) - mask: ndarray = contains_xy(shapely_polygon, shapely_points) - return mask.tolist() - - def _complex_split( - self, geom: LineString, splitter: LineString | ShapelyPolygon - ) -> GeometryCollection: - """Split a complex linestring by another geometry without splitting at - self-intersection points. - - Split a complex linestring using shapely. - - Inspired by https://github.com/Toblerity/Shapely/issues/1068 - - Parameters - ---------- - geom : LineString - An optionally complex LineString. - splitter : Geometry - A geometry to split by. - - Warnings - -------- - A known vulnerability is where the splitter intersects the complex - linestring at one of the self-intersecting points of the linestring. - In this case, only the first path through the self-intersection - will be split. - - Examples - -------- - >>> complex_line_string = LineString([(0, 0), (1, 1), (1, 0), (0, 1)]) - >>> splitter = LineString([(0, 0.5), (0.5, 1)]) - >>> complex_split(complex_line_string, splitter).wkt - 'GEOMETRYCOLLECTION ( - LINESTRING (0 0, 1 1, 1 0, 0.25 0.75), LINESTRING (0.25 0.75, 0 1) - )' - - Return - ------ - GeometryCollection - A collection of the geometries resulting from the split. - """ - if geom.is_simple: - return split(geom, splitter) - - if isinstance(splitter, ShapelyPolygon): - splitter = splitter.exterior - - # Ensure that intersection exists and is zero dimensional. - relate_str = geom.relate(splitter) - if relate_str[0] == "1": - raise ValueError( - "Cannot split LineString by a geometry which intersects a " - "continuous portion of the LineString." - ) - if not (relate_str[0] == "0" or relate_str[1] == "0"): - return GeometryCollection((geom,)) - - intersection_points = geom.intersection(splitter) - # This only inserts the point at the first pass of a self-intersection if - # the point falls on a self-intersection. - snapped_geom = snap( - geom, intersection_points, tolerance=1.0e-12 - ) # may want to make tolerance a parameter. - # A solution to the warning in the docstring is to roll your own split method - # here. The current one in shapely returns early when a point is found to be - # part of a segment. But if the point was at a self-intersection it could be - # part of multiple segments. - return split(snapped_geom, intersection_points) diff --git a/tests/OTAnalytics/plugin_intersect/test_shapely_intersect.py b/tests/OTAnalytics/plugin_intersect/test_shapely_intersect.py deleted file mode 100644 index 0c7210202..000000000 --- a/tests/OTAnalytics/plugin_intersect/test_shapely_intersect.py +++ /dev/null @@ -1,123 +0,0 @@ -import pytest - -from OTAnalytics.domain.geometry import Coordinate, Line, Polygon -from OTAnalytics.plugin_intersect.shapely.intersect import ShapelyIntersector - - -class TestShapelyIntersector: - @pytest.fixture - def polygon(self) -> Polygon: - return Polygon( - [ - Coordinate(0, 5), - Coordinate(5, 1), - Coordinate(10, 5), - Coordinate(5, 5), - Coordinate(0, 5), - ] - ) - - def test_line_intersects_line_true(self) -> None: - first_line = Line([Coordinate(0, 0.5), Coordinate(1, 0.5)]) - second_line = Line([Coordinate(0.5, 0), Coordinate(0.5, 1)]) - - intersector = ShapelyIntersector() - intersects = intersector.line_intersects_line(first_line, second_line) - - assert intersects - - def test_line_intersects_line_false(self) -> None: - first_line = Line([Coordinate(0, 0), Coordinate(10, 0)]) - second_line = Line([Coordinate(0, 5), Coordinate(10, 5)]) - - intersector = ShapelyIntersector() - intersects = intersector.line_intersects_line(first_line, second_line) - - assert not intersects - - def test_line_intersects_polygon_true(self, polygon: Polygon) -> None: - line = Line([Coordinate(5, 0), Coordinate(5, 10)]) - - intersector = ShapelyIntersector() - intersects = intersector.line_intersects_polygon(line, polygon) - - assert intersects - - def test_line_intersects_polygon_false(self, polygon: Polygon) -> None: - line = Line([Coordinate(20, 0), Coordinate(20, 10)]) - - intersector = ShapelyIntersector() - intersects = intersector.line_intersects_polygon(line, polygon) - - assert not intersects - - def test_intersection_line_with_line_exists(self) -> None: - first_line = Line([Coordinate(0, 0.5), Coordinate(1, 0.5)]) - second_line = Line([Coordinate(0.5, 0), Coordinate(0.5, 1)]) - - intersector = ShapelyIntersector() - intersection = intersector.intersection_line_with_line(first_line, second_line) - - assert intersection == [Coordinate(0.5, 0.5)] - - def test_intersection_line_with_line_does_not_exist(self) -> None: - first_line = Line([Coordinate(0, 0), Coordinate(10, 0)]) - second_line = Line([Coordinate(0, 5), Coordinate(10, 5)]) - - intersector = ShapelyIntersector() - intersection = intersector.intersection_line_with_line(first_line, second_line) - - assert intersection == [] - - def test_split_line_with_line_has_intersections(self) -> None: - first_line = Line([Coordinate(0, 0.5), Coordinate(1, 0.5)]) - splitter = Line([Coordinate(0.5, 0), Coordinate(0.5, 1)]) - - intersector = ShapelyIntersector() - splitted_lines = intersector.split_line_with_line(first_line, splitter) - - expected = [ - Line([Coordinate(0, 0.5), Coordinate(0.5, 0.5)]), - Line([Coordinate(0.5, 0.5), Coordinate(1, 0.5)]), - ] - assert splitted_lines == expected - - def test_split_line_with_line_no_intersections(self) -> None: - first_line = Line([Coordinate(0, 0), Coordinate(10, 0)]) - splitter = Line([Coordinate(0, 5), Coordinate(10, 5)]) - - intersector = ShapelyIntersector() - splitted_lines = intersector.split_line_with_line(first_line, splitter) - - assert splitted_lines == [] - - def test_distance_point_point(self) -> None: - first_point = Coordinate(0, 0) - second_point = Coordinate(1, 0) - - intersector = ShapelyIntersector() - distance = intersector.distance_between(first_point, second_point) - - assert distance == 1 - - def test_are_points_within_polygon(self) -> None: - coords = [ - Coordinate(0.0, 0.0), - Coordinate(0.0, 1.0), - Coordinate(1.0, 1.0), - Coordinate(1.0, 0.0), - Coordinate(0.0, 0.0), - ] - polygon = Polygon(coords) - points: list[Coordinate] = [ - Coordinate(0.0, 0.0), - Coordinate(0.5, 0.5), - Coordinate(2.0, 2.0), - Coordinate(0.1, 0.0), - ] - - intersector = ShapelyIntersector() - result_mask = intersector.are_coordinates_within_polygon(points, polygon) - expected_mask = [False, True, False, False] - - assert result_mask == expected_mask From 48d1595c362922c3feedf052b266776982be9308 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:56:23 +0200 Subject: [PATCH 04/17] Ensure Factory Method Always Receives a DataFrame Adjust track flyweight factory method to guarantee it always receives a DataFrame. Fix edge case where indexing a DataFrame with a single value might return a Series instead of a DataFrame. --- OTAnalytics/domain/track_dataset.py | 4 ++-- .../track_geometry_store/pygeos_store.py | 2 +- OTAnalytics/plugin_datastore/track_store.py | 6 +++--- .../plugin_datastore/test_track_store.py | 17 ++++++++++------- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/OTAnalytics/domain/track_dataset.py b/OTAnalytics/domain/track_dataset.py index 6f4e6f998..5e0324815 100644 --- a/OTAnalytics/domain/track_dataset.py +++ b/OTAnalytics/domain/track_dataset.py @@ -360,13 +360,13 @@ def remove(self, ids: Sequence[str]) -> "TrackGeometryDataset": raise NotImplementedError @abstractmethod - def get_for(self, track_ids: Iterable[str]) -> "TrackGeometryDataset": + def get_for(self, track_ids: list[str]) -> "TrackGeometryDataset": """Get geometries for given track ids if they exist. Ids that do not exist will not be included in the dataset. Args: - track_ids (Iterable[str]): the track ids. + track_ids (list[str]): the track ids. Returns: TrackGeometryDataset: the dataset with tracks. diff --git a/OTAnalytics/plugin_datastore/track_geometry_store/pygeos_store.py b/OTAnalytics/plugin_datastore/track_geometry_store/pygeos_store.py index b9b13286b..e593fa207 100644 --- a/OTAnalytics/plugin_datastore/track_geometry_store/pygeos_store.py +++ b/OTAnalytics/plugin_datastore/track_geometry_store/pygeos_store.py @@ -230,7 +230,7 @@ def remove(self, ids: Sequence[str]) -> TrackGeometryDataset: updated = self._dataset.drop(index=ids, errors="ignore") return PygeosTrackGeometryDataset(self._offset, updated) - def get_for(self, track_ids: Iterable[str]) -> "TrackGeometryDataset": + def get_for(self, track_ids: list[str]) -> "TrackGeometryDataset": _ids = self._dataset.index.intersection(track_ids) filtered_df = self._dataset.loc[_ids] diff --git a/OTAnalytics/plugin_datastore/track_store.py b/OTAnalytics/plugin_datastore/track_store.py index b5eef5add..d9e16ee8d 100644 --- a/OTAnalytics/plugin_datastore/track_store.py +++ b/OTAnalytics/plugin_datastore/track_store.py @@ -78,7 +78,7 @@ def frame(self) -> int: @property def occurrence(self) -> datetime: - return self._occurrence + return self._occurrence[1] @property def interpolated_detection(self) -> bool: @@ -397,7 +397,7 @@ def as_list(self) -> list[Track]: return [self.__create_track_flyweight(current) for current in track_ids] def __create_track_flyweight(self, track_id: str) -> Track: - track_frame = self._dataset.loc[track_id, :] + track_frame = self._dataset.loc[[track_id], :] return PandasTrack(track_id, track_frame) def get_data(self) -> DataFrame: @@ -427,7 +427,7 @@ def get_track_ids_as_string(self) -> Sequence[str]: return self._dataset.index.get_level_values(LEVEL_TRACK_ID).unique() def _get_geometries_for( - self, track_ids: Iterable[str] + self, track_ids: list[str] ) -> dict[RelativeOffsetCoordinate, TrackGeometryDataset]: geometry_datasets = {} for offset, geometry_dataset in self._geometry_datasets.items(): diff --git a/tests/OTAnalytics/plugin_datastore/test_track_store.py b/tests/OTAnalytics/plugin_datastore/test_track_store.py index 2235285f9..3bfadb851 100644 --- a/tests/OTAnalytics/plugin_datastore/test_track_store.py +++ b/tests/OTAnalytics/plugin_datastore/test_track_store.py @@ -60,7 +60,7 @@ def test_properties(self) -> None: python_detection = builder.build_detections()[0] data = Series( python_detection.to_dict(), - name=python_detection.occurrence, + name=(python_detection.track_id.id, python_detection.occurrence), ) pandas_detection = PandasDetection(python_detection.track_id.id, data) @@ -77,9 +77,12 @@ def test_properties(self) -> None: builder.append_detection() python_track = builder.build_track() detections = [detection.to_dict() for detection in python_track.detections] - data = DataFrame(detections).set_index([track.OCCURRENCE]).sort_index() + data = ( + DataFrame(detections) + .set_index([track.TRACK_ID, track.OCCURRENCE]) + .sort_index() + ) data[track.TRACK_CLASSIFICATION] = data[track.CLASSIFICATION] - data = data.drop([track.TRACK_ID], axis=1) pandas_track = PandasTrack(python_track.id.id, data) assert_equal_track_properties(pandas_track, python_track) @@ -439,12 +442,12 @@ def test_split_with_existing_geometries( PandasTrackDataset.from_list([pedestrian_track], track_geometry_factory), ) assert geometry_dataset_no_offset.get_for.call_args_list == [ - call((car_track.id.id,)), - call((pedestrian_track.id.id,)), + call([car_track.id.id]), + call([pedestrian_track.id.id]), ] assert geometry_dataset_with_offset.get_for.call_args_list == [ - call((car_track.id.id,)), - call((pedestrian_track.id.id,)), + call([car_track.id.id]), + call([pedestrian_track.id.id]), ] def test_filter_by_minimum_detection_length( From a2460fc20b0af747c94e82ac520efcc88a7d9c13 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:59:21 +0200 Subject: [PATCH 05/17] Format code --- OTAnalytics/plugin_parser/ottrk_dataformat.py | 1 + .../use_cases/test_highlight_intersections.py | 26 +++++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/OTAnalytics/plugin_parser/ottrk_dataformat.py b/OTAnalytics/plugin_parser/ottrk_dataformat.py index da6f0042d..ceebd6234 100644 --- a/OTAnalytics/plugin_parser/ottrk_dataformat.py +++ b/OTAnalytics/plugin_parser/ottrk_dataformat.py @@ -1,5 +1,6 @@ """Defines the dictionary keys to access an ottrk file. """ + DATE_FORMAT: str = "%Y-%m-%d %H:%M:%S.%f" # Ottrk Data Format diff --git a/tests/OTAnalytics/application/use_cases/test_highlight_intersections.py b/tests/OTAnalytics/application/use_cases/test_highlight_intersections.py index d570e1c7d..cb3a6c622 100644 --- a/tests/OTAnalytics/application/use_cases/test_highlight_intersections.py +++ b/tests/OTAnalytics/application/use_cases/test_highlight_intersections.py @@ -470,17 +470,21 @@ def test_filter( end_time = datetime(2020, 1, 1, 13, 30) end_detection.occurrence = end_time - with patch.object( - TracksOverlapOccurrenceWindow, "_has_overlap", return_value=True - ) as mock_has_overlap, patch( - "OTAnalytics.application.use_cases.highlight_intersections.Track.start", - new_callable=PropertyMock, - return_value=start_time, - ) as mock_start, patch( - "OTAnalytics.application.use_cases.highlight_intersections.Track.end", - new_callable=PropertyMock, - return_value=end_time, - ) as mock_end: + with ( + patch.object( + TracksOverlapOccurrenceWindow, "_has_overlap", return_value=True + ) as mock_has_overlap, + patch( + "OTAnalytics.application.use_cases.highlight_intersections.Track.start", + new_callable=PropertyMock, + return_value=start_time, + ) as mock_start, + patch( + "OTAnalytics.application.use_cases.highlight_intersections.Track.end", + new_callable=PropertyMock, + return_value=end_time, + ) as mock_end, + ): track_id = TrackId("1") track = Mock(spec=Track) track.id = track_id From 97b39e022388755f82cae2b8508b11386acd5524 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:01:28 +0200 Subject: [PATCH 06/17] Fix track export when using flyweights to create pandas DataFrame --- OTAnalytics/plugin_parser/track_export.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/OTAnalytics/plugin_parser/track_export.py b/OTAnalytics/plugin_parser/track_export.py index 5e40cee28..d55bf5c58 100644 --- a/OTAnalytics/plugin_parser/track_export.py +++ b/OTAnalytics/plugin_parser/track_export.py @@ -43,11 +43,14 @@ def _get_data(self) -> DataFrame: dataset = self._track_repository.get_all() if isinstance(dataset, PandasDataFrameProvider): return dataset.get_data().reset_index() - detections = [ - [detection.to_dict() for detection in track.detections] - for track in dataset.as_list() - ] - return DataFrame.from_dict(detections) + detections = [] + for _track in dataset.as_list(): + track_classification = _track.classification + for detection in _track.detections: + current = detection.to_dict() + current[track.TRACK_CLASSIFICATION] = track_classification + detections.append(current) + return DataFrame(detections) def set_column_order(dataframe: DataFrame) -> DataFrame: From 0d2addd4c2391718a0f95676aef91db719303883 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:08:45 +0200 Subject: [PATCH 07/17] Convert to pandas index to list to resolve interface incompatibility with PandasTrackDataset --- OTAnalytics/plugin_datastore/track_store.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/OTAnalytics/plugin_datastore/track_store.py b/OTAnalytics/plugin_datastore/track_store.py index d9e16ee8d..fa1111b04 100644 --- a/OTAnalytics/plugin_datastore/track_store.py +++ b/OTAnalytics/plugin_datastore/track_store.py @@ -409,8 +409,9 @@ def split(self, batches: int) -> Sequence["PandasTrackDataset"]: new_batches = [] for batch_ids in batched(self.get_track_ids_as_string(), batch_size): - batch_dataset = self._dataset.loc[list(batch_ids), :] - batch_geometries = self._get_geometries_for(batch_ids) + batch_ids_as_list = list(batch_ids) + batch_dataset = self._dataset.loc[batch_ids_as_list, :] + batch_geometries = self._get_geometries_for(batch_ids_as_list) new_batches.append( PandasTrackDataset.from_dataframe( batch_dataset, @@ -421,10 +422,10 @@ def split(self, batches: int) -> Sequence["PandasTrackDataset"]: ) return new_batches - def get_track_ids_as_string(self) -> Sequence[str]: + def get_track_ids_as_string(self) -> list[str]: if self._dataset.empty: return [] - return self._dataset.index.get_level_values(LEVEL_TRACK_ID).unique() + return self._dataset.index.get_level_values(LEVEL_TRACK_ID).unique().to_list() def _get_geometries_for( self, track_ids: list[str] @@ -673,7 +674,7 @@ def _get_dataset_with_classes(self, classes: list[str]) -> PandasTrackDataset: tracks_to_keep = filtered_df.index.get_level_values(LEVEL_TRACK_ID).unique() tracks_to_remove = tracks_to_keep.symmetric_difference( self._other.get_track_ids_as_string() - ) + ).to_list() updated_geometry_datasets = self._other._remove_from_geometry_dataset( tracks_to_remove ) From 3822680eb66ab2ec028cd6ce98e2940052fab83b Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:10:15 +0200 Subject: [PATCH 08/17] Fix type hints --- .../track_geometry_store/pygeos_store.py | 12 +++++++----- OTAnalytics/plugin_datastore/track_store.py | 11 ++++++----- OTAnalytics/plugin_filter/dataframe_filter.py | 7 ++++--- OTAnalytics/plugin_parser/export.py | 2 +- OTAnalytics/plugin_parser/pandas_parser.py | 4 +++- OTAnalytics/plugin_progress/tqdm_progressbar.py | 9 ++++++--- 6 files changed, 27 insertions(+), 18 deletions(-) diff --git a/OTAnalytics/plugin_datastore/track_geometry_store/pygeos_store.py b/OTAnalytics/plugin_datastore/track_geometry_store/pygeos_store.py index e593fa207..472ba0856 100644 --- a/OTAnalytics/plugin_datastore/track_geometry_store/pygeos_store.py +++ b/OTAnalytics/plugin_datastore/track_geometry_store/pygeos_store.py @@ -3,7 +3,7 @@ from itertools import chain from typing import Any, Iterable, Literal, Sequence -from pandas import DataFrame, concat +from pandas import DataFrame, Series, concat from pygeos import ( Geometry, contains, @@ -194,9 +194,7 @@ def __create_entries_from_dataframe( new_y = filtered_tracks[track.Y] + offset.y * filtered_tracks[track.H] tracks = concat([new_x, new_y], keys=[track.X, track.Y], axis=1) tracks_by_id = tracks.groupby(level=LEVEL_TRACK_ID, group_keys=True) - geometries = tracks_by_id.apply( - lambda coords: linestrings(coords[track.X], coords[track.Y]) - ) + geometries = tracks_by_id.apply(convert_to_linestrings) projections = calculate_all_projections(tracks) result = concat([geometries, projections], keys=COLUMNS, axis=1) @@ -339,7 +337,7 @@ def __eq__(self, other: Any) -> bool: ) -def calculate_all_projections(tracks: DataFrame) -> DataFrame: +def calculate_all_projections(tracks: DataFrame) -> Series: tracks_by_id = tracks.groupby(level=0, group_keys=True) tracks["last_x"] = tracks_by_id[track.X].shift(1) tracks["last_y"] = tracks_by_id[track.Y].shift(1) @@ -354,3 +352,7 @@ def calculate_all_projections(tracks: DataFrame) -> DataFrame: "distance" ].cumsum() return tracks.groupby(level=0, group_keys=True)["cum-distance"].agg(list) + + +def convert_to_linestrings(coords: DataFrame) -> Geometry: + return linestrings(coords[track.X], coords[track.Y]) diff --git a/OTAnalytics/plugin_datastore/track_store.py b/OTAnalytics/plugin_datastore/track_store.py index fa1111b04..10a513b7b 100644 --- a/OTAnalytics/plugin_datastore/track_store.py +++ b/OTAnalytics/plugin_datastore/track_store.py @@ -144,7 +144,7 @@ class PandasTrackClassificationCalculator(ABC): """ @abstractmethod - def calculate(self, detections: DataFrame) -> Series: + def calculate(self, detections: DataFrame) -> DataFrame: """Determine a track's classification. Args: @@ -441,7 +441,10 @@ def __len__(self) -> int: return len(self._dataset.index.get_level_values(LEVEL_TRACK_ID).unique()) def filter_by_min_detection_length(self, length: int) -> "PandasTrackDataset": - detection_counts_per_track = self._dataset.groupby(level=LEVEL_TRACK_ID).size() + # groupby.size always returns a series + detection_counts_per_track: Series[int] = self._dataset.groupby( # type: ignore + level=LEVEL_TRACK_ID + ).size() filtered_ids = detection_counts_per_track[ detection_counts_per_track >= length ].index @@ -585,9 +588,7 @@ def cut_with_section( intersection_points.keys() ) - def _create_cut_track_id( - self, row: DataFrame, cut_info: dict[str, list[int]] - ) -> str: + def _create_cut_track_id(self, row: Series, cut_info: dict[str, list[int]]) -> str: if (track_id := row[track.TRACK_ID]) in cut_info.keys(): cut_segment_index = bisect(cut_info[track_id], row["cumcount"]) return f"{track_id}_{cut_segment_index}" diff --git a/OTAnalytics/plugin_filter/dataframe_filter.py b/OTAnalytics/plugin_filter/dataframe_filter.py index 0ca675538..d3057a1e2 100644 --- a/OTAnalytics/plugin_filter/dataframe_filter.py +++ b/OTAnalytics/plugin_filter/dataframe_filter.py @@ -1,7 +1,8 @@ +from abc import ABC from datetime import datetime from typing import Iterable, Optional -from pandas import DataFrame, Series +from pandas import DataFrame from OTAnalytics.application.plotting import GetCurrentFrame from OTAnalytics.application.use_cases.video_repository import GetVideos @@ -35,7 +36,7 @@ def conjunct_with( return DataFrameConjunction(self, other) -class DataFramePredicate(Predicate[DataFrame, DataFrame]): +class DataFramePredicate(Predicate[DataFrame, DataFrame], ABC): """Checks DataFrame entries against predicate. Entries that do not fulfill predicate are filtered out. @@ -274,7 +275,7 @@ def _reset(self) -> None: self._result = None def _extend_complex_predicate( - self, predicate: Predicate[DataFrame, Series] + self, predicate: Predicate[DataFrame, DataFrame] ) -> None: if self._complex_predicate: self._complex_predicate = self._complex_predicate.conjunct_with(predicate) diff --git a/OTAnalytics/plugin_parser/export.py b/OTAnalytics/plugin_parser/export.py index 9754133dd..a9665fd52 100644 --- a/OTAnalytics/plugin_parser/export.py +++ b/OTAnalytics/plugin_parser/export.py @@ -93,7 +93,7 @@ def __create_data_frame(counts: Count) -> DataFrame: result_dict: dict = key.as_dict() result_dict["count"] = value indexed.append(result_dict) - return DataFrame.from_dict(indexed) + return DataFrame(indexed) def __create_path(self) -> Path: fixed_file_ending = ( diff --git a/OTAnalytics/plugin_parser/pandas_parser.py b/OTAnalytics/plugin_parser/pandas_parser.py index 7930e3c2d..501a0da98 100644 --- a/OTAnalytics/plugin_parser/pandas_parser.py +++ b/OTAnalytics/plugin_parser/pandas_parser.py @@ -73,7 +73,9 @@ def _parse_as_dataframe( inplace=True, ) data[track.TRACK_ID] = ( - data[track.TRACK_ID].astype(str).apply(id_generator).astype(str) + data[track.TRACK_ID] + .astype(str) + .apply(lambda track_id: str(id_generator(track_id))) ) data[track.VIDEO_NAME] = video_name data[track.INPUT_FILE] = input_file diff --git a/OTAnalytics/plugin_progress/tqdm_progressbar.py b/OTAnalytics/plugin_progress/tqdm_progressbar.py index 1d7adfdf3..389e515f9 100644 --- a/OTAnalytics/plugin_progress/tqdm_progressbar.py +++ b/OTAnalytics/plugin_progress/tqdm_progressbar.py @@ -8,18 +8,21 @@ class TqdmProgressBar(Progressbar): def __init__(self, sequence: Sequence, description: str, unit: str) -> None: self.__sequence = sequence - self.__current_iterator = tqdm(self.__sequence) self.__description = description self.__unit = unit + self.__current_iterator = self.__get_iterator() def __iter__(self) -> Iterator: - self.__current_iterator = tqdm( + self.__current_iterator = self.__get_iterator() + return self + + def __get_iterator(self) -> Iterator: + return tqdm( iterable=self.__sequence, desc=self.__description, unit=self.__unit, total=len(self.__sequence), ).__iter__() - return self def __next__(self) -> Any: return next(self.__current_iterator) From c133b5222e6dba3b5862fbef4ba8ea1c4711867b Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Tue, 20 Aug 2024 11:54:24 +0200 Subject: [PATCH 09/17] Don't reuse geometry cache when filtering --- OTAnalytics/plugin_datastore/track_store.py | 40 ++++++++++----------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/OTAnalytics/plugin_datastore/track_store.py b/OTAnalytics/plugin_datastore/track_store.py index 10a513b7b..fed307871 100644 --- a/OTAnalytics/plugin_datastore/track_store.py +++ b/OTAnalytics/plugin_datastore/track_store.py @@ -8,7 +8,7 @@ import numpy import pandas from more_itertools import batched -from pandas import DataFrame, MultiIndex, Series +from pandas import DataFrame, Index, MultiIndex, Series from OTAnalytics.application.logger import logger from OTAnalytics.domain import track @@ -278,11 +278,11 @@ def __iter__(self) -> Iterator[Track]: yield from self.as_generator() def as_generator(self) -> Generator[Track, None, None]: - if self._dataset.empty: + if (track_ids := self.get_index()) is None: yield from [] - track_ids = self.get_track_ids_as_string() - for current in track_ids: - yield self.__create_track_flyweight(current) + else: + for current in track_ids.array: + yield self.__create_track_flyweight(current) @staticmethod def from_list( @@ -387,9 +387,8 @@ def _remove_from_geometry_dataset( return updated_dataset def as_list(self) -> list[Track]: - if self._dataset.empty: + if (track_ids := self.get_index()) is None: return [] - track_ids = self.get_track_ids_as_string() logger().warning( "Creating track flyweight objects which is really slow in " f"'{PandasTrackDataset.as_list.__name__}'." @@ -408,7 +407,11 @@ def split(self, batches: int) -> Sequence["PandasTrackDataset"]: batch_size = ceil(dataset_size / batches) new_batches = [] - for batch_ids in batched(self.get_track_ids_as_string(), batch_size): + + if (track_ids := self.get_index()) is None: + return [self] + + for batch_ids in batched(track_ids, batch_size): batch_ids_as_list = list(batch_ids) batch_dataset = self._dataset.loc[batch_ids_as_list, :] batch_geometries = self._get_geometries_for(batch_ids_as_list) @@ -422,10 +425,10 @@ def split(self, batches: int) -> Sequence["PandasTrackDataset"]: ) return new_batches - def get_track_ids_as_string(self) -> list[str]: + def get_index(self) -> Index | None: if self._dataset.empty: - return [] - return self._dataset.index.get_level_values(LEVEL_TRACK_ID).unique().to_list() + return None + return self._dataset.index.get_level_values(LEVEL_TRACK_ID).unique() def _get_geometries_for( self, track_ids: list[str] @@ -672,18 +675,11 @@ def _get_dataset_with_classes(self, classes: list[str]) -> PandasTrackDataset: dataset = self._other.get_data() mask = dataset[track.TRACK_CLASSIFICATION].isin(classes) filtered_df = dataset[mask] - tracks_to_keep = filtered_df.index.get_level_values(LEVEL_TRACK_ID).unique() - tracks_to_remove = tracks_to_keep.symmetric_difference( - self._other.get_track_ids_as_string() - ).to_list() - updated_geometry_datasets = self._other._remove_from_geometry_dataset( - tracks_to_remove - ) return PandasTrackDataset( - self._other.track_geometry_factory, - filtered_df, - updated_geometry_datasets, - self._other.calculator, + track_geometry_factory=self._other.track_geometry_factory, + dataset=filtered_df, + geometry_datasets=None, + calculator=self._other.calculator, ) def add_all(self, other: Iterable[Track]) -> TrackDataset: From 233e2bb9dc2eb1b458cfa5ae9b3abd5bb275eb1d Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Tue, 20 Aug 2024 11:54:38 +0200 Subject: [PATCH 10/17] Format code --- OTAnalytics/plugin_ui/customtkinter_gui/treeview_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/OTAnalytics/plugin_ui/customtkinter_gui/treeview_template.py b/OTAnalytics/plugin_ui/customtkinter_gui/treeview_template.py index 800c74544..3f30daec1 100644 --- a/OTAnalytics/plugin_ui/customtkinter_gui/treeview_template.py +++ b/OTAnalytics/plugin_ui/customtkinter_gui/treeview_template.py @@ -14,7 +14,7 @@ class TreeviewTemplate(AbstractTreeviewInterface, Treeview): def __init__( self, show: Literal["tree", "headings", "tree headings", ""] = "tree", - **kwargs: Any + **kwargs: Any, ) -> None: super().__init__(selectmode="none", show=show, **kwargs) self.bind(tk_events.RIGHT_BUTTON_UP, self._on_deselect) From 4e1e241d334b1b1fd980c971c2fd27b3aa8d183e Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:48:58 +0200 Subject: [PATCH 11/17] Fix performance drop when creating track flyweights --- OTAnalytics/plugin_datastore/track_store.py | 23 +++++++++++-------- .../plugin_datastore/test_track_store.py | 21 ++++++++++++----- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/OTAnalytics/plugin_datastore/track_store.py b/OTAnalytics/plugin_datastore/track_store.py index fed307871..c96d38bc9 100644 --- a/OTAnalytics/plugin_datastore/track_store.py +++ b/OTAnalytics/plugin_datastore/track_store.py @@ -78,7 +78,7 @@ def frame(self) -> int: @property def occurrence(self) -> datetime: - return self._occurrence[1] + return self._occurrence @property def interpolated_detection(self) -> bool: @@ -282,7 +282,7 @@ def as_generator(self) -> Generator[Track, None, None]: yield from [] else: for current in track_ids.array: - yield self.__create_track_flyweight(current) + yield self._create_track_flyweight(current) @staticmethod def from_list( @@ -354,7 +354,7 @@ def get_for(self, id: TrackId) -> Optional[Track]: if self._dataset.empty: return None try: - return self.__create_track_flyweight(id.id) + return self._create_track_flyweight(id.id) except KeyError: return None @@ -393,11 +393,15 @@ def as_list(self) -> list[Track]: "Creating track flyweight objects which is really slow in " f"'{PandasTrackDataset.as_list.__name__}'." ) - return [self.__create_track_flyweight(current) for current in track_ids] + return [self._create_track_flyweight(current) for current in track_ids] - def __create_track_flyweight(self, track_id: str) -> Track: - track_frame = self._dataset.loc[[track_id], :] - return PandasTrack(track_id, track_frame) + def _create_track_flyweight(self, track_id: str) -> Track: + track_frame = self._dataset.loc[track_id, :] + if isinstance(track_frame, DataFrame): + return PandasTrack(track_id, track_frame) + if isinstance(track_frame, Series): + return PandasTrack(track_id, track_frame.to_frame(track_id)) + raise NotImplementedError(f"Not implemented for {type(track_frame)}") def get_data(self) -> DataFrame: return self._dataset @@ -444,10 +448,9 @@ def __len__(self) -> int: return len(self._dataset.index.get_level_values(LEVEL_TRACK_ID).unique()) def filter_by_min_detection_length(self, length: int) -> "PandasTrackDataset": - # groupby.size always returns a series - detection_counts_per_track: Series[int] = self._dataset.groupby( # type: ignore + detection_counts_per_track: Series[int] = self._dataset.groupby( level=LEVEL_TRACK_ID - ).size() + )[track.FRAME].size() filtered_ids = detection_counts_per_track[ detection_counts_per_track >= length ].index diff --git a/tests/OTAnalytics/plugin_datastore/test_track_store.py b/tests/OTAnalytics/plugin_datastore/test_track_store.py index 3bfadb851..b56d8dc9d 100644 --- a/tests/OTAnalytics/plugin_datastore/test_track_store.py +++ b/tests/OTAnalytics/plugin_datastore/test_track_store.py @@ -60,7 +60,7 @@ def test_properties(self) -> None: python_detection = builder.build_detections()[0] data = Series( python_detection.to_dict(), - name=(python_detection.track_id.id, python_detection.occurrence), + name=python_detection.occurrence, ) pandas_detection = PandasDetection(python_detection.track_id.id, data) @@ -77,12 +77,9 @@ def test_properties(self) -> None: builder.append_detection() python_track = builder.build_track() detections = [detection.to_dict() for detection in python_track.detections] - data = ( - DataFrame(detections) - .set_index([track.TRACK_ID, track.OCCURRENCE]) - .sort_index() - ) + data = DataFrame(detections).set_index([track.OCCURRENCE]).sort_index() data[track.TRACK_CLASSIFICATION] = data[track.CLASSIFICATION] + data = data.drop([track.TRACK_ID], axis=1) pandas_track = PandasTrack(python_track.id.id, data) assert_equal_track_properties(pandas_track, python_track) @@ -628,3 +625,15 @@ def test_get_max_confidences_for( result = filled_dataset.get_max_confidences_for([car_id, pedestrian_id]) assert result == {car_id: 0.8, pedestrian_id: 0.9} + + def test_create_test_flyweight_with_single_detection( + self, track_geometry_factory: TRACK_GEOMETRY_FACTORY + ) -> None: + track_builder = TrackBuilder() + track_builder.append_detection() + single_detection_track = track_builder.build_track() + dataset = PandasTrackDataset.from_list( + [single_detection_track], track_geometry_factory + ) + result = dataset._create_track_flyweight(single_detection_track.id.id) + assert_equal_track_properties(result, single_detection_track) From 6e7cf32b3fb9fcfdbae3d6c241f2aab242d1b5fb Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Tue, 20 Aug 2024 15:52:15 +0200 Subject: [PATCH 12/17] Make code more readable --- OTAnalytics/plugin_datastore/track_store.py | 48 ++++++++++----------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/OTAnalytics/plugin_datastore/track_store.py b/OTAnalytics/plugin_datastore/track_store.py index c96d38bc9..5f0a8ab69 100644 --- a/OTAnalytics/plugin_datastore/track_store.py +++ b/OTAnalytics/plugin_datastore/track_store.py @@ -387,13 +387,14 @@ def _remove_from_geometry_dataset( return updated_dataset def as_list(self) -> list[Track]: - if (track_ids := self.get_index()) is None: - return [] - logger().warning( - "Creating track flyweight objects which is really slow in " - f"'{PandasTrackDataset.as_list.__name__}'." - ) - return [self._create_track_flyweight(current) for current in track_ids] + if (track_ids := self.get_index()) is not None: + logger().warning( + "Creating track flyweight objects which is really slow in " + f"'{PandasTrackDataset.as_list.__name__}'." + ) + return [self._create_track_flyweight(current) for current in track_ids] + + return [] def _create_track_flyweight(self, track_id: str) -> Track: track_frame = self._dataset.loc[track_id, :] @@ -410,24 +411,23 @@ def split(self, batches: int) -> Sequence["PandasTrackDataset"]: dataset_size = len(self) batch_size = ceil(dataset_size / batches) - new_batches = [] - - if (track_ids := self.get_index()) is None: - return [self] - - for batch_ids in batched(track_ids, batch_size): - batch_ids_as_list = list(batch_ids) - batch_dataset = self._dataset.loc[batch_ids_as_list, :] - batch_geometries = self._get_geometries_for(batch_ids_as_list) - new_batches.append( - PandasTrackDataset.from_dataframe( - batch_dataset, - self.track_geometry_factory, - batch_geometries, - calculator=self.calculator, + if (track_ids := self.get_index()) is not None: + new_batches = [] + for batch_ids in batched(track_ids, batch_size): + batch_ids_as_list = list(batch_ids) + batch_dataset = self._dataset.loc[batch_ids_as_list, :] + batch_geometries = self._get_geometries_for(batch_ids_as_list) + new_batches.append( + PandasTrackDataset.from_dataframe( + batch_dataset, + self.track_geometry_factory, + batch_geometries, + calculator=self.calculator, + ) ) - ) - return new_batches + return new_batches + + return [self] def get_index(self) -> Index | None: if self._dataset.empty: From dbbeacc463fad7a67800bc2eb26905ad2b965559 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:31:36 +0200 Subject: [PATCH 13/17] Add update precommit script as custom precommit hook --- .pre-commit-config.yaml | 17 +++++++++++++---- update_precommit.py | 24 +++++++++--------------- 2 files changed, 22 insertions(+), 19 deletions(-) mode change 100644 => 100755 update_precommit.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c7e92856f..04c34a337 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,14 +35,14 @@ repos: - id: mypy entry: mypy OTAnalytics tests additional_dependencies: - - types-seaborn - - types-tqdm - - types-openpyxl - types-PyYAML - types-flake8 - - types-ujson + - types-openpyxl - types-pillow + - types-seaborn - types-shapely + - types-tqdm + - types-ujson always_run: true pass_filenames: false - repo: https://github.com/adrienverge/yamllint.git @@ -55,3 +55,12 @@ repos: rev: v0.10.0 hooks: - id: shellcheck + - repo: local + hooks: + - id: update-type-stubs + name: Check for Type Stubs and Update Config + entry: ./update_precommit.py + language: system + files: ^requirements.*\.txt$ + stages: + - commit diff --git a/update_precommit.py b/update_precommit.py old mode 100644 new mode 100755 index d3d1eaa85..58d29f8ac --- a/update_precommit.py +++ b/update_precommit.py @@ -1,4 +1,7 @@ +#!/usr/bin/env python3 + import re +from copy import deepcopy from pathlib import Path from typing import Iterable @@ -65,9 +68,7 @@ def retrieve_type_stubs(packages: Iterable[str]) -> list[str]: for package in packages: if check_type_stub_exists(package): type_stubs.append(f"types-{package}") - else: - print(f"No type stub found for package: {package}") - return type_stubs + return sorted(type_stubs) def read_precommit_file(precommit_file: Path) -> dict: @@ -77,7 +78,7 @@ def read_precommit_file(precommit_file: Path) -> dict: def update_precommit_config(config: dict, type_stubs: list[str]) -> dict: - updated_config = config.copy() + updated_config = deepcopy(config) for repo in updated_config[REPOSITORIES]: if repo[REPOSITORY] == MYPY_REPOSITORY: repo[HOOKS][0][ADDITIONAL_DEPENDENCIES] = type_stubs @@ -106,26 +107,19 @@ def display_available_type_stubs(type_stubs: list[str]) -> None: print("\n No type stubs to be added to your pre-commit configuration.") +def type_stubs_have_changed(actual: dict, to_compare: dict) -> bool: + return actual != to_compare + + def main() -> None: requirements_file = Path("requirements.txt") requirements_dev_file = Path("requirements-dev.txt") precommit_file = Path(".pre-commit-config.yaml") - print("Parsing requirements.txt and requirements-dev.txt...") packages = parse_multiple_requirements([requirements_file, requirements_dev_file]) - - print("Checking for type stubs...") type_stubs = retrieve_type_stubs(packages) - - display_available_type_stubs(type_stubs) - - print("Read pre-commit config...") precommit_config = read_precommit_file(precommit_file) - - print("Update pre-commit config...") updated_precommit_config = update_precommit_config(precommit_config, type_stubs) - - print("Save updated pre-commit config...") save_precommit_config(updated_precommit_config, precommit_file) From 4fdb58f40b2338bed5d96e1607be34f79bfdfb1f Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Tue, 20 Aug 2024 17:32:47 +0200 Subject: [PATCH 14/17] Check for correctly configured mypy before running mypy in pre-commit --- .pre-commit-config.yaml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 04c34a337..4c1cbe6db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,6 +14,15 @@ repos: - id: requirements-txt-fixer - id: check-executables-have-shebangs - id: detect-private-key + - repo: local + hooks: + - id: update-type-stubs + name: Check for Type Stubs and Update Config + entry: ./update_precommit.py + language: system + files: ^requirements.*\.txt$ + stages: + - commit - repo: https://github.com/PyCQA/flake8 rev: 7.1.1 hooks: @@ -55,12 +64,3 @@ repos: rev: v0.10.0 hooks: - id: shellcheck - - repo: local - hooks: - - id: update-type-stubs - name: Check for Type Stubs and Update Config - entry: ./update_precommit.py - language: system - files: ^requirements.*\.txt$ - stages: - - commit From 2f61750b7fccbfdb785110669088169bd0fcbc54 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Fri, 23 Aug 2024 09:56:32 +0200 Subject: [PATCH 15/17] Update black formatter --- .pre-commit-config.yaml | 2 +- requirements-dev.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4c1cbe6db..962077d54 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,7 +35,7 @@ repos: - --profile - black - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 24.8.0 hooks: - id: black - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/requirements-dev.txt b/requirements-dev.txt index 3a27a2880..3ed094386 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ -r requirements.txt -black==24.4.2 +black==24.8.0 flake8==7.1.1 hatch-requirements-txt==0.4.1 interrogate==1.7.0 From 7853a9c852c2ada274ea717f1d4ae44fe01fa9ea Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:01:57 +0200 Subject: [PATCH 16/17] Add comment --- OTAnalytics/plugin_parser/track_export.py | 1 + 1 file changed, 1 insertion(+) diff --git a/OTAnalytics/plugin_parser/track_export.py b/OTAnalytics/plugin_parser/track_export.py index d55bf5c58..65d863337 100644 --- a/OTAnalytics/plugin_parser/track_export.py +++ b/OTAnalytics/plugin_parser/track_export.py @@ -48,6 +48,7 @@ def _get_data(self) -> DataFrame: track_classification = _track.classification for detection in _track.detections: current = detection.to_dict() + # Add missing track classification to detection dict current[track.TRACK_CLASSIFICATION] = track_classification detections.append(current) return DataFrame(detections) From f6b3187a18e46f729bf2238da78a0648e01367c9 Mon Sep 17 00:00:00 2001 From: Randy Seng <19281702+randy-seng@users.noreply.github.com> Date: Thu, 12 Sep 2024 15:29:22 +0200 Subject: [PATCH 17/17] Add comment explaining why we invalidate the geometry cache --- OTAnalytics/plugin_datastore/track_store.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/OTAnalytics/plugin_datastore/track_store.py b/OTAnalytics/plugin_datastore/track_store.py index 5f0a8ab69..4420f8166 100644 --- a/OTAnalytics/plugin_datastore/track_store.py +++ b/OTAnalytics/plugin_datastore/track_store.py @@ -678,6 +678,14 @@ def _get_dataset_with_classes(self, classes: list[str]) -> PandasTrackDataset: dataset = self._other.get_data() mask = dataset[track.TRACK_CLASSIFICATION].isin(classes) filtered_df = dataset[mask] + # The pandas Index does not implement the Sequence interface, which causes + # compatibility issues with the PandasTrackDataset._remove_from_geometry method + # when trying to remove geometries for tracks that have been deleted. + # To address this, we invalidate the entire geometry cache rather than + # attempting selective removal. + # This approach is acceptable because track removal only occurs when + # cutting tracks, which is a rare use case. + return PandasTrackDataset( track_geometry_factory=self._other.track_geometry_factory, dataset=filtered_df,