Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KeOps MMD detector #548

Merged
merged 53 commits into from
Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
b322dbf
first commit keops
arnaudvl Jun 24, 2022
29ad4d9
update kernel and mmd keops
arnaudvl Jun 24, 2022
dd2d13e
allow multiple kernel bandwidths for keops
arnaudvl Jun 24, 2022
0a7944c
fix bug
arnaudvl Jun 24, 2022
17d1662
update mmd
arnaudvl Jun 30, 2022
cf528f7
remove learned kernel and base kernel_matrix MMD function
arnaudvl Jul 4, 2022
38ec19b
unify batched mmd2
arnaudvl Jul 5, 2022
e943c6c
update keops mmd
arnaudvl Jul 6, 2022
0487cb2
update docs and kernel import
arnaudvl Jul 6, 2022
a6a4641
bugfixes
arnaudvl Jul 7, 2022
e2b27f5
remove unused imports
arnaudvl Jul 8, 2022
d442be8
add benchmarking example
arnaudvl Jul 8, 2022
244ceb2
update test mmd
arnaudvl Jul 8, 2022
f913f5b
add test mmd keops
arnaudvl Jul 8, 2022
2da4a9a
update readme
arnaudvl Jul 8, 2022
c49f1d2
bugfix kernel and update mmd test
arnaudvl Jul 8, 2022
75481cc
remove print from test
arnaudvl Jul 8, 2022
e6996b9
update keops tests
arnaudvl Jul 8, 2022
afda208
Merge master and resolve conflicts
ascillitoe Jul 26, 2022
c87109f
Add save warning and update tests
ascillitoe Jul 26, 2022
eb307b6
Update setup and associated docs
ascillitoe Jul 26, 2022
0db2239
Fix typing issue in
ascillitoe Jul 27, 2022
45e7211
Install keops as part of CI
ascillitoe Jul 27, 2022
9cee6bc
Add keops tox environment
mauicv Jul 27, 2022
fb52d4c
Merge branch 'master' into keops
ascillitoe Jul 27, 2022
3fa460c
Add keops to all dependency bucket
mauicv Jul 28, 2022
44df386
Merge branch 'master' into arnaudvl-keops
mauicv Jul 28, 2022
82f6d3c
Fix minor issue
mauicv Jul 28, 2022
71142d7
Protect GaussianRBF with import optional
mauicv Jul 28, 2022
9c2da77
Skip keops tests on Windows, and keops notebook test. Fix backend val…
ascillitoe Jul 28, 2022
b126a6d
Skip keops kernel tests if not installed
ascillitoe Jul 28, 2022
48ad925
Add pykeops to op deps ERROR_TYPES
ascillitoe Jul 29, 2022
74ff992
Skip keops tests on MacOS
ascillitoe Jul 29, 2022
7c5e70d
Add note to docs about linux-only support for keops
ascillitoe Jul 29, 2022
fc12b9d
Add batch_size_permutations to pydantic models
ascillitoe Jul 29, 2022
f6b331b
remove print
arnaudvl Aug 9, 2022
ace20cc
remove unnecessary comment
arnaudvl Aug 9, 2022
718fb85
change default bandwidth fn to None
arnaudvl Aug 9, 2022
5922e3f
update infer sigma
arnaudvl Aug 9, 2022
b8adfbe
update test warning, update and clarify keops kernels logic
arnaudvl Aug 9, 2022
015cc5e
clean up
arnaudvl Aug 9, 2022
148019a
update docstring
arnaudvl Aug 9, 2022
7453368
fix bug
arnaudvl Aug 9, 2022
2d88bfc
undo unnecessary kwarg removal
arnaudvl Aug 10, 2022
54df257
make test consistent with torch/tf backends
arnaudvl Aug 10, 2022
211eeb9
add _mmd2 test
arnaudvl Aug 11, 2022
f98fd83
remove unused import
arnaudvl Aug 11, 2022
751d3a0
clarify docs, remove redundant framework checks
arnaudvl Aug 16, 2022
7c2d781
remove print
arnaudvl Aug 16, 2022
3f69740
update docs keops
arnaudvl Aug 16, 2022
ac5fe64
batched version of sigma_mean part 1
arnaudvl Aug 16, 2022
4ce018b
remove unused import
arnaudvl Aug 16, 2022
95634a1
update keops kernels test
arnaudvl Aug 16, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
if [ "$RUNNER_OS" != "Windows" ] && [ ${{ matrix.python }} < '3.10' ]; then # Skip Prophet tests on Windows as installation complex. Skip on Python 3.10 as not supported.
python -m pip install --upgrade --upgrade-strategy eager -e .[prophet]
fi
python -m pip install --upgrade --upgrade-strategy eager -e .[tensorflow,torch]
python -m pip install --upgrade --upgrade-strategy eager -e .[tensorflow,torch,keops]
python -m pip freeze

