Skip to content

Commit

Permalink
Tile configuration addition pass (#890)
Browse files Browse the repository at this point in the history
Test cases to be added, draft PR in progress. Runtime and code changes
are done, licm hoisting is incomplete, working on it.
  • Loading branch information
KavithaTipturMadhu authored and nhasabni committed Mar 14, 2024
1 parent c1349ac commit f062714
Show file tree
Hide file tree
Showing 17 changed files with 700 additions and 35 deletions.
6 changes: 4 additions & 2 deletions include/TPP/Dialect/Xsmm/XsmmEnum.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def Xsmm_GemmFlags : I64EnumAttr<
I64EnumAttrCase<"BETA_0", 4, "beta_0">,
I64EnumAttrCase<"VNNI_A", 2048, "vnni_a">,
I64EnumAttrCase<"VNNI_B", 4096, "vnni_b">,
I64EnumAttrCase<"VNNI_C", 8192, "vnni_c">
]> {
I64EnumAttrCase<"VNNI_C", 8192, "vnni_c">,
I64EnumAttrCase<"NO_RESET_TILECONFIG", 64, "no_reset_tileconfig">,
I64EnumAttrCase<"NO_SETUP_TILECONFIG", 128, "no_setup_tileconfig">
]> {
let cppNamespace = "mlir::xsmm";
}
19 changes: 19 additions & 0 deletions include/TPP/Dialect/Xsmm/XsmmOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,15 @@ def Xsmm_GemmDispatchOp : Xsmm_GemmLikeOp<"gemm.dispatch"> {
def Xsmm_BrgemmDispatchOp : Xsmm_GemmLikeOp<"brgemm.dispatch"> {
let summary = "dispatch for brgemm operation.";
let hasVerifier = 1;

}

//===----------------------------------------------------------------------===//
// IntelAMXTileConfigDispatchOp
//===----------------------------------------------------------------------===//

def Xsmm_IntelAMXTileConfigDispatchOp : Xsmm_GemmLikeOp<"IntelAMXtileConfig.dispatch"> {
let summary = "dispatch for Intel amx tileConfig operation.";
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -298,4 +307,14 @@ def Xsmm_FusedBrgemmDispatchOp : Xsmm_Op<"fused_brgemm.dispatch", [Pure]> {
let hasVerifier = 1;
}


//===----------------------------------------------------------------------===//
// IntelAMXTileConfigOp
//===----------------------------------------------------------------------===//

def Xsmm_IntelAMXTileConfigOp : Xsmm_Op<"IntelAMXtileConfig", [MemoryEffects<[MemWrite, MemRead]>]> {
let summary = "invoke for Intel AMX tileConfig operation.";
let arguments = (ins I64:$dispatch, Variadic<AnyMemRef>:$inputs);
}

#endif // TPP_XSMM_OPS
34 changes: 32 additions & 2 deletions include/TPP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ def DefaultTppPasses : Pass<"default-tpp-passes", "ModuleOp"> {
let options= [
Option<"linalgToLoops", "linalg-to-loops",
"bool", /*default=*/"false",
"Skip all TPP transformations. Lower linalg directly to loops.">
"Skip all TPP transformations. Lower linalg directly to loops.">,
ListOption<"parallelTaskGrid", "parallel-task-grid",
"unsigned", "Grid-sizes for parallel tasks.">

];
}

Expand Down Expand Up @@ -289,6 +292,12 @@ def LocalDialectsLowering : Pass<"lower-local-dialects", "ModuleOp"> {
"tensor::TensorDialect",
"xsmm::XsmmDialect",
"LLVM::LLVMDialect"];
let options = [
ListOption<"parallelTaskGrid", "parallel-task-grid",
"unsigned", "Grid-sizes for parallel tasks.">

];

}

def Postprocessing : Pass<"postprocess", "func::FuncOp"> {
Expand Down Expand Up @@ -470,10 +479,11 @@ def FoldXsmmFlags : Pass<"fold-xsmm-flags", "func::FuncOp"> {
let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ];
}


def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling-pass"> {
let summary = "Tile parallel loops";
let options = [
ListOption<"tileSizes", "parallel-loop-tile-sizes", "int64_t",
ListOption<"tileSizes", "parallel-loop-tile-sizes", "unsigned",
"Factors to tile parallel loops by">,
Option<"noMinMaxBounds", "no-min-max-bounds", "bool",
/*default=*/"false",
Expand All @@ -494,4 +504,24 @@ def GpuInlineConstants : Pass<"gpu-inline-constants", "func::FuncOp"> {
"arith::ArithDialect"];
}

def IntelAMXTileConfigInsertionPass : Pass<"intel-amx-tile-config-insertion-pass",
"func::FuncOp"> {
let summary = "Insert intel amx tile configuration xsmm calls";
let description = [{
Insert intel amx tile configuration xsmm calls.
}];

let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ];
}

