Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scaling ortholearners using Ray #800

Merged
merged 19 commits into from
Oct 27, 2023
Merged

Conversation

v-shaal
Copy link
Contributor

@v-shaal v-shaal commented Aug 2, 2023

issue : 793

  • Added Implementation of Ray based distributed parallelization to crossfit.
  • set flag use_ray = True or False to use ray implementation vs normal implementation
  • parallelized fit_nuisance via ray .
  • Added Testcases to compare ray vs regular implementation
  • Current PR implementation is for DML , can be extended to other estimators using _Othrolearners as baseclass

Signed-off-by: Vishal Verma <vishalmverma27@gmail.com>
Signed-off-by: Vishal Verma <vishalmverma27@gmail.com>
Signed-off-by: Vishal Verma <vishalmverma27@gmail.com>
Signed-off-by: Vishal Verma <vishalmverma27@gmail.com>
@v-shaal v-shaal marked this pull request as ready for review August 2, 2023 18:56
Signed-off-by: Vishal Verma <vishalmverma27@gmail.com>
@v-shaal v-shaal marked this pull request as draft August 3, 2023 03:17
@v-shaal v-shaal marked this pull request as ready for review August 3, 2023 05:26
@v-shaal v-shaal marked this pull request as draft August 3, 2023 05:42
Signed-off-by: Vishal Verma <vishalmverma27@gmail.com>
Signed-off-by: Vishal Verma <vishalmverma27@gmail.com>
Copy link
Collaborator

@kbattocchi kbattocchi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, this looks like a great addition to the library. However, there are a few changes that need to be addressed before it can be added.

First of all, please revert your changes to setup.cfg, merge the latest main back into your branch, and then make those changes to pyproject.toml instead - sorry that we changed this out from under you while your PR was in progress, but the package metadata has been moved there instead.

In addition to my comments on individual files, here are some other thoughts:

  • To be broadly useful, these changes need to be propagated to at least the main DML subclasses, rather than just OrthoLearner, RLearner, and DML, but really ideally to everything that uses _crossfit.
  • The coverage report shows that most of the new code in the _OrthoLearner class is never run by any of the tests, since you set use_ray=False for all of the tests that use the class directly. Setting use_ray=True should fix that specific coverage issue, but consider whether additional tests for RLearner or DML would also be useful.
  • This seems like a potentially very helpful feature, so it's probably worth creating a documentation page or notebook, or at the very least an FAQ entry, describing when/why/how to use it.

Comment on lines 196 to 229
extras: "[tf,plt]"
extras: "[tf,plt,ray]"
- kind: other
opts: '-m "cate_api" -n auto'
extras: "[tf,plt]"
extras: "[tf,plt,ray]"
- kind: dml
opts: '-m "dml"'
extras: "[tf,plt]"
extras: "[tf,plt,ray]"
- kind: main
opts: '-m "not (notebook or automl or dml or serial or cate_api or treatment_featurization)" -n 2'
extras: "[tf,plt,dowhy]"
extras: "[tf,plt,dowhy,ray]"
- kind: treatment
opts: '-m "treatment_featurization" -n auto'
extras: "[tf,plt]"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe ray only needs to be added to the main test kind, since that is where the test_ortho_learner tests are run.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in latest commit

Comment on lines 116 to 124
- kind: "except-customer-scenarios"
extras: "[tf,plt]"
extras: "[tf,plt,ray]"
pattern: "(?!CustomerScenarios)"
install_graphviz: true
version: '3.8' # no supported version of tensorflow for 3.9
- kind: "customer-scenarios"
extras: "[plt,dowhy]"
extras: "[plt,dowhy,ray]"
pattern: "CustomerScenarios"
version: '3.9'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless you make any changes to the notebooks to take advantage of the new ray functionality, these changes should not be necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in latest commit