- name: Lint with flake8
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_all_notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
if [ "$RUNNER_OS" != "Windows" ] && [ ${{ matrix.python }} < '3.10' ]; then # Skip Prophet tests on Windows as installation complex. Skip on Python 3.10 as not supported.
python -m pip install --upgrade --upgrade-strategy eager -e .[prophet]
fi
python -m pip install --upgrade --upgrade-strategy eager -e .[torch,tensorflow]
python -m pip install --upgrade --upgrade-strategy eager -e .[torch,tensorflow,keops]
python -m pip freeze

- name: Run notebooks
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_changed_notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
if [ "$RUNNER_OS" != "Windows" ] && [ ${{ matrix.python }} < '3.10' ]; then # Skip Prophet tests on Windows as installation complex. Skip on Python 3.10 as not supported.
python -m pip install --upgrade --upgrade-strategy eager -e .[prophet]
fi
python -m pip install --upgrade --upgrade-strategy eager -e .[torch,tensorflow]
python -m pip install --upgrade --upgrade-strategy eager -e .[torch,tensorflow,keops]
python -m pip freeze

- name: Run notebooks
Expand Down
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ The package, `alibi-detect` can be installed from:
pip install git+https://github.com/SeldonIO/alibi-detect.git
```

- To install with the tensorflow backend:
- To install with the TensorFlow backend:
```bash
pip install alibi-detect[tensorflow]
```
Expand All @@ -88,6 +88,11 @@ The package, `alibi-detect` can be installed from:
pip install alibi-detect[torch]
```

- To install with the KeOps backend:
```bash
pip install alibi-detect[keops]
```

- To use the `Prophet` time series outlier detector:

