Skip to content

Commit

Permalink
[Codegen][GPU] Make operand promotion controlled by lowering config
Browse files Browse the repository at this point in the history
Promoting the operands of a matmul is optional and best to control
through the lowering config rather than based on on the fly analysis.
This gives greater flexibility for adding support for other operations
too (like promotion of another kind of contraction or convolution like
op without have to always extend this pass).
  • Loading branch information
qedawkins committed Sep 22, 2024
1 parent 5a6bd8d commit 49a8ee7
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,47 +64,29 @@ void promoteOperand(OpBuilder &builder, Operation *op, unsigned index) {
op->setOperand(index, copy.getResult(0));
}

bool isNonMatvecContraction(linalg::LinalgOp linalgOp) {
SmallVector<int64_t, 4> bounds = linalgOp.getStaticLoopRanges();
FailureOr<mlir::linalg::ContractionDimensions> contractionDims =
mlir::linalg::inferContractionDims(linalgOp);
if (failed(contractionDims)) {
return false;
}

if (contractionDims->k.size() < 1 || contractionDims->m.size() < 1 ||
contractionDims->n.size() < 1) {
return false;
}

auto getElementCount = [&](ArrayRef<unsigned> dims) {
int64_t acc = 1;
for (auto mDim : dims) {
int64_t size = bounds[mDim];
if (ShapedType::isDynamic(size)) {
return size;
}
acc *= size;
}
return acc;
};
return getElementCount(contractionDims->m) != 1 &&
getElementCount(contractionDims->n) != 1;
}

