Skip to content

Commit

Permalink
Add plotting method for average NN distance
Browse files Browse the repository at this point in the history
  • Loading branch information
gclen committed May 1, 2021
1 parent b84a72f commit cc7e2f5
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
37 changes: 36 additions & 1 deletion umap/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from matplotlib.patches import Patch

from umap.utils import submatrix
from umap.utils import submatrix, average_nn_distance

from bokeh.plotting import show as show_interactive
from bokeh.plotting import output_file, output_notebook
Expand Down Expand Up @@ -1556,3 +1556,38 @@ def interactive(
)

return plot


def nearest_neighbour_distribution(umap_object, bins=25, ax=None):
"""Create a histogram of the average distance to each points
nearest neighbors.
Parameters
----------
umap_object: trained UMAP object
A trained UMAP object that has an embedding.
bins: int (optional, default 25)
Number of bins to put the points into
ax: matlotlib axis (optional, default None)
A matplotlib axis to plot to, or, if None, a new
axis will be created and returned.
Returns
-------
"""
nn_distances = average_nn_distance(umap_object.graph_)

if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)

ax.set_xlabel(f'Average distance to nearest neighbors')
ax.set_ylabel('Frequency')

ax.hist(nn_distances, bins=bins)

return ax

27 changes: 27 additions & 0 deletions umap/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# License: BSD 3 clause

import time
from warnings import warn

import numpy as np
import numba
Expand Down Expand Up @@ -188,3 +189,29 @@ def disconnected_vertices(model):
else:
vertices_disconnected = np.array(model.graph_.sum(axis=1)).flatten() == 0
return vertices_disconnected


def average_nn_distance(dist_matrix):
"""Calculate the average distance to each points nearest neighbors.
Parameters
----------
dist_matrix: a csr_matrix
A distance matrix (usually umap_model.graph_)
Returns
-------
An array with the average distance to each points nearest neighbors
"""
(row_idx, col_idx, val) = scipy.sparse.find(dist_matrix)

# Count/sum is done per row
count_non_zero_elems = np.bincount(row_idx)
sum_non_zero_elems = np.bincount(row_idx, weights=val)
averages = sum_non_zero_elems/count_non_zero_elems

if any(np.isnan(averages)):
warn("Embedding contains disconnected vertices which will be ignored. Use umap.utils.disconnected_vertices() to identify them.")

return averages

0 comments on commit cc7e2f5

Please sign in to comment.