Skip to content

Commit

Permalink
Make coordinate types generic within reason, i.e. restricted to float…
Browse files Browse the repository at this point in the history
…ing-point values where appropriate.
  • Loading branch information
adamreichold committed Aug 1, 2023
1 parent d6b0754 commit cd59fd7
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 36 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "sif-kdtree"
description = "simple, immutable, flat k-d tree"
version = "0.2.0"
version = "0.3.0"
edition = "2018"
rust-version = "1.55"
authors = ["Adam Reichold <adam.reichold@t-online.de>"]
Expand All @@ -13,6 +13,7 @@ keywords = ["kdtree"]
categories = ["data-structures", "simulation", "science::geo"]

[dependencies]
num-traits = "0.2"
rayon = { version = "1.7", optional = true }
serde = { version = "1.0", features = ["derive"], optional = true }

Expand Down
39 changes: 26 additions & 13 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,33 +109,46 @@ pub use look_up::{Query, WithinBoundingBox, WithinDistance};
use std::marker::PhantomData;
use std::ops::Deref;

use num_traits::Num;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

/// Defines a [finite-dimensional][Self::DIM] real space in terms of [coordinate values][Self::coord] along a chosen set of axes
/// Defines a [finite-dimensional][Self::DIM] space in terms of [coordinate values][Self::coord] along a chosen set of axes
pub trait Point {
/// The dimension of the underlying real space
/// The dimension of the underlying space
const DIM: usize;

/// The type of the coordinate values
type Coord: Num + Copy + PartialOrd;

/// Access the coordinate value of the point along the given `axis`
fn coord(&self, axis: usize) -> f64;
fn coord(&self, axis: usize) -> Self::Coord;

/// Return the squared distance between `self` and `other`.
///
/// This is called during nearest neighbour search and hence only the relation between two distance values is required so that computing square roots can be avoided.
fn distance_2(&self, other: &Self) -> f64;
fn distance_2(&self, other: &Self) -> Self::Coord;
}

