From 77ced609f9f71c0efb54693ac32ca0640b534b56 Mon Sep 17 00:00:00 2001 From: Luis Montero Date: Mon, 22 Jul 2024 14:34:04 +0200 Subject: [PATCH] fix: fix dtype check in quantizer dequant The commit fixes a dtype check that we have in the `dequant` method of our quantizer. Dtype objects are weird. https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype --- src/concrete/ml/quantization/quantizers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/concrete/ml/quantization/quantizers.py b/src/concrete/ml/quantization/quantizers.py index 3ce3eb825..c1bd058a0 100644 --- a/src/concrete/ml/quantization/quantizers.py +++ b/src/concrete/ml/quantization/quantizers.py @@ -103,7 +103,11 @@ class QuantizationOptions: is_precomputed_qat: bool = False def __init__( - self, n_bits: int, is_signed: bool = False, is_symmetric: bool = False, is_qat: bool = False + self, + n_bits: int, + is_signed: bool = False, + is_symmetric: bool = False, + is_qat: bool = False, ): self.n_bits = n_bits self.is_signed = is_signed @@ -789,7 +793,7 @@ def dequant(self, qvalues: numpy.ndarray) -> Union[float, numpy.ndarray, Tracer] assert_true( isinstance(self.scale, (numpy.floating, float)) - or (isinstance(self.scale, numpy.ndarray) and self.scale.dtype is numpy.float64), + or (isinstance(self.scale, numpy.ndarray) and self.scale.dtype == numpy.float64), "Scale is a of type " + type(self.scale).__name__ + ((" " + str(self.scale.dtype)) if isinstance(self.scale, numpy.ndarray) else ""),