Skip to content

Commit 86fcb28

Browse files
committed
Add prediction_model parameter to ICC sample function
Previously, only the 'exact' model could be used in the sample-wise ICC function. Now, there is a "prediction_model" parameter, similar to the population-based ICC function, that allows for the passing of a pre-defined prediction model. Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
1 parent 7eb4a0c commit 86fcb28

File tree

2 files changed

+88
-15
lines changed

2 files changed

+88
-15
lines changed

dowhy/gcm/influence.py

+39-14
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,12 @@ def intrinsic_causal_influence(
243243
:param prediction_model: Prediction model for estimating the functional relationship between subsets of ancestor
244244
noise terms and the target node. This can be an instance of a PredictionModel, the string
245245
'approx' or the string 'exact'. With 'exact', the underlying causal models in the graph
246-
are utilized directly by propagating given noise inputs through the graph. This is
247-
generally more accurate but slow. With 'approx', an appropriate model is selected and
248-
trained based on sampled data from the graph, which is less accurate but faster. A more
249-
detailed treatment on why we need this parameter is also provided in :ref:`icc`.
246+
are utilized directly by propagating given noise inputs through the graph, which ensures
247+
that generated samples follow the fitted models. In contrast, the 'approx' method involves
248+
selecting and training a suitable model based on data sampled from the graph. This might
249+
lead to deviations from the outcomes of the fitted models, but is faster and can be more
250+
robust in certain settings. A more detailed treatment on why we need this parameter is
251+
also provided in :ref:`icc`.
250252
:param attribution_func: Optional attribution function to measure the statistical property of the target node. This
251253
function expects two inputs; predictions after the randomization of certain features (i.e.
252254
samples from noise nodes) and a baseline where no features were randomized. The baseline
@@ -325,9 +327,11 @@ def intrinsic_causal_influence_sample(
325327
target_node: Any,
326328
baseline_samples: pd.DataFrame,
327329
noise_feature_samples: Optional[pd.DataFrame] = None,
330+
prediction_model: Union[PredictionModel, ClassificationModel, str] = "approx",
328331
subset_scoring_func: Optional[Callable[[np.ndarray, np.ndarray], Union[np.ndarray, float]]] = None,
329332
num_noise_feature_samples: int = 5000,
330333
max_batch_size: int = 100,
334+
auto_assign_quality: auto.AssignmentQuality = auto.AssignmentQuality.GOOD,
331335
shapley_config: Optional[ShapleyConfig] = None,
332336
) -> List[Dict[Any, Any]]:
333337
"""Estimates the intrinsic causal impact of upstream nodes on a specified target_node, using the provided
@@ -342,9 +346,18 @@ def intrinsic_causal_influence_sample(
342346
:param causal_model: The fitted invertible structural causal model.
343347
:param target_node: Node of interest.
344348
:param baseline_samples: Samples for which the influence should be estimated.
345-
:param noise_feature_samples: Optional noise samples of upstream nodes used as 'background' samples.. If None is
349+
:param noise_feature_samples: Optional noise samples of upstream nodes used as 'background' samples. If None is
346350
given, new noise samples are generated based on the graph. These samples are used for
347351
randomizing features that are not in the subset.
352+
:param prediction_model: Prediction model for estimating the functional relationship between subsets of ancestor
353+
noise terms and the target node. This can be an instance of a PredictionModel, the string
354+
'approx' or the string 'exact'. With 'exact', the underlying causal models in the graph
355+
are utilized directly by propagating given noise inputs through the graph, which ensures
356+
that generated samples follow the fitted models. In contrast, the 'approx' method involves
357+
selecting and training a suitable model based on data sampled from the graph. This might
358+
lead to deviations from the outcomes of the fitted models, but is faster and can be more
359+
robust in certain settings. A more detailed treatment on why we need this parameter is
360+
also provided in :ref:`icc`.
348361
:param subset_scoring_func: Set function for estimating the quantity of interest based. This function
349362
expects two inputs; the outcome of the model for some samples if certain features are permuted and the
350363
outcome of the model for the same samples when no features were permuted. By default,
@@ -353,6 +366,7 @@ def intrinsic_causal_influence_sample(
353366
This parameter indicates how many.
354367
:param max_batch_size: Maximum batch size for estimating multiple predictions at once. This has a significant influence on the
355368
overall memory usage. If set to -1, all samples are used in one batch.
369+
:param auto_assign_quality: Auto assign quality for the 'approx' prediction_model option.
356370
:param shapley_config: :class:`~dowhy.gcm.shapley.ShapleyConfig` for the Shapley estimator.
357371
:return: A list of dictionaries indicating the intrinsic causal influence of a node on the target for a particular
358372
sample. This is, each dictionary belongs to one baseline sample.
@@ -376,21 +390,32 @@ def intrinsic_causal_influence_sample(
376390
if subset_scoring_func is None:
377391
subset_scoring_func = means_difference
378392

393+
target_samples = feature_samples[target_node].to_numpy()
394+
node_names = noise_feature_samples.columns
395+
noise_feature_samples, target_samples = shape_into_2d(noise_feature_samples.to_numpy(), target_samples)
396+
397+
prediction_method = _get_icc_noise_function(
398+
causal_model,
399+
target_node,
400+
prediction_model,
401+
noise_feature_samples,
402+
node_names,
403+
target_samples,
404+
auto_assign_quality,
405+
False, # Currently only supports continues target since we need to reconstruct its noise term.
406+
)
407+
379408
shapley_vales = feature_relevance_sample(
380-
_get_icc_noise_function(
381-
causal_model, target_node, "exact", noise_feature_samples, noise_feature_samples.columns, None, None, False
382-
),
383-
feature_samples=noise_feature_samples.to_numpy(),
384-
baseline_samples=compute_noise_from_data(causal_model, baseline_samples)[
385-
noise_feature_samples.columns
386-
].to_numpy(),
409+
prediction_method,
410+
feature_samples=noise_feature_samples,
411+
baseline_samples=compute_noise_from_data(causal_model, baseline_samples)[node_names].to_numpy(),
387412
subset_scoring_func=subset_scoring_func,
388413
max_batch_size=max_batch_size,
389414
shapley_config=shapley_config,
390415
)
391416

392417
return [
393-
{(predecessor, target_node): shapley_vales[i][q] for q, predecessor in enumerate(noise_feature_samples.columns)}
418+
{(predecessor, target_node): shapley_vales[i][q] for q, predecessor in enumerate(node_names)}
394419
for i in range(shapley_vales.shape[0])
395420
]
396421

@@ -432,7 +457,7 @@ def icc_set_function(subset: np.ndarray) -> Union[np.ndarray, float]:
432457

433458

434459
def _get_icc_noise_function(
435-
causal_model: InvertibleStructuralCausalModel,
460+
causal_model: StructuralCausalModel,
436461
target_node: Any,
437462
prediction_model: Union[PredictionModel, ClassificationModel, str],
438463
noise_samples: np.ndarray,

tests/gcm/test_intrinsic_influence.py

+49-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717
)
1818
from dowhy.gcm._noise import noise_samples_of_ancestors
1919
from dowhy.gcm.influence import intrinsic_causal_influence_sample
20-
from dowhy.gcm.ml import create_hist_gradient_boost_classifier, create_linear_regressor_with_given_parameters
20+
from dowhy.gcm.ml import (
21+
create_hist_gradient_boost_classifier,
22+
create_hist_gradient_boost_regressor,
23+
create_linear_regressor,
24+
create_linear_regressor_with_given_parameters,
25+
)
2126
from dowhy.gcm.uncertainty import estimate_entropy_of_probabilities, estimate_variance
2227
from dowhy.gcm.util.general import apply_one_hot_encoding, fit_one_hot_encoders
2328
from dowhy.graph import node_connected_subgraph_view
@@ -247,3 +252,46 @@ def test_given_linear_gaussian_data_when_estimate_sample_wise_intrinsic_causal_i
247252
assert shapley_values[1][("X1", "X3")] == approx(0.5, abs=0.1)
248253
assert shapley_values[1][("X2", "X3")] == approx(2, abs=0.1)
249254
assert shapley_values[1][("X3", "X3")] == approx(1, abs=0.1)
255+
256+
257+
@flaky(max_runs=3)
258+
def test_given_linear_gaussian_data_when_estimate_sample_wise_intrinsic_causal_influence_with_a_pre_defined_model_then_returns_expected_values():
259+
causal_model = InvertibleStructuralCausalModel(nx.DiGraph([("X0", "X1"), ("X1", "X2"), ("X2", "X3")]))
260+
261+
causal_model.set_causal_mechanism("X0", ScipyDistribution(stats.norm, loc=0, scale=1))
262+
causal_model.set_causal_mechanism(
263+
"X1",
264+
AdditiveNoiseModel(
265+
create_linear_regressor_with_given_parameters(np.array([2])), ScipyDistribution(stats.norm, loc=0, scale=1)
266+
),
267+
)
268+
causal_model.set_causal_mechanism(
269+
"X2",
270+
AdditiveNoiseModel(
271+
create_linear_regressor_with_given_parameters(np.array([1])), ScipyDistribution(stats.norm, loc=0, scale=1)
272+
),
273+
)
274+
causal_model.set_causal_mechanism(
275+
"X3",
276+
AdditiveNoiseModel(
277+
create_linear_regressor_with_given_parameters(np.array([1])), ScipyDistribution(stats.norm, loc=0, scale=1)
278+
),
279+
)
280+
_persist_parents(causal_model.graph)
281+
282+
shapley_values = intrinsic_causal_influence_sample(
283+
causal_model,
284+
"X3",
285+
pd.DataFrame({"X0": [0, 1], "X1": [0.5, 2.5], "X2": [1.5, 4.5], "X3": [1.5, 5.5]}),
286+
prediction_model=create_linear_regressor(),
287+
)
288+
289+
assert shapley_values[0][("X0", "X3")] == approx(0, abs=0.15)
290+
assert shapley_values[0][("X1", "X3")] == approx(0.5, abs=0.15)
291+
assert shapley_values[0][("X2", "X3")] == approx(1, abs=0.15)
292+
assert shapley_values[0][("X3", "X3")] == approx(0, abs=0.15)
293+
294+
assert shapley_values[1][("X0", "X3")] == approx(2, abs=0.15)
295+
assert shapley_values[1][("X1", "X3")] == approx(0.5, abs=0.15)
296+
assert shapley_values[1][("X2", "X3")] == approx(2, abs=0.15)
297+
assert shapley_values[1][("X3", "X3")] == approx(1, abs=0.15)

0 commit comments

Comments
 (0)