Skip to content

Commit

Permalink
Lower vector.contract to vector.outerproduct.
Browse files Browse the repository at this point in the history
Implements the lowering of vector contraction op to vector outerproduct wrapped
inside an scf.for loop with iterargs to accumulate the result of each outerproduct.
The idea is to exploit the AVX feature to generate optimal vector code.
  • Loading branch information
shahidact committed Sep 24, 2024
1 parent 7b521f2 commit 2621935
Show file tree
Hide file tree
Showing 6 changed files with 576 additions and 0 deletions.
9 changes: 9 additions & 0 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def VectorizationPass : Pass<"vectorization-pass",
let dependentDialects = [ "memref::MemRefDialect", "linalg::LinalgDialect", "vector::VectorDialect" ];
}

def VectorContractToOuterproduct : Pass<
"vector-contract-to-outerproduct"> {
let summary = "Perform outerproduct lowering of vector contraction ops";
let dependentDialects = ["memref::MemRefDialect",
"scf::SCFDialect",
"tensor::TensorDialect",
"vector::VectorDialect"];
}


def ConvertXsmmToFunc : Pass<"convert-xsmm-to-func", "ModuleOp"> {
let summary = "Convert xsmm to func";
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ struct DefaultTppPasses
if (linalgToVector) {
pm.addNestedPass<func::FuncOp>(createVectorizationPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createVectorContractToOuterproduct());
} else {
// Lower all Tile operations.
pm.addNestedPass<func::FuncOp>(createLinalgLowering());
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ add_mlir_library(TPPTransforms
FoldIntoEltwise.cpp
FoldAddIntoDest.cpp
Vectorization.cpp
VectorContractToOuterproduct.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
234 changes: 234 additions & 0 deletions lib/TPP/Transforms/VectorContractToOuterproduct.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
//===--------------- VectorContractToOuterproduct.cpp ------------*- C++-*-===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements lowering of vector contraction to vector outerproduct.
//
//===----------------------------------------------------------------------===//

#include "TPP/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <cstdint>

#define DEBUG_TYPE "vector-contract-to-outerproduct"

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_VECTORCONTRACTTOOUTERPRODUCT
#include "TPP/Passes.h.inc"
} // namespace tpp
} // namespace mlir

using namespace mlir;
using namespace mlir::tpp;

namespace mlir {
namespace tpp {

/// Returns true if the \p map is transposed for the given operand in context of
/// matrix multiply. \p 'isRHS = false' flag indicates thet the check is being
/// performed on LHS operand, otherwise RHS operand.
bool isTransposed(AffineMap map, unsigned matrixDims, bool isRHS = false) {
auto results = map.getResults();
if (results.size() != matrixDims)
return false;

// Check the last two dimensions for transposition
auto secondLast = dyn_cast<AffineDimExpr>(results[matrixDims - 2]);
auto last = dyn_cast<AffineDimExpr>(results[matrixDims - 1]);

if (!secondLast || !last)
return false;

// For LHS, If the last dimension comes before the second last, it's
// transposed. its opposite for RHS.
return isRHS ? last.getPosition() > secondLast.getPosition()
: last.getPosition() < secondLast.getPosition();
}

struct VectorContractToOuterproductPattern
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {

if (contractOp.getKind() != vector::CombiningKind::ADD)
return rewriter.notifyMatchFailure(
contractOp,
"Unsupported combining kind, only supports ADD at the moment)");

auto maps = contractOp.getIndexingMapsArray();
if (llvm::any_of(
maps, [](AffineMap map) { return !map.isProjectedPermutation(); }))
return rewriter.notifyMatchFailure(contractOp, "Unexpected map");

// Check if this is a matrix multiply
auto iteratorTypes = contractOp.getIteratorTypesArray();
if (iteratorTypes.size() < 3 ||
iteratorTypes[iteratorTypes.size() - 3] !=
vector::IteratorType::parallel ||
iteratorTypes[iteratorTypes.size() - 2] !=
vector::IteratorType::parallel ||
iteratorTypes[iteratorTypes.size() - 1] !=
vector::IteratorType::reduction)
return rewriter.notifyMatchFailure(contractOp, "Not a gemm");

Location loc = contractOp.getLoc();
Value lhs = contractOp.getLhs();
Value rhs = contractOp.getRhs();
Value acc = contractOp.getAcc();

// Find the original tensor operands
auto lhsDefiningOp = lhs.getDefiningOp<vector::TransferReadOp>();
auto rhsDefiningOp = rhs.getDefiningOp<vector::TransferReadOp>();
if (!lhsDefiningOp || !rhsDefiningOp)
return failure();

Value lhsTensor = lhsDefiningOp.getSource();
Value rhsTensor = rhsDefiningOp.getSource();

auto lhsType = cast<ShapedType>(lhsTensor.getType());
auto rhsType = cast<ShapedType>(rhsTensor.getType());
auto accType = cast<VectorType>(acc.getType());

// Handle 3D subviews
auto mapLHS = maps[0];
auto mapRHS = maps[1];
unsigned matrixDimsLHS = mapLHS.getNumResults();
unsigned matrixDimsRHS = mapRHS.getNumResults();

int64_t M = accType.getDimSize(0);
int64_t N = accType.getDimSize(1);
int64_t K = !isTransposed(mapLHS, matrixDimsLHS)
? lhsType.getDimSize(lhsType.getRank() - 1)
: lhsType.getDimSize(lhsType.getRank() - 2);

// Create constants
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value cK = rewriter.create<arith::ConstantIndexOp>(loc, K);

auto elementType = lhsType.getElementType();
FloatType floatType = cast<FloatType>(elementType);
Value f0 = rewriter.create<arith::ConstantFloatOp>(
loc, APFloat::getZero(floatType.getFloatSemantics()), floatType);

// Create the outer scf.for loop
auto forOp = rewriter.create<scf::ForOp>(
loc, c0, cK, c1, ValueRange{acc},
[&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
ValueRange iterArgs) {
// Prepare indices and map to iterate over rows/colums and read
// slices of lhs/rhs input operands.
SmallVector<Value, 3> lhsIndices, rhsIndices;
AffineMap lhsMap, rhsMap;
for (int i = 0; i < lhsType.getRank() - 2; ++i)
lhsIndices.push_back(c0);
// LHS operand
if (!isTransposed(mapLHS, matrixDimsLHS)) {
// If not transposed, iterate over colums and read each column
// using map.
lhsIndices.push_back(c0);
lhsIndices.push_back(iv);
lhsMap = AffineMap::get(lhsType.getRank(), 0,
{nestedBuilder.getAffineDimExpr(0)},
nestedBuilder.getContext());
} else {
// If transposed, iterate over rows and read each row with default
// map.
lhsIndices.push_back(iv);
lhsIndices.push_back(c0);
lhsMap = AffineMap::get(lhsType.getRank(), 0,
{nestedBuilder.getAffineDimExpr(1)},
nestedBuilder.getContext());
}

for (int i = 0; i < rhsType.getRank() - 2; ++i)
rhsIndices.push_back(c0);
// RHS operand
if (!isTransposed(mapRHS, matrixDimsRHS, true)) {
// If not transposed, iterate over rows and read each row using
// default map.
rhsIndices.push_back(iv);
rhsIndices.push_back(c0);
rhsMap = AffineMap::get(rhsType.getRank(), 0,
{nestedBuilder.getAffineDimExpr(1)},
nestedBuilder.getContext());
} else {
// If transposed, iterate over columns and read each column with
// default map.
rhsIndices.push_back(c0);
rhsIndices.push_back(iv);
rhsMap = AffineMap::get(rhsType.getRank(), 0,
{nestedBuilder.getAffineDimExpr(0)},
nestedBuilder.getContext());
}

// Read vector slices using TransferReadOp
auto lhsSlice = nestedBuilder.create<vector::TransferReadOp>(
nestedLoc, VectorType::get({M}, lhsType.getElementType()),
lhsTensor, lhsIndices, AffineMapAttr::get(lhsMap), f0, Value(),
rewriter.getBoolArrayAttr({true}));

auto rhsSlice = nestedBuilder.create<vector::TransferReadOp>(
nestedLoc, VectorType::get({N}, rhsType.getElementType()),
rhsTensor, rhsIndices, rhsMap, f0, Value(),
rewriter.getBoolArrayAttr({true}));

// Perform outer product
auto outerProduct = nestedBuilder.create<vector::OuterProductOp>(
nestedLoc, accType, lhsSlice, rhsSlice, iterArgs[0],
vector::CombiningKind::ADD);

// Yield the result
nestedBuilder.create<scf::YieldOp>(nestedLoc,
ValueRange{outerProduct});
});

// Replace the original contraction with the result of the loop
rewriter.replaceOp(contractOp, forOp.getResults());

return success();
}
};

struct VectorContractToOuterproduct
: public tpp::impl::VectorContractToOuterproductBase<
VectorContractToOuterproduct> {

using VectorContractToOuterproductBase::VectorContractToOuterproductBase;

void runOnOperation() override {
auto funcOp = getOperation();
MLIRContext *context = &getContext();

RewritePatternSet patterns(context);
patterns.add<VectorContractToOuterproductPattern>(context);

if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
signalPassFailure();
}
}
};

} // namespace tpp
} // namespace mlir
97 changes: 97 additions & 0 deletions test/Integration/vector-contract-to-outerproduct.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// RUN: tpp-opt %s | tpp-run -e entry --entry-point-result=void -seed 123 -print > %t.1
// RUN: tpp-opt %s --vector-contract-to-outerproduct | tpp-run -e entry --entry-point-result=void -seed 123 -print > %t.2
// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF --allow-empty

