diff --git a/game/flightplan/waypointsolver.py b/game/flightplan/waypointsolver.py new file mode 100644 index 0000000000..8bf46e85b9 --- /dev/null +++ b/game/flightplan/waypointsolver.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import TYPE_CHECKING + +from dcs import Point + +if TYPE_CHECKING: + from .waypointstrategy import WaypointStrategy + + +class WaypointSolver: + def __init__(self) -> None: + self.strategies: list[WaypointStrategy] = [] + self.debug_output_directory: Path | None = None + + def add_strategy(self, strategy: WaypointStrategy) -> None: + self.strategies.append(strategy) + + def set_debug_output_directory(self, path: Path) -> None: + self.debug_output_directory = path + + def dump_debug_info(self) -> None: + path = self.debug_output_directory + if path is None: + return + + for idx, strategy in enumerate(self.strategies): + strategy_path = path / f"{idx}.json" + with strategy_path.open("w", encoding="utf-8") as strategy_debug_file: + json.dump( + { + "strategy_name": strategy.__class__.__name__, + "geojson": { + "type": "FeatureCollection", + "features": [ + d.to_geojson() for d in strategy.iter_debug_info() + ], + }, + }, + strategy_debug_file, + ) + + def solve(self) -> Point: + if not self.strategies: + raise ValueError( + "WaypointSolver.solve() called before any strategies were added" + ) + + for strategy in self.strategies: + if (point := strategy.find()) is not None: + return point + + self.dump_debug_info() + debug_details = "No debug output directory set" + if (debug_path := self.debug_output_directory) is not None: + debug_details = f"Debug details written to {debug_path}" + raise RuntimeError(f"No solutions found for waypoint. {debug_details}") diff --git a/game/flightplan/waypointstrategy.py b/game/flightplan/waypointstrategy.py new file mode 100644 index 0000000000..9516b4b66b --- /dev/null +++ b/game/flightplan/waypointstrategy.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import math +from abc import abstractmethod, ABC +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any + +from dcs.mapping import heading_between_points +from shapely import to_geojson +from shapely.geometry import Point, MultiPolygon, Polygon +from shapely.geometry.base import BaseGeometry as Geometry +from shapely.ops import nearest_points + +from game.utils import Distance, nautical_miles, Heading + + +def angle_between_points(a: Point, b: Point) -> float: + return heading_between_points(a.x, a.y, b.x, b.y) + + +def point_at_heading(p: Point, heading: Heading, distance: Distance) -> Point: + rad_heading = heading.radians + return Point( + p.x + math.cos(rad_heading) * distance.meters, + p.y + math.sin(rad_heading) * distance.meters, + ) + + +class Prerequisite(ABC): + @abstractmethod + def is_satisfied(self) -> bool: + ... + + +class DistancePrerequisite(Prerequisite): + def __init__(self, a: Point, b: Point, min_range: Distance) -> None: + self.a = a + self.b = b + self.min_range = min_range + + def is_satisfied(self) -> bool: + return self.a.distance(self.b) >= self.min_range.meters + + +class SafePrerequisite(Prerequisite): + def __init__(self, point: Point, threat_zones: MultiPolygon) -> None: + self.point = point + self.threat_zones = threat_zones + + def is_satisfied(self) -> bool: + return not self.point.intersects(self.threat_zones) + + +class PrerequisiteBuilder: + def __init__( + self, subject: Point, threat_zones: MultiPolygon, strategy: WaypointStrategy + ) -> None: + self.subject = subject + self.threat_zones = threat_zones + self.strategy = strategy + + def safe(self) -> None: + self.strategy.add_prerequisite( + SafePrerequisite(self.subject, self.threat_zones) + ) + + def min_distance_from(self, target: Point, distance: Distance) -> None: + self.strategy.add_prerequisite( + DistancePrerequisite(self.subject, target, distance) + ) + + +@dataclass(frozen=True) +class ThreatTolerance: + target: Point + target_buffer: Distance + tolerance: Distance + + +class RequirementBuilder: + def __init__(self, threat_zones: MultiPolygon, strategy: WaypointStrategy) -> None: + self.threat_zones = threat_zones + self.strategy = strategy + + def safe(self) -> None: + self.strategy.exclude_threat_zone() + + def at_least(self, distance: Distance) -> DistanceRequirementBuilder: + return DistanceRequirementBuilder(self.strategy, min_distance=distance) + + def at_most(self, distance: Distance) -> DistanceRequirementBuilder: + return DistanceRequirementBuilder(self.strategy, max_distance=distance) + + def maximum_turn_to( + self, turn_point: Point, next_point: Point, turn_limit: Heading + ) -> None: + + large_distance = nautical_miles(400) + next_heading = Heading.from_degrees( + angle_between_points(next_point, turn_point) + ) + limit_ccw = point_at_heading( + turn_point, next_heading - turn_limit, large_distance + ) + limit_cw = point_at_heading( + turn_point, next_heading + turn_limit, large_distance + ) + + allowed_wedge = Polygon([turn_point, limit_ccw, limit_cw]) + self.strategy.exclude( + f"restrict turn from {turn_point} to {next_point} to {turn_limit}", + turn_point.buffer(large_distance.meters).difference(allowed_wedge), + ) + + +class DistanceRequirementBuilder: + def __init__( + self, + strategy: WaypointStrategy, + min_distance: Distance | None = None, + max_distance: Distance | None = None, + ) -> None: + if min_distance is None and max_distance is None: + raise ValueError + self.strategy = strategy + self.min_distance = min_distance + self.max_distance = max_distance + + def away_from(self, target: Point) -> None: + if self.min_distance is not None: + self.strategy.exclude( + f"at least {self.min_distance} away from {target}", + target.buffer(self.min_distance.meters), + ) + if self.max_distance is not None: + self.strategy.exclude_beyond( + f"at most {self.min_distance} away from {target}", + target.buffer(self.max_distance.meters), + ) + + +@dataclass(frozen=True) +class WaypointDebugInfo: + description: str + geometry: Geometry + + def to_geojson(self) -> dict[str, Any]: + return { + "type": "Feature", + "properties": { + "description": self.description, + }, + "geometry": to_geojson(self.geometry), + } + + +class WaypointStrategy: + def __init__(self, threat_zones: MultiPolygon) -> None: + self.threat_zones = threat_zones + self.prerequisites: list[Prerequisite] = [] + self.allowed_area: Polygon = Point(0, 0).buffer(1_000_000) + self.debug_infos: list[WaypointDebugInfo] = [] + self._threat_tolerance: ThreatTolerance | None = None + self.point_for_nearest_solution: Point | None = None + + def add_prerequisite(self, prerequisite: Prerequisite) -> None: + self.prerequisites.append(prerequisite) + + def prerequisite(self, subject: Point) -> PrerequisiteBuilder: + return PrerequisiteBuilder(subject, self.threat_zones, self) + + def exclude(self, description: str, geometry: Geometry) -> None: + self.debug_infos.append(WaypointDebugInfo(description, geometry)) + self.allowed_area = self.allowed_area.difference(geometry) + + def exclude_beyond(self, description: str, geometry: Geometry) -> None: + self.exclude(description, self.allowed_area.difference(geometry)) + + def exclude_threat_zone(self) -> None: + if (tolerance := self._threat_tolerance) is not None: + description = ( + f"safe with a {tolerance.tolerance} tolerance to a " + f"{tolerance.target_buffer} radius about {tolerance.target}" + ) + else: + description = "safe" + self.exclude(description, self.threat_zones) + + def prerequisites_are_satisfied(self) -> bool: + for prereq in self.prerequisites: + if not prereq.is_satisfied(): + return False + return True + + def require(self) -> RequirementBuilder: + return RequirementBuilder(self.threat_zones, self) + + def threat_tolerance( + self, target: Point, target_size: Distance, wiggle: Distance + ) -> None: + if self.threat_zones.is_empty: + return + + min_distance_from_threat_to_target_buffer = target.buffer( + target_size.meters + ).distance(self.threat_zones.boundary) + threat_mask = self.threat_zones.buffer( + -min_distance_from_threat_to_target_buffer - wiggle.meters + ) + self._threat_tolerance = ThreatTolerance(target, target_size, wiggle) + self.threat_zones = self.threat_zones.difference(threat_mask) + + def nearest(self, point: Point) -> None: + if self.point_for_nearest_solution is not None: + raise RuntimeError("WaypointStrategy.nearest() called more than once") + self.point_for_nearest_solution = point + + def find(self) -> Point | None: + if self.point_for_nearest_solution is None: + raise RuntimeError( + "Must call WaypointStrategy.nearest() before WaypointStrategy.find()" + ) + + if not self.prerequisites_are_satisfied(): + return None + + try: + return nearest_points(self.allowed_area, self.point_for_nearest_solution)[0] + except ValueError: + # No solutions. + return None + + def iter_debug_info(self) -> Iterator[WaypointDebugInfo]: + yield from self.debug_infos diff --git a/tests/flightplan/test_waypointsolver.py b/tests/flightplan/test_waypointsolver.py new file mode 100644 index 0000000000..fe1718262a --- /dev/null +++ b/tests/flightplan/test_waypointsolver.py @@ -0,0 +1,117 @@ +import json +from pathlib import Path + +import pytest +from shapely.geometry import Point, MultiPolygon + +from game.flightplan.waypointsolver import WaypointSolver +from game.flightplan.waypointstrategy import WaypointStrategy + + +class NoSolutionsStrategy(WaypointStrategy): + def __init__(self) -> None: + super().__init__(MultiPolygon([])) + + def find(self) -> Point | None: + return None + + +class PointStrategy(WaypointStrategy): + def __init__(self, x: float, y: float) -> None: + super().__init__(MultiPolygon([])) + self.point = Point(x, y) + + def find(self) -> Point | None: + return self.point + + +class OriginStrategy(PointStrategy): + def __init__(self) -> None: + super().__init__(0, 0) + + +class DebuggableStrategy(NoSolutionsStrategy): + def __init__(self, distance_factor: int) -> None: + super().__init__() + center = Point(0, 0) + self.exclude("foo", center.buffer(1 * distance_factor)) + self.exclude( + "bar", + center.buffer(3 * distance_factor).difference( + center.buffer(2 * distance_factor) + ), + ) + + +def test_solver_tries_strategies_in_order() -> None: + solver = WaypointSolver() + solver.add_strategy(OriginStrategy()) + solver.add_strategy(PointStrategy(1, 1)) + assert solver.solve() == Point(0, 0) + + +def test_individual_failed_strategies_do_not_fail_solver() -> None: + solver = WaypointSolver() + solver.add_strategy(NoSolutionsStrategy()) + solver.add_strategy(OriginStrategy()) + assert solver.solve() == Point(0, 0) + + +def test_no_solutions_raises() -> None: + solver = WaypointSolver() + solver.add_strategy(NoSolutionsStrategy()) + with pytest.raises(RuntimeError): + solver.solve() + + +def test_no_strategies_raises() -> None: + solver = WaypointSolver() + with pytest.raises(ValueError): + solver.solve() + + +def test_success_does_not_dump_debug_info(tmp_path: Path) -> None: + solver = WaypointSolver() + solver.set_debug_output_directory(tmp_path) + solver.add_strategy(OriginStrategy()) + solver.solve() + assert not list(tmp_path.iterdir()) + + +def test_no_solutions_dumps_debug_info(tmp_path: Path) -> None: + solver = WaypointSolver() + solver.set_debug_output_directory(tmp_path) + strategy_0 = DebuggableStrategy(distance_factor=1) + strategy_1 = DebuggableStrategy(distance_factor=2) + solver.add_strategy(strategy_0) + solver.add_strategy(strategy_1) + with pytest.raises(RuntimeError): + solver.solve() + + strategy_0_path = Path(tmp_path / "0.json") + strategy_1_path = Path(tmp_path / "1.json") + assert set(tmp_path.iterdir()) == {strategy_0_path, strategy_1_path} + + with strategy_0_path.open("r", encoding="utf-8") as metadata_file: + data = json.load(metadata_file) + assert data["strategy_name"] == "DebuggableStrategy" + assert len(data["geojson"]) == 2 + assert len(data.keys()) == 2 + feature_collection = data["geojson"] + assert feature_collection["type"] == "FeatureCollection" + features = feature_collection["features"] + assert len(features) == 2 + for debug_info, feature in zip(strategy_0.iter_debug_info(), features): + assert debug_info.to_geojson() == feature + + with strategy_1_path.open("r", encoding="utf-8") as metadata_file: + data = json.load(metadata_file) + assert data["strategy_name"] == "DebuggableStrategy" + assert len(data["geojson"]) == 2 + assert len(data.keys()) == 2 + feature_collection = data["geojson"] + assert feature_collection["type"] == "FeatureCollection" + features = feature_collection["features"] + assert len(features) == 2 + for debug_info, feature in zip(strategy_1.iter_debug_info(), features): + assert debug_info.to_geojson() == feature diff --git a/tests/flightplan/test_waypointstrategy.py b/tests/flightplan/test_waypointstrategy.py new file mode 100644 index 0000000000..af569c5e21 --- /dev/null +++ b/tests/flightplan/test_waypointstrategy.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +from dcs.terrain import Terrain, Caucasus +from pytest import approx +from shapely.geometry import Point, MultiPolygon + +from game.flightplan.waypointstrategy import WaypointStrategy, angle_between_points +from game.utils import meters, Heading + + +@pytest.fixture(name="terrain") +def terrain_fixture() -> Terrain: + return Caucasus() + + +def test_safe_prerequisite_safe_point() -> None: + strategy = WaypointStrategy(MultiPolygon([])) + strategy.prerequisite(Point(0, 0)).safe() + assert strategy.prerequisites_are_satisfied() + + +def test_safe_prerequisite_unsafe_point() -> None: + strategy = WaypointStrategy(MultiPolygon([Point(0, 0).buffer(1)])) + strategy.prerequisite(Point(0, 0)).safe() + assert not strategy.prerequisites_are_satisfied() + + +def test_no_solution_if_prerequisites_failed() -> None: + strategy = WaypointStrategy(MultiPolygon([Point(0, 0).buffer(1)])) + strategy.prerequisite(Point(0, 0)).safe() + strategy.nearest(Point(0, 0)) + assert strategy.find() is None + + +def test_has_solution_if_prerequisites_satisfied() -> None: + strategy = WaypointStrategy(MultiPolygon([])) + strategy.prerequisite(Point(0, 0)).safe() + strategy.nearest(Point(0, 0)) + assert strategy.find() is not None + + +def test_require_nearest() -> None: + strategy = WaypointStrategy(MultiPolygon([])) + center = Point(0, 0) + strategy.nearest(center) + assert strategy.find() == center + + +def test_find_without_nearest_raises() -> None: + with pytest.raises(RuntimeError): + WaypointStrategy(MultiPolygon([])).find() + + +def test_multiple_nearest_raises() -> None: + strategy = WaypointStrategy(MultiPolygon([])) + strategy.nearest(Point(0, 0)) + with pytest.raises(RuntimeError): + strategy.nearest(Point(0, 0)) + + +def test_require_at_least() -> None: + strategy = WaypointStrategy(MultiPolygon([])) + center = Point(0, 0) + strategy.require().at_least(meters(10)).away_from(center) + strategy.nearest(center) + solution = strategy.find() + assert solution is not None + assert solution.distance(center) == approx(10, 0.1) + + +def test_require_at_most() -> None: + strategy = WaypointStrategy(MultiPolygon([])) + center = Point(0, 0) + strategy.require().at_most(meters(1)).away_from(center) + strategy.nearest(Point(10, 0)) + solution = strategy.find() + assert solution is not None + assert solution.distance(center) <= 1 + + +def test_require_safe() -> None: + threat = MultiPolygon([Point(0, 0).buffer(10)]) + strategy = WaypointStrategy(threat) + strategy.require().safe() + strategy.nearest(Point(0, 0)) + solution = strategy.find() + assert solution is not None + assert not solution.intersects(threat) + + +def test_require_maximum_turn_to() -> None: + strategy = WaypointStrategy(MultiPolygon([])) + turn_point = Point(1, 0) + turn_target = Point(0, 0) + strategy.require().maximum_turn_to(turn_point, turn_target, Heading(90)) + strategy.nearest(Point(0, 1)) + pre_turn_heading = Heading.from_degrees( + angle_between_points(strategy.find(), turn_point) + ) + post_turn_heading = Heading.from_degrees( + angle_between_points(turn_point, turn_target) + ) + assert pre_turn_heading.angle_between(post_turn_heading) <= Heading(90) + + +def test_combined_constraints() -> None: + strategy = WaypointStrategy(MultiPolygon([])) + center = Point(0, 0) + offset = Point(1, 0) + midpoint = Point(0.5, 0) + strategy.require().at_least(meters(1)).away_from(center) + strategy.require().at_least(meters(1)).away_from(offset) + strategy.nearest(midpoint) + solution = strategy.find() + assert solution is not None + assert solution.distance(center) == approx(1, rel=0.1, abs=0.1) + assert solution.distance(offset) == approx(1, rel=0.1, abs=0.1) + assert solution.distance(midpoint) < 1 + + +def test_threat_tolerance(tmp_path: Path) -> None: + home = Point(20, 0) + target = Point(-1, 0) + max_distance = meters(5) + threat = MultiPolygon([Point(0, 0).buffer(10)]) + strategy = WaypointStrategy(threat) + strategy.require().at_most(max_distance).away_from(target) + strategy.threat_tolerance(target, max_distance, meters(1)) + strategy.require().safe() + strategy.nearest(home) + solution = strategy.find() + assert solution is not None + # Max distance of 5 from -1, so the point should be at 4. Home is at 20. + assert solution.distance(home) == 16 + + +def test_threat_tolerance_does_nothing_if_no_threats(tmp_path: Path) -> None: + strategy = WaypointStrategy(MultiPolygon([])) + strategy.threat_tolerance(Point(0, 0), meters(1), meters(1)) + assert strategy._threat_tolerance is None + + +def test_no_solutions() -> None: + strategy = WaypointStrategy(MultiPolygon([])) + strategy.require().at_most(meters(1)).away_from(Point(0, 0)) + strategy.require().at_least(meters(2)).away_from(Point(0, 0)) + strategy.nearest(Point(0, 0)) + assert strategy.find() is None + + +def test_debug() -> None: + center = Point(0, 0) + threat = MultiPolygon([center.buffer(5)]) + strategy = WaypointStrategy(threat) + strategy.require().at_most(meters(10)).away_from(center) + strategy.require().at_least(meters(2)).away_from(center) + strategy.require().safe() + strategy.nearest(center) + debug_info = list(strategy.iter_debug_info()) + assert len(debug_info) == 3 + max_distance_debug, min_distance_debug, safe_debug = debug_info + assert max_distance_debug.description == "at most None away from POINT (0 0)" + assert max_distance_debug.geometry.distance(center) == approx(10, 0.1) + assert ( + min_distance_debug.description + == "at least Distance(distance_in_meters=2) away from POINT (0 0)" + ) + assert max_distance_debug.geometry.boundary.distance(center) == approx(10, 0.1) + assert safe_debug.description == "safe" + assert safe_debug.geometry == threat