diff --git a/econml/tests/test_cate_interpreter.py b/econml/tests/test_cate_interpreter.py index eff4393cc..8aa359087 100644 --- a/econml/tests/test_cate_interpreter.py +++ b/econml/tests/test_cate_interpreter.py @@ -4,7 +4,6 @@ import numpy as np import unittest import pytest -import matplotlib from econml.cate_interpreter import SingleTreeCateInterpreter, SingleTreePolicyInterpreter from econml.dml import LinearDML from sklearn.linear_model import LinearRegression, LogisticRegression @@ -14,11 +13,11 @@ from graphviz import Graph g = Graph() g.render() + import matplotlib + matplotlib.use('Agg') except Exception: graphviz_works = False -matplotlib.use('Agg') - @pytest.mark.skipif(not graphviz_works, reason="graphviz must be installed to run CATE interpreter tests") class TestCateInterpreter(unittest.TestCase): diff --git a/econml/tests/test_dowhy.py b/econml/tests/test_dowhy.py index 7208e8057..4d7c37ddf 100644 --- a/econml/tests/test_dowhy.py +++ b/econml/tests/test_dowhy.py @@ -62,6 +62,11 @@ def clf(): else: est_dowhy = est.dowhy.fit(Y, T, X=X, W=W) # test causal graph + # need to set matplotlib backend before viewing model + + import matplotlib + matplotlib.use('Agg') + est_dowhy.view_model(layout=None) # test refutation estimate est_dowhy.refute_estimate(method_name="random_common_cause", num_simulations=3) diff --git a/econml/tests/test_shap.py b/econml/tests/test_shap.py index fecbb202c..0761d565b 100644 --- a/econml/tests/test_shap.py +++ b/econml/tests/test_shap.py @@ -125,6 +125,8 @@ def test_discrete_t(self): def test_identical_output(self): # import here since otherwise test collection would fail if matplotlib is not installed + import matplotlib + matplotlib.use('Agg') from shap.plots import scatter, heatmap, bar, beeswarm, waterfall # Treatment effect function