Skip to content

Commit

Permalink
Add type-hints to adaptive/learner/triangulation.py
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Oct 12, 2022
1 parent 157574f commit 4bc151f
Showing 1 changed file with 96 additions and 48 deletions.
144 changes: 96 additions & 48 deletions adaptive/learner/triangulation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

import collections.abc
import numbers
from collections import Counter
from collections.abc import Iterable, Sized
from itertools import chain, combinations
from math import factorial, sqrt
from typing import Any, Iterable, Iterator, List, Sequence, Tuple, Union

import scipy.spatial
from numpy import abs as np_abs
Expand All @@ -13,6 +17,7 @@
dot,
eye,
mean,
ndarray,
ones,
square,
subtract,
Expand All @@ -22,8 +27,22 @@
from numpy.linalg import det as ndet
from numpy.linalg import matrix_rank, norm, slogdet, solve

from adaptive.types import Bool

try:
from typing import TypeAlias
except ImportError:
# Remove this when we drop support for Python 3.9
from typing_extensions import TypeAlias


def fast_norm(v):
SimplexPoints: TypeAlias = Union[List[Tuple[float, ...]], ndarray]
Simplex: TypeAlias = Union[Sequence[numbers.Integral], ndarray]
Point: TypeAlias = Union[Tuple[float, ...], ndarray]
Points: TypeAlias = Union[Sequence[Tuple[float, ...]], ndarray]


def fast_norm(v: tuple[float, ...] | ndarray) -> float:
"""Take the vector norm for len 2, 3 vectors.
Defaults to a square root of the dot product for larger vectors.
Expand All @@ -41,7 +60,9 @@ def fast_norm(v):
return sqrt(dot(v, v))


def fast_2d_point_in_simplex(point, simplex, eps=1e-8):
def fast_2d_point_in_simplex(
point: Point, simplex: SimplexPoints, eps: float = 1e-8
) -> Bool:
(p0x, p0y), (p1x, p1y), (p2x, p2y) = simplex
px, py = point

Expand All @@ -55,7 +76,7 @@ def fast_2d_point_in_simplex(point, simplex, eps=1e-8):
return (t >= -eps) and (s + t <= 1 + eps)


def point_in_simplex(point, simplex, eps=1e-8):
def point_in_simplex(point: Point, simplex: SimplexPoints, eps: float = 1e-8) -> Bool:
if len(point) == 2:
return fast_2d_point_in_simplex(point, simplex, eps)

Expand All @@ -66,7 +87,7 @@ def point_in_simplex(point, simplex, eps=1e-8):
return all(alpha > -eps) and sum(alpha) < 1 + eps


def fast_2d_circumcircle(points):
def fast_2d_circumcircle(points: Points) -> tuple[tuple[float, float], float]:
"""Compute the center and radius of the circumscribed circle of a triangle
Parameters
Expand All @@ -79,7 +100,7 @@ def fast_2d_circumcircle(points):
tuple
(center point : tuple(float), radius: float)
"""
points = array(points)
points = array(points, dtype=float)
# transform to relative coordinates
pts = points[1:] - points[0]

Expand All @@ -102,7 +123,9 @@ def fast_2d_circumcircle(points):
return (x + points[0][0], y + points[0][1]), radius


def fast_3d_circumcircle(points):
def fast_3d_circumcircle(
points: Points,
) -> tuple[tuple[float, float, float], float]:
"""Compute the center and radius of the circumscribed sphere of a simplex.
Parameters
Expand Down Expand Up @@ -142,7 +165,7 @@ def fast_3d_circumcircle(points):
return center, radius


def fast_det(matrix):
def fast_det(matrix: ndarray) -> float:
matrix = asarray(matrix, dtype=float)
if matrix.shape == (2, 2):
return matrix[0][0] * matrix[1][1] - matrix[1][0] * matrix[0][1]
Expand All @@ -153,7 +176,7 @@ def fast_det(matrix):
return ndet(matrix)


def circumsphere(pts):
def circumsphere(pts: Simplex) -> tuple[tuple[float, ...], float]:
"""Compute the center and radius of a N dimension sphere which touches each point in pts.
Parameters
Expand Down Expand Up @@ -201,7 +224,7 @@ def circumsphere(pts):
return tuple(center), radius


def orientation(face, origin):
def orientation(face: tuple | ndarray, origin: tuple | ndarray) -> int:
"""Compute the orientation of the face with respect to a point, origin.
Parameters
Expand All @@ -224,14 +247,14 @@ def orientation(face, origin):
sign, logdet = slogdet(vectors - origin)
if logdet < -50: # assume it to be zero when it's close to zero
return 0
return sign
return int(sign)


def is_iterable_and_sized(obj):
return isinstance(obj, Iterable) and isinstance(obj, Sized)
def is_iterable_and_sized(obj: Any) -> bool:
return isinstance(obj, collections.abc.Collection)


