Skip to content

Commit

Permalink
Addressed PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertSamoilescu committed Dec 11, 2024
1 parent 5c73414 commit 4b35346
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 13 deletions.
6 changes: 3 additions & 3 deletions alibi_detect/cd/tensorflow/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tensorflow as tf

from alibi_detect.utils.tensorflow.prediction import (
predict_batch, predict_batch_transformer, get_named_arg
predict_batch, predict_batch_transformer, get_call_arg_mapping
)
from tensorflow.keras.layers import Dense, Flatten, Input, Lambda
from tensorflow.keras.models import Model
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(

def call(self, x: Union[np.ndarray, tf.Tensor, Dict[str, tf.Tensor]]) -> tf.Tensor:
if not isinstance(x, (np.ndarray, tf.Tensor)):
x = get_named_arg(self.input_layer, x)
x = get_call_arg_mapping(self.input_layer, x)
x = self.input_layer(**x)
else:
x = self.input_layer(x)
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(

def call(self, x: Union[np.ndarray, tf.Tensor, Dict[str, tf.Tensor]]) -> tf.Tensor:
if not isinstance(x, (np.ndarray, tf.Tensor)):
x = get_named_arg(self.encoder, x)
x = get_call_arg_mapping(self.encoder, x)
return self.encoder(**x)
else:
return self.encoder(x)
Expand Down
3 changes: 0 additions & 3 deletions alibi_detect/saving/_tensorflow/tests/test_saving_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
backend = param_fixture("backend", ['tensorflow'])


# Note: The full save/load functionality of optimizers (inc. validation) is tested in test_save_classifierdrift.
@pytest.mark.skipif(version.parse(tf.__version__) < version.parse('2.16.0'),
reason="Skipping since tensorflow < 2.16.0")
def test_load_optimizer_object_tf2pt11(backend):
"""
Test the _load_optimizer_config with a tensorflow optimizer config. Only run if tensorflow>=2.16.
Expand Down
14 changes: 9 additions & 5 deletions alibi_detect/utils/tensorflow/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
from alibi_detect.utils.prediction import tokenize_transformer


def get_named_arg(model: tf.keras.Model, x: Any) -> Dict[str, Any]:
""" Extract argument names from the model call function
because keras3 does not accept other types of input
as a positional argument.
def get_call_arg_mapping(model: tf.keras.Model, x: Any) -> Dict[str, Any]:
""" Generates a dictionary mapping the first argument name of the
`call` method of a Keras model to the provided input value.
This function is particularly useful when working with Keras 3,
which enforces stricter input handling and requires named arguments
for certain operations. It extracts the argument names from the
`call` method of the provided model and maps the first argument to `x`.
Parameters
----------
Expand Down Expand Up @@ -67,7 +71,7 @@ def predict_batch(
x_batch = preprocess_fn(x_batch)

if not isinstance(x_batch, (np.ndarray, tf.Tensor)):
x_batch = get_named_arg(model, x_batch)
x_batch = get_call_arg_mapping(model, x_batch)
preds_tmp = model(**x_batch)
else:
preds_tmp = model(x_batch)
Expand Down
2 changes: 0 additions & 2 deletions alibi_detect/utils/tests/test_saving_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,6 @@ def test_save_load(select_detector):


# Note: The full save/load functionality of optimizers (inc. validation) is tested in test_save_classifierdrift.
@pytest.mark.skipif(version.parse(tf.__version__) < version.parse('2.16.0'),
reason="Skipping since tensorflow < 2.16.0")
@parametrize('legacy', [True, False])
def test_load_optimizer_object_tf2pt11(legacy, backend):
"""
Expand Down

0 comments on commit 4b35346

Please sign in to comment.