Skip to content

Commit

Permalink
minor code changes and test updates
Browse files Browse the repository at this point in the history
  • Loading branch information
KavithaTipturMadhu committed Oct 14, 2024
1 parent ff762d9 commit 853e9e4
Show file tree
Hide file tree
Showing 14 changed files with 97 additions and 278 deletions.
5 changes: 4 additions & 1 deletion include/TPP/PassBundles.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def DefaultTppPasses : Pass<"default-tpp-passes", "ModuleOp"> {
"unsigned", "Grid-sizes for parallel tasks.">,
Option<"lowerPackUnpackWithoutTranspose", "lower-pack-unpack-without-transpose",
"bool", /*default=*/"false",
"Lower non-constant packs and unpacks reverting any dim permutations.">
"Lower non-constant packs and unpacks reverting any dim permutations.">,
Option<"contractToOuterProduct", "contract-to-outer-product",
"bool",/*default=*/"false",
"Convert Contractions to outer product operations.">
];
}

Expand Down
25 changes: 20 additions & 5 deletions lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,14 @@ buildBrgemm(PatternRewriter &rewriter, Operation *contractOp, Value input0,
int64_t strideA = brgemmInfo.strideA;
int64_t strideB = brgemmInfo.strideB;
auto loc = contractOp->getLoc();
auto functionOp = contractOp->getParentOfType<func::FuncOp>();
auto dtype =
xsmm::utils::getDataType(rewriter, contractOp->getOperand(0).getType());
IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64);
SmallVector<Value, 10> dispatchOperands;
SmallVector<Type, 10> dispatchOperandTypes;
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(&*functionOp.getBody().op_begin());
// Dispatch the data type.
dispatchOperands.push_back(rewriter.create<arith::ConstantOp>(
loc, integer64, cast<TypedAttr>(dtype)));
Expand Down Expand Up @@ -154,7 +157,6 @@ buildBrgemm(PatternRewriter &rewriter, Operation *contractOp, Value input0,
dispatchOperandTypes.push_back(integer64);
}
int64_t oredFlag = xsmm::utils::getOredFlags(brgemmFlags);

dispatchOperands.push_back(rewriter.create<arith::ConstantOp>(
loc, integer64, IntegerAttr::get(rewriter.getI64Type(), oredFlag)));
dispatchOperandTypes.push_back(integer64);
Expand All @@ -173,6 +175,7 @@ buildBrgemm(PatternRewriter &rewriter, Operation *contractOp, Value input0,
loc, integer64, rewriter.getIntegerAttr(integer64, batch));
operandRange.push_back(batchDim);
}
rewriter.setInsertionPoint(contractOp);
auto invokeCall = xsmm::utils::buildInvokeCall(
rewriter, loc, module, operandRange, invokeName, dtype);
return std::make_pair(&*dispatched, &*invokeCall);
Expand Down Expand Up @@ -262,12 +265,16 @@ buildFusedBrgemm(PatternRewriter &rewriter, Operation *contractOp, Value input0,
IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64);
SmallVector<Value, 10> dispatchOperands;
SmallVector<Type, 10> dispatchOperandTypes;
auto functionOp = contractOp->getParentOfType<func::FuncOp>();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(&*functionOp.getBody().op_begin());
// Dispatch the data type.
dispatchOperands.push_back(rewriter.create<arith::ConstantOp>(
loc, integer64, cast<TypedAttr>(dtype)));
dispatchOperandTypes.push_back(integer64);
std::string dispatchName = "xsmm_fused_brgemm_dispatch";
std::string invokeName = "xsmm_fused_brgemm_invoke";

