Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DynamicDML #446

Merged
merged 36 commits into from
Aug 11, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
762e0e4
Implement DynamicDML
Mar 31, 2021
1070aea
Add performance tests and an example notebook
Apr 7, 2021
5f6da40
Add scores.
Apr 9, 2021
d8bc1f3
store some internal variables to allow calling from diased inference …
heimengqi Jun 4, 2021
6daf315
Swap t and j indexes to match the paper
Jul 29, 2021
0615d70
Update covariance matrix to include off-diagonal elements
Jul 29, 2021
7dc65b8
Add support for out of order groups
Jul 31, 2021
efd634d
Implement score
Jul 31, 2021
a508615
Merge branch 'master' into moprescu/dynamicdml
Aug 2, 2021
ac4dd70
Update docstring test outputs
Aug 2, 2021
a44a960
Fix merge issues
Aug 2, 2021
1950fd1
Address PR suggestions
Aug 2, 2021
28a92b6
Merge branch 'master' into moprescu/dynamicdml
Aug 2, 2021
4636257
Fix subscript printing in summary
Aug 3, 2021
9328a22
Address PR suggestions
Aug 5, 2021
24ca086
Update nuisance models in notebook
Aug 5, 2021
a39f1b5
Reverse effect indices to match paper
Aug 6, 2021
4210d1d
Add sample code to README
Aug 6, 2021
e7e7289
Merge branch 'master' into moprescu/dynamicdml
Aug 6, 2021
4cc1156
Adjust heterogeneity to depend only on features from the first period
Aug 6, 2021
e74067e
moved dynamic_dml to separate module. fixed remaining bugs in dgp. fi…
vsyrgkanis Aug 6, 2021
99e62e7
fixed ref in doc
vsyrgkanis Aug 6, 2021
a675564
Merge branch 'master' into moprescu/dynamicdml
vsyrgkanis Aug 6, 2021
1cf663a
doc bug
vsyrgkanis Aug 6, 2021
669c284
relaxed dynamci dml tests
vsyrgkanis Aug 6, 2021
0420656
fixed doctest
vsyrgkanis Aug 6, 2021
42c65dd
add ROI notebook
heimengqi Aug 8, 2021
8691c00
Merge branch 'master' into moprescu/dynamicdml
vsyrgkanis Aug 8, 2021
1c069e1
Merge branch 'master' into moprescu/dynamicdml
vsyrgkanis Aug 8, 2021
a5bd8e9
update setup to install jbl file
heimengqi Aug 9, 2021
e155ac6
Merge branch 'moprescu/dynamicdml' of https://github.com/microsoft/Ec…
heimengqi Aug 9, 2021
da8085d
update setup to install jbl file
heimengqi Aug 9, 2021
ebef40b
update roi notebook
heimengqi Aug 9, 2021
f0a5a29
Merge branch 'master' into moprescu/dynamicdml
heimengqi Aug 11, 2021
e77c568
Limit test paralellization
kbattocchi Aug 9, 2021
6aadad9
Merge branch 'master' into moprescu/dynamicdml
heimengqi Aug 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Double Machine Learning (DML)
econml.dml.SparseLinearDML
econml.dml.CausalForestDML
econml.dml.NonParamDML
econml.dml.DynamicDML

.. _dr_api:

Expand Down
40 changes: 36 additions & 4 deletions doc/spec/estimation/dml.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ What are the relevant estimator classes?
This section describes the methodology implemented in the classes, :class:`._RLearner`,
:class:`.DML`, :class:`.LinearDML`,
:class:`.SparseLinearDML`, :class:`.KernelDML`, :class:`.NonParamDML`,
:class:`.CausalForestDML`.
:class:`.CausalForestDML`,
:class:`.DynamicDML`.
Click on each of these links for a detailed module documentation and input parameters of each class.


Expand All @@ -50,6 +51,7 @@ characteristics :math:`X` of the treated samples, then one can use this method.

.. testsetup::

# DML
import numpy as np
X = np.random.choice(np.arange(5), size=(100,3))
Y = np.random.normal(size=(100,2))
Expand All @@ -58,6 +60,12 @@ characteristics :math:`X` of the treated samples, then one can use this method.
t = t0 = t1 = T[:,0]
W = np.random.normal(size=(100,2))

# DynamicDML
groups = np.repeat(a=np.arange(100), repeats=3, axis=0)
X_dyn = np.random.normal(size=(300, 1))
T_dyn = np.random.normal(size=(300, 2))
y_dyn = np.random.normal(size=(300, ))

.. testcode::

