Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several minor improvements #804

Merged
merged 10 commits into from
Oct 25, 2023
Merged

Several minor improvements #804

merged 10 commits into from
Oct 25, 2023

Conversation

kbattocchi
Copy link
Collaborator

@kbattocchi kbattocchi commented Aug 4, 2023

  • Support direct covariance fitting for DRIV
  • Ensure that groups can be passed to DMLIV and DRIV
  • Dependency cleanup:
    • Enable newer versions of shap, matplotlib, seaborn, and dowhy
    • Drop support for sklearn<1.0 and enable support for sklearn 1.3
  • CI improvements:
    • Run doctests as part of build
    • Don't fail fast when building packages fails on one platform
    • Store test output in an artifact

@@ -526,7 +526,7 @@ def score(self, Y, T, Z, X=None, W=None, sample_weight=None):
The MSE of the final CATE model on the new data.
"""
# Replacing score from _OrthoLearner, to enforce Z to be required and improve the docstring
return super().score(Y, T, X=X, W=W, Z=Z, sample_weight=sample_weight)
return super().score(Y, T, X=X, W=W, Z=Z, sample_weight=sample_weight, groups=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be groups=groups here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good catch. (It doesn't affect the results since groups are never used in scoring, but I'll fix it in the next set of changes).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon further consideration, I've removed groups from the DMLIV and DRIV scoring methods, because they are never used and so there's no point in including them.

The groups argument needs to exist on the nuisance models, because the signatures for fit, predict, and score all need to be compatible for how we do cross-fitting, but there's no need for them to pollute the estimators themselves, and indeed our existing classes like LinearDML do not have groups on their scoring methods.

@@ -318,7 +318,7 @@ def predict(self, X=None):
X = self._transform_X(X, fitting=False)
return self._model_final.predict(X).reshape((-1,) + self.d_y + self.d_t)

def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None):
def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, groups=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groups=groups?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the method definition, so groups=None is correct.

]

for est in est_list:
est.fit(y, T, Z=Z, X=X, W=W, groups=groups)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to make sure the groups are actually being used here? To avoid problems like when groups is accidentally left as None in the call to super().score() instead of threaded through from the args.

@@ -526,7 +526,7 @@ def score(self, Y, T, Z, X=None, W=None, sample_weight=None):
The MSE of the final CATE model on the new data.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor since groups aren't really used for scoring but they are not included in the docstring as parameters

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in a previous comment, removed groups from scoring on the estimator since they do nothing

@@ -837,7 +837,7 @@ def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None,
sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var, groups=groups,
cache_values=cache_values, inference=inference)

def score(self, Y, T, Z, X=None, W=None, sample_weight=None):
def score(self, Y, T, Z, X=None, W=None, sample_weight=None, groups=None):
"""
Score the fitted CATE model on a new data set. Generates nuisance parameters
for the new data set based on the fitted residual nuisance models created at fit time.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groups missing from docstring

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned, removed groups from this method.

@@ -1151,7 +1151,7 @@ def test_groups(self):
est.fit(y, t, groups=groups)

# test outer grouping
est = LinearDML(model_y=LinearRegression(), model_t=LinearRegression(), cv=GroupKFold(2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth adding some check to verify that a GroupKFold splitter was used under the hood?

for est in ests_list:
with self.subTest(est=est):
# no heterogeneity
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor question but is there a benefit to moving this inside the for loop?

Copy link
Collaborator Author

@kbattocchi kbattocchi Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test passes :-). The default is for fit_cov_directly to be True, which means that the previous random seed doesn't generate identical results to what they were before, which lead to a marginal failure on this test, but just slightly reorganizing it made it pass again.

Logically, I think this makes more sense anyway: it's weird to have different loops creating two sets of identical subtests that test different things; if you run the tests locally via unittest you'll see one result per subtest but there won't be any way to tell which was which.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The diff is hard to parse here for some reason even though the actual changes are minimal just like the econml+dowhy version of the notebook.
Not sure why. Different jupyter version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was indeed very weird; fixed.

@@ -793,9 +793,12 @@ def test_groups(self):
est.fit(y, t, W=w, groups=groups)

# test outer grouping
# NOTE: we should ideally use a stratified split with grouping, but sklearn doesn't have one yet
# NOTE: StratifiedGroupKFold has a bug when shuffle is True where it doesn't always stratify properly
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this bug worth worrying about for our users since we use crossfit uses StratifiedGroupKFold with shuffle=True?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully it will be fixed in sklearn and then it will have the right behavior, but until then it's possible that users can run into it (although the buggy behavior only occurs with certain datasets, so hopefully it works most of the time).

However I don't think there's any good fix on our end - in general we do want to shuffle, it's just that for the purposes of this one test we can ignore that, but it wouldn't be an appropriate substitute in general.

@fverac fverac linked an issue Oct 13, 2023 that may be closed by this pull request
@kbattocchi kbattocchi force-pushed the kebatt/minorFixes branch 3 times, most recently from d6aa09e to ff63e62 Compare October 20, 2023 16:11
@kbattocchi kbattocchi marked this pull request as ready for review October 20, 2023 16:43
@@ -1151,7 +1151,7 @@ def test_groups(self):
est.fit(y, t, groups=groups)

# test outer grouping
est = LinearDML(model_y=LinearRegression(), model_t=LinearRegression(), cv=GroupKFold(2))
est = LinearDML(model_y=LinearRegression(), model_t=LinearRegression())
est.fit(y, t, groups=groups)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert isinstance(est.splitter, GroupKFold)

What about adding something like this. Just to protect against the case where groups isn't actually used under the hood.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though seems like currently we don't save the splitter to our ests as an attribute

Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Copy link
Collaborator

@fverac fverac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
@kbattocchi kbattocchi merged commit 5423183 into main Oct 25, 2023
65 checks passed
@kbattocchi kbattocchi deleted the kebatt/minorFixes branch October 25, 2023 05:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Changing covariance logic in DRIV
2 participants