Skip to content

Commit

Permalink
Support fake_quantize_per_tensor_affine_cachemask (llvm#3477)
Browse files Browse the repository at this point in the history
Add a new op with shape/dtypes and decompose into
`fake_quantize_per_tensor_affine` when the second result is unused.

The xfail_set change is on ONNX because torch cannot export this op to
ONNX.
  • Loading branch information
mgehre-amd authored Jun 21, 2024
1 parent 83bfb6f commit acd57a3
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 0 deletions.
28 changes: 28 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4595,6 +4595,34 @@ def Torch_AtenFakeQuantizePerTensorAffineOp : Torch_Op<"aten.fake_quantize_per_t
}];
}

def Torch_AtenFakeQuantizePerTensorAffineCachemaskOp : Torch_Op<"aten.fake_quantize_per_tensor_affine_cachemask", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::fake_quantize_per_tensor_affine_cachemask : (Tensor, float, int, int, int) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_FloatType:$scale,
Torch_IntType:$zero_point,
Torch_IntType:$quant_min,
Torch_IntType:$quant_max
);
let results = (outs
AnyTorchOptionalTensorType:$output,
AnyTorchOptionalTensorType:$mask
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenFakeQuantizePerTensorAffineCachemaskOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 2);
}
void AtenFakeQuantizePerTensorAffineCachemaskOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 2);
}
}];
}

def Torch_AtenMaximumOp : Torch_Op<"aten.maximum", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
31 changes: 31 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6328,6 +6328,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.fake_quantize_per_tensor_affine_cachemask\"(%arg0: !torch.list<int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %2 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -10189,6 +10195,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.fake_quantize_per_tensor_affine_cachemask\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<int, int> {\n"
" %int11 = torch.constant.int 11\n"
" %int15 = torch.constant.int 15\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int1 = torch.constant.int 1\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.ne.int %0#1, %int15 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" %4 = torch.prim.TupleConstruct %3, %int11 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %4 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.cosh\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n"
Expand Down
27 changes: 27 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8146,6 +8146,31 @@ class DecomposeAtenFakeQuantizePerTensorAffineOp
};
} // namespace

namespace {
// Decompose aten.fake_quantize_per_tensor_affine_cachemask
// into aten.fake_quantize_per_tensor_affine
// when the second result is unused.
class DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp
: public OpRewritePattern<AtenFakeQuantizePerTensorAffineCachemaskOp> {
public:
using OpRewritePattern<
AtenFakeQuantizePerTensorAffineCachemaskOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenFakeQuantizePerTensorAffineCachemaskOp op,
PatternRewriter &rewriter) const override {
if (!op->getResult(1).use_empty())
return failure();

auto newOp = rewriter.create<AtenFakeQuantizePerTensorAffineOp>(
op.getLoc(), op->getResult(0).getType(), op.getSelf(), op.getScale(),
op.getZeroPoint(), op.getQuantMin(), op.getQuantMax());

rewriter.replaceAllUsesWith(op->getResult(0), newOp);
rewriter.eraseOp(op);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -8375,6 +8400,8 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTriuOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgNormOp>(patterns);
addPatternIfTargetOpIsIllegal<
DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp>(patterns);
// More specific conv ops
addPatternIfTargetOpIsIllegal<DecomposeAtenConvTbcOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConv1dOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenRelu6Op>();
target.addIllegalOp<AtenEluOp>();
target.addIllegalOp<AtenFakeQuantizePerTensorAffineOp>();
target.addIllegalOp<AtenFakeQuantizePerTensorAffineCachemaskOp>();
target.addIllegalOp<AtenGluOp>();
target.addIllegalOp<AtenSeluOp>();
target.addIllegalOp<AtenHardswishOp>();
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@
"ElementwiseRreluTrainStaticModule_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"EqIntModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic",
"FakeQuantizePerTensorAffineDynamicShapeModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
Expand Down Expand Up @@ -1055,6 +1056,7 @@
"EmptyStridedModule_basic",
"EqIntModule_basic",
"ExpandAsIntModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic",
"FakeQuantizePerTensorAffineModule_basic",
"FakeQuantizePerTensorAffineRoundToEvenModule_basic",
"Fill_TensorFloat64WithFloat32Static_basic",
Expand Down Expand Up @@ -2400,6 +2402,7 @@
"EmptyStridedSizeIntStrideModule_basic",
"EqIntModule_basic",
"ExponentialModule_basic",
"FakeQuantizePerTensorAffineCachemaskModule_basic",
"FloatImplicitModule_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def aten〇diagonal〡shape(self: List[int], offset: int = 0, dim1: int = 0, dim
def aten〇fake_quantize_per_tensor_affine〡shape(self: List[int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇fake_quantize_per_tensor_affine_cachemask〡shape(self: List[int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> Tuple[List[int], List[int]]:
return (upstream_shape_functions.unary(self), upstream_shape_functions.unary(self))

def aten〇sin〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -2162,6 +2165,14 @@ def aten〇fake_quantize_per_tensor_affine〡dtype(self_rank_dtype: Tuple[int, i
assert self_dtype != torch.bfloat16
return self_dtype

# note: fake_quantize_per_tensor_affine doesn't support "meta" device, use "cpu" instead.
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, tensor_device="cpu", scale=0.1, zero_point=0, quant_min=0, quant_max=255, error_types={torch.complex128, torch.complex64, torch.bfloat16, torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool}))
def aten〇fake_quantize_per_tensor_affine_cachemask〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, quant_min: int, quant_max: int) -> Tuple[int, int]:
self_rank, self_dtype = self_rank_dtype
assert is_float_dtype(self_dtype)
assert self_dtype != torch.bfloat16
return (self_rank_dtype[1], torch.bool)

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇cosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,9 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"aten::fake_quantize_per_tensor_affine : (Tensor, float, int, int, int) -> (Tensor)"
)
emit(
"aten::fake_quantize_per_tensor_affine_cachemask : (Tensor, float, int, int, int) -> (Tensor, Tensor)"
)
emit("aten::maximum : (Tensor, Tensor) -> (Tensor)")
emit("aten::minimum : (Tensor, Tensor) -> (Tensor)")
emit("aten::mish : (Tensor) -> (Tensor)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,28 @@ def get_quantized_mlp():
@register_test_case(module_factory=get_quantized_mlp)
def QuantizedMLP_basic(module, tu: TestUtils):
module.forward(get_quant_model_input())


# ==============================================================================


class FakeQuantizePerTensorAffineCachemaskModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([6, 4], torch.float32, True),
]
)
def forward(self, a):
return torch.ops.aten.fake_quantize_per_tensor_affine_cachemask(
a, 2.0, 0, -128, 127
)[0]


@register_test_case(module_factory=lambda: FakeQuantizePerTensorAffineCachemaskModule())
def FakeQuantizePerTensorAffineCachemaskModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 4))

0 comments on commit acd57a3

Please sign in to comment.