struct GPUPromoteMatmulOperandsPass final
: impl::GPUPromoteMatmulOperandsPassBase<GPUPromoteMatmulOperandsPass> {
void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();

OpBuilder builder(funcOp);
funcOp.walk([&](linalg::LinalgOp linalgOp) {
if (!isNonMatvecContraction(linalgOp)) {
funcOp.walk([&](Operation *op) {
auto loweringConfig =
getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
if (!loweringConfig) {
return;
}

std::optional<SmallVector<int64_t>> promotedOperands =
loweringConfig.getPromotedOperandList();
if (!promotedOperands) {
return;
}
builder.setInsertionPoint(linalgOp);
promoteOperand(builder, linalgOp, 0);
promoteOperand(builder, linalgOp, 1);

builder.setInsertionPoint(op);
for (auto operand : promotedOperands.value()) {
promoteOperand(builder, op, operand);
}
});
}
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
// RUN: iree-opt %s --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-promote-matmul-operands))" | FileCheck %s

#lowering_config = #iree_gpu.lowering_config<{promote_operands = [0, 1]}>

func.func @matmul(%a: tensor<32x1024xf32>, %b: tensor<1024x128xf32>) -> tensor<32x128xf32> {
%cst = arith.constant 0.000000e+00 : f32
%empty = tensor.empty() : tensor<32x128xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<32x128xf32>) -> tensor<32x128xf32>
%mm = linalg.matmul ins(%a, %b : tensor<32x1024xf32>, tensor<1024x128xf32>) outs(%fill : tensor<32x128xf32>) -> tensor<32x128xf32>
%mm = linalg.matmul {lowering_config = #lowering_config}
ins(%a, %b : tensor<32x1024xf32>, tensor<1024x128xf32>) outs(%fill : tensor<32x128xf32>) -> tensor<32x128xf32>
return %mm : tensor<32x128xf32>
}

Expand All @@ -13,33 +16,40 @@ func.func @matmul(%a: tensor<32x1024xf32>, %b: tensor<1024x128xf32>) -> tensor<3
// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<1024x128xf32>
// CHECK-DAG: %[[PA:.+]] = linalg.copy {{.*}} ins(%[[A]] : tensor<32x1024xf32>)
// CHECK-DAG: %[[PB:.+]] = linalg.copy {{.*}} ins(%[[B]] : tensor<1024x128xf32>)
// CHECK: linalg.matmul ins(%[[PA]], %[[PB]] : tensor<32x1024xf32>, tensor<1024x128xf32>)
// CHECK: linalg.matmul {{.*}} ins(%[[PA]], %[[PB]] : tensor<32x1024xf32>, tensor<1024x128xf32>)

// -----

func.func @matvec(%a: tensor<1x1024xf32>, %b: tensor<1024x128xf32>) -> tensor<1x128xf32> {
#lowering_config = #iree_gpu.lowering_config<{promote_operands = []}>

func.func @empty_config(%a: tensor<1x1024xf32>, %b: tensor<1024x128xf32>) -> tensor<1x128xf32> {
%cst = arith.constant 0.000000e+00 : f32
%empty = tensor.empty() : tensor<1x128xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x128xf32>) -> tensor<1x128xf32>
%mm = linalg.matmul ins(%a, %b : tensor<1x1024xf32>, tensor<1024x128xf32>) outs(%fill : tensor<1x128xf32>) -> tensor<1x128xf32>
%mm = linalg.matmul {lowering_config = #lowering_config}
ins(%a, %b : tensor<1x1024xf32>, tensor<1024x128xf32>) outs(%fill : tensor<1x128xf32>) -> tensor<1x128xf32>
return %mm : tensor<1x128xf32>
}

// Verify that no copies are generated for matvec operations.
// CHECK-LABEL: func.func @matvec
// Verify that no copies are generated with an empty lowering config
// CHECK-LABEL: func.func @empty_config
// CHECK-NOT: linalg.copy
// CHECK: return

// -----

#lowering_config = #iree_gpu.lowering_config<{promote_operands = [0]}>

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
func.func @generic_matmul(%a: tensor<32x1024xf32>, %b: tensor<1024x128xf32>) -> tensor<32x128xf32> {
func.func @lhs_only_matmul(%a: tensor<32x1024xf32>, %b: tensor<1024x128xf32>) -> tensor<32x128xf32> {
%cst = arith.constant 0.000000e+00 : f32
%empty = tensor.empty() : tensor<32x128xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<32x128xf32>) -> tensor<32x128xf32>
%mm = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]}
%mm = linalg.generic {
indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"],
lowering_config = #lowering_config}
ins(%a, %b : tensor<32x1024xf32>, tensor<1024x128xf32>) outs(%fill : tensor<32x128xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%7 = arith.mulf %in, %in_0 : f32
Expand All @@ -49,9 +59,8 @@ func.func @generic_matmul(%a: tensor<32x1024xf32>, %b: tensor<1024x128xf32>) ->
return %mm : tensor<32x128xf32>
}

// CHECK-LABEL: func.func @generic_matmul
// CHECK-LABEL: func.func @lhs_only_matmul
// CHECK-SAME: %[[A:[A-Za-z0-9]+]]: tensor<32x1024xf32>
// CHECK-SAME: %[[B:[A-Za-z0-9]+]]: tensor<1024x128xf32>
// CHECK-DAG: %[[PA:.+]] = linalg.copy {{.*}} ins(%[[A]] : tensor<32x1024xf32>)
// CHECK-DAG: %[[PB:.+]] = linalg.copy {{.*}} ins(%[[B]] : tensor<1024x128xf32>)
// CHECK: linalg.generic {{.*}} ins(%[[PA]], %[[PB]] : tensor<32x1024xf32>, tensor<1024x128xf32>)
// CHECK: linalg.generic {{.*}} ins(%[[PA]], %[[B]] : tensor<32x1024xf32>, tensor<1024x128xf32>)
24 changes: 19 additions & 5 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1312,17 +1312,20 @@ static StringRef getTilingLevelName(GPU::TilingLevel level) {
return StringAttr();
}

static SmallVector<int64_t> getTileSizes(DictionaryAttr config,
GPU::TilingLevel level) {
auto sizes = config.getAs<ArrayAttr>(getTilingLevelName(level));
if (!sizes || !llvm::all_of(sizes.getValue(), llvm::IsaPred<IntegerAttr>)) {
static SmallVector<int64_t> getIntegerList(ArrayAttr array) {
if (!array || !llvm::all_of(array.getValue(), llvm::IsaPred<IntegerAttr>)) {
return {};
}
return llvm::map_to_vector(sizes.getValue(), [](Attribute s) -> int64_t {
return llvm::map_to_vector(array.getValue(), [](Attribute s) -> int64_t {
return cast<IntegerAttr>(s).getInt();
});
}

static SmallVector<int64_t> getTileSizes(DictionaryAttr config,
GPU::TilingLevel level) {
return getIntegerList(config.getAs<ArrayAttr>(getTilingLevelName(level)));
}

SmallVector<int64_t> LoweringConfigAttr::getWorkgroupTileSizes() const {
return getTileSizes(getAttributes(), GPU::TilingLevel::Workgroup);
}
Expand Down Expand Up @@ -1366,6 +1369,17 @@ IREE::GPU::MmaInterfaceAttr LoweringConfigAttr::getMmaKind() const {
return getAttributes().getAs<IREE::GPU::MmaInterfaceAttr>(kMmaKindName);
}

constexpr StringLiteral kPromoteOperandsName = "promote_operands";

std::optional<SmallVector<int64_t>>
LoweringConfigAttr::getPromotedOperandList() const {
auto array = getAttributes().getAs<ArrayAttr>(kPromoteOperandsName);
if (!array) {
return std::nullopt;
}
return getIntegerList(array);
}

//===----------------------------------------------------------------------===//
// DerivedThreadConfigAttr
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def IREEGPU_LoweringConfigAttr :
let extraClassDeclaration = [{
/// Helper to retrieve a target mma intrinsic if present.
::mlir::iree_compiler::IREE::GPU::MmaInterfaceAttr getMmaKind() const;

/// Helper to retrieve a list of operand indices to promote.
std::optional<SmallVector<int64_t>> getPromotedOperandList() const;
}];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
attrs.emplace_back(StringAttr::get(context, "subgroup"),
b.getI64ArrayAttr(subgroupTileSizes));
attrs.emplace_back(StringAttr::get(context, "mma_kind"), mmaKind);
attrs.emplace_back(StringAttr::get(context, "promote_operands"),
b.getI64ArrayAttr({0, 1}));
auto configDict = DictionaryAttr::get(context, attrs);
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);

Expand All @@ -220,6 +222,35 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
workgroupSize, targetSubgroupSize, pipelineConfig);
}

/// Helper to identify contraction like operations for operand promotiong.
static bool isNonMatvecContraction(linalg::LinalgOp linalgOp) {
SmallVector<int64_t, 4> bounds = linalgOp.getStaticLoopRanges();
FailureOr<mlir::linalg::ContractionDimensions> contractionDims =
mlir::linalg::inferContractionDims(linalgOp);
if (failed(contractionDims)) {
return false;
}

if (contractionDims->k.size() < 1 || contractionDims->m.size() < 1 ||
contractionDims->n.size() < 1) {
return false;
}

auto getElementCount = [&](ArrayRef<unsigned> dims) {
int64_t acc = 1;
for (auto mDim : dims) {
int64_t size = bounds[mDim];
if (ShapedType::isDynamic(size)) {
return size;
}
acc *= size;
}
return acc;
};
return getElementCount(contractionDims->m) != 1 &&
getElementCount(contractionDims->n) != 1;
}

LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface entryPoint,
Operation *op) {
Expand Down Expand Up @@ -439,6 +470,11 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
attrs.emplace_back(StringAttr::get(context, "thread"),
b.getI64ArrayAttr(threadTileSizes));

if (isNonMatvecContraction(linalgOp)) {
attrs.emplace_back(StringAttr::get(context, "promote_operands"),
b.getI64ArrayAttr({0, 1}));
}

// Heuristic value chosen to limit maximum vector sizes when tiling below.
const unsigned maxVectorSize = 32;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func.func @expanded_matmul_transpose_b(%lhs: tensor<2x64x2048xf16>, %rhs: tensor

// CHECK: linalg.generic {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 0, 0, 4]
// CHECK-SAME: subgroup = [0, 0, 4, 1, 0]
// CHECK-SAME: workgroup = [1, 1, 64, 64, 0]
Expand All @@ -59,6 +60,7 @@ func.func @mfma_matmul_1024x1024x1024(%lhs: tensor<1024x1024xf16>, %rhs: tensor<

// CHECK: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 2]
// CHECK-SAME: subgroup = [4, 4, 0]
// CHECK-SAME: workgroup = [128, 128, 0]
Expand Down Expand Up @@ -100,6 +102,7 @@ module {
// CHECK-LABEL: func.func @matmul_dynamic_dim
// CHECK-SAME: #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
// CHECK: linalg.matmul {{.*}}lowering_config = #iree_gpu.lowering_config
// CHECK-SAME: promote_operands = [0, 1]
// CHECK-SAME: reduction = [0, 0, 4]
// CHECK-SAME: thread = [1, 1, 0]
// CHECK-SAME: workgroup = [1, 64, 0]
Expand Down
Loading

0 comments on commit 49a8ee7

Please sign in to comment.