Skip to content

Commit

Permalink
Add support for online drift detectors (#1108)
Browse files Browse the repository at this point in the history
  • Loading branch information
ascillitoe authored and Adrian Gonzalez-Martin committed Apr 20, 2023
1 parent 02cda70 commit 3ba89e7
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 18 deletions.
90 changes: 77 additions & 13 deletions runtimes/alibi-detect/mlserver_alibi_detect/runtime.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import numpy as np

from pydantic.error_wrappers import ValidationError
from typing import Optional, List
from pydantic import BaseSettings
from typing import Optional, List, Dict
from pydantic import BaseSettings, Field
from functools import cached_property

from alibi_detect.saving import load_detector
Expand All @@ -14,6 +15,7 @@
from mlserver.codecs import NumpyCodec, NumpyRequestCodec
from mlserver.utils import get_model_uri
from mlserver.errors import MLServerError, InferenceError
from mlserver.logging import logger

ENV_PREFIX_ALIBI_DETECT_SETTINGS = "MLSERVER_MODEL_ALIBI_DETECT_"

Expand Down Expand Up @@ -41,6 +43,12 @@ class Config:
inference runs for all of them).
"""

state_save_freq: Optional[int] = Field(100, gt=0)
"""
Save the detector state after every `state_save_freq` predictions.
Only applicable to detectors with a `save_state` method.
"""


class AlibiDetectRuntime(MLModel):
"""
Expand All @@ -58,9 +66,11 @@ def __init__(self, settings: ModelSettings):
super().__init__(settings)

async def load(self) -> bool:
model_uri = await get_model_uri(self._settings)
self._model_uri = await get_model_uri(self._settings)
try:
self._model = load_detector(model_uri)
self._model = load_detector(self._model_uri)
# Check whether an online drift detector (i.e. has a save_state method)
self._online = True if hasattr(self._model, "save_state") else False
except (
ValueError,
FileNotFoundError,
Expand All @@ -76,7 +86,7 @@ async def load(self) -> bool:

async def predict(self, payload: InferenceRequest) -> InferenceResponse:
# If batch is not configured, run the detector and return the output
if not self._ad_settings.batch_size:
if self._online or not self._ad_settings.batch_size:
return self._detect(payload)

if len(self._batch) < self._ad_settings.batch_size:
Expand Down Expand Up @@ -105,19 +115,32 @@ def _detect(self, payload: InferenceRequest) -> InferenceResponse:
input_data = self.decode_request(payload, default_codec=NumpyRequestCodec)
predict_kwargs = self._ad_settings.predict_parameters

try:
y = self._model.predict(np.array(input_data), **predict_kwargs)
return self._encode_response(y)
except (ValueError, IndexError) as e:
raise InferenceError(
f"Invalid predict parameters for model {self._settings.name}: {e}"
) from e
# If batch is configured or X has length 1, wrap X in a list to avoid unpacking
X = np.array(input_data)
if not self._online or len(input_data) == 1:
X = [X] # type: ignore[assignment]

# Run detector inference
pred = []
for x in X:
# Prediction
try:
pred.append(self._model.predict(x, **predict_kwargs))
except (ValueError, IndexError) as e:
raise InferenceError(
f"Invalid predict parameters for model {self._settings.name}: {e}"
) from e
# Save state if necessary
if self._should_save_state:
self._save_state()

return self._encode_response(self._postproc_pred(pred))

def _encode_response(self, y: dict) -> InferenceResponse:
outputs = []
for key in y["data"]:
outputs.append(
NumpyCodec.encode_output(name=key, payload=np.array([y["data"][key]]))
NumpyCodec.encode_output(name=key, payload=np.array(y["data"][key]))
)

# Add headers
Expand All @@ -132,6 +155,47 @@ def _encode_response(self, y: dict) -> InferenceResponse:
outputs=outputs,
)

