Skip to content

Commit

Permalink
Add tests for argmin_distance.
Browse files Browse the repository at this point in the history
  • Loading branch information
isaksamsten committed Oct 19, 2023
1 parent cb72fd8 commit f082a38
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/wildboar/distance/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from numpy.testing import assert_almost_equal, assert_equal
from wildboar.datasets import load_gun_point, load_two_lead_ecg
from wildboar.distance import (
argmin_distance,
paired_distance,
paired_subsequence_distance,
pairwise_distance,
Expand Down Expand Up @@ -1458,3 +1459,24 @@ def test_pairwise_distance_dim_mean():
)
actual = pairwise_distance(x, y, dim="mean", metric="euclidean")
assert_almost_equal(actual, expected)


@pytest.mark.parametrize("metric", list(_METRICS.keys()))
@pytest.mark.parametrize("k", [1, 3, 7])
def test_argmin_equals_pairwise_distance_argpartition(metric, k):
print(metric)
X, y = load_two_lead_ecg()
X, Y = X[:10], X[300:350]
ind_argmin, min_dist_argmin = argmin_distance(
X, Y, metric=metric, k=k, return_distance=True
)
ind_argmin = np.sort(ind_argmin, axis=1)

dist = pairwise_distance(X, Y, metric=metric)
ind_pairwise = np.argpartition(dist, k, axis=1)[:, :k]
# ind_pairwise = np.sort(ind_pairwise, axis=1)
# assert_equal(ind_pairwise, ind_argmin)
assert_almost_equal(
np.sort(np.take_along_axis(dist, ind_pairwise, axis=1), axis=1),
np.sort(min_dist_argmin, axis=1),
)

0 comments on commit f082a38

Please sign in to comment.