// RUN: tpp-opt %s | tpp-run -e permA --entry-point-result=void -seed 123 -print > %t.1
// RUN: tpp-opt %s --vector-contract-to-outerproduct | tpp-run -e permA --entry-point-result=void -seed 123 -print > %t.2
// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF-PERMA --allow-empty

// RUN: tpp-opt %s | tpp-run -e permB --entry-point-result=void -seed 123 -print > %t.1
// RUN: tpp-opt %s --vector-contract-to-outerproduct | tpp-run -e permB --entry-point-result=void -seed 123 -print > %t.2
// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF-PERMA --allow-empty

// RUN: tpp-opt %s | tpp-run -e permAB --entry-point-result=void -seed 123 -print > %t.1
// RUN: tpp-opt %s --vector-contract-to-outerproduct | tpp-run -e permAB --entry-point-result=void -seed 123 -print > %t.2
// RUN: diff %t.1 %t.2 | FileCheck %s --check-prefix=DIFF-PERMAB --allow-empty


// DIFF-NOT: {{.}}
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>

func.func @entry(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>, %arg2: tensor<16x16xf32>) -> tensor<16x16xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32>
%2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<16x16xf32>, vector<16x16xf32>
%3 = vector.contract {indexing_maps = [#map, #map1, #map2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %0, %1, %2
: vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
%4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, tensor<16x16xf32>
return %4 : tensor<16x16xf32>
}

// -----

// DIFF-PERMA-NOT: {{.}}
#permA0 = affine_map<(d0, d1, d2) -> (d2, d0)>
#permA1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#permA2 = affine_map<(d0, d1, d2) -> (d0, d1)>

func.func @permA(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%3 = vector.contract {indexing_maps = [#permA0, #permA1, #permA2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %0, %1, %2
: vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
%4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
return %4 : tensor<4x4xf32>
}

// -----

// DIFF-PERMB-NOT: {{.}}
#permB0 = affine_map<(d0, d1, d2) -> (d0, d2)>
#permB1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#permB2 = affine_map<(d0, d1, d2) -> (d0, d1)>

func.func @permB(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%3 = vector.contract {indexing_maps = [#permB0, #permB1, #permB2],
iterator_types = ["parallel", "parallel", "reduction"],
kind = #vector.kind<add>} %0, %1, %2
: vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
%4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
return %4 : tensor<4x4xf32>
}

// -----

// DIFF-PERMAB-NOT: {{.}}
#permAB0 = affine_map<(d0, d1, d2) -> (d2, d0)>
#permAB1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#permAB2 = affine_map<(d0, d1, d2) -> (d0, d1)>

func.func @permAB(%arg0: tensor<4x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32>) -> tensor<4x4xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%2 = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : tensor<4x4xf32>, vector<4x4xf32>
%3 = vector.contract {indexing_maps = [#permAB0, #permAB1, #permAB2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %0, %1, %2 : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32>
%4 = vector.transfer_write %3, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<4x4xf32>, tensor<4x4xf32>
return %4 : tensor<4x4xf32>
}

// -----
Loading

0 comments on commit 2621935

Please sign in to comment.