diff --git a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt index 1b67c5db261e2..a0334ad25d985 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt @@ -89,6 +89,7 @@ iree_cc_library( "ConvertAccGEMMToGEMMPass.cpp" "ConvertBf16ArithToF32.cpp" "ConvertBf16ToUInt16Buffers.cpp" + "BubbleUpOrdinalOps.cpp" "ConvertToDestinationPassingStylePass.cpp" "ConvolutionToIGEMM.cpp" "DecomposeAffineOpsPass.cpp" diff --git a/compiler/src/iree/compiler/Codegen/Common/ConvertF8ArithToF32.cpp b/compiler/src/iree/compiler/Codegen/Common/ConvertF8ArithToF32.cpp new file mode 100644 index 0000000000000..59440d1f3f4ce --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/ConvertF8ArithToF32.cpp @@ -0,0 +1,83 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +//===- BubbleUpOrdinalOpPass.cpp -----------------------------------------===// +// +// The workgroup count computation when using slices needs the ordinal +// annotation ops to be bubbled up as much as possible. This pass implements +// patterns to bubble these operations up. +// +//===---------------------------------------------------------------------===// + +#include "iree/compiler/Codegen/Common/Passes.h" +#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir::iree_compiler { + +#define GEN_PASS_DEF_CONVERTF8ARITHTOF32PASS +#include "iree/compiler/Codegen/Common/Passes.h.inc" + +namespace { + +/// Replace the following sequence +/// +/// ```mlir +/// %1 = arith.negf %input : vector<1x2x1x1x1x1xf8E4M3FNUZ> +/// ``` +/// +/// with +/// +/// ```mlir +/// %0 = arith.extf %input : f8E4M3FNUZ to f32 +/// %1 = arith.negf %0 : vector<1x2x1x1x1x1xf32> +/// %2 = arith.truncf %1 : vector<1x2x1x1x1x1xf8E4M3FNUZ> +/// ``` +/// +/// to make all the uses flow through `flow.dispatch.workload.ordinal` ops. +template +struct F8ArithToF32CastOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + + auto inputType = op.getOperand().getType().cast(); + if (inputType.getElementType().isF8E4M3FNUZ()) { + // Extend the input to f32 + auto extendedType = inputType.clone(rewriter.getF32Type()); + auto extended = rewriter.create(op.getLoc(), extendedType, + op.getOperand()); + + // Negate the extended value + auto negated = rewriter.create(op.getLoc(), extended); + + // Truncate back to f8E4M3FNUZ + auto truncated = + rewriter.create(op.getLoc(), inputType, negated); + + // Replace the original operation + rewriter.replaceOp(op, truncated.getResult()); + return success(); + } + return failure(); + } +}; + +struct ConvertF8ArithToF32Pass final + : impl::ConvertF8ArithToF32PassBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertF8ArithToF32Pass::runOnOperation() { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.insert>(context); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } +} +} // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Codegen/Common/Passes.td b/compiler/src/iree/compiler/Codegen/Common/Passes.td index 811fa9ccc588b..e937984ffd217 100644 --- a/compiler/src/iree/compiler/Codegen/Common/Passes.td +++ b/compiler/src/iree/compiler/Codegen/Common/Passes.td @@ -80,6 +80,11 @@ def ConvertBf16ToUInt16BuffersPass : let summary = "Convert BF16 buffer ops and conversions to simulated behavior with uint16."; } +def ConvertF8ArithToF32Pass : + Pass<"iree-codegen-convert-f8-to-f32-buffers", ""> { + let summary = "Convert f8 buffer ops and conversions to simulated behavior with f32."; +} + def ConvertToDestinationPassingStylePass : InterfacePass<"iree-codegen-convert-to-destination-passing-style", "mlir::FunctionOpInterface"> { let summary =