Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RandomBrightness, Enhance IndexLookup Initialization and Expand Test Coverage for Preprocessing Layers #19513

Merged
merged 10 commits into from
Apr 16, 2024
68 changes: 68 additions & 0 deletions keras/layers/preprocessing/category_encoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,71 @@ def test_tf_data_compatibility(self):
for output in ds.take(1):
output = output.numpy()
self.assertAllClose(output, expected_output)

def test_category_encoding_without_num_tokens(self):
with self.assertRaisesRegex(
ValueError, r"num_tokens must be set to use this layer"
):
layers.CategoryEncoding(output_mode="multi_hot")

def test_category_encoding_with_invalid_num_tokens(self):
with self.assertRaisesRegex(ValueError, r"`num_tokens` must be >= 1"):
layers.CategoryEncoding(num_tokens=0, output_mode="multi_hot")

with self.assertRaisesRegex(ValueError, r"`num_tokens` must be >= 1"):
layers.CategoryEncoding(num_tokens=-1, output_mode="multi_hot")

def test_category_encoding_with_unnecessary_count_weights(self):
layer = layers.CategoryEncoding(num_tokens=4, output_mode="multi_hot")
input_data = np.array([0, 1, 2, 3])
count_weights = np.array([0.1, 0.2, 0.3, 0.4])
with self.assertRaisesRegex(
ValueError, r"`count_weights` is not used when `output_mode`"
):
layer(input_data, count_weights=count_weights)

def test_invalid_output_mode_raises_error(self):
with self.assertRaisesRegex(
ValueError, r"Unknown arg for output_mode: invalid_mode"
):
layers.CategoryEncoding(num_tokens=4, output_mode="invalid_mode")

def test_encode_one_hot_single_sample(self):
layer = layers.CategoryEncoding(num_tokens=4, output_mode="one_hot")
input_array = np.array([1, 2, 3, 1])
expected_output = np.array(
[
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
[0, 1, 0, 0],
]
)
output = layer._encode(input_array)
self.assertAllClose(expected_output, output)

def test_encode_one_hot_batched_samples(self):
layer = layers.CategoryEncoding(num_tokens=4, output_mode="one_hot")
input_array = np.array([[3, 2, 0, 1], [3, 2, 0, 1]])
expected_output = np.array(
[
[[0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0]],
[[0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0]],
]
)
output = layer._encode(input_array)
self.assertAllClose(expected_output, output)

def test_count_single_sample(self):
layer = layers.CategoryEncoding(num_tokens=4, output_mode="count")
input_array = np.array([1, 2, 3, 1])
expected_output = np.array([0, 2, 1, 1])
output = layer._count(input_array)
self.assertAllClose(expected_output, output)

def test_count_batched_samples(self):
layer = layers.CategoryEncoding(num_tokens=4, output_mode="count")
input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])
expected_output = np.array([[0, 2, 1, 1], [2, 1, 0, 1]])
output = layer._count(input_array)
self.assertAllClose(expected_output, output)
37 changes: 37 additions & 0 deletions keras/layers/preprocessing/hashing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,43 @@ def test_hash_list_input(self, input_data, expected):
expected, backend.convert_to_numpy(out_data).tolist()
)

def test_hashing_invalid_num_bins(self):
# Test with `num_bins` set to None
with self.assertRaisesRegex(
ValueError,
"The `num_bins` for `Hashing` cannot be `None` or non-positive",
):
layers.Hashing(num_bins=None)

# Test with `num_bins` set to 0
with self.assertRaisesRegex(
ValueError,
"The `num_bins` for `Hashing` cannot be `None` or non-positive",
):
layers.Hashing(num_bins=0)

def test_hashing_invalid_output_mode(self):
# Test with an unsupported `output_mode`
with self.assertRaisesRegex(
ValueError,
"Invalid value for argument `output_mode`. Expected one of",
):
layers.Hashing(num_bins=3, output_mode="unsupported_mode")

def test_hashing_invalid_dtype_for_int_mode(self):
with self.assertRaisesRegex(
ValueError,
'When `output_mode="int"`, `dtype` should be an integer type,',
):
layers.Hashing(num_bins=3, output_mode="int", dtype="float32")