// TODO: Support more than just COL_0 BCAST
auto addf = addfTransferWrite->getOperand(0).getDefiningOp();
auto broadcastInput =
Expand Down Expand Up @@ -311,6 +318,7 @@ buildFusedBrgemm(PatternRewriter &rewriter, Operation *contractOp, Value input0,
dispatchOperands.push_back(rewriter.create<arith::ConstantOp>(
loc, integer64, IntegerAttr::get(rewriter.getI64Type(), oredFlag)));
dispatchOperandTypes.push_back(integer64);

dispatchOperands.push_back(rewriter.create<arith::ConstantOp>(
loc, integer64,
cast<TypedAttr>(xsmm::UnaryFlagsAttr::get(rewriter.getContext(),
Expand All @@ -324,6 +332,7 @@ buildFusedBrgemm(PatternRewriter &rewriter, Operation *contractOp, Value input0,
dispatchOperandTypes.push_back(integer64);

ModuleOp module = contractOp->getParentOfType<ModuleOp>();

auto dispatched = xsmm::utils::buildDispatchCall(
rewriter, loc, dispatchOperands, dispatchOperandTypes, module,
SymbolRefAttr::get(contractOp->getContext(), dispatchName));
Expand All @@ -339,6 +348,7 @@ buildFusedBrgemm(PatternRewriter &rewriter, Operation *contractOp, Value input0,
loc, integer64, rewriter.getIntegerAttr(integer64, batch));

operandRange.push_back(batchDim);
rewriter.setInsertionPoint(contractOp);
auto invokeCall = xsmm::utils::buildInvokeCall(
rewriter, loc, module, operandRange, invokeName, dtype);

Expand Down Expand Up @@ -414,6 +424,9 @@ buildTransposeOp(PatternRewriter &rewriter, Operation *transposeOp,
// Adjust ldo based on the VNNI factor.
unaryInfo.ldo =
stridesOnOutput->front() / *vnni::utils::getVnniBlockingFactor(output);
auto functionOp = transposeOp->getParentOfType<func::FuncOp>();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(&*functionOp.getBody().op_begin());

// If `OpTy` is unary or binary we need to dispatch and extra
// integer for the kind of operation to invoke.
Expand Down Expand Up @@ -466,7 +479,7 @@ buildTransposeOp(PatternRewriter &rewriter, Operation *transposeOp,
rewriter.create<func::FuncOp>(loc, fnName.getValue(), libFnType);
funcOp.setPrivate();
}

rewriter.setInsertionPoint(transposeOp);
auto invokeCall = rewriter.create<func::CallOp>(
loc, fnName.getValue(), TypeRange(),
xsmm::utils::getOperands(rewriter, loc, transposeOp->getOperands(),
Expand All @@ -477,8 +490,10 @@ buildTransposeOp(PatternRewriter &rewriter, Operation *transposeOp,
auto unaryInfo = xsmm::utils::getUnaryInfo(transposeOp->getOperand(0),
transposeOp->getResult(0),
xsmm::UnaryFlags::NONE);
// If `OpTy` is unary or binary we need to dispatch and extra
// integer for the kind of operation to invoke.
auto functionOp = transposeOp->getParentOfType<func::FuncOp>();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(&*functionOp.getBody().op_begin());

dispatchOperands.push_back(rewriter.create<arith::ConstantOp>(
loc, integer64, cast<TypedAttr>(dtype)));
dispatchOperandTypes.push_back(integer64);
Expand Down Expand Up @@ -532,7 +547,7 @@ buildTransposeOp(PatternRewriter &rewriter, Operation *transposeOp,
rewriter.create<func::FuncOp>(loc, fnName.getValue(), libFnType);
funcOp.setPrivate();
}

rewriter.setInsertionPoint(transposeOp);
auto invokeCall = rewriter.create<func::CallOp>(
loc, fnName.getValue(), TypeRange(),
xsmm::utils::getOperands(rewriter, loc, transposeOp->getOperands(),
Expand Down
11 changes: 9 additions & 2 deletions lib/TPP/DefaultPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ llvm::cl::opt<bool> lowerPackUnpackWithoutTranspose(
llvm::cl::desc("Lower packs and unpacks reverting any dim permutations"),
llvm::cl::init(false));

// Control parallelism.
llvm::cl::opt<bool> contractToOuterProduct(
"contract-to-outer-product",
llvm::cl::desc("Convert Contractions to Outer Product operations"),
llvm::cl::init(false));

namespace mlir {
namespace tpp {
#define GEN_PASS_DEF_DEFAULTPIPELINE
Expand Down Expand Up @@ -127,8 +133,9 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase<DefaultPipeline>,
pm.addPass(createGpuPipeline(GpuPipelineOptions{gpuBackend}));
} else {
// Apply the default preprocessing pass
DefaultTppPassesOptions tppDefaultOptions{
linalgToLoops, parallelTaskGrid, lowerPackUnpackWithoutTranspose};
DefaultTppPassesOptions tppDefaultOptions{linalgToLoops, parallelTaskGrid,
lowerPackUnpackWithoutTranspose,
contractToOuterProduct};
pm.addPass(createDefaultTppPasses(tppDefaultOptions));
}

Expand Down
13 changes: 8 additions & 5 deletions lib/TPP/DefaultTppPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,20 @@ struct DefaultTppPasses

pm.addNestedPass<func::FuncOp>(createVectorizationPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(createVectorContractToOuterproduct());

pm.addPass(createCleanup());
if (contractToOuterProduct) {
pm.addNestedPass<func::FuncOp>(createVectorContractToOuterproduct());
pm.addPass(createCleanup());
}
}

// Convert forAll to parallel loops should run after bufferization
// as scf.parallel does not handle tensor.
pm.addPass(createConvertForAllToParallelOp());
LowLevelParallelizationOptions LowLevelParallelization{parallelTaskGrid};

pm.addPass(createConvertVectorToXsmm());
if (!contractToOuterProduct && !linalgToLoops) {
pm.addPass(createConvertVectorToXsmm());
pm.addPass(createLoopInvariantCodeMotionPass());
}
// Low level parallelization passes.
pm.addPass(createLowLevelParallelization(LowLevelParallelization));
// Covert all local TPP-related dialects.
Expand Down
2 changes: 2 additions & 0 deletions lib/TPP/Dialect/Xsmm/XsmmUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,8 @@ func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc,
funcOp.setPrivate();
}

rewriter.setInsertionPoint(dispatchOperands.back().getDefiningOp());

func::CallOp call = rewriter.create<func::CallOp>(
loc, fnName.getValue(), IntegerType::get(rewriter.getContext(), 64),
dispatchOperands);
Expand Down
1 change: 1 addition & 0 deletions lib/TPP/PassBundles/LowLevelParallelization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,6 @@ struct LowLevelParallelization
mlir::tpp::SCFParallelLoopTilingOptions tilingOptions;
tilingOptions.tileSizes = parallelTaskGrid;
pm.addPass(createSCFParallelLoopTiling(tilingOptions));
pm.addPass(createLoopInvariantCodeMotionPass());
}
};
2 changes: 0 additions & 2 deletions test/BF16/Integration/tpp-run-splat-shape.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,3 @@ func.func @entry(%arg0: tensor<4x8x8x8xbf16>, %output: tensor<4x8x8x8xbf16>) ->
// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<4x8x8x8xbf16>
// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<8x8x4x8x2xbf16>
// CHECK: xsmm_brgemm_invoke
// CHECK: xsmm_binary_invoke
// CHECK: xsmm_unary_invoke
4 changes: 0 additions & 4 deletions test/Integration/copy.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
// IR-LABEL: copytppbrcast
func.func @copytppbrcast(%A: tensor<1x6xf32>) -> tensor<9x6xf32> {
%B = tensor.empty() : tensor<9x6xf32>
// IR: xsmm_unary_invoke
%O = linalg.generic { indexing_maps = [#map1, #map0],
iterator_types = ["parallel", "parallel"] }
ins(%A: tensor<1x6xf32>) outs(%B: tensor<9x6xf32>) {
Expand All @@ -30,7 +29,6 @@ func.func @copytppbrcast(%A: tensor<1x6xf32>) -> tensor<9x6xf32> {
// IR-LABEL: copytppbrcastother
func.func @copytppbrcastother(%A: tensor<6x1xf32>) -> tensor<6x9xf32> {
%B = tensor.empty() : tensor<6x9xf32>
// IR: xsmm_unary_invoke
%O = linalg.generic { indexing_maps = [#map2, #map0],
iterator_types = ["parallel", "parallel"] }
ins(%A: tensor<6x1xf32>) outs(%B: tensor<6x9xf32>) {
Expand All @@ -43,7 +41,6 @@ func.func @copytppbrcastother(%A: tensor<6x1xf32>) -> tensor<6x9xf32> {
// IR-LABEL: copyscalar
func.func @copyscalar(%A: f32) -> tensor<6x9xf32> {
%B = tensor.empty() : tensor<6x9xf32>
// IR: linalg.fill
%O = linalg.generic { indexing_maps = [#map3, #map0],
iterator_types = ["parallel", "parallel"] }
ins(%A: f32) outs(%B: tensor<6x9xf32>) {
Expand All @@ -57,7 +54,6 @@ func.func @copyscalar(%A: f32) -> tensor<6x9xf32> {
func.func @copyscalarother(%A: tensor<1x1xf32>) -> tensor<6x9xf32> {
%B = tensor.empty() : tensor<6x9xf32>
// Rank-0 is on input is not matched to xsmm.
// IR: linalg.generic
%O = linalg.generic { indexing_maps = [#map4, #map0],
iterator_types = ["parallel", "parallel"] }
ins(%A: tensor<1x1xf32>) outs(%B: tensor<6x9xf32>) {
Expand Down
4 changes: 3 additions & 1 deletion test/Integration/tpp-relu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
// IR-LABEL: relutpp
func.func @relutpp(%A: tensor<9x6xf32>) -> tensor<9x6xf32> {
%c0 = arith.constant 0.0 : f32
// IR: xsmm_unary_invoke
// IR: vector.transfer_read
// IR: arith.maximumf
// IR: vector.transfer_write
%O = linalg.generic { indexing_maps = [#map0], iterator_types = ["parallel", "parallel"] }
outs(%A: tensor<9x6xf32>) {
^bb0(%a: f32):
Expand Down
14 changes: 11 additions & 3 deletions test/Integration/xsmm-fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,17 @@ func.func @entry(%A: tensor<2x4x8xf32>,
// CHECK-DAG: %[[c0_i64:.*]] = arith.constant 0 : i64
// CHECK-DAG: %[[c5_i64:.*]] = arith.constant 5 : i64
// CHECK-DAG: %[[c2_i64:.*]] = arith.constant 2 : i64
// CHECK: %[[DISPATCH:.*]] = call @xsmm_fused_brgemm_dispatch(%[[c1_i64]], %[[c4_i64]], %[[c4_i64]], %[[c8_i64]], %[[c8_i64]], %[[c4_i64]], %[[c4_i64]], %[[c32_i64]], %[[c32_i64]], %[[c4_i64]], %[[c0_i64]], %[[c5_i64]], %[[c4_i64]], %[[c1_i64]])
// CHECK: call @xsmm_fused_brgemm_invoke(%[[c1_i64]], %[[DISPATCH]], %{{.*}}, %[[c0]], %{{.*}}, %[[c0]], %{{.*}}, %[[c0]], %{{.*}}, %[[c0]], %[[c2_i64]])

// CHECK: %[[DISPATCH:.*]] = call @xsmm_fused_brgemm_dispatch(%[[c1_i64]], %[[c4_i64]], %[[c4_i64]], %[[c8_i64]], %[[c8_i64]], %[[c4_i64]], %[[c4_i64]], %[[c32_i64]], %[[c32_i64]], %[[c4_i64]], %[[c0_i64]], %[[c5_i64]])
// CHECK-DAG: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<2x4x8xf32> -> index
// CHECK-DAG: %[[index:.*]] = arith.index_cast %[[intptr]] : index to i64
// CHECK-DAG: %[[ptr:.*]] = llvm.inttoptr %[[index]] : i64 to !llvm.ptr
// CHECK-DAG: %[[intptr_0:.*]] = memref.extract_aligned_pointer_as_index {{.*}}: memref<2x8x4xf32> -> index
// CHECK-DAG: %[[intptr1:.*]] = arith.index_cast %[[intptr_0]] : index to i64
// CHECK-DAG: %[[inttoptr:.*]] = llvm.inttoptr %[[intptr1]] : i64 to !llvm.ptr
// CHECK-DAG: %[[intptr_1:.*]] = memref.extract_aligned_pointer_as_index {{.*}} : memref<4x4xf32> -> index
// CHECK-DAG: %[[intptr2:.*]] = arith.index_cast %[[intptr_1]] : index to i64
// CHECK-DAG: %[[intptr3:.*]] = llvm.inttoptr %[[intptr2]] : i64 to !llvm.ptr
// CHECK: call @xsmm_fused_brgemm_invoke(%[[c1_i64]], %[[DISPATCH]], %[[ptr]], %[[c0]], %[[inttoptr]], %[[c0]], %[[intptr3]], %[[c0]], %[[intptr3]], %[[c0]], %[[c2_i64]])
// RESULT: ( 3.62953, 3.87851, 3.65424, 3.69154 )
// RESULT: ( 1.34322, 1.59219, 1.36792, 1.40522 )
// RESULT: ( 0.766812, 1.01579, 0.791519, 0.828817 )
Expand Down
4 changes: 2 additions & 2 deletions test/Passes/DefaultPipeline/default-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ func.func @matmul(%A: tensor<4x8xf32>,
return %D : tensor<4x4xf32>
}

// CHECK: llvm.func @xsmm_gemm_invoke
// CHECK: llvm.func @xsmm_gemm_dispatch
// CHECK-DAG: llvm.func @xsmm_gemm_invoke
// CHECK-DAG: llvm.func @xsmm_gemm_dispatch
// CHECK: llvm.func @matmul(%[[ARG0:.+]]: !llvm.ptr,
// CHECK: llvm.insertvalue
// CHECK: llvm.mlir.constant
Expand Down
Loading

0 comments on commit 853e9e4

Please sign in to comment.