-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Learned kernel MMD with KeOps backend #602
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Codecov Report
@@ Coverage Diff @@
## master #602 +/- ##
==========================================
- Coverage 83.58% 82.14% -1.44%
==========================================
Files 207 209 +2
Lines 13838 14159 +321
==========================================
+ Hits 11566 11631 +65
- Misses 2272 2528 +256
|
n_features = [5] | ||
n_instances = [(100, 100), (100, 75)] | ||
kernel_a = ['GaussianRBF', 'MyKernel'] | ||
kernel_b = ['GaussianRBF', 'MyKernel', None] | ||
eps = [0.5, 'trainable'] | ||
tests_dk = list(product(n_features, n_instances, kernel_a, kernel_b, eps)) | ||
n_tests_dk = len(tests_dk) | ||
|
||
|
||
@pytest.fixture | ||
def deep_kernel_params(request): | ||
return tests_dk[request.param] | ||
|
||
|
||
@pytest.mark.skipif(not has_keops, reason='Skipping since pykeops is not installed.') | ||
@pytest.mark.parametrize('deep_kernel_params', list(range(n_tests_dk)), indirect=True) | ||
def test_deep_kernel(deep_kernel_params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a huge fan of using this pattern to parametrizing tests, is there a reason not to parametrize each parameter directly? @ascillitoe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeh the more conventional way to do it would be to parametrize each parameter separately. e.g. see test_save_model
in test_saving.py
:
@parametrize('model', [encoder_model])
@parametrize('layer', [None, -1])
def test_save_model(data, model, layer, backend, tmp_path):
This approach has the advantage of giving much more descriptive test names which is useful when things go wrong.
We keep writing tests with the current pattern for consistency with existing tests. But, unless we are going to refactor existing tests very soon maybe we should prioritise adopting/exploring a new pattern so that we have less refactoring to do later...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I was not aware here tbh and followed existing patterns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if has_keops: | ||
class MyKernel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x: LazyTensor, y: LazyTensor) -> LazyTensor: | ||
return (- ((x - y) ** 2).sum(-1)).exp() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is typing the only reason why this needs to be inside if has_keops
block? If so we could just use forward-references to 'LazyTensor'
? @ascillitoe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Looks like that might be the case...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did the foward-ref not work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jklaise and I were chatting earlier and decided (correct me if I'm wrong) that it's not necessarily a better solution, just a bit different, so I kept it as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK fair enough!
(I don't have a strong opinion either way)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM. Ok with proposed departure for the DeepKernel
API for keops
assuming DeepKernel
is never really meant to be used by user directly?
No, its intended usage is always within a learned detector. I will wait then until @ascillitoe leaves comments before making possibly some final changes and merging. |
@arnaudvl the proposed api for the keops |
@@ -37,7 +41,16 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | |||
return torch.einsum('ji,ki->jk', self.dense(x), self.dense(y)) | |||
|
|||
|
|||
tests_lkdrift = ['tensorflow', 'pytorch', 'PyToRcH', 'mxnet'] | |||
if has_keops: | |||
class MyKernelKeops(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as this one?
@@ -2,14 +2,22 @@ | |||
"cells": [ | |||
{ | |||
"cell_type": "markdown", | |||
"metadata": {}, | |||
"metadata": { | |||
"pycharm": { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super nitpicky: It is preferable to strip out this unnecessary metadata...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Following #548 , also extending the learned kernel MMD detector with the KeOps backend to further scale and speed up the detector.
LearnedKernelDrift
is required given decreased coverage.kernel_a
and optionallykernel_b
needs to be set.utils.keops.kernels.DeepKernel
. The main issue here is thatDeepKernel.proj
is not used within the DeepKernel's forward pass, but explicitly called by the learned kernel drift detector. The reason is that we first need to do regular torch tensor computations using the projection, and then with those projected features compute the kernel matrix. KeOps is only used for the latter, not when computing the projection itself. So technically you could pass a separate projection (i.e.DeepKernel.proj
) to the drift detector and apply a weighted sum kernel later on the original and projected data. But this would break the consistency of the API's/input kwargs with the other backends and realistically make the detector harder to understand. As a result, I chose to keep the same DeepKernel format as the PyTorch and TensorFlow backends and deal with this difference directly in the drift detector. This also means there are explicit checks in place (self.has_proj
andself.has_kernel_b
) to check if the DeepKernel format is used and we can do the projection separately.num_workers
for both KeOps and PyTorch backends since it can make a significant difference for the dataloader speed for higher number of instances. Addnum_workers
to PyTorch/KeOps backend where relevant #611batch_size_predict
kwarg which I would also like to add to the PyTorch backend. The reason is that the optimal batch size for training can be wildly different than that for prediction (where we just care about being as fast as possible within our compute budget). So if we e.g. pickbatch_size=32
for training we might want to change this tobatch_size_predict=1000
for prediction using the trained kernel. There is also another reason why they can be very different: during training of the detector the whole training batch (all tensors incl. the projection etc) needs to fit on the GPU. But for predictions across all permutations at once we can first compute all the projections separately, and then lazily evaluate both the projections and original instances for all permutations. This means we can likely get away with much higher batch sizes for the projection predictions. Addbatch_size_predict
as kwarg to PyTorch backend for learned detectors. #612The smaller potential PyTorch changes (
num_workers
andbatch_size_predict
) can be done in a quick follow up PR.