Skip to content

Commit

Permalink
Fix a bug in lowering mhlo.convolution in ConvertMHLOQuantToInt pass
Browse files Browse the repository at this point in the history
The current impl use the same tensor shape for lhs, rhs and result, even if they might be different in case of Convolution.

PiperOrigin-RevId: 556625694
  • Loading branch information
tensorflower-gardener committed Aug 14, 2023
1 parent 4c562cc commit f205faa
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,10 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor,
return rewriter.notifyMatchFailure(op, "Unsupported input element type.");
}

auto lhs_float32_tensor_type =
op.getLhs().getType().clone(rewriter.getF32Type());
auto rhs_float32_tensor_type =
op.getRhs().getType().clone(rewriter.getF32Type());
auto res_float32_tensor_type =
op.getResult().getType().clone(rewriter.getF32Type());

Expand All @@ -508,14 +512,14 @@ LogicalResult matchAndRewriteDotLikeOp(OpType &op, OpAdaptorType &adaptor,

// Offset xxx_int32_tensor according to zero points.
Value lhs_float32_tensor = rewriter.create<mhlo::ConvertOp>(
op->getLoc(), res_float32_tensor_type, lhs);
op->getLoc(), lhs_float32_tensor_type, lhs);
lhs_float32_tensor = rewriter.create<chlo::BroadcastSubOp>(
op->getLoc(), res_float32_tensor_type, lhs_float32_tensor, lhs_zero_point,
op->getLoc(), lhs_float32_tensor_type, lhs_float32_tensor, lhs_zero_point,
nullptr);
Value rhs_float32_tensor = rewriter.create<mhlo::ConvertOp>(
op->getLoc(), res_float32_tensor_type, rhs);
op->getLoc(), rhs_float32_tensor_type, rhs);
rhs_float32_tensor = rewriter.create<chlo::BroadcastSubOp>(
op->getLoc(), res_float32_tensor_type, rhs_float32_tensor, rhs_zero_point,
op->getLoc(), rhs_float32_tensor_type, rhs_float32_tensor, rhs_zero_point,
nullptr);

// Execute the conversion target op.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,38 @@ func.func @uniform_quantized_convolution(%arg0: tensor<?x?x?x?xf32>, %arg1: tens

// -----

// CHECK-LABEL: func @uniform_quantized_convolution_static_shape
func.func @uniform_quantized_convolution_static_shape(%arg0: tensor<128x28x28x1xf32>, %arg1: tensor<3x3x1x128xf32>) {
// CHECK: %[[VAL28:.*]] = mhlo.convert %[[VAL12:.*]] : (tensor<128x28x28x1xi8>) -> tensor<128x28x28x1xf32>
// CHECK: %[[LHS:.*]] = chlo.broadcast_subtract %[[VAL28]], %[[VAL26:.*]] : (tensor<128x28x28x1xf32>, tensor<f32>) -> tensor<128x28x28x1xf32>
// CHECK: %[[VAL30:.*]] = mhlo.convert %[[VAL25:.*]] : (tensor<3x3x1x128xi8>) -> tensor<3x3x1x128xf32>
// CHECK: %[[RHS:.*]] = chlo.broadcast_subtract %[[VAL30]], %[[VAL27:.*]] : (tensor<3x3x1x128xf32>, tensor<f32>) -> tensor<3x3x1x128xf32>
// CHECK: %[[VAL32:.*]] = mhlo.convolution(%[[LHS]], %[[RHS]])
// CHECK-SAME{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
// CHECK-SAME{LITERAL}: batch_group_count = 1 : i64, feature_group_count = 1 : i64
// CHECK-SAME: (tensor<128x28x28x1xf32>, tensor<3x3x1x128xf32>) -> tensor<128x26x26x128xf32>
// CHECK: %[[VAL43:.*]] = mhlo.clamp %[[VAL41:.*]], %[[VAL40:.*]], %[[VAL42:.*]] : (tensor<i32>, tensor<128x26x26x128xi32>, tensor<i32>) -> tensor<128x26x26x128xi32>
// CHECK: %[[VAL44:.*]] = mhlo.convert %[[VAL43]] : tensor<128x26x26x128xi32>
%0 = mhlo.uniform_quantize %arg0 : (tensor<128x28x28x1xf32>) -> tensor<128x28x28x1x!quant.uniform<i8:f32, 2.000000e+00:4>>
%1 = mhlo.uniform_quantize %arg1 : (tensor<3x3x1x128xf32>) -> tensor<3x3x1x128x!quant.uniform<i8:f32, 3.000000e+00:1>>
%2 = mhlo.convolution(%0, %1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {
stride = [1, 1], pad = [[0, 0], [0, 0]],
lhs_dilate = [1, 1],
rhs_dilate = [1, 1]
}
{
batch_group_count = 1 : i64,
feature_group_count = 1 : i64
} : (tensor<128x28x28x1x!quant.uniform<i8:f32, 2.000000e+00:4>>, tensor<3x3x1x128x!quant.uniform<i8:f32, 3.000000e+00:1>>)
-> tensor<128x26x26x128x!quant.uniform<i32:f32, 1.000000e+00:5>>
return
}

// -----

// CHECK-LABEL: func @uniform_quantize_dot_hybrid
func.func @uniform_quantize_dot_hybrid(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[VAL1:.*]] = mhlo.convert %[[VAL0:.*]] : (tensor<?x?xi8>) -> tensor<?x?xf32>
Expand Down

0 comments on commit f205faa

Please sign in to comment.