Skip to content

Commit

Permalink
Fix issue in KL estimation using knn
Browse files Browse the repository at this point in the history
Before, the method threw an error when all samples were equal. However, in these cases, it should rather return a KL divergence of 0.

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
  • Loading branch information
bloebp committed Nov 17, 2023
1 parent 601c2ae commit b2e75a7
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions dowhy/gcm/divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,11 @@ def auto_estimate_kl_divergence(X: np.ndarray, Y: np.ndarray) -> float:
if X.ndim == 2 and X.shape[1] > 1:
return estimate_kl_divergence_continuous_clf(X, Y)
else:
try:
return estimate_kl_divergence_continuous_knn(X, Y)
except _KNNTooFewSamples:
# If there are too many common elements, this error can happen.
return estimate_kl_divergence_continuous_clf(X, Y)
return estimate_kl_divergence_continuous_knn(X, Y)


def estimate_kl_divergence_continuous_knn(
X: np.ndarray, Y: np.ndarray, k: int = 1, remove_common_elements: bool = True
X: np.ndarray, Y: np.ndarray, k: int = 1, remove_common_elements: bool = True, n_jobs: int = 1
) -> float:
"""Estimates KL-Divergence using k-nearest neighbours (Wang et al., 2009).
Expand All @@ -46,6 +42,9 @@ def estimate_kl_divergence_continuous_knn(
:param remove_common_elements: If true, common values in X and Y are removed. This would otherwise lead to
a KNN distance of zero for these values if k is set to 1, which would cause a
division by zero error.
:param n_jobs: Number of parallel jobs used for the nearest neighbors model. -1 means it uses all available cores.
Note that in most applications, parallelizing this rather introduces more overhead, leading to a
slower runtime.
return: Estimated value of D(P_X||P_Y).
"""
X, Y = shape_into_2d(X, Y)
Expand All @@ -64,19 +63,17 @@ def estimate_kl_divergence_continuous_knn(
if remove_common_elements:
X = setdiff2d(X, Y, assume_unique=True)
if X.shape[0] < k + 1:
raise _KNNTooFewSamples(
"After removing common elements, there are too few samples left! "
"Got %d samples with k=%d." % (X.shape[0], k)
)
# All elements are equal (or at least less than k samples are different)
return 0

n, m = X.shape[0], Y.shape[0]
if n == 0:
return 0

d = float(X.shape[1])

x_neighbourhood = NearestNeighbors(n_neighbors=k + 1).fit(X)
y_neighbourhood = NearestNeighbors(n_neighbors=k).fit(Y)
x_neighbourhood = NearestNeighbors(n_neighbors=k + 1, n_jobs=n_jobs).fit(X)
y_neighbourhood = NearestNeighbors(n_neighbors=k, n_jobs=n_jobs).fit(Y)

distances_x, _ = x_neighbourhood.kneighbors(X, n_neighbors=k + 1)
distances_y, _ = y_neighbourhood.kneighbors(X, n_neighbors=k)
Expand Down Expand Up @@ -206,9 +203,3 @@ def is_probability_matrix(X: np.ndarray) -> bool:
return False
else:
return np.all(np.isclose(np.sum(abs(X.astype(np.float64)), axis=1), 1))


class _KNNTooFewSamples(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(self.message)

0 comments on commit b2e75a7

Please sign in to comment.