@staticmethod
def _postproc_pred(pred: List[dict]) -> dict:
"""
Postprocess the detector's prediction(s) to return a single results dictionary.
- If a single instance (or batch of instances) was run, the predictions will be
a length 1 list containing one dictionary, which is returned as is.
- If N instances were run in an online fashion, the predictions will be a
length N list of results dictionaries, which are merged into a single
dictionary containing data lists of length N.
"""
data: Dict[str, list] = {key: [] for key in pred[0]["data"].keys()}
for i, pred_i in enumerate(pred):
for key in data:
data[key].append(pred_i["data"][key])
y = {"data": data, "meta": pred[0]["meta"]}
return y

@property
def _should_save_state(self) -> bool:
return (
self._online
and self._model.t % self._ad_settings.state_save_freq == 0
and self._model.t > 0
)

def _save_state(self) -> None:
# The detector should have a save_state method, but double-check...
if hasattr(self._model, "save_state"):
try:
self._model.save_state(os.path.join(self._model_uri, "state"))
except Exception as e:
raise MLServerError(
f"Error whilst attempting to save state for model "
f"{self._settings.name}: {e}"
) from e
else:
logger.warning(
"Attempting to save state but detector doesn't have save_state method."
)

@cached_property
def alibi_method(self) -> str:
module: str = type(self._model).__module__
Expand Down
44 changes: 43 additions & 1 deletion runtimes/alibi-detect/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tensorflow as tf

from tensorflow.keras.layers import Dense, InputLayer
from alibi_detect.cd import TabularDrift
from alibi_detect.cd import TabularDrift, CVMDriftOnline
from alibi_detect.od import OutlierVAE
from alibi_detect.saving import save_detector

Expand All @@ -18,6 +18,8 @@
tf.keras.backend.clear_session()

P_VAL_THRESHOLD = 0.05
ERT = 50
WINDOW_SIZES = [10]

TESTS_PATH = os.path.dirname(__file__)
TESTDATA_PATH = os.path.join(TESTS_PATH, "testdata")
Expand Down Expand Up @@ -125,6 +127,24 @@ def drift_detector_settings(
)


@pytest.fixture
def online_drift_detector_settings(
online_drift_detector_uri: str,
) -> ModelSettings:
return ModelSettings(
name="alibi-detect-model",
implementation=AlibiDetectRuntime,
parameters=ModelParameters(
uri=online_drift_detector_uri,
version="v1.2.3",
extra={
"batch_size": 50,
"state_save_freq": 10,
}, # spec batch_size to check that it is ignored
),
)


@pytest.fixture
def drift_detector_uri(tmp_path: str) -> str:
X_ref = np.array([[1, 2, 3]])
Expand All @@ -137,9 +157,31 @@ def drift_detector_uri(tmp_path: str) -> str:
return detector_uri


@pytest.fixture
def online_drift_detector_uri(tmp_path: str) -> str:
X_ref = np.ones((10, 3))

cd = CVMDriftOnline(X_ref, ert=ERT, window_sizes=WINDOW_SIZES)

detector_uri = os.path.join(tmp_path, "alibi-detector-artifacts")
save_detector(cd, detector_uri)

return detector_uri


@pytest.fixture
async def drift_detector(drift_detector_settings: ModelSettings) -> AlibiDetectRuntime:
model = AlibiDetectRuntime(drift_detector_settings)
model.ready = await model.load()

return model


@pytest.fixture
async def online_drift_detector(
online_drift_detector_settings: ModelSettings,
) -> AlibiDetectRuntime:
model = AlibiDetectRuntime(online_drift_detector_settings)
model.ready = await model.load()

return model
70 changes: 66 additions & 4 deletions runtimes/alibi-detect/tests/test_drift_detector.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import numpy as np

from alibi_detect.cd import TabularDrift
from alibi_detect.cd import TabularDrift, CVMDriftOnline

from mlserver.types import InferenceRequest
from mlserver.codecs import NumpyRequestCodec
from mlserver.types import InferenceRequest, Parameters, RequestInput

from mlserver.codecs import NumpyCodec, NumpyRequestCodec

from mlserver_alibi_detect import AlibiDetectRuntime

