From cd59fd75ef4709c076caed2c18fc8e8d7b4180b6 Mon Sep 17 00:00:00 2001 From: Adam Reichold Date: Sun, 30 Jul 2023 09:29:09 +0200 Subject: [PATCH] Make coordinate types generic within reason, i.e. restricted to floating-point values where appropriate. --- Cargo.toml | 3 ++- src/lib.rs | 39 ++++++++++++++++++++++++++------------- src/look_up.rs | 50 ++++++++++++++++++++++++++++++-------------------- src/nearest.rs | 8 ++++++-- 4 files changed, 64 insertions(+), 36 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 47ff530..fcbc171 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "sif-kdtree" description = "simple, immutable, flat k-d tree" -version = "0.2.0" +version = "0.3.0" edition = "2018" rust-version = "1.55" authors = ["Adam Reichold "] @@ -13,6 +13,7 @@ keywords = ["kdtree"] categories = ["data-structures", "simulation", "science::geo"] [dependencies] +num-traits = "0.2" rayon = { version = "1.7", optional = true } serde = { version = "1.0", features = ["derive"], optional = true } diff --git a/src/lib.rs b/src/lib.rs index 6d5cf62..32dca8f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,33 +109,46 @@ pub use look_up::{Query, WithinBoundingBox, WithinDistance}; use std::marker::PhantomData; use std::ops::Deref; +use num_traits::Num; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; -/// Defines a [finite-dimensional][Self::DIM] real space in terms of [coordinate values][Self::coord] along a chosen set of axes +/// Defines a [finite-dimensional][Self::DIM] space in terms of [coordinate values][Self::coord] along a chosen set of axes pub trait Point { - /// The dimension of the underlying real space + /// The dimension of the underlying space const DIM: usize; + /// The type of the coordinate values + type Coord: Num + Copy + PartialOrd; + /// Access the coordinate value of the point along the given `axis` - fn coord(&self, axis: usize) -> f64; + fn coord(&self, axis: usize) -> Self::Coord; /// Return the squared distance between `self` and `other`. /// /// This is called during nearest neighbour search and hence only the relation between two distance values is required so that computing square roots can be avoided. - fn distance_2(&self, other: &Self) -> f64; + fn distance_2(&self, other: &Self) -> Self::Coord; } -/// `N`-dimensional real space using [Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance) -impl Point for [f64; N] { +/// `N`-dimensional space using [Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance) +impl Point for [T; N] +where + T: Num + Copy + PartialOrd, +{ const DIM: usize = N; - fn coord(&self, axis: usize) -> f64 { + type Coord = T; + + fn coord(&self, axis: usize) -> Self::Coord { self[axis] } - fn distance_2(&self, other: &Self) -> f64 { - (0..N).map(|axis| (self[axis] - other[axis]).powi(2)).sum() + fn distance_2(&self, other: &Self) -> Self::Coord { + (0..N).fold(T::zero(), |res, axis| { + let diff = self[axis] - other[axis]; + + res + diff * diff + }) } } @@ -234,13 +247,13 @@ mod tests { use proptest::{collection::vec, strategy::Strategy}; - pub fn random_points(len: usize) -> impl Strategy> { - (vec(0.0..=1.0, len), vec(0.0..=1.0, len)) + pub fn random_points(len: usize) -> impl Strategy> { + (vec(0.0_f32..=1.0, len), vec(0.0_f32..=1.0, len)) .prop_map(|(x, y)| x.into_iter().zip(y).map(|(x, y)| [x, y]).collect()) } #[derive(Debug, PartialEq)] - pub struct RandomObject(pub [f64; 2]); + pub struct RandomObject(pub [f32; 2]); impl Eq for RandomObject {} @@ -257,7 +270,7 @@ mod tests { } impl Object for RandomObject { - type Point = [f64; 2]; + type Point = [f32; 2]; fn position(&self) -> &Self::Point { &self.0 diff --git a/src/look_up.rs b/src/look_up.rs index 8c69ec3..5207573 100644 --- a/src/look_up.rs +++ b/src/look_up.rs @@ -1,5 +1,6 @@ use std::ops::ControlFlow; +use num_traits::Num; #[cfg(feature = "rayon")] use rayon::join; @@ -23,59 +24,68 @@ pub trait Query { fn test(&self, position: &P) -> bool; } -/// A query which yields all objects within a given axis-aligned boundary box (AABB) in `N`-dimensional real space +/// A query which yields all objects within a given axis-aligned boundary box (AABB) in `N`-dimensional space #[derive(Debug)] -pub struct WithinBoundingBox { - aabb: ([f64; N], [f64; N]), +pub struct WithinBoundingBox { + aabb: ([T; N], [T; N]), } -impl WithinBoundingBox { +impl WithinBoundingBox { /// Construct a query from first the corner smallest coordinate values `lower` and then the corner with the largest coordinate values `upper` - pub fn new(lower: [f64; N], upper: [f64; N]) -> Self { + pub fn new(lower: [T; N], upper: [T; N]) -> Self { Self { aabb: (lower, upper), } } } -impl Query<[f64; N]> for WithinBoundingBox { - fn aabb(&self) -> &([f64; N], [f64; N]) { +impl Query<[T; N]> for WithinBoundingBox +where + T: Num + Copy + PartialOrd, +{ + fn aabb(&self) -> &([T; N], [T; N]) { &self.aabb } - fn test(&self, _position: &[f64; N]) -> bool { + fn test(&self, _position: &[T; N]) -> bool { true } } /// A query which yields all objects within a given distance to a central point in `N`-dimensional real space #[derive(Debug)] -pub struct WithinDistance { - aabb: ([f64; N], [f64; N]), - center: [f64; N], - distance_2: f64, +pub struct WithinDistance { + aabb: ([T; N], [T; N]), + center: [T; N], + distance_2: T, } -impl WithinDistance { +impl WithinDistance +where + T: Num + Copy + PartialOrd, +{ /// Construct a query from the `center` and the largest allowed Euclidean `distance` to it - pub fn new(center: [f64; N], distance: f64) -> Self { + pub fn new(center: [T; N], distance: T) -> Self { Self { aabb: ( center.map(|coord| coord - distance), center.map(|coord| coord + distance), ), center, - distance_2: distance.powi(2), + distance_2: distance * distance, } } } -impl Query<[f64; N]> for WithinDistance { - fn aabb(&self) -> &([f64; N], [f64; N]) { +impl Query<[T; N]> for WithinDistance +where + T: Num + Copy + PartialOrd, +{ + fn aabb(&self) -> &([T; N], [T; N]) { &self.aabb } - fn test(&self, position: &[f64; N]) -> bool { + fn test(&self, position: &[T; N]) -> bool { self.center.distance_2(position) <= self.distance_2 } } @@ -224,8 +234,8 @@ mod tests { use crate::tests::{random_objects, random_points}; - pub fn random_queries(len: usize) -> impl Strategy>> { - (random_points(len), vec(0.0..=1.0, len)).prop_map(|(centers, distances)| { + pub fn random_queries(len: usize) -> impl Strategy>> { + (random_points(len), vec(0.0_f32..=1.0, len)).prop_map(|(centers, distances)| { centers .into_iter() .zip(distances) diff --git a/src/nearest.rs b/src/nearest.rs index ad4fb4d..9a51405 100644 --- a/src/nearest.rs +++ b/src/nearest.rs @@ -1,10 +1,13 @@ use std::mem::swap; +use num_traits::Float; + use crate::{split, KdTree, Object, Point}; impl KdTree where O: Object, + ::Coord: Float, S: AsRef<[O]>, { /// Find the object nearest to the given `target` @@ -15,7 +18,7 @@ where pub fn nearest(&self, target: &O::Point) -> Option<&O> { let mut args = NearestArgs { target, - distance_2: f64::INFINITY, + distance_2: ::Coord::infinity(), best_match: None, }; @@ -34,13 +37,14 @@ where O: Object, { target: &'b O::Point, - distance_2: f64, + distance_2: ::Coord, best_match: Option<&'a O>, } fn nearest<'a, O>(args: &mut NearestArgs<'a, '_, O>, mut objects: &'a [O], mut axis: usize) where O: Object, + ::Coord: Float, { loop { let (mut left, object, mut right) = split(objects);