from econml.dml import LinearDML
Expand All @@ -71,8 +79,10 @@ Most of the methods provided make a parametric form assumption on the heterogene
linear on some pre-defined; potentially high-dimensional; featurization). These methods include:
:class:`.DML`, :class:`.LinearDML`,
:class:`.SparseLinearDML`, :class:`.KernelDML`.
For fullly non-parametric heterogeneous treatment effect models, checkout the :class:`.NonParamDML`
and the :class:`.CausalForestDML`. For more options of non-parametric CATE estimators,
For fullly non-parametric heterogeneous treatment effect models, check out the :class:`.NonParamDML`
and the :class:`.CausalForestDML`.
For treatments assigned sequentially over several time periods, see the class :class:`.DynamicDML`.
For more options of non-parametric CATE estimators,
check out the :ref:`Forest Estimators User Guide <orthoforestuserguide>`
and the :ref:`Meta Learners User Guide <metalearnersuserguide>`.

Expand Down Expand Up @@ -155,7 +165,7 @@ Class Hierarchy Structure
In this library we implement variants of several of the approaches mentioned in the last section. The hierarchy
structure of the implemented CATE estimators is as follows.

.. inheritance-diagram:: econml.dml.LinearDML econml.dml.SparseLinearDML econml.dml.KernelDML econml.dml.NonParamDML econml.dml.CausalForestDML
.. inheritance-diagram:: econml.dml.LinearDML econml.dml.SparseLinearDML econml.dml.KernelDML econml.dml.NonParamDML econml.dml.CausalForestDML econml.dml.DynamicDML
:parts: 1
:private-bases:
:top-classes: econml._rlearner._RLearner, econml._cate_estimator.StatsModelsCateEstimatorMixin, econml._cate_estimator.DebiasedLassoCateEstimatorMixin
Expand Down Expand Up @@ -286,6 +296,28 @@ Below we give a brief description of each of these classes:
Check out :ref:`Forest Estimators User Guide <orthoforestuserguide>` for more information on forest based CATE models and other
alternatives to the :class:`.CausalForestDML`.

* **DynamicDML.** The class :class:`.DynamicDML` is an extension of the Double ML approach for treatments assigned sequentially over time periods.
This estimator will adjust for treatments that can have causal effects on future outcomes. The data corresponds to a Markov decision process :math:`\{X_t, W_t, T_t, Y_t\}_{t=1}^m`,
where :math:`X_t, W_t` corresponds to the state at time :math:`t`, :math:`T_t` is the treatment at time :math:`t` and :math:`Y_t` is the observed outcome at time :math:`t`.

The model makes the following structural equation assumptions on the data generating process:

.. math::

X_t =~& A \cdot T_{t-1} + B \cdot X_{t-1} + \eta_t\\
T_t =~& p(T_{t-1}, X_t, \zeta_t) \\
Y_t =~& \theta_0'T_t + \mu'X_t \epsilon_t

For more details about this model and underlying assumptions, see [Lewis2021]_.

To learn the treatment effects of treatments in the different periods on the last period outcome, one can simply call:

.. testcode::

from econml.dml import DynamicDML
est = DynamicDML()
est.fit(y_dyn, T_dyn, X=X_dyn, W=None, groups=groups, inference="auto")

* **_RLearner.** The internal private class :class:`._RLearner` is a parent of the :class:`.DML`
and allows the user to specify any way of fitting a final model that takes as input the residual :math:`\tilde{T}`,
the features :math:`X` and predicts the residual :math:`\tilde{Y}`. Moreover, the nuisance models take as input
Expand Down
7 changes: 6 additions & 1 deletion doc/spec/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,9 @@ References
.. [Lundberg2017]
Lundberg, S., Lee, S. (2017).
A Unified Approach to Interpreting Model Predictions.
URL https://arxiv.org/abs/1705.07874
URL https://arxiv.org/abs/1705.07874

.. [Lewis2021]
Lewis, G., Syrgkanis, V. (2021).
Double/Debiased Machine Learning for Dynamic Treatment Effects.
URL https://arxiv.org/abs/2002.07285
2 changes: 1 addition & 1 deletion econml/_cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def effect(self, X=None, *, T0, T1):
"""
Calculate the heterogeneous treatment effect :math:`\\tau(X, T0, T1)`.

The effect is calculatred between the two treatment points
The effect is calculated between the two treatment points
conditional on a vector of features on a set of m test samples :math:`\\{T0_i, T1_i, X_i\\}`.
Since this class assumes a linear effect, only the difference between T0ᵢ and T1ᵢ
matters for this computation.
Expand Down
21 changes: 13 additions & 8 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,8 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, freq_weight=N
nuisances=nuisances,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var)
sample_var=sample_var,
groups=groups)

return self

Expand Down Expand Up @@ -770,18 +771,19 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, group
return nuisances, fitted_models, fitted_inds, scores

def _fit_final(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None,
freq_weight=None, sample_var=None):
freq_weight=None, sample_var=None, groups=None):
self._ortho_learner_model_final.fit(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var))
sample_var=sample_var,
groups=groups))
self.score_ = None
if hasattr(self._ortho_learner_model_final, 'score'):
self.score_ = self._ortho_learner_model_final.score(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight)
)
sample_weight=sample_weight,
groups=groups))

def const_marginal_effect(self, X=None):
moprescu marked this conversation as resolved.
Show resolved Hide resolved
X, = check_input_arrays(X)
Expand Down Expand Up @@ -816,7 +818,7 @@ def effect_inference(self, X=None, *, T0=0, T1=1):
return super().effect_inference(X, T0=T0, T1=T1)
effect_inference.__doc__ = LinearCateEstimator.effect_inference.__doc__

def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
"""
Score the fitted CATE model on a new data set. Generates nuisance parameters
for the new data set based on the fitted nuisance models created at fit time.
Expand All @@ -840,6 +842,8 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
Instruments for each sample
sample_weight: optional(n,) vector or None (Default=None)
Weights for each samples
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.