from .conftest import P_VAL_THRESHOLD
from .conftest import P_VAL_THRESHOLD, ERT, WINDOW_SIZES


async def test_load_folder(
Expand All @@ -17,6 +18,13 @@ async def test_load_folder(
assert type(drift_detector._model) == TabularDrift


async def test_load_folder_online(
online_drift_detector: AlibiDetectRuntime,
):
assert online_drift_detector.ready
assert type(online_drift_detector._model) == CVMDriftOnline


async def test_predict(
drift_detector: AlibiDetectRuntime,
inference_request: InferenceRequest,
Expand Down Expand Up @@ -79,3 +87,57 @@ async def test_predict_batch_cleared(
# Batch should now be cleared (and started from scratch)
response = await drift_detector.predict(inference_request)
assert len(response.outputs) == 0


async def test_predict_online(
online_drift_detector: AlibiDetectRuntime,
inference_request: InferenceRequest,
):
# Test a request of length 1
response = await online_drift_detector.predict(inference_request)
assert len(response.outputs) == 7
assert response.outputs[0].name == "is_drift"
assert response.outputs[0].shape == [1, 1]
assert response.outputs[1].name == "distance"
assert response.outputs[2].name == "p_val"
assert response.outputs[3].name == "threshold"
assert response.outputs[4].name == "time"
assert response.outputs[4].data[0] == 1
assert response.outputs[5].name == "ert"
assert response.outputs[5].data[0] == ERT
assert response.outputs[6].name == "test_stat"
assert response.outputs[6].shape == [1, 1, 3]


async def test_predict_batch_online(online_drift_detector: AlibiDetectRuntime):
# Send a batch request, the drift detector should run on one instance at a time
batch_size = 50
data = np.random.normal(size=(batch_size, 3))
inference_request = InferenceRequest(
parameters=Parameters(content_type=NumpyRequestCodec.ContentType),
inputs=[
RequestInput(
name="predict",
shape=data.shape,
data=data.tolist(),
datatype="FP32",
)
],
)
response = await online_drift_detector.predict(inference_request)
assert len(response.outputs) == 7
assert response.outputs[0].name == "is_drift"
assert response.outputs[0].shape == [50, 1]
assert response.outputs[1].name == "distance"
assert response.outputs[2].name == "p_val"
assert response.outputs[3].name == "threshold"
assert response.outputs[4].name == "time"
assert response.outputs[4].data[-1] == 50
assert response.outputs[5].name == "ert"
assert response.outputs[5].data[0] == ERT
assert response.outputs[6].name == "test_stat"
assert response.outputs[6].shape == [50, 1, 3]
# Test stat should be NaN until the test window is filled
test_stats = NumpyCodec.decode_output(response.outputs[6])
assert np.isnan(test_stats[0]).all()
assert not np.isnan(test_stats[WINDOW_SIZES[0]]).all()
23 changes: 23 additions & 0 deletions runtimes/alibi-detect/tests/test_runtime.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pytest

from mlserver.codecs import CodecError
Expand All @@ -16,3 +17,25 @@ async def test_multiple_inputs_error(

with pytest.raises(CodecError):
await outlier_detector.predict(inference_request)


async def test_saving_state(
online_drift_detector: AlibiDetectRuntime,
inference_request: InferenceRequest,
):
save_freq = online_drift_detector._ad_settings.state_save_freq
state_uri = os.path.join(online_drift_detector._model_uri, "state")

# Check nothing written after (save_freq -1) requests
for _ in range(save_freq - 1): # type: ignore
await online_drift_detector.predict(inference_request)
assert not os.path.isdir(state_uri)

# Check state written after (save_freq) requests
await online_drift_detector.predict(inference_request)
assert os.path.isdir(state_uri)

# Check state properly loaded in new runtime
new_online_drift_detector = AlibiDetectRuntime(online_drift_detector.settings)
await new_online_drift_detector.load()
assert new_online_drift_detector._model.t == save_freq

0 comments on commit 3ba89e7

Please sign in to comment.