def simplex_volume_in_embedding(vertices) -> float:
def simplex_volume_in_embedding(vertices: Sequence[Point]) -> float:
"""Calculate the volume of a simplex in a higher dimensional embedding.
That is: dim > len(vertices) - 1. For example if you would like to know the
surface area of a triangle in a 3d space.
Expand Down Expand Up @@ -312,7 +335,7 @@ class Triangulation:
or more simplices in the
"""

def __init__(self, coords):
def __init__(self, coords: Points) -> None:
if not is_iterable_and_sized(coords):
raise TypeError("Please provide a 2-dimensional list of points")
coords = list(coords)
Expand Down Expand Up @@ -340,38 +363,40 @@ def __init__(self, coords):
"(the points are linearly dependent)"
)

self.vertices = list(coords)
self.simplices = set()
self.vertices: list[Point] = list(coords)
self.simplices: set[Simplex] = set()
# initialise empty set for each vertex
self.vertex_to_simplices = [set() for _ in coords]
self.vertex_to_simplices: list[set[Simplex]] = [set() for _ in coords]

# find a Delaunay triangulation to start with, then we will throw it
# away and continue with our own algorithm
initial_tri = scipy.spatial.Delaunay(coords)
for simplex in initial_tri.simplices:
self.add_simplex(simplex)

def delete_simplex(self, simplex):
def delete_simplex(self, simplex: Simplex) -> None:
simplex = tuple(sorted(simplex))
self.simplices.remove(simplex)
for vertex in simplex:
self.vertex_to_simplices[vertex].remove(simplex)

def add_simplex(self, simplex):
def add_simplex(self, simplex: Simplex) -> None:
simplex = tuple(sorted(simplex))
self.simplices.add(simplex)
for vertex in simplex:
self.vertex_to_simplices[vertex].add(simplex)

def get_vertices(self, indices):
def get_vertices(self, indices: Iterable[numbers.Integral]) -> list[Point | None]:
return [self.get_vertex(i) for i in indices]

def get_vertex(self, index):
def get_vertex(self, index: numbers.Integral | None) -> Point | None:
if index is None:
return None
return self.vertices[index]

def get_reduced_simplex(self, point, simplex, eps=1e-8) -> list:
def get_reduced_simplex(
self, point: Point, simplex: Simplex, eps: float = 1e-8
) -> list[numbers.Integral]:
"""Check whether vertex lies within a simplex.
Returns
Expand All @@ -396,11 +421,13 @@ def get_reduced_simplex(self, point, simplex, eps=1e-8) -> list:

return [simplex[i] for i in result]

def point_in_simplex(self, point, simplex, eps=1e-8):
def point_in_simplex(
self, point: Point, simplex: Simplex, eps: float = 1e-8
) -> Bool:
vertices = self.get_vertices(simplex)
return point_in_simplex(point, vertices, eps)

def locate_point(self, point):
def locate_point(self, point: Point) -> Simplex:
"""Find to which simplex the point belongs.
Return indices of the simplex containing the point.
Expand All @@ -412,10 +439,15 @@ def locate_point(self, point):
return ()

@property
def dim(self):
def dim(self) -> int:
return len(self.vertices[0])

def faces(self, dim=None, simplices=None, vertices=None):
def faces(
self,
dim: int | None = None,
simplices: Iterable[Simplex] | None = None,
vertices: Iterable[int] | None = None,
) -> Iterator[tuple[numbers.Integral, ...]]:
"""Iterator over faces of a simplex or vertex sequence."""
if dim is None:
dim = self.dim
Expand All @@ -436,11 +468,11 @@ def faces(self, dim=None, simplices=None, vertices=None):
else:
return faces

def containing(self, face):
def containing(self, face: tuple[int, ...]) -> set[Simplex]:
"""Simplices containing a face."""
return set.intersection(*(self.vertex_to_simplices[i] for i in face))

def _extend_hull(self, new_vertex, eps=1e-8):
def _extend_hull(self, new_vertex: Point, eps: float = 1e-8) -> set[Simplex]:
# count multiplicities in order to get all hull faces
multiplicities = Counter(face for face in self.faces())
hull_faces = [face for face, count in multiplicities.items() if count == 1]
Expand Down Expand Up @@ -480,7 +512,9 @@ def _extend_hull(self, new_vertex, eps=1e-8):

return new_simplices