def test_hashing_sparse_with_int_mode(self):
# Test setting `sparse=True` with `output_mode='int'`
with self.assertRaisesRegex(
ValueError, "`sparse` may only be true if `output_mode` is"
):
layers.Hashing(num_bins=3, output_mode="int", sparse=True)


# TODO: support tf.RaggedTensor.
# def test_hash_ragged_string_input_farmhash(self):
Expand Down
10 changes: 5 additions & 5 deletions keras/layers/preprocessing/index_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ class IndexLookup(Layer):

def __init__(
self,
max_tokens,
num_oov_indices,
mask_token,
oov_token,
vocabulary_dtype,
mask_token=None,
oov_token="[OOV]",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized -- the reason there were no default args in this layer is because there should be no defaults. This layer does both strings and ints, and it is only used internally. We can't provide defaults if we don't know vocabulary_dtype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will keep it without changes
Thank you for your feedback, @fchollet

vocabulary_dtype="string",
num_oov_indices=1,
max_tokens=None,
Faisal-Alsrheed marked this conversation as resolved.
Show resolved Hide resolved
vocabulary=None,
idf_weights=None,
invert=False,
Expand Down
180 changes: 180 additions & 0 deletions keras/layers/preprocessing/index_lookup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,183 @@ def test_adapt_with_tf_data(self):
self.assertEqual(list(output), [2, 3, 1])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)

def test_max_tokens_less_than_two(self):
with self.assertRaisesRegex(
ValueError,
"If set, `max_tokens` must be greater than 1.",
):
layers.IndexLookup(
max_tokens=1,
num_oov_indices=1,
mask_token=None,
oov_token=None,
vocabulary_dtype="int64",
)

def test_max_tokens_none_with_pad_to_max_tokens(self):
with self.assertRaisesRegex(
ValueError,
"If pad_to_max_tokens is True, must set `max_tokens`.",
):
layers.IndexLookup(
num_oov_indices=1,
mask_token=None,
oov_token=None,
vocabulary_dtype="int64",
pad_to_max_tokens=True,
)

def test_negative_num_oov_indices(self):
with self.assertRaisesRegex(
ValueError,
"`num_oov_indices` must be greater than or equal to 0.",
):
layers.IndexLookup(
max_tokens=10,
num_oov_indices=-1,
mask_token=None,
oov_token=None,
vocabulary_dtype="int64",
)

def test_invert_with_non_int_output_mode(self):
with self.assertRaisesRegex(
ValueError, r"`output_mode` must be `'int'` when `invert` is true."
):
layers.IndexLookup(
num_oov_indices=1,
mask_token=None,
oov_token=None,
vocabulary_dtype="string",
invert=True,
output_mode="one_hot", # Invalid combination
)

def test_sparse_true_with_int_output_mode(self):
with self.assertRaisesRegex(
ValueError,
r"`sparse` may only be true if `output_mode` is `'one_hot'`",
):
layers.IndexLookup(
num_oov_indices=1,
mask_token=None,
oov_token=None,
vocabulary_dtype="string",
sparse=True,
output_mode="int", # Invalid combination
)

def test_idf_weights_set_with_non_tfidf_output_mode(self):
with self.assertRaisesRegex(
ValueError,
r"`idf_weights` should only be set if `output_mode` is `'tf_idf'`",
):
layers.IndexLookup(
num_oov_indices=1,
mask_token=None,
oov_token=None,
vocabulary_dtype="string",
idf_weights=[
0.5,
0.1,
0.3,
], # Should not be set for non-TF-IDF modes
output_mode="int",
)

def test_unrecognized_kwargs(self):
with self.assertRaisesRegex(
ValueError, "Unrecognized keyword argument"
):
layers.IndexLookup(
num_oov_indices=1,
mask_token=None,
oov_token=None,
vocabulary_dtype="string",
output_mode="int",
# This is an unrecognized argument
extra_arg=True,
)

def test_non_tf_idf_with_idf_weights(self):
with self.assertRaisesRegex(
ValueError,
"`idf_weights` should only be set if `output_mode` is",
):
layers.IndexLookup(
num_oov_indices=1,
mask_token=None,
oov_token=None,
vocabulary_dtype="string",
output_mode="multi_hot",
idf_weights=[
0.5,
0.1,
0.3,
], # idf_weights not valid for multi_hot mode
)

