From e727c6277dc5d1d38b6d22e64a067a332c547403 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Mon, 26 Feb 2024 22:21:05 +0000 Subject: [PATCH] Quantized storage type required to be signless --- stablehlo/dialect/Base.td | 4 ++-- stablehlo/tests/ops_stablehlo_quantized.mlir | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index 46f681856a8..ced30e2826b 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -57,7 +57,7 @@ class StableHLO_UniformQuantizedSignedInt CPred<"$_self.cast()" # ".getStorageTypeIntegralWidth() == " # width>, CPred<"$_self.cast()" # - ".isSigned()">]>, + ".getStorageType().cast().isSignless()">]>, "QI" # width # " type"> { string name = "UniformQuantizedSignedInt"; int bitwidth = width; @@ -68,7 +68,7 @@ class StableHLO_UniformQuantizedPerAxisSignedInt CPred<"$_self.cast()" # ".getStorageTypeIntegralWidth() == " # width>, CPred<"$_self.cast()" # - ".isSigned()">]>, + ".getStorageType().cast().isSignless()">]>, "QI" # width # " type"> { string name = "UniformQuantizedPerAxisSignedInt"; int bitwidth = width; diff --git a/stablehlo/tests/ops_stablehlo_quantized.mlir b/stablehlo/tests/ops_stablehlo_quantized.mlir index 1d31bba9a56..a4fdd960dcc 100644 --- a/stablehlo/tests/ops_stablehlo_quantized.mlir +++ b/stablehlo/tests/ops_stablehlo_quantized.mlir @@ -807,3 +807,11 @@ func.func @negative_select_and_scatter_quantization(%arg0: tensor<10x24x24x64x!q } : (tensor<10x24x24x64x!quant.uniform>, tensor<10x23x23x64x!quant.uniform>, tensor>) -> tensor<10x24x24x64x!quant.uniform> func.return %0 : tensor<10x24x24x64x!quant.uniform> } + +// ----- + +func.func @main(%arg0: tensor<4x!quant.uniform>) -> tensor<4xf32> { + // expected-error@+1 {{operand #0 must be tensor of 4/8/16/32-bit uniform quantized signed integer or 4/8/16/32-bit uniform quantized unsigned integer or 4/8/16/32-bit uniform quantized per axis signed integer or 4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<4x!quant.uniform>}} + %0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor<4x!quant.uniform>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +}