Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plotting method for average NN distance #661

Merged
merged 2 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion umap/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

from matplotlib.patches import Patch

from umap.utils import submatrix
from umap.utils import submatrix, average_nn_distance

from bokeh.plotting import show as show_interactive
from bokeh.plotting import output_file, output_notebook
Expand Down Expand Up @@ -1556,3 +1556,37 @@ def interactive(
)

return plot


def nearest_neighbour_distribution(umap_object, bins=25, ax=None):
"""Create a histogram of the average distance to each points
nearest neighbors.

Parameters
----------
umap_object: trained UMAP object
A trained UMAP object that has an embedding.

bins: int (optional, default 25)
Number of bins to put the points into

ax: matlotlib axis (optional, default None)
A matplotlib axis to plot to, or, if None, a new
axis will be created and returned.

Returns
-------

"""
nn_distances = average_nn_distance(umap_object.graph_)

if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)

ax.set_xlabel(f'Average distance to nearest neighbors')
ax.set_ylabel('Frequency')

ax.hist(nn_distances, bins=bins)

return ax
30 changes: 30 additions & 0 deletions umap/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# License: BSD 3 clause

import time
from warnings import warn

import numpy as np
import numba
Expand Down Expand Up @@ -188,3 +189,32 @@ def disconnected_vertices(model):
else:
vertices_disconnected = np.array(model.graph_.sum(axis=1)).flatten() == 0
return vertices_disconnected


def average_nn_distance(dist_matrix):
"""Calculate the average distance to each points nearest neighbors.

Parameters
----------
dist_matrix: a csr_matrix
A distance matrix (usually umap_model.graph_)

Returns
-------
An array with the average distance to each points nearest neighbors

"""
(row_idx, col_idx, val) = scipy.sparse.find(dist_matrix)

# Count/sum is done per row
count_non_zero_elems = np.bincount(row_idx)
sum_non_zero_elems = np.bincount(row_idx, weights=val)
averages = sum_non_zero_elems/count_non_zero_elems

if any(np.isnan(averages)):
warn(
"Embedding contains disconnected vertices which will be ignored."
"Use umap.utils.disconnected_vertices() to identify them."
)

return averages