Skip to content

Commit

Permalink
Add new method to estimate KL divergence using classifier
Browse files Browse the repository at this point in the history
This should work better with multivariate data and mixed data types. However, it is generally slower than the knn appraoch.

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
  • Loading branch information
bloebp committed Nov 10, 2023
1 parent 395d1fa commit 8dc8bf8
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 21 deletions.
90 changes: 86 additions & 4 deletions dowhy/gcm/divergence.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,40 @@
"""Functions in this module should be considered experimental, meaning there might be breaking API changes in the
future.
"""
from functools import partial
from typing import Callable, Union

import numpy as np
from scipy.stats import entropy
from sklearn.model_selection import KFold
from sklearn.neighbors import NearestNeighbors

from dowhy.gcm.auto import AssignmentQuality, select_model
from dowhy.gcm.constant import EPS
from dowhy.gcm.util.general import is_categorical, setdiff2d, shape_into_2d
from dowhy.gcm.ml.classification import ClassificationModel, create_logistic_regression_classifier
from dowhy.gcm.util.general import has_categorical, is_categorical, setdiff2d, shape_into_2d


def auto_estimate_kl_divergence(X: np.ndarray, Y: np.ndarray) -> float:
if is_categorical(X):
return estimate_kl_divergence_categorical(X, Y)
elif is_probability_matrix(X):
elif not has_categorical(X) and is_probability_matrix(X):
return estimate_kl_divergence_of_probabilities(X, Y)
else:
return estimate_kl_divergence_continuous(X, Y)
if X.ndim == 2 and X.shape[1] > 1:
return estimate_kl_divergence_continuous_clf(X, Y)
else:
return estimate_kl_divergence_continuous_knn(X, Y)


