From e568f7e99980722f22e988cd2fe1b437be614ce5 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Wed, 29 Nov 2023 17:43:09 +0000 Subject: [PATCH] Move handling of integer signedness to the backend conversions (#2597) The function `getTypeForScalarType` currently takes an argument to specify the signedness of integer types. This is leakage of backend specific requirements into the torch dialect world. Because `getTypeForScalarType` is a utility function for the torch dialect, it should only produce types that match the sign conventions used by PyTorch (regular integers are signed and unsigned integers are unsigned). This commit removes the signedness argument from `getTypeForScalarType`, and moves the backend specific handling of integer types to the backend code. --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 5 ++--- .../TorchToLinalg/TensorConstructors.cpp | 12 ++++++------ .../TorchToLinalg/Uncategorized.cpp | 6 +++--- lib/Conversion/TorchToLinalg/Utils.cpp | 16 +++++++++++++++- lib/Conversion/TorchToLinalg/Utils.h | 7 +++++++ lib/Conversion/TorchToStablehlo/Basic.cpp | 9 +++++++-- lib/Dialect/Torch/Utils/Utils.cpp | 9 ++++----- .../test_suite/constant_alloc.py | 19 +++++++++++++++++++ 8 files changed, 63 insertions(+), 20 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 14622f654139..0e4c2b0a0ab7 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -26,9 +26,8 @@ bool getListConstructElements(Value v, SmallVectorImpl &elems); std::optional matchLegalConstantIndexIntoListOfSize(Value v, int64_t length); torch_upstream::ScalarType getScalarTypeForType(Type type); -FailureOr getTypeForScalarType( - MLIRContext *context, torch_upstream::ScalarType dtypeInt, - mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed); +FailureOr getTypeForScalarType(MLIRContext *context, + torch_upstream::ScalarType dtypeInt); Type getTypeForTorchType( MLIRContext *context, Type type, diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 7e73fabd8e9f..65aa22711800 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -127,9 +127,9 @@ class ConvertConstantTensorAllocOp : public OpConversionPattern { if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); - FailureOr maybeResultElementType = getTypeForScalarType( - op->getContext(), (torch_upstream::ScalarType)dtypeInt, - IntegerType::Signless); + FailureOr maybeResultElementType = + torch_to_linalg::getBackendTypeForScalarType( + op->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { return rewriter.notifyMatchFailure( op, "unable to convert `dtypeInt` to builtin type"); @@ -233,9 +233,9 @@ class ConvertAtenEmptyMemoryFormatOp if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); - FailureOr maybeResultElementType = getTypeForScalarType( - op->getContext(), (torch_upstream::ScalarType)dtypeInt, - IntegerType::Signless); + FailureOr maybeResultElementType = + torch_to_linalg::getBackendTypeForScalarType( + op->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { return rewriter.notifyMatchFailure( op, "unable to convert `dtypeInt` to builtin type"); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index c9f5364ba246..9c99b5e52982 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1057,9 +1057,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( atenToDtype.emitError("unimplemented: dtype must be a constant integer"); return nullptr; } - FailureOr maybeResultElementType = getTypeForScalarType( - atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt, - IntegerType::Signless); + FailureOr maybeResultElementType = + torch_to_linalg::getBackendTypeForScalarType( + atenToDtype->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { atenToDtype.emitError("unable to convert `dtypeInt` to builtin type"); return nullptr; diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index a666ca30b02f..84bf8a83e449 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -20,7 +20,6 @@ #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" -#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; @@ -546,3 +545,18 @@ Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc, return torch_to_linalg::createElementwiseLinalgGeneric( b, loc, {tensor}, elementType, dtypePromoteBody); } + +FailureOr torch_to_linalg::getBackendTypeForScalarType( + MLIRContext *context, torch_upstream::ScalarType dtypeInt) { + FailureOr maybeType = + getTypeForScalarType(context, (torch_upstream::ScalarType)dtypeInt); + if (failed(maybeType)) { + return failure(); + } + Type type = *maybeType; + // The linalg-on-tensors backend currently expects integers to be signless. + if (auto intType = type.dyn_cast()) { + type = IntegerType::get(context, intType.getWidth(), IntegerType::Signless); + } + return type; +} diff --git a/lib/Conversion/TorchToLinalg/Utils.h b/lib/Conversion/TorchToLinalg/Utils.h index 3bee8d642533..134fbeca46dc 100644 --- a/lib/Conversion/TorchToLinalg/Utils.h +++ b/lib/Conversion/TorchToLinalg/Utils.h @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" namespace mlir { namespace torch { @@ -88,6 +89,12 @@ Value removeSizeInformation(OpBuilder &b, Location loc, Value tensor); Value convertTensorToElementType(OpBuilder &b, Location loc, Value tensor, Type elementType); +// Convert a scalar type to the corresponding builtin type in the +// linalg-on-tensors backend. +FailureOr +getBackendTypeForScalarType(MLIRContext *context, + torch_upstream::ScalarType dtypeInt); + } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 979182ae7fd7..73710997709a 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1672,13 +1672,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); FailureOr maybeResultElementType = getTypeForScalarType( - op->getContext(), (torch_upstream::ScalarType)dtypeInt, - IntegerType::Signless); + op->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { return rewriter.notifyMatchFailure( op, "unable to convert `dtypeInt` to builtin type"); } resultElementType = *maybeResultElementType; + // The stablehlo backend expects signed integers to be signless. + if (resultElementType.isSignedInteger()) { + resultElementType = IntegerType::get( + op->getContext(), resultElementType.getIntOrFloatBitWidth(), + IntegerType::Signless); + } } // Create an uninitialized tensor of `resultSize` shape. diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index d4f59deced6a..0e0fc0f819da 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -85,17 +85,16 @@ Type Torch::getTypeForTorchType( FailureOr Torch::getTypeForScalarType(MLIRContext *context, - torch_upstream::ScalarType dtypeInt, - mlir::IntegerType::SignednessSemantics signedness) { + torch_upstream::ScalarType dtypeInt) { switch (dtypeInt) { case torch_upstream::ScalarType::Float: return Float32Type::get(context); case torch_upstream::ScalarType::Double: return Float64Type::get(context); case torch_upstream::ScalarType::Long: - return IntegerType::get(context, 64, signedness); + return IntegerType::get(context, 64, mlir::IntegerType::Signed); case torch_upstream::ScalarType::Int: - return IntegerType::get(context, 32, signedness); + return IntegerType::get(context, 32, mlir::IntegerType::Signed); case torch_upstream::ScalarType::Bool: return IntegerType::get(context, 1); case torch_upstream::ScalarType::BFloat16: @@ -105,7 +104,7 @@ Torch::getTypeForScalarType(MLIRContext *context, case torch_upstream::ScalarType::Byte: return mlir::IntegerType::get(context, 8, mlir::IntegerType::Unsigned); case torch_upstream::ScalarType::Char: - return mlir::IntegerType::get(context, 8, signedness); + return mlir::IntegerType::get(context, 8, mlir::IntegerType::Signed); case torch_upstream::ScalarType::ComplexHalf: return mlir::ComplexType::get(Float16Type::get(context)); case torch_upstream::ScalarType::ComplexFloat: diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py index 9277e987e4e2..eb0143b9d06b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/constant_alloc.py @@ -451,6 +451,25 @@ def EmptyModule_int(module, tu: TestUtils): module.forward() +class EmptyUInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + empty = torch.ops.aten.empty([1], dtype=torch.uint8) + return torch.ops.aten.zeros_like(empty).to(torch.int8) + + +@register_test_case(module_factory=lambda: EmptyUInt8Module()) +def EmptyModule_uint8(module, tu: TestUtils): + module.forward() + + class EmptyFloatModule(torch.nn.Module): def __init__(self):