From 566124222e308bd0321c537c136705e1ebae7ba4 Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Thu, 11 Jan 2024 12:23:00 +0000 Subject: [PATCH] [TOSA] FFT2D operator (#77005) This PR adds lowering for TOSA Fft2d operator down to Linalg. --- .../Conversion/TosaToLinalg/TosaToLinalg.cpp | 130 +++++++++++++++++ .../TosaToLinalg/tosa-to-linalg.mlir | 134 ++++++++++++++++++ 2 files changed, 264 insertions(+) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 678081837b8138..1e94dfd7feb94e 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2344,6 +2344,135 @@ struct RFFT2dConverter final : public OpRewritePattern { } }; +struct FFT2dConverter final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FFT2dOp fft2d, + PatternRewriter &rewriter) const override { + if (!llvm::all_of(fft2d->getOperandTypes(), + RFFT2dConverter::isRankedTensor) || + !llvm::all_of(fft2d->getResultTypes(), + RFFT2dConverter::isRankedTensor)) { + return rewriter.notifyMatchFailure(fft2d, "only supports ranked tensors"); + } + + Location loc = fft2d.getLoc(); + Value input_real = fft2d.getInputReal(); + Value input_imag = fft2d.getInputImag(); + BoolAttr inverse = fft2d.getInverseAttr(); + + auto real_el_ty = cast( + cast(input_real.getType()).getElementType()); + auto imag_el_ty = cast( + cast(input_imag.getType()).getElementType()); + + assert(real_el_ty == imag_el_ty); + + // Compute the output type and set of dynamic sizes + SmallVector dynamicSizes; + + // Get [N, H, W] + ArrayRef dims = + tensor::getMixedSizes(rewriter, loc, input_real); + + SmallVector staticSizes; + dispatchIndexOpFoldResults(dims, dynamicSizes, staticSizes); + + auto outputType = RankedTensorType::get(staticSizes, real_el_ty); + + // Iterator types for the linalg.generic implementation + SmallVector iteratorTypes = { + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::parallel, utils::IteratorType::reduction, + utils::IteratorType::reduction}; + + // Inputs/outputs to the linalg.generic implementation + SmallVector genericOpInputs = {input_real, input_imag}; + SmallVector genericOpOutputs = { + RFFT2dConverter::createZeroTensor(rewriter, loc, outputType, + dynamicSizes), + RFFT2dConverter::createZeroTensor(rewriter, loc, outputType, + dynamicSizes)}; + + // Indexing maps for input and output tensors + auto indexingMaps = AffineMap::inferFromExprList( + ArrayRef{RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4), + RFFT2dConverter::affineDimsExpr(rewriter, 0, 3, 4), + RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2), + RFFT2dConverter::affineDimsExpr(rewriter, 0, 1, 2)}); + + // Width and height dimensions of the original input. + auto dimH = rewriter.createOrFold(loc, input_real, 1); + auto dimW = rewriter.createOrFold(loc, input_real, 2); + + // Constants and dimension sizes + auto twoPiAttr = rewriter.getFloatAttr(real_el_ty, 6.283185307179586); + auto twoPi = rewriter.create(loc, twoPiAttr); + Value constH = + RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimH); + Value constW = + RFFT2dConverter::castIndexToFloat(rewriter, loc, real_el_ty, dimW); + + auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange args) { + Value valReal = args[0]; + Value valImag = args[1]; + Value sumReal = args[2]; + Value sumImag = args[3]; + + // Indices for angle computation + Value oy = + RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 1); + Value ox = + RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 2); + Value iy = + RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 3); + Value ix = + RFFT2dConverter::createLinalgIndex(builder, loc, real_el_ty, 4); + + // float_t angle = sign_val * 2 * pi() * ((iy * oy) / H + (ix * ox) / W); + auto iyXoy = builder.create(loc, iy, oy); + auto ixXox = builder.create(loc, ix, ox); + auto yComponent = builder.create(loc, iyXoy, constH); + auto xComponent = builder.create(loc, ixXox, constW); + auto sumXY = builder.create(loc, yComponent, xComponent); + auto angle = builder.create(loc, twoPi, sumXY); + if (inverse.getValue()) { + angle = builder.create( + loc, angle, + rewriter.create( + loc, rewriter.getFloatAttr(real_el_ty, -1.0))); + } + + // realComponent = val_real * cos(a) + val_imag * sin(a); + // imagComponent = -val_real * sin(a) + val_imag * cos(a); + auto cosAngle = builder.create(loc, angle); + auto sinAngle = builder.create(loc, angle); + + auto rcos = builder.create(loc, valReal, cosAngle); + auto rsin = builder.create(loc, valImag, sinAngle); + auto realComponent = builder.create(loc, rcos, rsin); + + auto icos = builder.create(loc, valImag, cosAngle); + auto isin = builder.create(loc, valReal, sinAngle); + + auto imagComponent = builder.create(loc, icos, isin); + + // outReal = sumReal + realComponent + // outImag = sumImag - imagComponent + auto outReal = builder.create(loc, sumReal, realComponent); + auto outImag = builder.create(loc, sumImag, imagComponent); + + builder.create(loc, ValueRange{outReal, outImag}); + }; + + rewriter.replaceOpWithNewOp( + fft2d, fft2d.getResultTypes(), genericOpInputs, genericOpOutputs, + indexingMaps, iteratorTypes, buildBody); + + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgConversionPatterns( @@ -2407,6 +2536,7 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns( RescaleConverter, ReverseConverter, RFFT2dConverter, + FFT2dConverter, TableConverter, TileConverter>(patterns->getContext()); // clang-format on diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 8a29752ff8d7f8..1f63b7d5ca6c8b 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1739,3 +1739,137 @@ func.func @test_dynamic_rfft2d(%arg0: tensor) -> (tensor, %output_real, %output_imag = "tosa.rfft2d"(%arg0) {} : (tensor) -> (tensor, tensor) return %output_real, %output_imag : tensor, tensor } + +// ----- +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @test_static_fft2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8x8xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8x8xf32>) -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>) { +// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<8x8x8xf32> +// CHECK: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_4:.*]] = linalg.fill ins(%[[VAL_3]] : f32) outs(%[[VAL_2]] : tensor<8x8x8xf32>) -> tensor<8x8x8xf32> +// CHECK: %[[VAL_5:.*]] = tensor.empty() : tensor<8x8x8xf32> +// CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_7:.*]] = linalg.fill ins(%[[VAL_6]] : f32) outs(%[[VAL_5]] : tensor<8x8x8xf32>) -> tensor<8x8x8xf32> +// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_9:.*]] = arith.constant 8 : index +// CHECK: %[[VAL_10:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_11:.*]] = arith.constant 8 : index +// CHECK: %[[VAL_12:.*]] = arith.constant 6.28318548 : f32 +// CHECK: %[[VAL_13:.*]] = arith.index_castui %[[VAL_9]] : index to i32 +// CHECK: %[[VAL_14:.*]] = arith.uitofp %[[VAL_13]] : i32 to f32 +// CHECK: %[[VAL_15:.*]] = arith.index_castui %[[VAL_11]] : index to i32 +// CHECK: %[[VAL_16:.*]] = arith.uitofp %[[VAL_15]] : i32 to f32 +// CHECK: %[[VAL_17:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_1]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor<8x8x8xf32>, tensor<8x8x8xf32>) outs(%[[VAL_4]], %[[VAL_7]] : tensor<8x8x8xf32>, tensor<8x8x8xf32>) { +// CHECK: ^bb0(%[[VAL_18:.*]]: f32, %[[VAL_19:.*]]: f32, %[[VAL_20:.*]]: f32, %[[VAL_21:.*]]: f32): +// CHECK: %[[VAL_22:.*]] = linalg.index 1 : index +// CHECK: %[[VAL_23:.*]] = arith.index_castui %[[VAL_22]] : index to i32 +// CHECK: %[[VAL_24:.*]] = arith.uitofp %[[VAL_23]] : i32 to f32 +// CHECK: %[[VAL_25:.*]] = linalg.index 2 : index +// CHECK: %[[VAL_26:.*]] = arith.index_castui %[[VAL_25]] : index to i32 +// CHECK: %[[VAL_27:.*]] = arith.uitofp %[[VAL_26]] : i32 to f32 +// CHECK: %[[VAL_28:.*]] = linalg.index 3 : index +// CHECK: %[[VAL_29:.*]] = arith.index_castui %[[VAL_28]] : index to i32 +// CHECK: %[[VAL_30:.*]] = arith.uitofp %[[VAL_29]] : i32 to f32 +// CHECK: %[[VAL_31:.*]] = linalg.index 4 : index +// CHECK: %[[VAL_32:.*]] = arith.index_castui %[[VAL_31]] : index to i32 +// CHECK: %[[VAL_33:.*]] = arith.uitofp %[[VAL_32]] : i32 to f32 +// CHECK: %[[VAL_34:.*]] = arith.mulf %[[VAL_30]], %[[VAL_24]] : f32 +// CHECK: %[[VAL_35:.*]] = arith.mulf %[[VAL_33]], %[[VAL_27]] : f32 +// CHECK: %[[VAL_36:.*]] = arith.divf %[[VAL_34]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_37:.*]] = arith.divf %[[VAL_35]], %[[VAL_16]] : f32 +// CHECK: %[[VAL_38:.*]] = arith.addf %[[VAL_36]], %[[VAL_37]] : f32 +// CHECK: %[[VAL_39:.*]] = arith.mulf %[[VAL_12]], %[[VAL_38]] : f32 +// CHECK: %[[VAL_40:.*]] = math.cos %[[VAL_39]] : f32 +// CHECK: %[[VAL_41:.*]] = math.sin %[[VAL_39]] : f32 +// CHECK: %[[VAL_42:.*]] = arith.mulf %[[VAL_18]], %[[VAL_40]] : f32 +// CHECK: %[[VAL_43:.*]] = arith.mulf %[[VAL_19]], %[[VAL_41]] : f32 +// CHECK: %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f32 +// CHECK: %[[VAL_45:.*]] = arith.mulf %[[VAL_19]], %[[VAL_40]] : f32 +// CHECK: %[[VAL_46:.*]] = arith.mulf %[[VAL_18]], %[[VAL_41]] : f32 +// CHECK: %[[VAL_47:.*]] = arith.subf %[[VAL_45]], %[[VAL_46]] : f32 +// CHECK: %[[VAL_48:.*]] = arith.addf %[[VAL_20]], %[[VAL_44]] : f32 +// CHECK: %[[VAL_49:.*]] = arith.addf %[[VAL_21]], %[[VAL_47]] : f32 +// CHECK: linalg.yield %[[VAL_48]], %[[VAL_49]] : f32, f32 +// CHECK: } -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>) +// CHECK: return %[[VAL_50:.*]]#0, %[[VAL_50]]#1 : tensor<8x8x8xf32>, tensor<8x8x8xf32> +// CHECK: } +func.func @test_static_fft2d(%arg0: tensor<8x8x8xf32>, %arg1: tensor<8x8x8xf32>) -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>) { + %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse=false} : (tensor<8x8x8xf32>, tensor<8x8x8xf32>) -> (tensor<8x8x8xf32>, tensor<8x8x8xf32>) + return %output_real, %output_imag : tensor<8x8x8xf32>, tensor<8x8x8xf32> +} + +// ----- +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py + +// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)> +// CHECK: #[[$ATTR_3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)> + +// CHECK-LABEL: func.func @test_dynamic_fft2d( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> (tensor, tensor) { +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_6]] : tensor +// CHECK: %[[VAL_8:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_5]], %[[VAL_7]]) : tensor +// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_9]] : f32) outs(%[[VAL_8]] : tensor) -> tensor +// CHECK: %[[VAL_11:.*]] = tensor.empty(%[[VAL_3]], %[[VAL_5]], %[[VAL_7]]) : tensor +// CHECK: %[[VAL_12:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_13:.*]] = linalg.fill ins(%[[VAL_12]] : f32) outs(%[[VAL_11]] : tensor) -> tensor +// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_15:.*]] = tensor.dim %[[VAL_0]], %[[VAL_14]] : tensor +// CHECK: %[[VAL_16:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_17:.*]] = tensor.dim %[[VAL_0]], %[[VAL_16]] : tensor +// CHECK: %[[VAL_18:.*]] = arith.constant 6.28318548 : f32 +// CHECK: %[[VAL_19:.*]] = arith.index_castui %[[VAL_15]] : index to i32 +// CHECK: %[[VAL_20:.*]] = arith.uitofp %[[VAL_19]] : i32 to f32 +// CHECK: %[[VAL_21:.*]] = arith.index_castui %[[VAL_17]] : index to i32 +// CHECK: %[[VAL_22:.*]] = arith.uitofp %[[VAL_21]] : i32 to f32 +// CHECK: %[[VAL_23:.*]]:2 = linalg.generic {indexing_maps = [#[[$ATTR_2]], #[[$ATTR_2]], #[[$ATTR_3]], #[[$ATTR_3]]], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%[[VAL_0]], %[[VAL_1]] : tensor, tensor) outs(%[[VAL_10]], %[[VAL_13]] : tensor, tensor) { +// CHECK: ^bb0(%[[VAL_24:.*]]: f32, %[[VAL_25:.*]]: f32, %[[VAL_26:.*]]: f32, %[[VAL_27:.*]]: f32): +// CHECK: %[[VAL_28:.*]] = linalg.index 1 : index +// CHECK: %[[VAL_29:.*]] = arith.index_castui %[[VAL_28]] : index to i32 +// CHECK: %[[VAL_30:.*]] = arith.uitofp %[[VAL_29]] : i32 to f32 +// CHECK: %[[VAL_31:.*]] = linalg.index 2 : index +// CHECK: %[[VAL_32:.*]] = arith.index_castui %[[VAL_31]] : index to i32 +// CHECK: %[[VAL_33:.*]] = arith.uitofp %[[VAL_32]] : i32 to f32 +// CHECK: %[[VAL_34:.*]] = linalg.index 3 : index +// CHECK: %[[VAL_35:.*]] = arith.index_castui %[[VAL_34]] : index to i32 +// CHECK: %[[VAL_36:.*]] = arith.uitofp %[[VAL_35]] : i32 to f32 +// CHECK: %[[VAL_37:.*]] = linalg.index 4 : index +// CHECK: %[[VAL_38:.*]] = arith.index_castui %[[VAL_37]] : index to i32 +// CHECK: %[[VAL_39:.*]] = arith.uitofp %[[VAL_38]] : i32 to f32 +// CHECK: %[[VAL_40:.*]] = arith.mulf %[[VAL_36]], %[[VAL_30]] : f32 +// CHECK: %[[VAL_41:.*]] = arith.mulf %[[VAL_39]], %[[VAL_33]] : f32 +// CHECK: %[[VAL_42:.*]] = arith.divf %[[VAL_40]], %[[VAL_20]] : f32 +// CHECK: %[[VAL_43:.*]] = arith.divf %[[VAL_41]], %[[VAL_22]] : f32 +// CHECK: %[[VAL_44:.*]] = arith.addf %[[VAL_42]], %[[VAL_43]] : f32 +// CHECK: %[[VAL_45:.*]] = arith.mulf %[[VAL_18]], %[[VAL_44]] : f32 +// CHECK: %[[VAL_46:.*]] = arith.constant -1.000000e+00 : f32 +// CHECK: %[[VAL_47:.*]] = arith.mulf %[[VAL_45]], %[[VAL_46]] : f32 +// CHECK: %[[VAL_48:.*]] = math.cos %[[VAL_47]] : f32 +// CHECK: %[[VAL_49:.*]] = math.sin %[[VAL_47]] : f32 +// CHECK: %[[VAL_50:.*]] = arith.mulf %[[VAL_24]], %[[VAL_48]] : f32 +// CHECK: %[[VAL_51:.*]] = arith.mulf %[[VAL_25]], %[[VAL_49]] : f32 +// CHECK: %[[VAL_52:.*]] = arith.addf %[[VAL_50]], %[[VAL_51]] : f32 +// CHECK: %[[VAL_53:.*]] = arith.mulf %[[VAL_25]], %[[VAL_48]] : f32 +// CHECK: %[[VAL_54:.*]] = arith.mulf %[[VAL_24]], %[[VAL_49]] : f32 +// CHECK: %[[VAL_55:.*]] = arith.subf %[[VAL_53]], %[[VAL_54]] : f32 +// CHECK: %[[VAL_56:.*]] = arith.addf %[[VAL_26]], %[[VAL_52]] : f32 +// CHECK: %[[VAL_57:.*]] = arith.addf %[[VAL_27]], %[[VAL_55]] : f32 +// CHECK: linalg.yield %[[VAL_56]], %[[VAL_57]] : f32, f32 +// CHECK: } -> (tensor, tensor) +// CHECK: return %[[VAL_58:.*]]#0, %[[VAL_58]]#1 : tensor, tensor +// CHECK: } +func.func @test_dynamic_fft2d(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %output_real, %output_imag = "tosa.fft2d"(%arg0, %arg1) {inverse = true} : (tensor, tensor) -> (tensor, tensor) + return %output_real, %output_imag : tensor, tensor +}