From fd3eb99b006914f6db94c9c81f253ae64f3833b8 Mon Sep 17 00:00:00 2001 From: hamptonm1 <79232909+hamptonm1@users.noreply.github.com> Date: Mon, 16 Sep 2024 19:34:27 -0400 Subject: [PATCH 1/2] LLVM/StableHLO Upgrade eaa95a1 (#2943) Co-authored-by: Megan Hampton --- docs/BuildOnLinuxOSX.md | 2 +- docs/BuildOnWindows.md | 2 +- third_party/stablehlo | 2 +- utils/clone-mlir.sh | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/BuildOnLinuxOSX.md b/docs/BuildOnLinuxOSX.md index abc789c5cd..c0571a37fc 100644 --- a/docs/BuildOnLinuxOSX.md +++ b/docs/BuildOnLinuxOSX.md @@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project): ``` bash git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout f142f8afe21bceb00fb495468aa0b5043e98c419 && cd .. +cd llvm-project && git checkout eaa95a1c2bd38332c1a4e634595f29d22b28ffea && cd .. ``` [same-as-file]: <> (utils/build-mlir.sh) diff --git a/docs/BuildOnWindows.md b/docs/BuildOnWindows.md index 0c4f778713..ad7283a53c 100644 --- a/docs/BuildOnWindows.md +++ b/docs/BuildOnWindows.md @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project): ```shell git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout f142f8afe21bceb00fb495468aa0b5043e98c419 && cd .. +cd llvm-project && git checkout eaa95a1c2bd38332c1a4e634595f29d22b28ffea && cd .. ``` [same-as-file]: <> (utils/build-mlir.cmd) diff --git a/third_party/stablehlo b/third_party/stablehlo index 54aa1a5717..e51fd95e5b 160000 --- a/third_party/stablehlo +++ b/third_party/stablehlo @@ -1 +1 @@ -Subproject commit 54aa1a57178251981da616b877dda1a88d840d11 +Subproject commit e51fd95e5b2c28861f22dc9d609fb2a7f002124e diff --git a/utils/clone-mlir.sh b/utils/clone-mlir.sh index e9dfb24e72..b01bbf0b1f 100644 --- a/utils/clone-mlir.sh +++ b/utils/clone-mlir.sh @@ -1,3 +1,3 @@ git clone -n https://github.com/llvm/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout f142f8afe21bceb00fb495468aa0b5043e98c419 && cd .. +cd llvm-project && git checkout eaa95a1c2bd38332c1a4e634595f29d22b28ffea && cd .. From a6ebca0297ca16bc7b9d65938e3e278c67e0ffad Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Tue, 17 Sep 2024 10:04:28 -0400 Subject: [PATCH 2/2] added support for no-zero-point quantization (#2938) Signed-off-by: Alexandre Eichenberger Co-authored-by: Tung D. Le --- src/Compiler/CompilerOptions.cpp | 17 +- src/Compiler/CompilerOptions.hpp | 1 + .../ONNXToKrnl/Math/Elementwise.cpp | 10 +- .../Quantization/DynamicQuantizeLinear.cpp | 23 ++- .../Quantization/QuantizeHelper.hpp | 6 +- .../Quantization/QuantizeLinear.cpp | 24 ++- .../DequantizeLinear_with_canonicalize.mlir | 10 +- .../QuantizationWithoutZeroPoint.mlir | 176 ++++++++++++++++++ 8 files changed, 242 insertions(+), 25 deletions(-) create mode 100644 test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir diff --git a/src/Compiler/CompilerOptions.cpp b/src/Compiler/CompilerOptions.cpp index c8e28ba457..8a5bd9d657 100644 --- a/src/Compiler/CompilerOptions.cpp +++ b/src/Compiler/CompilerOptions.cpp @@ -42,6 +42,7 @@ bool enableONNXHybridPass; // common for both std::vector functionsToDecompose; // common for both std::string opsForCall; // common for both bool disableKrnlOpFusion; // common for both +bool disableQuantZeroPoint; // common for both bool enableKrnlBufferReuse; // common for both bool disableMemRefPrefetch; // common for both EmissionTargetType emissionTarget; // onnx-mlir only @@ -195,7 +196,7 @@ static llvm::cl::list> llvm::cl::cat(OnnxMlirCommonOptions)); static llvm::cl::opt enableONNXHybridPassOpt("onnx-hybrid-pass", - llvm::cl::desc("Enable ONNX hybrid pass (default=true)\n" + llvm::cl::desc("Enable ONNX hybrid pass (default=true).\n" "Set to 'false' if you want to disable ONNX hybrid pass."), llvm::cl::location(enableONNXHybridPass), llvm::cl::init(true), llvm::cl::cat(OnnxMlirCommonOptions)); @@ -208,11 +209,20 @@ static llvm::cl::list> static llvm::cl::opt disableKrnlOpFusionOpt( "disable-krnl-op-fusion", - llvm::cl::desc("disable op fusion in onnx-to-krnl pass (default=false)\n" + llvm::cl::desc("Disable op fusion in onnx-to-krnl pass (default=false).\n" "Set to 'true' if you want to disable fusion."), llvm::cl::location(disableKrnlOpFusion), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); +static llvm::cl::opt disable_quantization_zero_point( + "disable-quantization-zero-point", + llvm::cl::desc( + "Disable the use of zero-point in quantization (default=false).\n" + "Set to 'true' if you want to disable the use of zero-point\n" + "in dyn/static quantization/dequantization."), + llvm::cl::location(disableQuantZeroPoint), llvm::cl::init(false), + llvm::cl::cat(OnnxMlirCommonOptions)); + static llvm::cl::opt enableKrnlBufferReuseOpt( "enable-krnl-buffer-reuse", llvm::cl::desc("enable buffer reuse within an op in onnx-to-krnl pass" @@ -223,7 +233,7 @@ static llvm::cl::opt enableKrnlBufferReuseOpt( static llvm::cl::opt disableMemRefPrefetchOpt( "disable-memref-prefetch", - llvm::cl::desc("disable generation of memref.prefetch (default=false)\n" + llvm::cl::desc("Disable generation of memref.prefetch (default=false).\n" "Set to 'true' if you want to disable prefetch."), llvm::cl::location(disableMemRefPrefetch), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); @@ -1145,7 +1155,6 @@ std::string getLibraryPath() { // as lrodataScript. std::string getToolPath( const std::string &tool, bool flag /*false by default*/) { - if (!flag) { std::string execDir = llvm::sys::path::parent_path(getExecPath()).str(); llvm::SmallString<8> toolPath(execDir); diff --git a/src/Compiler/CompilerOptions.hpp b/src/Compiler/CompilerOptions.hpp index 7bdc758129..3e0940d70b 100644 --- a/src/Compiler/CompilerOptions.hpp +++ b/src/Compiler/CompilerOptions.hpp @@ -87,6 +87,7 @@ extern bool enableONNXHybridPass; // common for both extern std::vector functionsToDecompose; // common for both extern std::string opsForCall; // common for both extern bool disableKrnlOpFusion; // common for both +extern bool disableQuantZeroPoint; // common for both extern bool enableKrnlBufferReuse; // common for both extern bool disableMemRefPrefetch; // common for both extern EmissionTargetType emissionTarget; // onnx-mlir only diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 1205b2c7f0..81c4b9768b 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -1358,9 +1358,15 @@ Value emitScalarOpFor( Value scaleFloat = scalarOperands[1]; Value zeroPointInt = scalarOperands[2]; - Value zeroPointFloat = create.math.cast(elementType, zeroPointInt); Value xFloat = create.math.cast(elementType, XInt); - Value sub = create.math.sub(xFloat, zeroPointFloat); + + Value sub; + if (!disableQuantZeroPoint && !isNoneValue(zeroPointInt)) { + Value zeroPointFloat = create.math.cast(elementType, zeroPointInt); + sub = create.math.sub(xFloat, zeroPointFloat); + } else { + sub = xFloat; + } Value res = create.math.mul(sub, scaleFloat); return res; } diff --git a/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp index 8d325c1964..5484974624 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/DynamicQuantizeLinear.cpp @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" @@ -29,7 +30,7 @@ void emitDynamicQuantizationLinearScalarParameters( ConversionPatternRewriter &rewriter, Location loc, Operation *op, MemRefType inputType, MemRefType quantizedType, Value input, Value qMin, Value qMax, Value &scale, Value &zeroPoint, Value &quantizedZeroPoint, - bool enableSIMD, bool enableParallel) { + bool wantZeroPoint, bool enableSIMD, bool enableParallel) { MultiDialectBuilder create(rewriter, loc); // Types @@ -62,11 +63,15 @@ void emitDynamicQuantizationLinearScalarParameters( scale = create.math.div(xDiff, boundDiff); // Compute y_zero_point. - Value interZeroPoint = create.math.sub(qMin, create.math.div(xMin, scale)); - // Saturate zero point. - Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax); - // Round zero point. - zeroPoint = create.math.round(saturateZeroPoint); + if (wantZeroPoint) { + Value interZeroPoint = create.math.sub(qMin, create.math.div(xMin, scale)); + // Saturate zero point. + Value saturateZeroPoint = create.math.clip(interZeroPoint, qMin, qMax); + // Round zero point. + zeroPoint = create.math.round(saturateZeroPoint); + } else { + zeroPoint = zero; + } quantizedZeroPoint = create.math.cast(quantizedElementType, zeroPoint); } @@ -122,15 +127,17 @@ struct ONNXDynamicQuantizeLinearOpLowering Value qMin = create.math.constant(elementType, 0.0); Value scale, zeroPoint, zeroPointInt; + bool wantZeroPoint = !disableQuantZeroPoint; emitDynamicQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType, yMemRefType, X, qMin, qMax, scale, zeroPoint, zeroPointInt, - enableSIMD, enableParallel); + wantZeroPoint, enableSIMD, enableParallel); create.krnl.store(scale, YScale); create.krnl.store(zeroPointInt, YZeroPoint); emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType, yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale, - zeroPoint, enableSIMD, enableParallel); + zeroPoint, wantZeroPoint /*wanted one, so we have a zero point*/, + enableSIMD, enableParallel); rewriter.replaceOp(op, {Y, YScale, YZeroPoint}); onnxToKrnlSimdReport(op); diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp index 124b854bde..96042bd799 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeHelper.hpp @@ -23,7 +23,8 @@ void emitQuantizationLinearScalarParameters( mlir::Operation *op, mlir::MemRefType inputType, mlir::MemRefType quantizedType, mlir::Value alloc, DimsExpr &allocDims, mlir::Value input, mlir::Value qMin, mlir::Value qMax, mlir::Value scale, - mlir::Value zeroPoint, bool enableSIMD, bool enableParallel); + mlir::Value zeroPoint, bool hasZeroPoint, bool enableSIMD, + bool enableParallel); // Scan the input to compute scale, zeroPoint, and quantizedZeroPoint given qMin // and qMax. @@ -32,5 +33,6 @@ void emitDynamicQuantizationLinearScalarParameters( mlir::Operation *op, mlir::MemRefType inputType, mlir::MemRefType quantizedType, mlir::Value input, mlir::Value qMin, mlir::Value qMax, mlir::Value &scale, mlir::Value &zeroPoint, - mlir::Value &quantizedZeroPoint, bool enableSIMD, bool enableParallel); + mlir::Value &quantizedZeroPoint, bool wantZeroPoint, bool enableSIMD, + bool enableParallel); } // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp index 715968583d..2567c4a1f4 100644 --- a/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp +++ b/src/Conversion/ONNXToKrnl/Quantization/QuantizeLinear.cpp @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +#include "src/Compiler/CompilerOptions.hpp" #include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" #include "src/Dialect/Krnl/DialectBuilder.hpp" #include "src/Dialect/ONNX/DialectBuilder.hpp" @@ -26,7 +27,8 @@ namespace onnx_mlir { void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, Location loc, Operation *op, MemRefType inputType, MemRefType quantizedType, Value alloc, DimsExpr &allocDims, Value input, Value qMin, Value qMax, - Value scale, Value zeroPoint, bool enableSIMD, bool enableParallel) { + Value scale, Value zeroPoint, bool hasZeroPoint, bool enableSIMD, + bool enableParallel) { MultiDialectBuilder create( rewriter, loc); @@ -77,7 +79,11 @@ void emitQuantizationLinearScalarParameters(ConversionPatternRewriter &rewriter, // Round Value roundX = create.math.round(scaleX); // Adjust - Value adjustX = create.math.add(roundX, zeroPoint); + Value adjustX; + if (hasZeroPoint) + adjustX = create.math.add(roundX, zeroPoint); + else + adjustX = roundX; // Saturate Value saturateX = create.math.clip(adjustX, qMin, qMax); Value res = create.math.cast(quantizedElementType, saturateX); @@ -160,15 +166,21 @@ struct ONNXQuantizeLinearOpLowering // Load y_zero_point. Value zeroPoint; + bool hasZeroPoint = false; if (!isNoneValue(YZeroPoint)) { zeroPoint = create.krnl.load(adaptor.getYZeroPoint()); zeroPoint = create.math.cast(elementType, zeroPoint); - } else - zeroPoint = create.math.constant(elementType, 0.0); - + hasZeroPoint = true; + } + if (disableQuantZeroPoint) { + // TODO: should we expect to disable hasZeroPoint forcefully, or generate + // an error if we had a zero point? Right now, just forcefully assert we + // have no zero point, i.e. ignore one even if we had a zero point. + hasZeroPoint = false; + } emitQuantizationLinearScalarParameters(rewriter, loc, op, xMemRefType, yMemRefType, Y, shapeHelper.getOutputDims(0), X, qMin, qMax, scale, - zeroPoint, enableSIMD, enableParallel); + zeroPoint, hasZeroPoint, enableSIMD, enableParallel); rewriter.replaceOp(op, {Y}); onnxToKrnlSimdReport(op); diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir index f6b022444a..93d38fc77a 100644 --- a/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/DequantizeLinear_with_canonicalize.mlir @@ -3,6 +3,8 @@ // Adding canonicalize is important here as this is the only way to check the values of the map, // which are otherwise before the function, and thus are hard to test. +// ----- + func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor, %arg2: tensor) -> tensor<4xf32> { %0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xi8>, tensor, tensor) -> tensor<4xf32> return %0 : tensor<4xf32> @@ -29,10 +31,12 @@ func.func @test_dequantizelinear_i8(%arg0: tensor<4xi8>, %arg1: tensor, %ar // ----- + func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor, %arg2: tensor) -> tensor<4xf32> { %0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xui8>, tensor, tensor) -> tensor<4xf32> return %0 : tensor<4xf32> +// mlir2FileCheck.py // CHECK-LABEL: func.func @test_dequantizelinear_ui8 // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xui8>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<4xf32> { // CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4xf32> @@ -42,11 +46,11 @@ func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor, % // CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<4xui8> // CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref // CHECK-DAG: [[LOAD_PARAM_2_MEM_:%.+]] = krnl.load [[PARAM_2_]][] : memref -// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 +// CHECK: [[VAR_5_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 // CHECK-DAG: [[VAR_6_:%.+]] = arith.uitofp [[VAR_5_]] : i8 to f32 -// CHECK-DAG: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 +// CHECK-DAG: [[VAR_7_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_2_MEM_]] : ui8 to i8 // CHECK: [[VAR_8_:%.+]] = arith.uitofp [[VAR_7_]] : i8 to f32 -// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_8_]], [[VAR_6_]] : f32 +// CHECK: [[VAR_9_:%.+]] = arith.subf [[VAR_6_]], [[VAR_8_]] : f32 // CHECK: [[VAR_10_:%.+]] = arith.mulf [[VAR_9_]], [[LOAD_PARAM_1_MEM_]] : f32 // CHECK: krnl.store [[VAR_10_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> // CHECK: } diff --git a/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir new file mode 100644 index 0000000000..e456311773 --- /dev/null +++ b/test/mlir/conversion/onnx_to_krnl/Quantization/QuantizationWithoutZeroPoint.mlir @@ -0,0 +1,176 @@ +// RUN: onnx-mlir-opt --disable-quantization-zero-point --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +// Test quantization with disabled zero point + +// Adding canonicalize is important here as this is the only way to check the values of the map, +// which are otherwise before the function, and thus are hard to test. + +// ----- + + +func.func @test_dequantizelinear_ui8(%arg0: tensor<4xui8>, %arg1: tensor, %arg2: tensor) -> tensor<4xf32> { + %0 = "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<4xui8>, tensor, tensor) -> tensor<4xf32> + return %0 : tensor<4xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_dequantizelinear_ui8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<4xui8>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<4xf32> { +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<4xf32> +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 4){ +// CHECK: [[VAR_1_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]{{.}} : memref<4xui8> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref +// CHECK: [[VAR_4_:%.+]] = builtin.unrealized_conversion_cast [[LOAD_PARAM_0_MEM_]] : ui8 to i8 +// CHECK: [[VAR_5_:%.+]] = arith.uitofp [[VAR_4_]] : i8 to f32 +// CHECK: [[VAR_6_:%.+]] = arith.mulf [[VAR_5_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: krnl.store [[VAR_6_]], [[RES_]]{{.}}[[VAR_1_]]{{.}} : memref<4xf32> +// CHECK: } +// CHECK: return [[RES_]] : memref<4xf32> +// CHECK: } +} + +// ----- + + +func.func @test_dynamic_quantize_linear(%arg0: tensor) -> (tensor, tensor, tensor) { + %y, %y_scale, %y_zero_point = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor) -> (tensor, tensor, tensor) + return %y, %y_scale, %y_zero_point: tensor, tensor, tensor + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-LABEL: func.func @test_dynamic_quantize_linear +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref) -> (memref, memref, memref) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i8 +// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[CST_0_3_:%.+]] = arith.constant 0 : index +// CHECK: [[VAR_dim_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_3_]] : memref +// CHECK-DAG: [[RES_:%.+]] = memref.alloc([[VAR_dim_]]) {{.*}}: memref +// CHECK-DAG: [[RES_1_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_2_:%.+]] = memref.alloc() : memref +// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_3_]], [[CST_0_2_]] : memref +// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2 +// CHECK-DAG: [[VAR_dim_9_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_3_]] : memref +// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to [[VAR_dim_9_]], [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 2){ +// CHECK: [[VAR_12_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_12_]]#0, [[VAR_12_]]#1] : memref +// CHECK-DAG: [[LOAD_RES_3_MEM_:%.+]] = krnl.load [[RES_3_]][] : memref +// CHECK: [[VAR_15_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_]], [[LOAD_PARAM_0_MEM_]] : f32 +// CHECK: krnl.store [[VAR_15_]], [[RES_3_]][] : memref +// CHECK: } +// CHECK: [[RES_4_:%.+]] = memref.alloc() : memref +// CHECK: krnl.memset [[RES_4_]], [[CST_0_1_]] : memref +// CHECK-DAG: [[LOOP_1_:%.+]]:2 = krnl.define_loops 2 +// CHECK-DAG: [[VAR_dim_11_:%.+]] = memref.dim [[PARAM_0_]], [[CST_0_3_]] : memref +// CHECK: krnl.iterate([[LOOP_1_]]#0, [[LOOP_1_]]#1) with ([[LOOP_1_]]#0 -> [[I_2_:%.+]] = 0 to [[VAR_dim_11_]], [[LOOP_1_]]#1 -> [[I_3_:%.+]] = 0 to 2){ +// CHECK: [[VAR_12_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_1_]]#0, [[LOOP_1_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index) +// CHECK-DAG: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_12_1_]]#0, [[VAR_12_1_]]#1] : memref +// CHECK-DAG: [[LOAD_RES_3_MEM_1_:%.+]] = krnl.load [[RES_4_]][] : memref +// CHECK: [[VAR_15_1_:%.+]] = arith.maxnumf [[LOAD_RES_3_MEM_1_]], [[LOAD_PARAM_0_MEM_1_]] : f32 +// CHECK: krnl.store [[VAR_15_1_]], [[RES_4_]][] : memref +// CHECK: } +// CHECK-DAG: [[LOAD_RES_3_MEM_2_:%.+]] = krnl.load [[RES_3_]][] : memref +// CHECK-DAG: [[LOAD_RES_4_MEM_:%.+]] = krnl.load [[RES_4_]][] : memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = arith.maxnumf [[LOAD_RES_4_MEM_]], [[CST_0_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_5_:%.+]] = arith.minnumf [[LOAD_RES_3_MEM_2_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.divf [[VAR_6_]], [[CST_2_dot_550000_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = builtin.unrealized_conversion_cast [[CST_0_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_7_]], [[RES_1_]][] : memref +// CHECK: krnl.store [[VAR_8_]], [[RES_2_]][] : memref +// CHECK-DAG: [[VAR_9_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[RES_5_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[VAR_9_]], [[RES_5_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_:%.+]] = memref.reshape [[PARAM_0_]]([[RES_5_]]) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[VAR_10_:%.+]] = affine.apply [[MAP_0_]](){{.}}[[VAR_dim_]]{{.}} +// CHECK-DAG: [[RES_6_:%.+]] = memref.alloc() {{.*}}: memref<1xindex> +// CHECK: affine.store [[VAR_10_]], [[RES_6_]][0] : memref<1xindex> +// CHECK-DAG: [[VAR_reshape_14_:%.+]] = memref.reshape [[RES_]]([[RES_]]_13) : (memref, memref<1xindex>) -> memref +// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_4_:%.+]] = 0 to [[MAP_1_]]([[VAR_dim_]])){ +// CHECK: [[VAR_12_2_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_1_:%.+]] = krnl.load [[VAR_reshape_]]{{.}}[[VAR_12_2_]]{{.}} : memref +// CHECK: [[LOAD_RES_3_MEM_1_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_1_]], [[VAR_7_]] : f32 +// CHECK: [[VAR_15_2_:%.+]] = math.floor [[LOAD_RES_3_MEM_1_]] : f32 +// CHECK: [[VAR_16_:%.+]] = arith.subf [[LOAD_RES_3_MEM_1_]], [[VAR_15_2_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf ogt, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_18_:%.+]] = arith.addf [[VAR_15_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_19_:%.+]] = arith.select [[VAR_17_]], [[VAR_18_]], [[VAR_15_2_]] : f32 +// CHECK-DAG: [[VAR_20_:%.+]] = arith.mulf [[VAR_15_2_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = math.floor [[VAR_20_]] : f32 +// CHECK: [[VAR_22_:%.+]] = arith.mulf [[VAR_21_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_23_:%.+]] = arith.subf [[VAR_15_2_]], [[VAR_22_]] : f32 +// CHECK-DAG: [[VAR_24_:%.+]] = arith.cmpf oeq, [[VAR_23_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_25_:%.+]] = arith.addf [[VAR_15_2_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_26_:%.+]] = arith.select [[VAR_24_]], [[VAR_25_]], [[VAR_15_2_]] : f32 +// CHECK-DAG: [[VAR_27_:%.+]] = arith.cmpf oeq, [[VAR_16_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_28_:%.+]] = arith.select [[VAR_27_]], [[VAR_26_]], [[VAR_19_]] : f32 +// CHECK: [[VAR_29_:%.+]] = arith.maxnumf [[VAR_28_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_30_:%.+]] = arith.minnumf [[VAR_29_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_31_:%.+]] = arith.fptoui [[VAR_30_]] : f32 to i8 +// CHECK: [[VAR_32_:%.+]] = builtin.unrealized_conversion_cast [[VAR_31_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_32_]], [[VAR_reshape_14_]]{{.}}[[VAR_12_2_]]{{.}} : memref +// CHECK: } +// CHECK: return [[RES_]], [[RES_]]_6, [[RES_]]_7 : memref, memref, memref +// CHECK: } +} + +// ----- + + +func.func @test_quantize_linear_ui8(%arg0: tensor<6xf32>, %arg1: tensor, %arg2: tensor) -> tensor<6xui8> { + %0 = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<6xf32>, tensor, tensor) -> tensor<6xui8> + return %0 : tensor<6xui8> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_quantize_linear_ui8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<6xf32>, [[PARAM_1_:%.+]]: memref, [[PARAM_2_:%.+]]: memref) -> memref<6xui8> { +// CHECK-DAG: [[CST_5_dot_000000_:%.+]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: [[CST_2_dot_000000_:%.+]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: [[CST_1_dot_000000_:%.+]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_2_dot_550000_:%.+]] = arith.constant 2.550000e+02 : f32 +// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<6xui8> +// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]][] : memref +// CHECK-DAG: [[LOOP_0_:%.+]] = krnl.define_loops 1 +// CHECK: krnl.iterate([[LOOP_0_]]) with ([[LOOP_0_]] -> [[I_0_:%.+]] = 0 to 6){ +// CHECK: [[VAR_2_:%.+]] = krnl.get_induction_var_value([[LOOP_0_]]) : (!krnl.loop) -> index +// CHECK: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_2_]]{{.}} : memref<6xf32> +// CHECK: [[VAR_4_:%.+]] = arith.divf [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : f32 +// CHECK: [[VAR_5_:%.+]] = math.floor [[VAR_4_]] : f32 +// CHECK: [[VAR_6_:%.+]] = arith.subf [[VAR_4_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpf ogt, [[VAR_6_]], [[CST_5_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_8_:%.+]] = arith.addf [[VAR_5_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = arith.select [[VAR_7_]], [[VAR_8_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_10_:%.+]] = arith.mulf [[VAR_5_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_11_:%.+]] = math.floor [[VAR_10_]] : f32 +// CHECK: [[VAR_12_:%.+]] = arith.mulf [[VAR_11_]], [[CST_2_dot_000000_]] : f32 +// CHECK: [[VAR_13_:%.+]] = arith.subf [[VAR_5_]], [[VAR_12_]] : f32 +// CHECK-DAG: [[VAR_14_:%.+]] = arith.cmpf oeq, [[VAR_13_]], [[CST_1_dot_000000_]] : f32 +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addf [[VAR_5_]], [[CST_1_dot_000000_]] : f32 +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = arith.select [[VAR_14_]], [[VAR_15_]], [[VAR_5_]] : f32 +// CHECK-DAG: [[VAR_17_:%.+]] = arith.cmpf oeq, [[VAR_6_]], [[CST_5_dot_000000_]] : f32 +// CHECK: [[VAR_18_:%.+]] = arith.select [[VAR_17_]], [[VAR_16_]], [[VAR_9_]] : f32 +// CHECK: [[VAR_19_:%.+]] = arith.maxnumf [[VAR_18_]], [[CST_0_dot_000000_]] : f32 +// CHECK: [[VAR_20_:%.+]] = arith.minnumf [[VAR_19_]], [[CST_2_dot_550000_]] : f32 +// CHECK: [[VAR_21_:%.+]] = arith.fptoui [[VAR_20_]] : f32 to i8 +// CHECK: [[VAR_22_:%.+]] = builtin.unrealized_conversion_cast [[VAR_21_]] : i8 to ui8 +// CHECK: krnl.store [[VAR_22_]], [[RES_]]{{.}}[[VAR_2_]]{{.}} : memref<6xui8> +// CHECK: } +// CHECK: return [[RES_]] : memref<6xui8> +// CHECK: } +} +