From b89ba05f1b260b42b62a79f402a3ef323c592a03 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Fri, 4 Oct 2024 09:11:00 -0700 Subject: [PATCH] [CPU] Switching to linalg::LinalgOp for MaterializeContractionOp pattern (#18690) It moves the logics to MaterializeContractionOp pattern; the other one (i.e., MaterializeDPSOperation) becomes ShapeIndependent pattern. The former one is a CPU specfic pattern, and the latter one is general after the shapes are resolved. It is not moved to ShapeIndependent category because the support of tile swizzling is not implemented yet. I need to think about the broadcast_map cases. It should be an NFC, so there are no new tests. Signed-off-by: hanhanW --- .../MaterializeEncodingIntoPackUnPack.cpp | 64 ++++++++----------- 1 file changed, 25 insertions(+), 39 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp index 351bedd0c76c..b2d21ccc4a96 100644 --- a/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp @@ -410,15 +410,6 @@ static FailureOr lowerGenericOpWithEncoding( RewriterBase &rewriter, linalg::GenericOp genericOp, ValueRange convertedInputOperands, ValueRange convertedOutputOperands, const MaterializeEncodingTypeConverter &typeConverter) { - if (!genericOp.hasPureTensorSemantics()) { - return failure(); - } - if (genericOp.getNumReductionLoops() != 0) { - return rewriter.notifyMatchFailure(genericOp, "Loops are not all parallel"); - } - if (genericOp.getNumDpsInits() != 1) { - return rewriter.notifyMatchFailure(genericOp, "Not only 1 init operand"); - } OpOperand *outputOperand = genericOp.getDpsInitOperand(0); AffineMap outputMap = genericOp.getMatchingIndexingMap(outputOperand); if (!outputMap.isIdentity()) { @@ -499,30 +490,29 @@ static FailureOr lowerGenericOpWithEncoding( /// Utility method to convert from a linalg::LinalgOp on `tensor` types with /// encodings to a linalg::LinalgOp on the materialized type. The current /// supported op types are: -/// - linalg::LinalgOp that `isaContractionOpInterface` /// - linalg::FillOp -/// - linalg::GenericOp with parallel iterators and a single output +/// - linalg::GenericOp +// - All the iterators are parallel iterators. +// - The op has a single output. static FailureOr lowerOpWithEncoding(RewriterBase &rewriter, linalg::LinalgOp linalgOp, ValueRange convertedInputOperands, ValueRange convertedOutputOperands, const MaterializeEncodingTypeConverter &typeConverter, MaterializeEncodingValueFn) { - if (linalg::isaContractionOpInterface(linalgOp)) { - SmallVector operands; - operands.append(convertedInputOperands.begin(), - convertedInputOperands.end()); - operands.append(convertedOutputOperands.begin(), - convertedOutputOperands.end()); - return lowerContractionOpWithEncoding(rewriter, linalgOp, operands, - typeConverter); + if (!linalgOp.hasPureTensorSemantics()) { + return rewriter.notifyMatchFailure(linalgOp, "Not pure tensor semantics"); + } + if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) { + return rewriter.notifyMatchFailure(linalgOp, "Loops are not all parallel"); + } + if (linalgOp.getNumDpsInits() != 1) { + return rewriter.notifyMatchFailure(linalgOp, "Not only 1 init operand"); } return TypeSwitch>(linalgOp) .Case( [&](linalg::FillOp fillOp) -> FailureOr { - if (!fillOp.hasPureTensorSemantics()) - return failure(); Operation *materializedFillOp = rewriter.create( fillOp.getLoc(), convertedOutputOperands[0].getType(), convertedInputOperands, convertedOutputOperands); @@ -816,11 +806,6 @@ struct UnsetEncodingOpToUnPackOpConversion }; /// Generic pattern to convert operation that is in Destination Passing Style. -/// TODO(hanchung): Implement a different pattern for non-elementwise -/// operations. Because they should implement their own patterns based on -/// backends. The elementwise operations are just like shape-like op in -/// data-tiling concept. They still have the same computation but with different -/// shapes. template struct MaterializeDPSOperation : public OpMaterializeEncodingPattern { using OpMaterializeEncodingPattern::OpMaterializeEncodingPattern; @@ -891,31 +876,30 @@ struct MaterializeOptimizationBarrierOp }; /// Pattern to convert contraction operations. -class MaterializeContractionOp : public OpInterfaceConversionPattern< - mlir::linalg::ContractionOpInterface> { +class MaterializeContractionOp + : public OpInterfaceConversionPattern { public: MaterializeContractionOp( MLIRContext *context, const MaterializeEncodingTypeConverter &typeConverter, MaterializeEncodingValueFn materializeEncodingValueFn = {}, PatternBenefit benefit = 1) - : OpInterfaceConversionPattern( - typeConverter, context, benefit), + : OpInterfaceConversionPattern(typeConverter, context, + benefit), materializeEncodingValueFn(materializeEncodingValueFn) {} LogicalResult - matchAndRewrite(mlir::linalg::ContractionOpInterface op, - ArrayRef operands, + matchAndRewrite(linalg::LinalgOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { + if (!linalg::isaContractionOpInterface(op)) { + return rewriter.notifyMatchFailure( + op, "does not implement ContractionOpInterface"); + } + auto converter = static_cast( this->getTypeConverter()); - auto linalgOp = dyn_cast(op.getOperation()); - if (!linalgOp || operands.size() != 3) { - return failure(); - } - FailureOr convertedOp = lowerOpWithEncoding( - rewriter, linalgOp, operands.take_front(2), operands.take_back(1), - *converter, this->materializeEncodingValueFn); + FailureOr convertedOp = + lowerContractionOpWithEncoding(rewriter, op, operands, *converter); if (failed(convertedOp)) { return failure(); } @@ -934,6 +918,8 @@ void populateMaterializeEncodingIntoPackUnPackPatterns( MaterializeEncodingTypeConverter &typeConverter, MaterializeEncodingValueFn materializeEncodingValueFn) { MLIRContext *context = patterns.getContext(); + // TODO(hanchung): Move the generic op pattern to ShapeIndependent category + // after we add the support for tile swizzling variants. patterns.insert, MaterializeContractionOp, SetEncodingOpToPackOpConversion, UnsetEncodingOpToUnPackOpConversion>(