def circumscribed_circle(self, simplex, transform):
def circumscribed_circle(
self, simplex: Simplex, transform: ndarray
) -> tuple[tuple[float, ...], float]:
"""Compute the center and radius of the circumscribed circle of a simplex.
Parameters
Expand All @@ -496,7 +530,9 @@ def circumscribed_circle(self, simplex, transform):
pts = dot(self.get_vertices(simplex), transform)
return circumsphere(pts)

def point_in_cicumcircle(self, pt_index, simplex, transform):
def point_in_cicumcircle(
self, pt_index: int, simplex: Simplex, transform: ndarray
) -> Bool:
# return self.fast_point_in_circumcircle(pt_index, simplex, transform)
eps = 1e-8

Expand All @@ -506,10 +542,15 @@ def point_in_cicumcircle(self, pt_index, simplex, transform):
return norm(center - pt) < (radius * (1 + eps))

@property
def default_transform(self):
def default_transform(self) -> ndarray:
return eye(self.dim)

def bowyer_watson(self, pt_index, containing_simplex=None, transform=None):
def bowyer_watson(
self,
pt_index: int,
containing_simplex: Simplex | None = None,
transform: ndarray | None = None,
) -> tuple[set[Simplex], set[Simplex]]:
"""Modified Bowyer-Watson point adding algorithm.
Create a hole in the triangulation around the new point,
Expand Down Expand Up @@ -569,10 +610,10 @@ def bowyer_watson(self, pt_index, containing_simplex=None, transform=None):
new_triangles = self.vertex_to_simplices[pt_index]
return bad_triangles - new_triangles, new_triangles - bad_triangles

def _simplex_is_almost_flat(self, simplex):
def _simplex_is_almost_flat(self, simplex: Simplex) -> Bool:
return self._relative_volume(simplex) < 1e-8

def _relative_volume(self, simplex):
def _relative_volume(self, simplex: Simplex) -> float:
"""Compute the volume of a simplex divided by the average (Manhattan)
distance of its vertices. The advantage of this is that the relative
volume is only dependent on the shape of the simplex and not on the
Expand All @@ -583,20 +624,25 @@ def _relative_volume(self, simplex):
average_edge_length = mean(np_abs(vectors))
return self.volume(simplex) / (average_edge_length**self.dim)

def add_point(self, point, simplex=None, transform=None):
def add_point(
self,
point: Point,
simplex: Simplex | None = None,
transform: ndarray | None = None,
) -> tuple[set[Simplex], set[Simplex]]:
"""Add a new vertex and create simplices as appropriate.
Parameters
----------
point : float vector
Coordinates of the point to be added.
transform : N*N matrix of floats
Multiplication matrix to apply to the point (and neighbouring
simplices) when running the Bowyer Watson method.
simplex : tuple of ints, optional
Simplex containing the point. Empty tuple indicates points outside
the hull. If not provided, the algorithm costs O(N), so this should
be used whenever possible.
transform : N*N matrix of floats
Multiplication matrix to apply to the point (and neighbouring
simplices) when running the Bowyer Watson method.
"""
point = tuple(point)
if simplex is None:
Expand Down Expand Up @@ -632,16 +678,16 @@ def add_point(self, point, simplex=None, transform=None):
self.vertices.append(point)
return self.bowyer_watson(pt_index, actual_simplex, transform)

def volume(self, simplex):
def volume(self, simplex: Simplex) -> float:
prefactor = factorial(self.dim)
vertices = array(self.get_vertices(simplex))
vectors = vertices[1:] - vertices[0]
return float(abs(fast_det(vectors)) / prefactor)

def volumes(self):
def volumes(self) -> list[float]:
return [self.volume(sim) for sim in self.simplices]

def reference_invariant(self):
def reference_invariant(self) -> bool:
"""vertex_to_simplices and simplices are compatible."""
for vertex in range(len(self.vertices)):
if any(vertex not in tri for tri in self.vertex_to_simplices[vertex]):
Expand All @@ -655,26 +701,28 @@ def vertex_invariant(self, vertex):
"""Simplices originating from a vertex don't overlap."""
raise NotImplementedError

def get_neighbors_from_vertices(self, simplex):
def get_neighbors_from_vertices(self, simplex: Simplex) -> set[Simplex]:
return set.union(*[self.vertex_to_simplices[p] for p in simplex])

def get_face_sharing_neighbors(self, neighbors, simplex):
def get_face_sharing_neighbors(
self, neighbors: set[Simplex], simplex: Simplex
) -> set[Simplex]:
"""Keep only the simplices sharing a whole face with simplex."""
return {
simpl for simpl in neighbors if len(set(simpl) & set(simplex)) == self.dim
} # they share a face

def get_simplices_attached_to_points(self, indices):
def get_simplices_attached_to_points(self, indices: Simplex) -> set[Simplex]:
# Get all simplices that share at least a point with the simplex
neighbors = self.get_neighbors_from_vertices(indices)
return self.get_face_sharing_neighbors(neighbors, indices)

def get_opposing_vertices(self, simplex):
def get_opposing_vertices(self, simplex: Simplex) -> tuple[int, ...]:
if simplex not in self.simplices:
raise ValueError("Provided simplex is not part of the triangulation")
neighbors = self.get_simplices_attached_to_points(simplex)

def find_opposing_vertex(vertex):
def find_opposing_vertex(vertex: int):
# find the simplex:
simp = next((x for x in neighbors if vertex not in x), None)
if simp is None:
Expand All @@ -687,7 +735,7 @@ def find_opposing_vertex(vertex):
return result

@property
def hull(self):
def hull(self) -> set[numbers.Integral]:
"""Compute hull from triangulation.
Parameters
Expand Down

0 comments on commit 4bc151f

Please sign in to comment.