Skip to content

Commit

Permalink
[CPU] Switching to linalg::LinalgOp for MaterializeContractionOp patt…
Browse files Browse the repository at this point in the history
…ern (#18690)

It moves the logics to MaterializeContractionOp pattern; the other one
(i.e., MaterializeDPSOperation<linalg::GenericOp>) becomes
ShapeIndependent pattern. The former one is a CPU specfic pattern, and
the latter one is general after the shapes are resolved. It is not moved
to ShapeIndependent category because the support of tile swizzling is
not implemented yet. I need to think about the broadcast_map cases.

It should be an NFC, so there are no new tests.

Signed-off-by: hanhanW <hanhan0912@gmail.com>
  • Loading branch information
hanhanW authored Oct 4, 2024
1 parent 3801a5d commit b89ba05
Showing 1 changed file with 25 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -410,15 +410,6 @@ static FailureOr<Operation *> lowerGenericOpWithEncoding(
RewriterBase &rewriter, linalg::GenericOp genericOp,
ValueRange convertedInputOperands, ValueRange convertedOutputOperands,
const MaterializeEncodingTypeConverter &typeConverter) {
if (!genericOp.hasPureTensorSemantics()) {
return failure();
}
if (genericOp.getNumReductionLoops() != 0) {
return rewriter.notifyMatchFailure(genericOp, "Loops are not all parallel");
}
if (genericOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(genericOp, "Not only 1 init operand");
}
OpOperand *outputOperand = genericOp.getDpsInitOperand(0);
AffineMap outputMap = genericOp.getMatchingIndexingMap(outputOperand);
if (!outputMap.isIdentity()) {
Expand Down Expand Up @@ -499,30 +490,29 @@ static FailureOr<Operation *> lowerGenericOpWithEncoding(
/// Utility method to convert from a linalg::LinalgOp on `tensor` types with
/// encodings to a linalg::LinalgOp on the materialized type. The current
/// supported op types are:
/// - linalg::LinalgOp that `isaContractionOpInterface`
/// - linalg::FillOp
/// - linalg::GenericOp with parallel iterators and a single output
/// - linalg::GenericOp
// - All the iterators are parallel iterators.
// - The op has a single output.
static FailureOr<Operation *>
lowerOpWithEncoding(RewriterBase &rewriter, linalg::LinalgOp linalgOp,
ValueRange convertedInputOperands,
ValueRange convertedOutputOperands,
const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn) {
if (linalg::isaContractionOpInterface(linalgOp)) {
SmallVector<Value> operands;
operands.append(convertedInputOperands.begin(),
convertedInputOperands.end());
operands.append(convertedOutputOperands.begin(),
convertedOutputOperands.end());
return lowerContractionOpWithEncoding(rewriter, linalgOp, operands,
typeConverter);
if (!linalgOp.hasPureTensorSemantics()) {
return rewriter.notifyMatchFailure(linalgOp, "Not pure tensor semantics");
}
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) {
return rewriter.notifyMatchFailure(linalgOp, "Loops are not all parallel");
}
if (linalgOp.getNumDpsInits() != 1) {
return rewriter.notifyMatchFailure(linalgOp, "Not only 1 init operand");
}

return TypeSwitch<Operation *, FailureOr<Operation *>>(linalgOp)
.Case<linalg::FillOp>(
[&](linalg::FillOp fillOp) -> FailureOr<Operation *> {
if (!fillOp.hasPureTensorSemantics())
return failure();
Operation *materializedFillOp = rewriter.create<linalg::FillOp>(
fillOp.getLoc(), convertedOutputOperands[0].getType(),
convertedInputOperands, convertedOutputOperands);
Expand Down Expand Up @@ -816,11 +806,6 @@ struct UnsetEncodingOpToUnPackOpConversion
};

/// Generic pattern to convert operation that is in Destination Passing Style.
/// TODO(hanchung): Implement a different pattern for non-elementwise
/// operations. Because they should implement their own patterns based on
/// backends. The elementwise operations are just like shape-like op in
/// data-tiling concept. They still have the same computation but with different
/// shapes.
template <typename OpTy>
struct MaterializeDPSOperation : public OpMaterializeEncodingPattern<OpTy> {
using OpMaterializeEncodingPattern<OpTy>::OpMaterializeEncodingPattern;
Expand Down Expand Up @@ -891,31 +876,30 @@ struct MaterializeOptimizationBarrierOp
};

/// Pattern to convert contraction operations.
class MaterializeContractionOp : public OpInterfaceConversionPattern<
mlir::linalg::ContractionOpInterface> {
class MaterializeContractionOp
: public OpInterfaceConversionPattern<linalg::LinalgOp> {
public:
MaterializeContractionOp(
MLIRContext *context,
const MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn = {},
PatternBenefit benefit = 1)
: OpInterfaceConversionPattern<mlir::linalg::ContractionOpInterface>(
typeConverter, context, benefit),
: OpInterfaceConversionPattern<linalg::LinalgOp>(typeConverter, context,
benefit),
materializeEncodingValueFn(materializeEncodingValueFn) {}

LogicalResult
matchAndRewrite(mlir::linalg::ContractionOpInterface op,
ArrayRef<Value> operands,
matchAndRewrite(linalg::LinalgOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (!linalg::isaContractionOpInterface(op)) {
return rewriter.notifyMatchFailure(
op, "does not implement ContractionOpInterface");
}

auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
this->getTypeConverter());
auto linalgOp = dyn_cast<linalg::LinalgOp>(op.getOperation());
if (!linalgOp || operands.size() != 3) {
return failure();
}
FailureOr<Operation *> convertedOp = lowerOpWithEncoding(
rewriter, linalgOp, operands.take_front(2), operands.take_back(1),
*converter, this->materializeEncodingValueFn);
FailureOr<Operation *> convertedOp =
lowerContractionOpWithEncoding(rewriter, op, operands, *converter);
if (failed(convertedOp)) {
return failure();
}
Expand All @@ -934,6 +918,8 @@ void populateMaterializeEncodingIntoPackUnPackPatterns(
MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {
MLIRContext *context = patterns.getContext();
// TODO(hanchung): Move the generic op pattern to ShapeIndependent category
// after we add the support for tile swizzling variants.
patterns.insert<MaterializeDPSOperation<linalg::GenericOp>,
MaterializeContractionOp, SetEncodingOpToPackOpConversion,
UnsetEncodingOpToUnPackOpConversion>(
Expand Down

0 comments on commit b89ba05

Please sign in to comment.