def estimate_kl_divergence_continuous(
def estimate_kl_divergence_continuous_knn(
X: np.ndarray, Y: np.ndarray, k: int = 1, remove_common_elements: bool = True
) -> float:
"""Estimates KL-Divergence using k-nearest neighbours (Wang et al., 2009).
While, in theory, this handles multidimensional inputs, consider using estimate_kl_divergence_continuous_clf
for data with more than one dimension.
Q. Wang, S. R. Kulkarni, and S. Verdú,
"Divergence estimation for multidimensional densities via k-nearest-neighbor distances",
IEEE Transactions on Information Theory, vol. 55, no. 5, pp. 2392-2405, May 2009.
Expand Down Expand Up @@ -82,6 +93,75 @@ def estimate_kl_divergence_continuous(
return result


def estimate_kl_divergence_continuous_clf(
samples_P: np.ndarray,
samples_Q: np.ndarray,
n_splits: int = 5,
classifier_model: Union[AssignmentQuality, Callable[[], ClassificationModel]] = partial(
create_logistic_regression_classifier, max_iter=10000
),
epsilon: float = EPS,
) -> float:
"""Estimates KL-Divergence based on probabilities given by classifier. This is:
D_f(P || Q) = \int f(p(x)/q(x)) q(x) dx ~= -1/N \sum_x log(p(Y = 1 | x) / (1 - p(Y = 1 | x)))
Here, the KL divergence can be approximated using the log ratios of probabilities to predict whether a sample
comes from distribution P or Q.
:param samples_P: Samples drawn from P. Can have a different number of samples than Q.
:param samples_Q: Samples drawn from Q. Can have a different number of samples than P.
:param n_splits: Number of splits of the training and test data. The classifier is trained on the training
data and evaluated on the test data to obtain the probabilities.
:param classifier_model: Used to estimate the probabilities for the log ratio. This can either be a
ClassificationModel or an AssignmentQuality. In the latter, a model is automatically
selected based on the best performance on a training set.
:param epsilon: If the probability is either 1 or 0, this value will be used for clipping, i.e., 0 becomes epsilon
and 1 becomes 1- epsilon.
:return: Estimated value of the KL divergence D(P||Q).
"""
samples_P, samples_Q = shape_into_2d(samples_P, samples_Q)

if samples_P.shape[1] != samples_Q.shape[1]:
raise ValueError("X and Y need to have the same number of features!")

all_probs = []

splits_p = list(KFold(n_splits=n_splits, shuffle=True).split(samples_P))
splits_q = list(KFold(n_splits=n_splits, shuffle=True).split(samples_Q))

if isinstance(classifier_model, AssignmentQuality):
classifier_model = select_model(
np.vstack([samples_P, samples_Q]),
np.concatenate([np.zeros(samples_P.shape[0]), np.ones(samples_Q.shape[0])]).astype(str),
classifier_model,
)[0]
else:
classifier_model = classifier_model()

for k in range(n_splits):
# Balance the classes
num_samples = min(len(splits_p[k][0]), len(splits_q[k][0]))

classifier_model.fit(
np.vstack([samples_P[splits_p[k][0][:num_samples]], samples_Q[splits_q[k][0][:num_samples]]]),
np.concatenate([np.zeros(num_samples), np.ones(num_samples)]).astype(str),
)

probs_P = classifier_model.predict_probabilities(samples_P[splits_p[k][1]])[:, 1]
probs_P[probs_P == 0] = epsilon
probs_P[probs_P == 1] = 1 - epsilon
all_probs.append(probs_P)

all_probs = np.concatenate(all_probs)
kl_divergence = -np.mean(np.log(all_probs / (1 - all_probs)))

if kl_divergence < 0:
kl_divergence = 0

return kl_divergence


def estimate_kl_divergence_categorical(X: np.ndarray, Y: np.ndarray) -> float:
X, Y = shape_into_2d(X, Y)

Expand Down Expand Up @@ -116,5 +196,7 @@ def estimate_kl_divergence_of_probabilities(X: np.ndarray, Y: np.ndarray) -> flo
def is_probability_matrix(X: np.ndarray) -> bool:
if X.ndim == 1:
return np.all(np.isclose(np.sum(abs(X.astype(np.float64)), axis=0), 1))
elif X.shape[1] == 1:
return False
else:
return np.all(np.isclose(np.sum(abs(X.astype(np.float64)), axis=1), 1))
5 changes: 3 additions & 2 deletions dowhy/gcm/stochastic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from sklearn.mixture import BayesianGaussianMixture

from dowhy.gcm.causal_mechanisms import StochasticModel
from dowhy.gcm.divergence import estimate_kl_divergence_continuous
from dowhy.gcm.util.general import shape_into_2d

_CONTINUOUS_DISTRIBUTIONS = [
Expand Down Expand Up @@ -127,7 +126,9 @@ def find_suitable_continuous_distribution(
generated_samples = distribution.rvs(size=distribution_samples.shape[0], loc=loc, scale=scale, *arg)

# Check the KL divergence between the distribution of the given and fitted distribution.
divergence = estimate_kl_divergence_continuous(distribution_samples, generated_samples)
from dowhy.gcm.divergence import estimate_kl_divergence_continuous_knn

divergence = estimate_kl_divergence_continuous_knn(distribution_samples, generated_samples)
if divergence < divergence_threshold:
currently_best_distribution = distribution
currently_best_parameters = params
Expand Down
25 changes: 18 additions & 7 deletions tests/gcm/test_arrow_strength.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
fit,
)
from dowhy.gcm.auto import assign_causal_mechanisms
from dowhy.gcm.divergence import estimate_kl_divergence_continuous
from dowhy.gcm.divergence import estimate_kl_divergence_continuous_knn
from dowhy.gcm.influence import arrow_strength_of_model
from dowhy.gcm.ml import create_linear_regressor, create_logistic_regression_classifier
from dowhy.gcm.ml import (
create_linear_regressor,
create_linear_regressor_with_given_parameters,
create_logistic_regression_classifier,
)


@pytest.fixture
Expand All @@ -34,10 +38,11 @@ def preserve_random_generator_state():
@flaky(max_runs=5)
def test_given_kl_divergence_attribution_func_when_estimate_arrow_strength_then_returns_expected_results():
causal_strengths = arrow_strength(
_create_causal_model(), "X2", difference_estimation_func=estimate_kl_divergence_continuous
_create_causal_model(), "X2", difference_estimation_func=estimate_kl_divergence_continuous_knn
)
assert causal_strengths[("X0", "X2")] == approx(2.76, abs=0.4)
assert causal_strengths[("X1", "X2")] == approx(1.6, abs=0.4)

assert causal_strengths[("X0", "X2")] == approx(1.2, abs=0.2)
assert causal_strengths[("X1", "X2")] == approx(0.3, abs=0.1)


@flaky(max_runs=5)
Expand Down Expand Up @@ -199,12 +204,18 @@ def _create_causal_model():
causal_model = ProbabilisticCausalModel(nx.DiGraph([("X1", "X2"), ("X0", "X2")]))
causal_model.set_causal_mechanism("X1", ScipyDistribution(stats.norm, loc=0, scale=1))
causal_model.set_causal_mechanism("X0", ScipyDistribution(stats.norm, loc=0, scale=1))
causal_model.set_causal_mechanism("X2", AdditiveNoiseModel(prediction_model=create_linear_regressor()))
causal_model.set_causal_mechanism(
"X2",
AdditiveNoiseModel(
prediction_model=create_linear_regressor_with_given_parameters([3, 1]),
noise_model=ScipyDistribution(stats.norm, loc=0, scale=1),
),
)

X0 = np.random.normal(0, 1, 1000)
X1 = np.random.normal(0, 1, 1000)

test_data = pd.DataFrame({"X0": X0, "X1": X1, "X2": 3 * X0 + X1 + np.random.normal(0, 0.2, X0.shape[0])})
test_data = pd.DataFrame({"X0": X0, "X1": X1, "X2": 3 * X0 + X1 + np.random.normal(0, 1, X0.shape[0])})
fit(causal_model, test_data)

return causal_model
56 changes: 51 additions & 5 deletions tests/gcm/test_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from dowhy.gcm.divergence import (
auto_estimate_kl_divergence,
estimate_kl_divergence_categorical,
estimate_kl_divergence_continuous,
estimate_kl_divergence_continuous_clf,
estimate_kl_divergence_continuous_knn,
estimate_kl_divergence_of_probabilities,
is_probability_matrix,
)
Expand All @@ -16,8 +17,8 @@ def test_given_simple_gaussian_data_when_estimate_kl_divergence_continuous_then_
X = np.random.normal(0, 1, 2000)
Y = np.random.normal(1, 1, 2000)

assert estimate_kl_divergence_continuous(X, X) == approx(0, abs=0.001)
assert estimate_kl_divergence_continuous(X, Y) == approx(0.5, abs=0.15)
assert estimate_kl_divergence_continuous_knn(X, X) == approx(0, abs=0.001)
assert estimate_kl_divergence_continuous_knn(X, Y) == approx(0.5, abs=0.15)


@flaky(max_runs=3)
Expand Down Expand Up @@ -81,5 +82,50 @@ def test_given_simple_gaussian_data_with_overlap_when_estimate_kl_divergence_con

Y[:10] = X[:10]

assert estimate_kl_divergence_continuous(X, X) == approx(0, abs=0.001)
assert estimate_kl_divergence_continuous(X, Y) == approx(0.5, abs=0.15)
assert estimate_kl_divergence_continuous_knn(X, X) == approx(0, abs=0.001)
assert estimate_kl_divergence_continuous_knn(X, Y) == approx(0.5, abs=0.15)


@flaky(max_runs=3)
def test_given_simple_gaussian_data_when_estimate_kl_divergence_continuous_clf_then_returns_correct_result():
X = np.random.normal(0, 1, 2000)
Y = np.random.normal(1, 1, 2000)

assert estimate_kl_divergence_continuous_clf(X, X) == approx(0, abs=0.001)
assert estimate_kl_divergence_continuous_clf(X, Y) == approx(0.5, abs=0.15)


@flaky(max_runs=3)
def test_given_multi_dim_simple_gaussian_data_when_estimate_kl_divergence_continuous_clf_then_returns_correct_result():
X = np.random.normal(0, 1, (2000, 2))
Y = np.random.normal(1, 1, (2000, 2))

assert estimate_kl_divergence_continuous_clf(X, X) == approx(0, abs=0.001)
assert estimate_kl_divergence_continuous_clf(X, Y) == approx(1, abs=0.2)


@flaky(max_runs=3)
def test_given_multi_dim_gaussian_and_categorical_data_when_estimate_kl_divergence_continuous_clf_then_returns_correct_result():
X0 = np.random.normal(0, 1, 2000)
X1 = np.random.choice(3, 2000, replace=True).astype(str)

X0_other = np.random.normal(1, 1, 2000)
X1_other = (np.random.choice(4, 2000, replace=True)).astype(str)

X = np.array([X0, X1], dtype=object).T

assert estimate_kl_divergence_continuous_clf(X, X) == approx(0, abs=0.001)
# Only Gaussian component changed
assert estimate_kl_divergence_continuous_clf(X, np.array([X0_other, X1], dtype=object).T) == approx(0.5, abs=0.15)
assert estimate_kl_divergence_continuous_clf(X, np.array([X0_other, X1_other], dtype=object).T) == approx(
0.78, abs=0.15
)


@flaky(max_runs=3)
def test_given_exponential_data_when_estimate_kl_divergence_continuous_then_returns_correct_result():
X = np.random.exponential(1, 2000)
Y = np.random.exponential(0.5, 2000)

assert estimate_kl_divergence_continuous_clf(X, Y) == approx(0.31, abs=0.1)
assert estimate_kl_divergence_continuous_knn(X, Y) == approx(0.31, abs=0.1)
6 changes: 3 additions & 3 deletions tests/gcm/test_fcms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
fit,
)
from dowhy.gcm.auto import assign_causal_mechanisms
from dowhy.gcm.divergence import estimate_kl_divergence_continuous
from dowhy.gcm.divergence import estimate_kl_divergence_continuous_clf
from dowhy.gcm.ml import (
SklearnRegressionModel,
create_linear_regressor,
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_given_linear_data_when_draw_samples_from_fitted_anm_then_generates_corr
generated_samples = scm.causal_mechanism("X1").draw_samples(np.array([2] * 1000))
assert np.mean(generated_samples) == approx(6, abs=0.05)
assert np.std(generated_samples) == approx(0.1, abs=0.05)
assert estimate_kl_divergence_continuous(
assert estimate_kl_divergence_continuous_clf(
test_data["X1"].to_numpy(), draw_samples(scm, 10000)["X1"].to_numpy()
) == approx(0, abs=0.05)

Expand Down Expand Up @@ -102,7 +102,7 @@ def test_given_categorical_input_data_when_draw_from_fitted_causal_graph_with_li
test_data[:, 2].astype(float).reshape(-1, 1)
)

assert estimate_kl_divergence_continuous(test_data[:, 2], draw_samples(scm, 1000)["X2"].to_numpy()) == approx(
assert estimate_kl_divergence_continuous_clf(test_data[:, 2], draw_samples(scm, 1000)["X2"].to_numpy()) == approx(
0, abs=0.05
)

Expand Down

0 comments on commit 8dc8bf8

Please sign in to comment.