def IntelAMXTileConfigHoistingPass : Pass<"intel-amx-tile-config-hoisting-pass",
"func::FuncOp"> {
let summary = "Hoist intel amx tile configuration invoke xsmm calls";
let description = [{
Run LICM on intel amx tile configuration invoke calls.
}];

let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ];
}

#endif // TPP_DIALECT_TPP_PASSES
53 changes: 42 additions & 11 deletions lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,21 @@ struct ConvertFusedBrgemmXsmmOp : public OpRewritePattern<FusedBrgemmOp> {
}
};

struct ConvertIntelAMXTileConfigXsmmOp
: public OpRewritePattern<IntelAMXTileConfigOp> {
using OpRewritePattern<IntelAMXTileConfigOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IntelAMXTileConfigOp tileConfigOp,
PatternRewriter &rewriter) const override {
std::string funcName = "xsmm_intel_amx_tile_config_invoke";
buildInvokeCall(
rewriter, tileConfigOp.getLoc(), funcName, tileConfigOp,
xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16));
rewriter.eraseOp(tileConfigOp);
return success();
}
};

static func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc,
ArrayRef<Value> dispatchOperands,
ArrayRef<Type> dispatchOperandTypes,
Expand All @@ -195,10 +210,9 @@ static func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc,
return call;
}