def test_vocabulary_file_does_not_exist(self):
with self.assertRaisesRegex(
ValueError,
"Vocabulary file path/to/missing_vocab.txt does not exist",
):
layers.IndexLookup(
num_oov_indices=1,
mask_token=None,
oov_token=None,
vocabulary_dtype="string",
output_mode="int",
# Nonexistent file path
vocabulary="path/to/missing_vocab.txt",
)

def test_repeated_tokens_in_vocabulary(self):
with self.assertRaisesRegex(
ValueError, "The passed vocabulary has at least one repeated term."
):
layers.IndexLookup(
num_oov_indices=1,
mask_token=None,
oov_token=None,
vocabulary_dtype="string",
vocabulary=["token", "token", "unique"],
)

def test_mask_token_in_wrong_position(self):
with self.assertRaisesRegex(
ValueError,
"Found reserved mask token at unexpected location in `vocabulary`.",
):
layers.IndexLookup(
num_oov_indices=1,
mask_token="mask",
oov_token=None,
vocabulary_dtype="string",
vocabulary=[
"token",
"mask",
"unique",
], # 'mask' should be at the start if included explicitly
)

def test_ensure_known_vocab_size_without_vocabulary(self):
kwargs = {
"num_oov_indices": 1,
# Assume empty string or some default token is valid.
"mask_token": "",
# Assume [OOV] or some default token is valid.
"oov_token": "[OOV]",
"output_mode": "multi_hot",
"pad_to_max_tokens": False,
"vocabulary_dtype": "string",
}
layer = layers.IndexLookup(**kwargs)

# Try calling the layer without setting the vocabulary.
with self.assertRaisesRegex(
RuntimeError, "When using `output_mode=multi_hot` and"
):
input_data = ["sample", "data"]
layer(input_data)
45 changes: 41 additions & 4 deletions keras/layers/preprocessing/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ def test_normalization_adapt(self, input_type):
self.assertAllClose(np.var(output, axis=(0, 3)), 1.0, atol=1e-5)
self.assertAllClose(np.mean(output, axis=(0, 3)), 0.0, atol=1e-5)

def test_normalization_errors(self):
# TODO
pass

@pytest.mark.skipif(
backend.backend() != "torch",
reason="Test symbolic call for torch meta device.",
Expand All @@ -107,3 +103,44 @@ def test_call_on_meta_device_after_built(self):
layer.adapt(data)
with core.device_scope("meta"):
layer(data)

def test_normalization_with_mean_only_raises_error(self):
# Test error when only `mean` is provided
with self.assertRaisesRegex(
ValueError, "both `mean` and `variance` must be set"
):
layers.Normalization(mean=0.5)

def test_normalization_with_variance_only_raises_error(self):
# Test error when only `variance` is provided
with self.assertRaisesRegex(
ValueError, "both `mean` and `variance` must be set"
):
layers.Normalization(variance=0.1)

def test_normalization_axis_too_high(self):
with self.assertRaisesRegex(
ValueError, "All `axis` values must be in the range"
):
layer = layers.Normalization(axis=3)
layer.build((2, 2))

def test_normalization_axis_too_low(self):
with self.assertRaisesRegex(
ValueError, "All `axis` values must be in the range"
):
layer = layers.Normalization(axis=-4)
layer.build((2, 3, 4))

def test_normalization_unknown_axis_shape(self):
with self.assertRaisesRegex(ValueError, "All `axis` values to be kept"):
layer = layers.Normalization(axis=1)
layer.build((None, None))

def test_normalization_adapt_with_incompatible_shape(self):
layer = layers.Normalization(axis=-1)
initial_shape = (10, 5)
layer.build(initial_shape)
new_shape_data = np.random.random((10, 3))
with self.assertRaisesRegex(ValueError, "an incompatible shape"):
layer.adapt(new_shape_data)
4 changes: 2 additions & 2 deletions keras/layers/preprocessing/random_brightness.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def __init__(self, factor, value_range=(0, 255), seed=None, **kwargs):
def _set_value_range(self, value_range):
if not isinstance(value_range, (tuple, list)):
raise ValueError(
self.value_range_VALIDATION_ERROR
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
if len(value_range) != 2:
raise ValueError(
self.value_range_VALIDATION_ERROR
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
self.value_range = sorted(value_range)
Expand Down
Loading
Loading