From f205faa8c27693b3f6737157134bf0e92f1b7604 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sun, 13 Aug 2023 18:22:21 -0700 Subject: [PATCH] Fix a bug in lowering mhlo.convolution in ConvertMHLOQuantToInt pass The current impl use the same tensor shape for lhs, rhs and result, even if they might be different in case of Convolution. PiperOrigin-RevId: 556625694 --- .../bridge/convert_mhlo_quant_to_int.cc | 12 ++++--- .../bridge/convert-mhlo-quant-to-int.mlir | 32 +++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc index bdf2ed63936cd9..dfcdacd7062c21 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc @@ -484,6 +484,10 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, return rewriter.notifyMatchFailure(op, "Unsupported input element type."); } + auto lhs_float32_tensor_type = + op.getLhs().getType().clone(rewriter.getF32Type()); + auto rhs_float32_tensor_type = + op.getRhs().getType().clone(rewriter.getF32Type()); auto res_float32_tensor_type = op.getResult().getType().clone(rewriter.getF32Type()); @@ -508,14 +512,14 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor, // Offset xxx_int32_tensor according to zero points. Value lhs_float32_tensor = rewriter.create( - op->getLoc(), res_float32_tensor_type, lhs); + op->getLoc(), lhs_float32_tensor_type, lhs); lhs_float32_tensor = rewriter.create( - op->getLoc(), res_float32_tensor_type, lhs_float32_tensor, lhs_zero_point, + op->getLoc(), lhs_float32_tensor_type, lhs_float32_tensor, lhs_zero_point, nullptr); Value rhs_float32_tensor = rewriter.create( - op->getLoc(), res_float32_tensor_type, rhs); + op->getLoc(), rhs_float32_tensor_type, rhs); rhs_float32_tensor = rewriter.create( - op->getLoc(), res_float32_tensor_type, rhs_float32_tensor, rhs_zero_point, + op->getLoc(), rhs_float32_tensor_type, rhs_float32_tensor, rhs_zero_point, nullptr); // Execute the conversion target op. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir index 8a9cec3af57e2d..ef8a7e0d8a2241 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir @@ -350,6 +350,38 @@ func.func @uniform_quantized_convolution(%arg0: tensor, %arg1: tens // ----- +// CHECK-LABEL: func @uniform_quantized_convolution_static_shape +func.func @uniform_quantized_convolution_static_shape(%arg0: tensor<128x28x28x1xf32>, %arg1: tensor<3x3x1x128xf32>) { + // CHECK: %[[VAL28:.*]] = mhlo.convert %[[VAL12:.*]] : (tensor<128x28x28x1xi8>) -> tensor<128x28x28x1xf32> + // CHECK: %[[LHS:.*]] = chlo.broadcast_subtract %[[VAL28]], %[[VAL26:.*]] : (tensor<128x28x28x1xf32>, tensor) -> tensor<128x28x28x1xf32> + // CHECK: %[[VAL30:.*]] = mhlo.convert %[[VAL25:.*]] : (tensor<3x3x1x128xi8>) -> tensor<3x3x1x128xf32> + // CHECK: %[[RHS:.*]] = chlo.broadcast_subtract %[[VAL30]], %[[VAL27:.*]] : (tensor<3x3x1x128xf32>, tensor) -> tensor<3x3x1x128xf32> + // CHECK: %[[VAL32:.*]] = mhlo.convolution(%[[LHS]], %[[RHS]]) + // CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} + // CHECK-SAME{LITERAL}: batch_group_count = 1 : i64, feature_group_count = 1 : i64 + // CHECK-SAME: (tensor<128x28x28x1xf32>, tensor<3x3x1x128xf32>) -> tensor<128x26x26x128xf32> + // CHECK: %[[VAL43:.*]] = mhlo.clamp %[[VAL41:.*]], %[[VAL40:.*]], %[[VAL42:.*]] : (tensor, tensor<128x26x26x128xi32>, tensor) -> tensor<128x26x26x128xi32> + // CHECK: %[[VAL44:.*]] = mhlo.convert %[[VAL43]] : tensor<128x26x26x128xi32> + %0 = mhlo.uniform_quantize %arg0 : (tensor<128x28x28x1xf32>) -> tensor<128x28x28x1x!quant.uniform> + %1 = mhlo.uniform_quantize %arg1 : (tensor<3x3x1x128xf32>) -> tensor<3x3x1x128x!quant.uniform> + %2 = mhlo.convolution(%0, %1) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [0, 0]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<128x28x28x1x!quant.uniform>, tensor<3x3x1x128x!quant.uniform>) + -> tensor<128x26x26x128x!quant.uniform> + return +} + +// ----- + // CHECK-LABEL: func @uniform_quantize_dot_hybrid func.func @uniform_quantize_dot_hybrid(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor) -> tensor