template <typename OpTy,
typename = std::enable_if_t<
std::is_same<OpTy, xsmm::UnaryDispatchOp>::value ||
std::is_same<OpTy, xsmm::BinaryDispatchOp>::value>>
template <typename OpTy, typename = std::enable_if_t<
std::is_same<OpTy, xsmm::UnaryDispatchOp>::value ||
std::is_same<OpTy, xsmm::BinaryDispatchOp>::value>>
void addKindOperand(RewriterBase &rewriter, OpTy dispatchOp,
SmallVectorImpl<Value> &dispatchOperands,
SmallVectorImpl<Type> &dispatchOperandTypes) {
Expand Down Expand Up @@ -227,6 +241,13 @@ void addKindOperand(RewriterBase &rewriter, FusedBrgemmDispatchOp dispatchOp,
/* do nothing */
}

void addKindOperand(RewriterBase &rewriter,
IntelAMXTileConfigDispatchOp dispatchOp,
SmallVectorImpl<Value> &dispatchOperands,
SmallVectorImpl<Type> &dispatchOperandTypes) {
/* do nothing */
}

static int64_t getOredFlags(ArrayAttr flags) {
int64_t oredFlag = 0;
for (auto flag : flags) {
Expand Down Expand Up @@ -370,6 +391,17 @@ struct ConvertUnaryDispatchOp : public OpRewritePattern<UnaryDispatchOp> {
}
};

struct ConvertIntelAMXTileConfigDispatchOp
: public OpRewritePattern<IntelAMXTileConfigDispatchOp> {
using OpRewritePattern<IntelAMXTileConfigDispatchOp>::OpRewritePattern;

LogicalResult matchAndRewrite(IntelAMXTileConfigDispatchOp dispatchOp,
PatternRewriter &rewriter) const override {
return buildDispatchOp<IntelAMXTileConfigDispatchOp>(
rewriter, dispatchOp, "xsmm_intel_amx_tile_config_dispatch");
}
};

struct ConvertFusedBrgemmOp : public OpRewritePattern<FusedBrgemmDispatchOp> {
using OpRewritePattern<FusedBrgemmDispatchOp>::OpRewritePattern;

Expand All @@ -393,13 +425,12 @@ struct ConvertXsmmToFunc
: public tpp::impl::ConvertXsmmToFuncBase<ConvertXsmmToFunc> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns
.add<ConvertBinaryXsmmOp, ConvertUnaryXsmmOp,
ConvertGemmXsmmOp, ConvertBrgemmXsmmOp, ConvertFusedBrgemmXsmmOp>(
patterns.getContext());
patterns.add<ConvertBinaryDispatchOp,
ConvertUnaryDispatchOp, ConvertGemmDispatchOp,
ConvertBrgemmDispatchOp, ConvertFusedBrgemmOp>(
patterns.add<ConvertBinaryXsmmOp, ConvertUnaryXsmmOp, ConvertGemmXsmmOp,
ConvertBrgemmXsmmOp, ConvertFusedBrgemmXsmmOp,
ConvertIntelAMXTileConfigXsmmOp>(patterns.getContext());
patterns.add<ConvertBinaryDispatchOp, ConvertUnaryDispatchOp,
ConvertGemmDispatchOp, ConvertBrgemmDispatchOp,
ConvertFusedBrgemmOp, ConvertIntelAMXTileConfigDispatchOp>(
patterns.getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
Expand Down
13 changes: 5 additions & 8 deletions lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ llvm::cl::opt<bool>
llvm::cl::init(false));

// Control grid parallelism sizes.
llvm::cl::list<int64_t>
llvm::cl::list<unsigned>
parallelTaskGrid("parallel-task-grid",
llvm::cl::desc("Grid-sizes for parallel tasks"),
llvm::cl::list_init<int64_t>(SmallVector<int64_t>{2, 8}),
llvm::cl::list_init<unsigned>(SmallVector<unsigned>{2, 8}),
llvm::cl::CommaSeparated);

namespace mlir {
Expand Down Expand Up @@ -127,7 +127,8 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
pm.addPass(createGpuPipeline(GpuPipelineOptions{gpuBackend}));
} else {
// Apply the default preprocessing pass
DefaultTppPassesOptions tppDefaultOptions{linalgToLoops};
DefaultTppPassesOptions tppDefaultOptions{linalgToLoops,
parallelTaskGrid};
pm.addPass(createDefaultTppPasses(tppDefaultOptions));
}

Expand All @@ -140,12 +141,8 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
pm.addPass(tpp::createConvertPerfToFunc());
pm.addPass(createConvertTensorToLinalgPass());
pm.addNestedPass<func::FuncOp>(createConvertLinalgToLoopsPass());
if (defParallel) {
mlir::tpp::SCFParallelLoopTilingOptions tilingOptions;
tilingOptions.tileSizes = parallelTaskGrid;
pm.addPass(createSCFParallelLoopTiling(tilingOptions));
if (defParallel)
pm.addPass(createConvertSCFToOpenMPPass());
}
pm.addPass(createConvertVectorToSCFPass());
pm.addPass(arith::createArithExpandOpsPass());
pm.addPass(createLowerAffinePass());
Expand Down
17 changes: 16 additions & 1 deletion lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ struct LocalDialectsLowering
: public tpp::impl::LocalDialectsLoweringBase<LocalDialectsLowering>,
UtilityPassBase<ModuleOp> {

LocalDialectsLowering() {}
LocalDialectsLowering(const LocalDialectsLoweringOptions &options) {
parallelTaskGrid = options.parallelTaskGrid;
}
void runOnOperation() override {
auto module = getOperation();

Expand Down Expand Up @@ -106,6 +110,16 @@ struct LocalDialectsLowering
// that they are hoisted out of loops.
pm.addNestedPass<func::FuncOp>(createCleanup());

mlir::tpp::SCFParallelLoopTilingOptions tilingOptions;
tilingOptions.tileSizes = parallelTaskGrid;
pm.addPass(createSCFParallelLoopTiling(tilingOptions));

pm.addNestedPass<func::FuncOp>(createIntelAMXTileConfigInsertionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createLoopInvariantCodeMotionPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createIntelAMXTileConfigHoistingPass());

pm.addPass(createConvertXsmmToFunc());
pm.addPass(createConvertPerfToFunc());
}
Expand Down Expand Up @@ -310,7 +324,8 @@ struct DefaultTppPasses
pm.addPass(createConvertForAllToParallelOp());

// Covert all local TPP-related dialects.
pm.addPass(createLocalDialectsLowering());
LocalDialectsLoweringOptions localDialectsLowering{parallelTaskGrid};
pm.addPass(createLocalDialectsLowering(localDialectsLowering));

// Clean up after the default pipeline.
pm.addNestedPass<func::FuncOp>(createPostprocessing());
Expand Down
17 changes: 16 additions & 1 deletion lib/TPP/Dialect/Xsmm/XsmmOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ static void printerDataTypeImpl(OpAsmPrinter &printer, OpTy op) {

template <typename AttrTy>
static void printerFlagsImpl(OpAsmPrinter &printer,
const std::function<ArrayAttr()>& fn,
const std::function<ArrayAttr()> &fn,
const std::string_view &flagsName) {
printer << " " << flagsName << " = (";
llvm::interleaveComma(fn(), printer, [&](auto &flag) {
Expand Down Expand Up @@ -235,6 +235,21 @@ void BinaryDispatchOp::print(OpAsmPrinter &printer) {
printerDataTypeImpl<BinaryDispatchOp>(printer, *this);
}

void IntelAMXTileConfigDispatchOp::print(OpAsmPrinter &printer) {
printerInputImpl<IntelAMXTileConfigDispatchOp>(printer, *this);
auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); };
printerFlagsImpl<GemmFlagsAttr>(printer, getOpFlags, FLAGS_NAME);
printerDataTypeImpl<IntelAMXTileConfigDispatchOp>(printer, *this);
}

ParseResult IntelAMXTileConfigDispatchOp::parse(OpAsmParser &parser,
OperationState &result) {
if (failed(parseInputImpl(parser, result)) ||
failed(parserFlagsImpl<GemmFlags>(parser, result, FLAGS_NAME)))
return failure();
return parseDataTypeImpl(parser, result);
}

template <typename FLAGS>
static LogicalResult
verifyUniquenessAndConsistency(ArrayAttr flags, Operation *op,
Expand Down
2 changes: 2 additions & 0 deletions lib/TPP/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ add_mlir_library(TPPTransforms
TransformUtils.cpp
CombineXsmmPass.cpp
SCFParallelLoopTiling.cpp
IntelAMXTileConfig.cpp
IntelAMXTileConfigHoisting.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/TPP
Expand Down
Loading

0 comments on commit f062714

Please sign in to comment.