```bash
Expand Down Expand Up @@ -180,8 +185,8 @@ The following tables show the advised use cases for each algorithm. The column *

#### TensorFlow and PyTorch support
arnaudvl marked this conversation as resolved.
Show resolved Hide resolved

The drift detectors support TensorFlow and PyTorch backends. Alibi Detect does not install these as default. See the
[installation options](#installation-and-usage) for more details.
The drift detectors support TensorFlow, PyTorch and (where applicable) [KeOps](https://www.kernel-operations.io/keops/index.html) backends.
However, Alibi Detect does not install these by default. See the [installation options](#installation-and-usage) for more details.

```python
from alibi_detect.cd import MMDDrift
Expand All @@ -197,6 +202,13 @@ cd = MMDDrift(x_ref, backend='pytorch', p_val=.05)
preds = cd.predict(x)
```

Or in KeOps:

```python
cd = MMDDrift(x_ref, backend='keops', p_val=.05)
preds = cd.predict(x)
```

#### Built-in preprocessing steps

Alibi Detect also comes with various preprocessing steps such as randomly initialized encoders, pretrained text
Expand Down
5 changes: 0 additions & 5 deletions alibi_detect/cd/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,11 +602,6 @@ def preprocess(self, x: Union[np.ndarray, list]) -> Tuple[np.ndarray, np.ndarray
else:
return self.x_ref, x # type: ignore[return-value]

@abstractmethod
def kernel_matrix(self, x: Union['torch.Tensor', 'tf.Tensor'], y: Union['torch.Tensor', 'tf.Tensor']) \
-> Union['torch.Tensor', 'tf.Tensor']:
pass

@abstractmethod
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]:
pass
Expand Down
Empty file.
179 changes: 179 additions & 0 deletions alibi_detect/cd/keops/mmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import logging
import numpy as np
from pykeops.torch import LazyTensor
import torch
from typing import Callable, Dict, List, Optional, Tuple, Union
from alibi_detect.cd.base import BaseMMDDrift
from alibi_detect.utils.keops.kernels import GaussianRBF
from alibi_detect.utils.pytorch import get_device

logger = logging.getLogger(__name__)


class MMDDriftKeops(BaseMMDDrift):
def __init__(
self,
x_ref: Union[np.ndarray, list],
p_val: float = .05,
x_ref_preprocessed: bool = False,
preprocess_at_init: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
kernel: Callable = GaussianRBF,
sigma: Optional[np.ndarray] = None,
configure_kernel_from_x_ref: bool = True,
n_permutations: int = 100,
batch_size_permutations: int = 1000000,
device: Optional[str] = None,
input_shape: Optional[tuple] = None,
data_type: Optional[str] = None
) -> None:
"""
Maximum Mean Discrepancy (MMD) data drift detector using a permutation test.

Parameters
----------
x_ref
Data used as reference distribution.
p_val
p-value used for the significance of the permutation test.
x_ref_preprocessed
Whether the given reference data `x_ref` has been preprocessed yet. If `x_ref_preprocessed=True`, only
the test data `x` will be preprocessed at prediction time. If `x_ref_preprocessed=False`, the reference
data will also be preprocessed.
preprocess_at_init
Whether to preprocess the reference data when the detector is instantiated. Otherwise, the reference
data will be preprocessed at prediction time. Only applies if `x_ref_preprocessed=False`.
update_x_ref
Reference data can optionally be updated to the last n instances seen by the detector
or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while
for reservoir sampling {'reservoir_sampling': n} is passed.
preprocess_fn
Function to preprocess the data before computing the data drift metrics.
kernel
Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
sigma
Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array.
The kernel evaluation is then averaged over those bandwidths.
configure_kernel_from_x_ref
Whether to already configure the kernel bandwidth from the reference data.
n_permutations
Number of permutations used in the permutation test.
batch_size_permutations
KeOps computes the n_permutations of the MMD^2 statistics in chunks of batch_size_permutations.
device
Device type used. The default None tries to use the GPU and falls back on CPU if needed.
Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
input_shape
Shape of input data.
data_type
Optionally specify the data type (tabular, image or time-series). Added to metadata.
"""
super().__init__(
x_ref=x_ref,
p_val=p_val,
x_ref_preprocessed=x_ref_preprocessed,
preprocess_at_init=preprocess_at_init,
update_x_ref=update_x_ref,
preprocess_fn=preprocess_fn,
sigma=sigma,
configure_kernel_from_x_ref=configure_kernel_from_x_ref,
n_permutations=n_permutations,
input_shape=input_shape,
data_type=data_type
)
self.meta.update({'backend': 'keops'})

# set device
self.device = get_device(device)

# initialize kernel
sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment]
np.ndarray) else None
self.kernel = kernel(sigma).to(self.device) if kernel == GaussianRBF else kernel

# set the correct MMD^2 function based on the batch size for the permutations
self.batch_size = batch_size_permutations
self.n_batches = 1 + (n_permutations - 1) // batch_size_permutations

# infer the kernel bandwidth from the reference data
if self.infer_sigma or isinstance(sigma, torch.Tensor):
x = torch.from_numpy(self.x_ref).to(self.device)
_ = self.kernel(LazyTensor(x[:, None, :]), LazyTensor(x[None, :, :]), infer_sigma=self.infer_sigma)
arnaudvl marked this conversation as resolved.
Show resolved Hide resolved
self.infer_sigma = False
else:
self.infer_sigma = True

def _mmd2(self, x_all: torch.Tensor, perms: List[torch.Tensor], m: int, n: int) \
-> Tuple[torch.Tensor, torch.Tensor]:
"""
Batched (across the permutations) MMD^2 computation for the original test statistic and the permutations.
arnaudvl marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
x_all
Concatenated reference and test instances.
perms
List with permutation vectors.
m
Number of reference instances.
n
Number of test instances.

Returns
-------
MMD^2 statistic for the original and permuted reference and test sets.
"""
k_xx, k_yy, k_xy = [], [], []
for batch in range(self.n_batches):
i, j = batch * self.batch_size, (batch + 1) * self.batch_size
# construct stacked tensors with a batch of permutations for the reference set x and test set y
x = torch.cat([x_all[perm[:m]][None, :, :] for perm in perms[i:j]], 0)
y = torch.cat([x_all[perm[m:]][None, :, :] for perm in perms[i:j]], 0)
if batch == 0:
x = torch.cat([x_all[None, :m, :], x], 0)
y = torch.cat([x_all[None, m:, :], y], 0)
x, y = x.to(self.device), y.to(self.device)

# batch-wise kernel matrix computation over the permutations
k_xx.append(self.kernel(
LazyTensor(x[:, :, None, :]), LazyTensor(x[:, None, :, :])).sum(1).sum(1).squeeze(-1))
arnaudvl marked this conversation as resolved.
Show resolved Hide resolved
k_yy.append(self.kernel(
LazyTensor(y[:, :, None, :]), LazyTensor(y[:, None, :, :])).sum(1).sum(1).squeeze(-1))
k_xy.append(self.kernel(
LazyTensor(x[:, :, None, :]), LazyTensor(y[:, None, :, :])).sum(1).sum(1).squeeze(-1))
c_xx, c_yy, c_xy = 1 / (m * (m - 1)), 1 / (n * (n - 1)), 2. / (m * n)
stats = c_xx * (torch.cat(k_xx) - m) + c_yy * (torch.cat(k_yy) - n) - c_xy * torch.cat(k_xy)
return stats[0], stats[1:]

def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, float]:
"""
Compute the p-value resulting from a permutation test using the maximum mean discrepancy
as a distance measure between the reference data and the data to be tested.

Parameters
----------
x
Batch of instances.

Returns
-------
p-value obtained from the permutation test, the MMD^2 between the reference and test set,
and the MMD^2 threshold above which drift is flagged.
"""
x_ref, x = self.preprocess(x)
x_ref = torch.from_numpy(x_ref).float() # type: ignore[assignment]
x = torch.from_numpy(x).float() # type: ignore[assignment]
arnaudvl marked this conversation as resolved.
Show resolved Hide resolved
# compute kernel matrix, MMD^2 and apply permutation test
m, n = x_ref.shape[0], x.shape[0]
perms = [torch.randperm(m + n) for _ in range(self.n_permutations)]
# TODO - Rethink typings (related to https://github.com/SeldonIO/alibi-detect/issues/540)
x_all = torch.cat([x_ref, x], 0) # type: ignore[list-item]
mmd2, mmd2_permuted = self._mmd2(x_all, perms, m, n)
if self.device.type == 'cuda':
mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu()
p_val = (mmd2 <= mmd2_permuted).float().mean()
# compute distance threshold
idx_threshold = int(self.p_val * len(mmd2_permuted))
distance_threshold = torch.sort(mmd2_permuted, descending=True).values[idx_threshold]
return p_val.numpy().item(), mmd2.numpy().item(), distance_threshold.numpy()
97 changes: 97 additions & 0 deletions alibi_detect/cd/keops/tests/test_mmd_keops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from functools import partial
from itertools import product
import numpy as np
import pytest
import torch
import torch.nn as nn
from typing import Callable, List
from alibi_detect.cd.keops.mmd import MMDDriftKeops
from alibi_detect.cd.pytorch.preprocess import HiddenOutput, preprocess_drift

n, n_hidden, n_classes = 500, 10, 5


class MyModel(nn.Module):
def __init__(self, n_features: int):
super().__init__()
self.dense1 = nn.Linear(n_features, 20)
self.dense2 = nn.Linear(20, 2)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = nn.ReLU()(self.dense1(x))
return self.dense2(x)


# test List[Any] inputs to the detector
def preprocess_list(x: List[np.ndarray]) -> np.ndarray:
return np.concatenate(x, axis=0)


n_features = [10]
n_enc = [None, 3]
preprocess = [
(None, None),
(preprocess_drift, {'model': HiddenOutput, 'layer': -1}),
(preprocess_list, None)
]
preprocess_at_init = [True, False]
n_permutations = [10]
batch_size_permutations = [10, 1000000]
configure_kernel_from_x_ref = [True, False]
tests_mmddrift = list(product(n_features, n_enc, preprocess, n_permutations, preprocess_at_init,
batch_size_permutations, configure_kernel_from_x_ref))
n_tests = len(tests_mmddrift)
arnaudvl marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture
def mmd_params(request):
return tests_mmddrift[request.param]


@pytest.mark.parametrize('mmd_params', list(range(n_tests)), indirect=True)
def test_mmd(mmd_params):
n_features, n_enc, preprocess, n_permutations, preprocess_at_init, \
batch_size_permutations, configure_kernel_from_x_ref = mmd_params

np.random.seed(0)
torch.manual_seed(0)

x_ref = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32)
preprocess_fn, preprocess_kwargs = preprocess
to_list = False
if hasattr(preprocess_fn, '__name__') and preprocess_fn.__name__ == 'preprocess_list':
if not preprocess_at_init:
return
to_list = True
x_ref = [_[None, :] for _ in x_ref]
elif isinstance(preprocess_fn, Callable) and 'layer' in list(preprocess_kwargs.keys()) \
and preprocess_kwargs['model'].__name__ == 'HiddenOutput':
model = MyModel(n_features)
layer = preprocess_kwargs['layer']
preprocess_fn = partial(preprocess_fn, model=HiddenOutput(model=model, layer=layer))
else:
preprocess_fn = None

cd = MMDDriftKeops(
x_ref=x_ref,
p_val=.05,
preprocess_at_init=preprocess_at_init if isinstance(preprocess_fn, Callable) else False,
preprocess_fn=preprocess_fn,
configure_kernel_from_x_ref=configure_kernel_from_x_ref,
n_permutations=n_permutations,
batch_size_permutations=batch_size_permutations
)
x = x_ref.copy()
preds = cd.predict(x, return_p_val=True)
assert preds['data']['is_drift'] == 0 and preds['data']['p_val'] >= cd.p_val

x_h1 = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32)
if to_list:
x_h1 = [_[None, :] for _ in x_h1]
preds = cd.predict(x_h1, return_p_val=True)
if preds['data']['is_drift'] == 1:
assert preds['data']['p_val'] < preds['data']['threshold'] == cd.p_val
assert preds['data']['distance'] > preds['data']['distance_threshold']
else:
assert preds['data']['p_val'] >= preds['data']['threshold'] == cd.p_val
assert preds['data']['distance'] <= preds['data']['distance_threshold']
Loading