econml/dml/dml.py Show resolved Hide resolved
random_state=None):
random_state=None,
use_ray=False,
**ray_remote_func_options
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**ray_remote_func_options
ray_remote_func_options={}

I think it would be better to make this an explicit dictionary argument, rather than having it implicitly include any other keyword arguments passed to the DML initializer since in the future we might want similar arguments for other compute backends.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This also applies all the way up the hierarchy, to the RLearner and OrthoLearner initializer arguments)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in latest commit

setup.cfg Outdated
@@ -66,6 +66,8 @@ plt =
matplotlib < 3.6.0
dowhy =
dowhy < 0.9
ray =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the inconvenience but these changes now need to be made to pyproject.toml instead - we've tried to move as much of the static metadata for the project as possible to that file.

@@ -272,15 +272,17 @@ def _gen_rlearner_model_final(self):
"""

def __init__(self, *, discrete_treatment, treatment_featurizer, categories,
cv, random_state, mc_iters=None, mc_agg='mean'):
cv, random_state, mc_iters=None, mc_agg='mean', use_ray=False, **ray_remote_func_options):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cv, random_state, mc_iters=None, mc_agg='mean', use_ray=False, **ray_remote_func_options):
cv, random_state, mc_iters=None, mc_agg='mean', use_ray=False, ray_remote_func_options=ray_remote_func_options):

return nuisance_temp, model, test_idxs, (score_temp if calculate_scores else None)


def _crossfit(model, use_ray, folds, ray_remote_fun_option, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes more sense for folds to come before the ray arguments (and certainly for the ray arguments to be adjacent), and these changes make the specification match the docstring.

Suggested change
def _crossfit(model, use_ray, folds, ray_remote_fun_option, *args, **kwargs):
def _crossfit(model, folds, use_ray=False, ray_remote_fun_option={}, *args, **kwargs):

@@ -60,6 +120,10 @@ def _crossfit(model, folds, *args, **kwargs):
function estimates a model of the nuisance function, based on the input
data to fit. Predict evaluates the fitted nuisance function on the input
data to predict.
use_ray: bool, default False (optional)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
use_ray: bool, default False (optional)
use_ray: bool, default False

having a default implies optional

@@ -60,6 +120,10 @@ def _crossfit(model, folds, *args, **kwargs):
function estimates a model of the nuisance function, based on the input
data to fit. Predict evaluates the fitted nuisance function on the input
data to predict.
use_ray: bool, default False (optional)
Flag to indicate whether to use ray to parallelize the cross-fitting step.
ray_remote_fun_option: dict, default None (optional)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ray_remote_fun_option: dict, default None (optional)
ray_remote_fun_option: dict, default {}

Having a default implies optional

nuisance, model_list, fitted_inds, scores = _crossfit(Wrapper(model), folds, X, y, W=y, Z=None)
use_ray = False
ray_remote_fun_option = {}
nuisance, model_list, fitted_inds, scores = _crossfit(Wrapper(model),use_ray, folds,ray_remote_fun_option,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
nuisance, model_list, fitted_inds, scores = _crossfit(Wrapper(model),use_ray, folds,ray_remote_fun_option,
nuisance, model_list, fitted_inds, scores = _crossfit(Wrapper(model), folds, use_ray, ray_remote_fun_option,

1) Fixed ci.yml extras dependencies
2) Added Description of all the added option in doc string in case of dml and rlearner
3) Addressed chaneges suggested for _ortho_learner.py
4)Removed ray.shutdown(), it can be taken care of explicitly on case to case basis .
5)Made ray_remote_func_options as explicit dictionary argument.

What has been added ?
1) Extended the changes to all estimators using _crossfit.
2) Added Test case to run for with_ray and without_ray for above changes
3) Added Notebook on how to use this feature.

Signed-off-by: Vishal Verma <vishalmverma27@gmail.com>
@v-shaal
Copy link
Contributor Author

v-shaal commented Aug 16, 2023

What have been fixed since last commit ?

  1. Fixed ci.yml extras dependencies
  2. Added Description of all the added option in doc string in case of dml and rlearner
  3. Addressed chaneges suggested for _ortho_learner.py
    4)Removed ray.shutdown(), it can be taken care of explicitly on case to case basis .
    5)Made ray_remote_func_options as explicit dictionary argument.

What has been added ?

  1. Extended the changes to all estimators using _crossfit.
  2. Added Test case to run for with_ray and without_ray for above changes
  3. Added Notebook on how to use this feature.

@kbattocchi kindly review the latest commit and provide feedback if any !

… testcases

Signed-off-by: Vishal Verma <vishalmverma27@gmail.com>
…mode for tests.

Signed-off-by: Vishal Verma <vishalmverma27@gmail.com>
use_ray=False,
ray_remote_func_options=None,
):
if ray_remote_func_options is None:
Copy link
Collaborator

@fverac fverac Aug 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this logic to within the OrthoLearner fit function, and remove this logic from all subclass __init__ functions. That way we avoid redundant code in all of the subclass __init__ functions and maintain a scikit-learn-like API. If interested in more context, see the "Instantiation" section of this page https://scikit-learn.org/stable/developers/develop.html#apis-of-scikit-learn-objects.

For instance, imagine a user does the following

est = LinearDML(use_ray=some_dict)
est.use_ray = None # user changes their mind about use_ray
est.fit(…)

We want the logic of converting None to an empty dict in .fit so we can allow for this kind of behavior.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted make sense, I will move the redundant code, to fit function within Ortholearner

@@ -642,6 +657,12 @@ class LinearDML(StatsModelsCateEstimatorMixin, DML):
If None, the random number generator is the :class:`~numpy.random.mtrand.RandomState` instance used
by :mod:`np.random<numpy.random>`.

use_ray: bool, default False
Whether to use Ray to parallelize the cross-fitting step. If True, Ray must be installed.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fix the spacing here with a new line in between the arg descriptions. Same for SparseLinearDML and KernelDML

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted

v-shaal and others added 2 commits August 26, 2023 19:31
-removed redundant code for ray_remote_function and moved to ortholearner's fit

Signed-off-by: Vishal Verma <vishalmverma27@gmail.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
@kbattocchi kbattocchi force-pushed the scaling_ortholearners branch 2 times, most recently from a7c168a to d76dff0 Compare October 25, 2023 05:37
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
@kbattocchi kbattocchi marked this pull request as ready for review October 27, 2023 18:49
Copy link
Collaborator

@kbattocchi kbattocchi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've broken the tests out into a new mark and I think things look good, so I'll merge once the checks pass. Thanks for this contribution!

@@ -412,7 +418,6 @@ def _gen_ortho_learner_model_final(self):
discrete_instrument=False, categories='auto', random_state=None)
est.fit(y, X[:, 0], W=X[:, 1:])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Or (for parallelization using ray)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if my previous comment was unclear: I think including the comment is helpful for understanding why est is being redefined; it's just that it needs to be a comment so that the entire block is valid python code that can be run.

Comment on lines 514 to 515
if ray_remote_func_options is None:
ray_remote_func_options = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider whether just making the default {} instead of None would make sense. In general, we try not to put any logic in our initializers, because it's possible the user will do something like this:

    est = LinearDML()
    est.use_ray = True
    est.ray_remote_options = None

and then the logic to turn it into {} won't run. So I think it's fine to require it to be an actual dictionary instead of None and skip the extra logic.

Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
@kbattocchi kbattocchi merged commit 01899a8 into py-why:main Oct 27, 2023
72 checks passed
kbattocchi added a commit that referenced this pull request Oct 31, 2023
Added Implementation of ray-based distributed parallelization to crossfit.

---------

Signed-off-by: Vishal Verma <vishalmverma27@gmail.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Co-authored-by: Keith Battocchi <kebatt@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants