Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect usage of quantize #19541

Merged
merged 3 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions keras/src/layers/core/dense.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import ml_dtypes

from keras.src import activations
from keras.src import backend
from keras.src import constraints
from keras.src import dtype_policies
from keras.src import initializers
Expand Down Expand Up @@ -347,6 +346,7 @@ def _int8_build(
initializer=kernel_scale_initializer,
trainable=False,
)
self._is_quantized = True

def _float8_build(self):
if not isinstance(
Expand Down Expand Up @@ -396,6 +396,7 @@ def _float8_build(self):
self.kernel_amax_history.overwrite_with_gradient = True
self.outputs_grad_scale.overwrite_with_gradient = True
self.outputs_grad_amax_history.overwrite_with_gradient = True
self._is_quantized = True

def quantized_call(self, inputs):
if self.dtype_policy.quantization_mode == "int8":
Expand Down Expand Up @@ -552,8 +553,6 @@ def quantize(self, mode):

self._tracker.unlock()
if mode == "int8":
if backend.standardize_dtype(self._kernel.dtype) == "int8":
raise ValueError("`quantize` can only be done once per layer.")
# Configure `self.inputs_quantizer`
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
# Quantize `self._kernel` to int8 and compute corresponding scale
Expand All @@ -572,8 +571,6 @@ def quantize(self, mode):
lambda shape, dtype: kernel_scale,
)
elif mode == "float8":
if hasattr(self, "inputs_amax_history"):
raise ValueError("`quantize` can only be done once per layer.")
self._float8_build()
else:
raise NotImplementedError(
Expand Down
22 changes: 18 additions & 4 deletions keras/src/layers/core/dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,24 @@ def test_quantize_when_already_quantized(self, mode):
layer = layers.Dense(units=2)
layer.build((None, 2))
layer.quantize(mode)
with self.assertRaisesRegex(
ValueError, "`quantize` can only be done once per layer."
):
layer.quantize(mode)
for m in ["int8", "float8"]:
with self.assertRaisesRegex(
ValueError, "is already quantized with dtype_policy="
):
layer.quantize(m)

@parameterized.named_parameters(
("int8", "int8_from_float32"),
("float8", "float8_from_float32"),
)
def test_quantize_when_already_quantized_using_dtype_argument(self, mode):
layer = layers.Dense(units=2, dtype=mode)
layer.build((None, 2))
for m in ["int8", "float8"]:
with self.assertRaisesRegex(
ValueError, "is already quantized with dtype_policy="
):
layer.quantize(m)

@parameterized.named_parameters(
("int8", "int8_from_float32", 3),
Expand Down
7 changes: 2 additions & 5 deletions keras/src/layers/core/einsum_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np

from keras.src import activations
from keras.src import backend
from keras.src import constraints
from keras.src import dtype_policies
from keras.src import initializers
Expand Down Expand Up @@ -431,6 +430,7 @@ def _int8_build(
initializer=kernel_scale_initializer,
trainable=False,
)
self._is_quantized = True

def _float8_build(self):
if not isinstance(
Expand Down Expand Up @@ -480,6 +480,7 @@ def _float8_build(self):
self.kernel_amax_history.overwrite_with_gradient = True
self.outputs_grad_scale.overwrite_with_gradient = True
self.outputs_grad_amax_history.overwrite_with_gradient = True
self._is_quantized = True

def quantized_call(self, inputs):
if self.dtype_policy.quantization_mode == "int8":
Expand Down Expand Up @@ -665,8 +666,6 @@ def quantize(self, mode):

self._tracker.unlock()
if mode == "int8":
if backend.standardize_dtype(self._kernel.dtype) == "int8":
raise ValueError("`quantize` can only be done once per layer.")
(
self._input_reduced_axes,
self._kernel_reduced_axes,
Expand Down Expand Up @@ -709,8 +708,6 @@ def quantize(self, mode):
lambda shape, dtype: kernel_scale,
)
elif mode == "float8":
if hasattr(self, "inputs_amax_history"):
raise ValueError("`quantize` can only be done once per layer.")
self._float8_build()
else:
raise NotImplementedError(
Expand Down
29 changes: 24 additions & 5 deletions keras/src/layers/core/einsum_dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,15 +508,34 @@ class MyEinsumDense(layers.EinsumDense):
def test_quantize_when_already_quantized(self, mode):
layer = layers.EinsumDense(
equation="ab,bcd->acd",
output_shape=(8, 32),
output_shape=(8, 16),
bias_axes="d",
)
layer.build((None, 3))
layer.quantize(mode)
with self.assertRaisesRegex(
ValueError, "`quantize` can only be done once per layer."
):
layer.quantize(mode)
for m in ["int8", "float8"]:
with self.assertRaisesRegex(
ValueError, "is already quantized with dtype_policy="
):
layer.quantize(m)

@parameterized.named_parameters(
("int8", "int8_from_float32"),
("float8", "float8_from_float32"),
)
def test_quantize_when_already_quantized_using_dtype_argument(self, mode):
layer = layers.EinsumDense(
equation="ab,bcd->acd",
output_shape=(8, 16),
bias_axes="d",
dtype=mode,
)
layer.build((None, 3))
for m in ["int8", "float8"]:
with self.assertRaisesRegex(
ValueError, "is already quantized with dtype_policy="
):
layer.quantize(m)

@parameterized.named_parameters(
("int8", "int8_from_float32", 3),
Expand Down
3 changes: 1 addition & 2 deletions keras/src/layers/core/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def _int8_build(
initializer=embeddings_scale_initializer,
trainable=False,
)
self._is_quantized = True

def quantized_call(self, inputs):
if self.dtype_policy.quantization_mode == "int8":
Expand Down Expand Up @@ -374,8 +375,6 @@ def quantize(self, mode):

self._tracker.unlock()
if mode == "int8":
if backend.standardize_dtype(self._embeddings.dtype) == "int8":
raise ValueError("`quantize` can only be done once per layer.")
# Configure `self.inputs_quantizer`
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
# Quantize `self._embeddings` to int8 and compute corresponding
Expand Down
10 changes: 9 additions & 1 deletion keras/src/layers/core/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,15 @@ def test_quantize_when_already_quantized(self):
layer.build()
layer.quantize("int8")
with self.assertRaisesRegex(
ValueError, "`quantize` can only be done once per layer."
ValueError, "is already quantized with dtype_policy="
):
layer.quantize("int8")

def test_quantize_when_already_quantized_using_dtype_argument(self):
layer = layers.Embedding(10, 16, dtype="int8_from_float32")
layer.build()
with self.assertRaisesRegex(
ValueError, "is already quantized with dtype_policy="
):
layer.quantize("int8")

Expand Down
6 changes: 6 additions & 0 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,12 @@ def _check_quantize_args(self, mode, compute_dtype):
f"Layer '{self.name}' (of type '{self.__class__.__name__}') "
"is not built yet."
)
if getattr(self, "_is_quantized", False):
raise ValueError(
f"Layer '{self.name}' is already quantized with "
f"dtype_policy='{self.dtype_policy.name}'. "
f"Received: mode={mode}"
)
if mode not in dtype_policies.QUANTIZATION_MODES:
raise ValueError(
"Invalid quantization mode. "
Expand Down