Skip to content

Commit

Permalink
[Codegen] Add f8 to f32 pass for arith.negf
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed Feb 10, 2025
1 parent 0781072 commit dc944ce
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ iree_cc_library(
"ConvertAccGEMMToGEMMPass.cpp"
"ConvertBf16ArithToF32.cpp"
"ConvertBf16ToUInt16Buffers.cpp"
"BubbleUpOrdinalOps.cpp"
"ConvertToDestinationPassingStylePass.cpp"
"ConvolutionToIGEMM.cpp"
"DecomposeAffineOpsPass.cpp"
Expand Down
83 changes: 83 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/ConvertF8ArithToF32.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===- BubbleUpOrdinalOpPass.cpp -----------------------------------------===//
//
// The workgroup count computation when using slices needs the ordinal
// annotation ops to be bubbled up as much as possible. This pass implements
// patterns to bubble these operations up.
//
//===---------------------------------------------------------------------===//

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_CONVERTF8ARITHTOF32PASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

/// Replace the following sequence
///
/// ```mlir
/// %1 = arith.negf %input : vector<1x2x1x1x1x1xf8E4M3FNUZ>
/// ```
///
/// with
///
/// ```mlir
/// %0 = arith.extf %input : f8E4M3FNUZ to f32
/// %1 = arith.negf %0 : vector<1x2x1x1x1x1xf32>
/// %2 = arith.truncf %1 : vector<1x2x1x1x1x1xf8E4M3FNUZ>
/// ```
///
/// to make all the uses flow through `flow.dispatch.workload.ordinal` ops.
template <typename OpTy>
struct F8ArithToF32CastOp : public OpRewritePattern<OpTy> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {

auto inputType = op.getOperand().getType().cast<VectorType>();
if (inputType.getElementType().isF8E4M3FNUZ()) {
// Extend the input to f32
auto extendedType = inputType.clone(rewriter.getF32Type());
auto extended = rewriter.create<arith::ExtFOp>(op.getLoc(), extendedType,
op.getOperand());

// Negate the extended value
auto negated = rewriter.create<OpTy>(op.getLoc(), extended);

// Truncate back to f8E4M3FNUZ
auto truncated =
rewriter.create<arith::TruncFOp>(op.getLoc(), inputType, negated);

// Replace the original operation
rewriter.replaceOp(op, truncated.getResult());
return success();
}
return failure();
}
};

struct ConvertF8ArithToF32Pass final
: impl::ConvertF8ArithToF32PassBase<ConvertF8ArithToF32Pass> {
void runOnOperation() override;
};
} // namespace

void ConvertF8ArithToF32Pass::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.insert<F8ArithToF32CastOp<arith::NegFOp>>(context);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
} // namespace mlir::iree_compiler
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def ConvertBf16ToUInt16BuffersPass :
let summary = "Convert BF16 buffer ops and conversions to simulated behavior with uint16.";
}

def ConvertF8ArithToF32Pass :
Pass<"iree-codegen-convert-f8-to-f32-buffers", ""> {
let summary = "Convert f8 buffer ops and conversions to simulated behavior with f32.";
}

def ConvertToDestinationPassingStylePass :
InterfacePass<"iree-codegen-convert-to-destination-passing-style", "mlir::FunctionOpInterface"> {
let summary =
Expand Down

0 comments on commit dc944ce

Please sign in to comment.