Skip to content

Commit

Permalink
Refactor DynamicDML to remove incompatible method signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Jan 17, 2023
1 parent b84072b commit db4978f
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions econml/panel/dml/_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,32 +547,29 @@ def _gen_model_t(self):
def _gen_model_final(self):
return StatsModelsLinearRegression(fit_intercept=False)

def _gen_ortho_learner_model_nuisance(self, n_periods):
def _gen_ortho_learner_model_nuisance(self):
return _DynamicModelNuisance(
model_t=self._gen_model_t(),
model_y=self._gen_model_y(),
n_periods=n_periods)
n_periods=self._n_periods)

def _gen_ortho_learner_model_final(self, n_periods):
def _gen_ortho_learner_model_final(self):
wrapped_final_model = _DynamicFinalWrapper(
StatsModelsLinearRegression(fit_intercept=False),
fit_cate_intercept=self.fit_cate_intercept,
featurizer=self.featurizer,
use_weight_trick=False)
return _LinearDynamicModelFinal(wrapped_final_model, n_periods=n_periods)
return _LinearDynamicModelFinal(wrapped_final_model, n_periods=self._n_periods)

def _prefit(self, Y, T, *args, groups=None, only_final=False, **kwargs):
# we need to set the number of periods before calling super()._prefit, since that will generate the
# final and nuisance models, which need to have self._n_periods set
u_periods = np.unique(np.unique(groups, return_counts=True)[1])
if len(u_periods) > 1:
raise AttributeError(
"Imbalanced panel. Method currently expects only panels with equal number of periods. Pad your data")
self._n_periods = u_periods[0]
# generate an instance of the final model
self._ortho_learner_model_final = self._gen_ortho_learner_model_final(self._n_periods)
if not only_final:
# generate an instance of the nuisance model
self._ortho_learner_model_nuisance = self._gen_ortho_learner_model_nuisance(self._n_periods)
TreatmentExpansionMixin._prefit(self, Y, T, *args, **kwargs)
super()._prefit(self, Y, T, *args, **kwargs)

def _postfit(self, Y, T, *args, **kwargs):
super()._postfit(Y, T, *args, **kwargs)
Expand Down

0 comments on commit db4978f

Please sign in to comment.