-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Lower vector.contract to vector.outerproduct.
Implements the lowering of vector contraction op to vector outerproduct wrapped inside an scf.for loop with iterargs to accumulate the result of each outerproduct. The idea is to exploit the AVX feature to generate optimal vector code.
- Loading branch information
Showing
6 changed files
with
576 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
//===--------------- VectorContractToOuterproduct.cpp ------------*- C++-*-===// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements lowering of vector contraction to vector outerproduct. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "TPP/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" | ||
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" | ||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" | ||
#include "mlir/IR/Attributes.h" | ||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/BuiltinTypeInterfaces.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include "mlir/Interfaces/VectorInterfaces.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
#include "llvm/Support/Debug.h" | ||
#include <cstdint> | ||
|
||
#define DEBUG_TYPE "vector-contract-to-outerproduct" | ||
|
||
namespace mlir { | ||
namespace tpp { | ||
#define GEN_PASS_DEF_VECTORCONTRACTTOOUTERPRODUCT | ||
#include "TPP/Passes.h.inc" | ||
} // namespace tpp | ||
} // namespace mlir | ||
|
||
using namespace mlir; | ||
using namespace mlir::tpp; | ||
|
||
namespace mlir { | ||
namespace tpp { | ||
|
||
/// Returns true if the \p map is transposed for the given operand in context of | ||
/// matrix multiply. \p 'isRHS = false' flag indicates thet the check is being | ||
/// performed on LHS operand, otherwise RHS operand. | ||
bool isTransposed(AffineMap map, unsigned matrixDims, bool isRHS = false) { | ||
auto results = map.getResults(); | ||
if (results.size() != matrixDims) | ||
return false; | ||
|
||
// Check the last two dimensions for transposition | ||
auto secondLast = dyn_cast<AffineDimExpr>(results[matrixDims - 2]); | ||
auto last = dyn_cast<AffineDimExpr>(results[matrixDims - 1]); | ||
|
||
if (!secondLast || !last) | ||
return false; | ||
|
||
// For LHS, If the last dimension comes before the second last, it's | ||
// transposed. its opposite for RHS. | ||
return isRHS ? last.getPosition() > secondLast.getPosition() | ||
: last.getPosition() < secondLast.getPosition(); | ||
} | ||
|
||
struct VectorContractToOuterproductPattern | ||
: public OpRewritePattern<vector::ContractionOp> { | ||
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(vector::ContractionOp contractOp, | ||
PatternRewriter &rewriter) const override { | ||
|
||
if (contractOp.getKind() != vector::CombiningKind::ADD) | ||
return rewriter.notifyMatchFailure( | ||
contractOp, | ||
"Unsupported combining kind, only supports ADD at the moment)"); | ||
|
||
auto maps = contractOp.getIndexingMapsArray(); | ||
if (llvm::any_of( | ||
maps, [](AffineMap map) { return !map.isProjectedPermutation(); })) | ||
return rewriter.notifyMatchFailure(contractOp, "Unexpected map"); | ||
|
||
// Check if this is a matrix multiply | ||
auto iteratorTypes = contractOp.getIteratorTypesArray(); | ||
if (iteratorTypes.size() < 3 || | ||
iteratorTypes[iteratorTypes.size() - 3] != | ||
vector::IteratorType::parallel || | ||
iteratorTypes[iteratorTypes.size() - 2] != | ||
vector::IteratorType::parallel || | ||
iteratorTypes[iteratorTypes.size() - 1] != | ||
vector::IteratorType::reduction) | ||
return rewriter.notifyMatchFailure(contractOp, "Not a gemm"); | ||
|
||
Location loc = contractOp.getLoc(); | ||
Value lhs = contractOp.getLhs(); | ||
Value rhs = contractOp.getRhs(); | ||
Value acc = contractOp.getAcc(); | ||
|
||
// Find the original tensor operands | ||
auto lhsDefiningOp = lhs.getDefiningOp<vector::TransferReadOp>(); | ||
auto rhsDefiningOp = rhs.getDefiningOp<vector::TransferReadOp>(); | ||
if (!lhsDefiningOp || !rhsDefiningOp) | ||
return failure(); | ||
|
||
Value lhsTensor = lhsDefiningOp.getSource(); | ||
Value rhsTensor = rhsDefiningOp.getSource(); | ||
|
||
auto lhsType = cast<ShapedType>(lhsTensor.getType()); | ||
auto rhsType = cast<ShapedType>(rhsTensor.getType()); | ||
auto accType = cast<VectorType>(acc.getType()); | ||
|
||
// Handle 3D subviews | ||
auto mapLHS = maps[0]; | ||
auto mapRHS = maps[1]; | ||
unsigned matrixDimsLHS = mapLHS.getNumResults(); | ||
unsigned matrixDimsRHS = mapRHS.getNumResults(); | ||
|
||
int64_t M = accType.getDimSize(0); | ||
int64_t N = accType.getDimSize(1); | ||
int64_t K = !isTransposed(mapLHS, matrixDimsLHS) | ||
? lhsType.getDimSize(lhsType.getRank() - 1) | ||
: lhsType.getDimSize(lhsType.getRank() - 2); | ||
|
||
// Create constants | ||
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); | ||
Value c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1); | ||
Value cK = rewriter.create<arith::ConstantIndexOp>(loc, K); | ||
|
||
auto elementType = lhsType.getElementType(); | ||
FloatType floatType = cast<FloatType>(elementType); | ||
Value f0 = rewriter.create<arith::ConstantFloatOp>( | ||
loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); | ||
|
||
// Create the outer scf.for loop | ||
auto forOp = rewriter.create<scf::ForOp>( | ||
loc, c0, cK, c1, ValueRange{acc}, | ||
[&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, | ||
ValueRange iterArgs) { | ||
// Prepare indices and map to iterate over rows/colums and read | ||
// slices of lhs/rhs input operands. | ||
SmallVector<Value, 3> lhsIndices, rhsIndices; | ||
AffineMap lhsMap, rhsMap; | ||
for (int i = 0; i < lhsType.getRank() - 2; ++i) | ||
lhsIndices.push_back(c0); | ||
// LHS operand | ||
if (!isTransposed(mapLHS, matrixDimsLHS)) { | ||
// If not transposed, iterate over colums and read each column | ||
// using map. | ||
lhsIndices.push_back(c0); | ||
lhsIndices.push_back(iv); | ||
lhsMap = AffineMap::get(lhsType.getRank(), 0, | ||
{nestedBuilder.getAffineDimExpr(0)}, | ||
nestedBuilder.getContext()); | ||
} else { | ||
// If transposed, iterate over rows and read each row with default | ||
// map. | ||
lhsIndices.push_back(iv); | ||
lhsIndices.push_back(c0); | ||
lhsMap = AffineMap::get(lhsType.getRank(), 0, | ||
{nestedBuilder.getAffineDimExpr(1)}, | ||
nestedBuilder.getContext()); | ||
} | ||
|
||
for (int i = 0; i < rhsType.getRank() - 2; ++i) | ||
rhsIndices.push_back(c0); | ||
// RHS operand | ||
if (!isTransposed(mapRHS, matrixDimsRHS, true)) { | ||
// If not transposed, iterate over rows and read each row using | ||
// default map. | ||
rhsIndices.push_back(iv); | ||
rhsIndices.push_back(c0); | ||
rhsMap = AffineMap::get(rhsType.getRank(), 0, | ||
{nestedBuilder.getAffineDimExpr(1)}, | ||
nestedBuilder.getContext()); | ||
} else { | ||
// If transposed, iterate over columns and read each column with | ||
// default map. | ||
rhsIndices.push_back(c0); | ||
rhsIndices.push_back(iv); | ||
rhsMap = AffineMap::get(rhsType.getRank(), 0, | ||
{nestedBuilder.getAffineDimExpr(0)}, | ||
nestedBuilder.getContext()); | ||
} | ||
|
||
// Read vector slices using TransferReadOp | ||
auto lhsSlice = nestedBuilder.create<vector::TransferReadOp>( | ||
nestedLoc, VectorType::get({M}, lhsType.getElementType()), | ||
lhsTensor, lhsIndices, AffineMapAttr::get(lhsMap), f0, Value(), | ||
rewriter.getBoolArrayAttr({true})); | ||
|
||
auto rhsSlice = nestedBuilder.create<vector::TransferReadOp>( | ||
nestedLoc, VectorType::get({N}, rhsType.getElementType()), | ||
rhsTensor, rhsIndices, rhsMap, f0, Value(), | ||
rewriter.getBoolArrayAttr({true})); | ||
|
||
// Perform outer product | ||
auto outerProduct = nestedBuilder.create<vector::OuterProductOp>( | ||
nestedLoc, accType, lhsSlice, rhsSlice, iterArgs[0], | ||
vector::CombiningKind::ADD); | ||
|
||
// Yield the result | ||
nestedBuilder.create<scf::YieldOp>(nestedLoc, | ||
ValueRange{outerProduct}); | ||
}); | ||
|
||
// Replace the original contraction with the result of the loop | ||
rewriter.replaceOp(contractOp, forOp.getResults()); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
struct VectorContractToOuterproduct | ||
: public tpp::impl::VectorContractToOuterproductBase< | ||
VectorContractToOuterproduct> { | ||
|
||
using VectorContractToOuterproductBase::VectorContractToOuterproductBase; | ||
|
||
void runOnOperation() override { | ||
auto funcOp = getOperation(); | ||
MLIRContext *context = &getContext(); | ||
|
||
RewritePatternSet patterns(context); | ||
patterns.add<VectorContractToOuterproductPattern>(context); | ||
|
||
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { | ||
signalPassFailure(); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace tpp | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// RUN: tpp-opt %s | tpp-run -e entry --entry-point-result=void -seed 123 -print > %t.1 | ||
// RUN: tpp-opt %s --vector-contract-to-outerproduct | tpp-run -e entry --entry-point-result=void -seed 123 -print > %t.2 | ||
// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF --allow-empty | ||
|
||
// RUN: tpp-opt %s | tpp-run -e permA --entry-point-result=void -seed 123 -print > %t.1 | ||
// RUN: tpp-opt %s --vector-contract-to-outerproduct | tpp-run -e permA --entry-point-result=void -seed 123 -print > %t.2 | ||
// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF-PERMA --allow-empty | ||
|
||
// RUN: tpp-opt %s | tpp-run -e permB --entry-point-result=void -seed 123 -print > %t.1 | ||
// RUN: tpp-opt %s --vector-contract-to-outerproduct | tpp-run -e permB --entry-point-result=void -seed 123 -print > %t.2 | ||
// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF-PERMA --allow-empty | ||
|
||
// RUN: tpp-opt %s | tpp-run -e permAB --entry-point-result=void -seed 123 -print > %t.1 | ||
// RUN: tpp-opt %s --vector-contract-to-outerproduct | tpp-run -e permAB --entry-point-result=void -seed 123 -print > %t.2 | ||
// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF-PERMAB --allow-empty | ||
|
||
|
||
// DIFF-NOT: {{.}} | ||
#map = affine_map<(d0, d1, d2) -> (d0, d2)> | ||
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> | ||
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> | ||
|
||
func.func @entry(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> { | ||
%c0 = arith.constant 0 : index | ||
%cst = arith.constant 0.000000e+00 : f32 | ||
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> | ||
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> | ||
%2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32> | ||
%3 = vector.contract {indexing_maps = [#map, #map1, #map2], | ||
iterator_types = ["parallel", "parallel", "reduction"], | ||
kind = #vector.kind<add>} %0, %1, %2 | ||
: vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32> | ||
%4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32> | ||
return %4 : tensor<16x16xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// DIFF-PERMA-NOT: {{.}} | ||
#permA0 = affine_map<(d0, d1, d2) -> (d2, d0)> | ||
#permA1 = affine_map<(d0, d1, d2) -> (d2, d1)> | ||
#permA2 = affine_map<(d0, d1, d2) -> (d0, d1)> | ||
|
||
func.func @permA(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> { | ||
%c0 = arith.constant 0 : index | ||
%cst = arith.constant 0.000000e+00 : f32 | ||
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> | ||
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> | ||
%2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> | ||
%3 = vector.contract {indexing_maps = [#permA0, #permA1, #permA2], | ||
iterator_types = ["parallel", "parallel", "reduction"], | ||
kind = #vector.kind<add>} %0, %1, %2 | ||
: vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32> | ||
%4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32> | ||
return %4 : tensor<4x4xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// DIFF-PERMB-NOT: {{.}} | ||
#permB0 = affine_map<(d0, d1, d2) -> (d0, d2)> | ||
#permB1 = affine_map<(d0, d1, d2) -> (d1, d2)> | ||
#permB2 = affine_map<(d0, d1, d2) -> (d0, d1)> | ||
|
||
func.func @permB(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> { | ||
%c0 = arith.constant 0 : index | ||
%cst = arith.constant 0.000000e+00 : f32 | ||
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> | ||
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> | ||
%2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> | ||
%3 = vector.contract {indexing_maps = [#permB0, #permB1, #permB2], | ||
iterator_types = ["parallel", "parallel", "reduction"], | ||
kind = #vector.kind<add>} %0, %1, %2 | ||
: vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32> | ||
%4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32> | ||
return %4 : tensor<4x4xf32> | ||
} | ||
|
||
// ----- | ||
|
||
// DIFF-PERMAB-NOT: {{.}} | ||
#permAB0 = affine_map<(d0, d1, d2) -> (d2, d0)> | ||
#permAB1 = affine_map<(d0, d1, d2) -> (d1, d2)> | ||
#permAB2 = affine_map<(d0, d1, d2) -> (d0, d1)> | ||
|
||
func.func @permAB(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> { | ||
%c0 = arith.constant 0 : index | ||
%cst = arith.constant 0.000000e+00 : f32 | ||
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> | ||
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> | ||
%2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32> | ||
%3 = vector.contract {indexing_maps = [#permAB0, #permAB1, #permAB2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32> | ||
%4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32> | ||
return %4 : tensor<4x4xf32> | ||
} | ||
|
||
// ----- |
Oops, something went wrong.