Returns
-------
Expand All @@ -862,7 +866,7 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
for i, models_nuisances in enumerate(self._models_nuisance):
# for each model under cross fit setting
for j, mdl in enumerate(models_nuisances):
nuisance_temp = mdl.predict(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z))
nuisance_temp = mdl.predict(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z, groups=groups))
if not isinstance(nuisance_temp, tuple):
nuisance_temp = (nuisance_temp,)

Expand All @@ -876,7 +880,8 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
nuisances[it] = np.mean(nuisances[it], axis=0)

return self._ortho_learner_model_final.score(Y, T, nuisances=nuisances,
**filter_none_kwargs(X=X, W=W, Z=Z, sample_weight=sample_weight))
**filter_none_kwargs(X=X, W=W, Z=Z,
sample_weight=sample_weight, groups=groups))

@property
def ortho_learner_model_final_(self):
Expand Down
9 changes: 7 additions & 2 deletions econml/dml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
Then estimates a CATE model by regressing the residual outcome on the residual treatment
in a manner that accounts for heterogeneity in the regression coefficient, with respect
to X. For the theoretical foundations of these methods see [dml]_, [rlearner]_, [paneldml]_,
[lassodml]_, [ortholearner]_.
[lassodml]_, [ortholearner]_, [dynamicdml]_.

References
----------
Expand All @@ -33,10 +33,14 @@
Orthogonal Statistical Learning.
ACM Conference on Learning Theory. `<https://arxiv.org/abs/1901.09036>`_

.. [dynamicdml] Greg Lewis and Vasilis Syrgkanis.
Double/Debiased Machine Learning for Dynamic Treatment Effects.
`<https://arxiv.org/abs/2002.07285>`_, 2021.
"""

from .dml import (DML, LinearDML, SparseLinearDML,
KernelDML, NonParamDML, ForestDML)
from .dynamic_dml import DynamicDML
from .causal_forest import CausalForestDML

__all__ = ["DML",
Expand All @@ -45,4 +49,5 @@
"KernelDML",
"NonParamDML",
"ForestDML",
"CausalForestDML", ]
"CausalForestDML",
"DynamicDML"]
5 changes: 3 additions & 2 deletions econml/dml/_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ class _ModelFinal:
def __init__(self, model_final):
self._model_final = model_final

def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, freq_weight=None, sample_var=None):
def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None,
sample_weight=None, freq_weight=None, sample_var=None, groups=None):
Y_res, T_res = nuisances
self._model_final.fit(X, T, T_res, Y_res, sample_weight=sample_weight,
freq_weight=freq_weight, sample_var=sample_var)
Expand All @@ -100,7 +101,7 @@ def fit(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None,
def predict(self, X=None):
return self._model_final.predict(X)

def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None):
def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, groups=None):
Y_res, T_res = nuisances
if Y_res.ndim == 1:
Y_res = Y_res.reshape((-1, 1))
Expand Down
2 changes: 1 addition & 1 deletion econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _ate_and_stderr(self, drpreds, mask=None):
stderr = (np.nanstd(drpreds, axis=0) / np.sqrt(nonnan)).reshape(self._d_y + self._d_t)
return point, stderr

def fit(self, X, T, T_res, Y_res, sample_weight=None, freq_weight=None, sample_var=None):
def fit(self, X, T, T_res, Y_res, sample_weight=None, freq_weight=None, sample_var=None, groups=None):
# Track training dimensions to see if Y or T is a vector instead of a 2-dimensional array
self._d_t = shape(T_res)[1:]
self._d_y = shape(Y_res)[1:]
Expand Down
2 changes: 1 addition & 1 deletion econml/dml/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _combine(self, X, T, fitting=True):
F = np.ones((T.shape[0], 1))
return cross_product(F, T)

def fit(self, X, T, T_res, Y_res, sample_weight=None, freq_weight=None, sample_var=None):
def fit(self, X, T, T_res, Y_res, sample_weight=None, freq_weight=None, sample_var=None, groups=None):
# Track training dimensions to see if Y or T is a vector instead of a 2-dimensional array
self._d_t = shape(T_res)[1:]
self._d_y = shape(Y_res)[1:]
Expand Down
Loading