From b2e75a7a709bab018acb976348569f0296285c82 Mon Sep 17 00:00:00 2001 From: Patrick Bloebaum Date: Fri, 17 Nov 2023 07:20:06 -0800 Subject: [PATCH] Fix issue in KL estimation using knn Before, the method threw an error when all samples were equal. However, in these cases, it should rather return a KL divergence of 0. Signed-off-by: Patrick Bloebaum --- dowhy/gcm/divergence.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/dowhy/gcm/divergence.py b/dowhy/gcm/divergence.py index 13833d5bf9..88ddf4e749 100644 --- a/dowhy/gcm/divergence.py +++ b/dowhy/gcm/divergence.py @@ -21,15 +21,11 @@ def auto_estimate_kl_divergence(X: np.ndarray, Y: np.ndarray) -> float: if X.ndim == 2 and X.shape[1] > 1: return estimate_kl_divergence_continuous_clf(X, Y) else: - try: - return estimate_kl_divergence_continuous_knn(X, Y) - except _KNNTooFewSamples: - # If there are too many common elements, this error can happen. - return estimate_kl_divergence_continuous_clf(X, Y) + return estimate_kl_divergence_continuous_knn(X, Y) def estimate_kl_divergence_continuous_knn( - X: np.ndarray, Y: np.ndarray, k: int = 1, remove_common_elements: bool = True + X: np.ndarray, Y: np.ndarray, k: int = 1, remove_common_elements: bool = True, n_jobs: int = 1 ) -> float: """Estimates KL-Divergence using k-nearest neighbours (Wang et al., 2009). @@ -46,6 +42,9 @@ def estimate_kl_divergence_continuous_knn( :param remove_common_elements: If true, common values in X and Y are removed. This would otherwise lead to a KNN distance of zero for these values if k is set to 1, which would cause a division by zero error. + :param n_jobs: Number of parallel jobs used for the nearest neighbors model. -1 means it uses all available cores. + Note that in most applications, parallelizing this rather introduces more overhead, leading to a + slower runtime. return: Estimated value of D(P_X||P_Y). """ X, Y = shape_into_2d(X, Y) @@ -64,10 +63,8 @@ def estimate_kl_divergence_continuous_knn( if remove_common_elements: X = setdiff2d(X, Y, assume_unique=True) if X.shape[0] < k + 1: - raise _KNNTooFewSamples( - "After removing common elements, there are too few samples left! " - "Got %d samples with k=%d." % (X.shape[0], k) - ) + # All elements are equal (or at least less than k samples are different) + return 0 n, m = X.shape[0], Y.shape[0] if n == 0: @@ -75,8 +72,8 @@ def estimate_kl_divergence_continuous_knn( d = float(X.shape[1]) - x_neighbourhood = NearestNeighbors(n_neighbors=k + 1).fit(X) - y_neighbourhood = NearestNeighbors(n_neighbors=k).fit(Y) + x_neighbourhood = NearestNeighbors(n_neighbors=k + 1, n_jobs=n_jobs).fit(X) + y_neighbourhood = NearestNeighbors(n_neighbors=k, n_jobs=n_jobs).fit(Y) distances_x, _ = x_neighbourhood.kneighbors(X, n_neighbors=k + 1) distances_y, _ = y_neighbourhood.kneighbors(X, n_neighbors=k) @@ -206,9 +203,3 @@ def is_probability_matrix(X: np.ndarray) -> bool: return False else: return np.all(np.isclose(np.sum(abs(X.astype(np.float64)), axis=1), 1)) - - -class _KNNTooFewSamples(Exception): - def __init__(self, message: str): - self.message = message - super().__init__(self.message)