/// `N`-dimensional real space using [Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance)
impl<const N: usize> Point for [f64; N] {
/// `N`-dimensional space using [Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance)
impl<T, const N: usize> Point for [T; N]
where
T: Num + Copy + PartialOrd,
{
const DIM: usize = N;

fn coord(&self, axis: usize) -> f64 {
type Coord = T;

fn coord(&self, axis: usize) -> Self::Coord {
self[axis]
}

fn distance_2(&self, other: &Self) -> f64 {
(0..N).map(|axis| (self[axis] - other[axis]).powi(2)).sum()
fn distance_2(&self, other: &Self) -> Self::Coord {
(0..N).fold(T::zero(), |res, axis| {
let diff = self[axis] - other[axis];

res + diff * diff
})
}
}

Expand Down Expand Up @@ -234,13 +247,13 @@ mod tests {

use proptest::{collection::vec, strategy::Strategy};

pub fn random_points(len: usize) -> impl Strategy<Value = Vec<[f64; 2]>> {
(vec(0.0..=1.0, len), vec(0.0..=1.0, len))
pub fn random_points(len: usize) -> impl Strategy<Value = Vec<[f32; 2]>> {
(vec(0.0_f32..=1.0, len), vec(0.0_f32..=1.0, len))
.prop_map(|(x, y)| x.into_iter().zip(y).map(|(x, y)| [x, y]).collect())
}

#[derive(Debug, PartialEq)]
pub struct RandomObject(pub [f64; 2]);
pub struct RandomObject(pub [f32; 2]);

impl Eq for RandomObject {}

Expand All @@ -257,7 +270,7 @@ mod tests {
}

impl Object for RandomObject {
type Point = [f64; 2];
type Point = [f32; 2];

fn position(&self) -> &Self::Point {
&self.0
Expand Down
50 changes: 30 additions & 20 deletions src/look_up.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::ops::ControlFlow;

use num_traits::Num;
#[cfg(feature = "rayon")]
use rayon::join;

Expand All @@ -23,59 +24,68 @@ pub trait Query<P: Point> {
fn test(&self, position: &P) -> bool;
}

/// A query which yields all objects within a given axis-aligned boundary box (AABB) in `N`-dimensional real space
/// A query which yields all objects within a given axis-aligned boundary box (AABB) in `N`-dimensional space
#[derive(Debug)]
pub struct WithinBoundingBox<const N: usize> {
aabb: ([f64; N], [f64; N]),
pub struct WithinBoundingBox<T, const N: usize> {
aabb: ([T; N], [T; N]),
}

impl<const N: usize> WithinBoundingBox<N> {
impl<T, const N: usize> WithinBoundingBox<T, N> {
/// Construct a query from first the corner smallest coordinate values `lower` and then the corner with the largest coordinate values `upper`
pub fn new(lower: [f64; N], upper: [f64; N]) -> Self {
pub fn new(lower: [T; N], upper: [T; N]) -> Self {
Self {
aabb: (lower, upper),
}
}
}

impl<const N: usize> Query<[f64; N]> for WithinBoundingBox<N> {
fn aabb(&self) -> &([f64; N], [f64; N]) {
impl<T, const N: usize> Query<[T; N]> for WithinBoundingBox<T, N>
where
T: Num + Copy + PartialOrd,
{
fn aabb(&self) -> &([T; N], [T; N]) {
&self.aabb
}

fn test(&self, _position: &[f64; N]) -> bool {
fn test(&self, _position: &[T; N]) -> bool {
true
}
}

/// A query which yields all objects within a given distance to a central point in `N`-dimensional real space
#[derive(Debug)]
pub struct WithinDistance<const N: usize> {
aabb: ([f64; N], [f64; N]),
center: [f64; N],
distance_2: f64,
pub struct WithinDistance<T, const N: usize> {
aabb: ([T; N], [T; N]),
center: [T; N],
distance_2: T,
}

impl<const N: usize> WithinDistance<N> {
impl<T, const N: usize> WithinDistance<T, N>
where
T: Num + Copy + PartialOrd,
{
/// Construct a query from the `center` and the largest allowed Euclidean `distance` to it
pub fn new(center: [f64; N], distance: f64) -> Self {
pub fn new(center: [T; N], distance: T) -> Self {
Self {
aabb: (
center.map(|coord| coord - distance),
center.map(|coord| coord + distance),
),
center,
distance_2: distance.powi(2),
distance_2: distance * distance,
}
}
}

impl<const N: usize> Query<[f64; N]> for WithinDistance<N> {
fn aabb(&self) -> &([f64; N], [f64; N]) {
impl<T, const N: usize> Query<[T; N]> for WithinDistance<T, N>
where
T: Num + Copy + PartialOrd,
{
fn aabb(&self) -> &([T; N], [T; N]) {
&self.aabb
}

fn test(&self, position: &[f64; N]) -> bool {
fn test(&self, position: &[T; N]) -> bool {
self.center.distance_2(position) <= self.distance_2
}
}
Expand Down Expand Up @@ -224,8 +234,8 @@ mod tests {

use crate::tests::{random_objects, random_points};

pub fn random_queries(len: usize) -> impl Strategy<Value = Vec<WithinDistance<2>>> {
(random_points(len), vec(0.0..=1.0, len)).prop_map(|(centers, distances)| {
pub fn random_queries(len: usize) -> impl Strategy<Value = Vec<WithinDistance<f32, 2>>> {
(random_points(len), vec(0.0_f32..=1.0, len)).prop_map(|(centers, distances)| {
centers
.into_iter()
.zip(distances)
Expand Down
8 changes: 6 additions & 2 deletions src/nearest.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use std::mem::swap;

use num_traits::Float;

use crate::{split, KdTree, Object, Point};

impl<O, S> KdTree<O, S>
where
O: Object,
<O::Point as Point>::Coord: Float,
S: AsRef<[O]>,
{
/// Find the object nearest to the given `target`
Expand All @@ -15,7 +18,7 @@ where
pub fn nearest(&self, target: &O::Point) -> Option<&O> {
let mut args = NearestArgs {
target,
distance_2: f64::INFINITY,
distance_2: <O::Point as Point>::Coord::infinity(),
best_match: None,
};

Expand All @@ -34,13 +37,14 @@ where
O: Object,
{
target: &'b O::Point,
distance_2: f64,
distance_2: <O::Point as Point>::Coord,
best_match: Option<&'a O>,
}

fn nearest<'a, O>(args: &mut NearestArgs<'a, '_, O>, mut objects: &'a [O], mut axis: usize)
where
O: Object,
<O::Point as Point>::Coord: Float,
{
loop {
let (mut left, object, mut right) = split(objects);
Expand Down

0 comments on commit cd59fd7

Please sign in to comment.