diff --git a/lib/Transforms/FlattenMemRefs.cpp b/lib/Transforms/FlattenMemRefs.cpp index 20cc5e1291ca..672073dc3098 100644 --- a/lib/Transforms/FlattenMemRefs.cpp +++ b/lib/Transforms/FlattenMemRefs.cpp @@ -24,6 +24,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" namespace circt { @@ -46,6 +47,21 @@ struct FunctionRewrite { FunctionType type; }; +static std::atomic globalCounter(0); +static DenseMap globalNameMap; + +static MemRefType getFlattenedMemRefType(MemRefType type) { + return MemRefType::get(SmallVector{type.getNumElements()}, + type.getElementType()); +} + +static std::string getFlattenedMemRefName(StringAttr baseName, + MemRefType type) { + unsigned uniqueID = globalCounter++; + return llvm::formatv("{0}_{1}x{2}_{3}", baseName, type.getNumElements(), + type.getElementType(), uniqueID); +} + // Flatten indices by generating the product of the i'th index and the [0:i-1] // shapes, for each index, and then summing these. static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op, @@ -154,13 +170,74 @@ struct AllocOpConversion : public OpConversionPattern { MemRefType type = op.getType(); if (isUniDimensional(type) || !type.hasStaticShape()) return failure(); - MemRefType newType = MemRefType::get( - SmallVector{type.getNumElements()}, type.getElementType()); + MemRefType newType = getFlattenedMemRefType(type); rewriter.replaceOpWithNewOp(op, newType); return success(); } }; +struct GlobalOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemRefType type = op.getType(); + if (isUniDimensional(type) || !type.hasStaticShape()) + return failure(); + MemRefType newType = getFlattenedMemRefType(type); + + auto cstAttr = + llvm::dyn_cast_or_null(op.getConstantInitValue()); + + SmallVector flattenedVals; + for (auto attr : cstAttr.getValues()) + flattenedVals.push_back(attr); + + auto newTypeAttr = TypeAttr::get(newType); + auto newNameStr = getFlattenedMemRefName(op.getConstantAttrName(), type); + auto newName = rewriter.getStringAttr(newNameStr); + globalNameMap[op.getSymNameAttr()] = newName; + + RankedTensorType tensorType = RankedTensorType::get( + {static_cast(flattenedVals.size())}, type.getElementType()); + auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals); + + rewriter.replaceOpWithNewOp( + op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue, + op.getConstantAttr(), op.getAlignmentAttr()); + + return success(); + } +}; + +struct GetGlobalOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *symbolTableOp = op->getParentWithTrait(); + auto globalOp = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr())); + + MemRefType type = globalOp.getType(); + if (isUniDimensional(type) || !type.hasStaticShape()) + return failure(); + + MemRefType newType = getFlattenedMemRefType(type); + auto originalName = globalOp.getSymNameAttr(); + auto newNameIt = globalNameMap.find(originalName); + if (newNameIt == globalNameMap.end()) + return failure(); + auto newName = newNameIt->second; + + rewriter.replaceOpWithNewOp(op, newType, newName); + + return success(); + } +}; + // A generic pattern which will replace an op with a new op of the same type // but using the adaptor (type converted) operands. template @@ -256,7 +333,10 @@ static void populateFlattenMemRefsLegality(ConversionTarget &target) { [](memref::StoreOp op) { return op.getIndices().size() == 1; }); target.addDynamicallyLegalOp( [](memref::LoadOp op) { return op.getIndices().size() == 1; }); - + target.addDynamicallyLegalOp( + [](memref::GlobalOp op) { return isUniDimensional(op.getType()); }); + target.addDynamicallyLegalOp( + [](memref::GetGlobalOp op) { return isUniDimensional(op.getType()); }); addGenericLegalityConstraint(target); @@ -323,6 +403,7 @@ struct FlattenMemRefPass RewritePatternSet patterns(ctx); SetVector rewrittenCallees; patterns.add, OperandConversionPattern, CondBranchOpConversion, diff --git a/test/Transforms/flatten_memref.mlir b/test/Transforms/flatten_memref.mlir index 573b8bf72f38..ccc2cd566706 100644 --- a/test/Transforms/flatten_memref.mlir +++ b/test/Transforms/flatten_memref.mlir @@ -185,3 +185,89 @@ func.func @dealloc_copy(%arg : memref<4x4xi32>) -> memref<4x4xi32> { memref.dealloc %0 : memref<4x4xi32> return %0 : memref<4x4xi32> } + +// ----- + +module { + // CHECK-LABEL: memref.global "private" constant @constant_10xf32_0 : memref<10xf32> = dense<[0.433561265, 0.0884729773, -0.39487046, -0.190938368, 0.705071926, -0.648731529, -0.00710275536, -0.278010637, -0.573243499, 5.029220e-01]> {alignment = 64 : i64} + memref.global "private" constant @__constant_5x2xf32 : memref<5x2xf32> = dense<[[0.433561265, 0.0884729773], [-0.39487046, -0.190938368], [0.705071926, -0.648731529], [-0.00710275536, -0.278010637], [-0.573243499, 5.029220e-01]]> {alignment = 64 : i64} + + // CHECK-LABEL: func.func @forward() -> f32 { + // CHECK: %[[VAL_0:.*]] = arith.constant 2 : index + // CHECK: %[[VAL_1:.*]] = arith.constant 1 : index + // CHECK: %[[VAL_2:.*]] = memref.get_global @constant_10xf32_0 : memref<10xf32> + // CHECK: %[[VAL_3:.*]] = arith.constant 1 : index + // CHECK: %[[VAL_4:.*]] = arith.shli %[[VAL_0]], %[[VAL_3]] : index + // CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_1]] : index + // CHECK: %[[VAL_6:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref<10xf32> + // CHECK: return %[[VAL_6]] : f32 + // CHECK: } + // CHECK: } + func.func @forward() -> f32 { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %0 = memref.get_global @__constant_5x2xf32 : memref<5x2xf32> + %1 = memref.load %0[%c2, %c1] : memref<5x2xf32> + return %1 :f32 + } +} + +// GlobalOp/GetGlobalOp may result in name conflict after flattening + +module { + // CHECK-LABEL: module { + // CHECK: memref.global "private" constant @__constant_1xf32 : memref<1xf32> = dense<-0.344258487> {alignment = 64 : i64} + // CHECK: memref.global "private" constant @constant_2xf32_1 : memref<2xf32> = dense<[-0.154929623, 0.142687559]> {alignment = 64 : i64} + // CHECK: memref.global "private" constant @__constant_2xf32 : memref<2xf32> = dense<[-0.23427248, 0.918611288]> {alignment = 64 : i64} + // CHECK: memref.global "private" constant @constant_2xf32_2 : memref<2xf32> = dense<[0.764538527, 0.83000791]> {alignment = 64 : i64} + memref.global "private" constant @__constant_1xf32 : memref<1xf32> = dense<-0.344258487> {alignment = 64 : i64} + memref.global "private" constant @__constant_1x2xf32 : memref<1x2xf32> = dense<[[-0.154929623, 0.142687559]]> {alignment = 64 : i64} + memref.global "private" constant @__constant_2xf32 : memref<2xf32> = dense<[-0.23427248, 0.918611288]> {alignment = 64 : i64} + memref.global "private" constant @__constant_2x1xf32 : memref<2x1xf32> = dense<[[0.764538527], [0.83000791]]> {alignment = 64 : i64} + + // CHECK: func.func @main(%[[VAL_0:.*]]: memref<2xf32>, %[[VAL_1:.*]]: memref<1xf32>) { + // CHECK: %[[VAL_2:.*]] = arith.constant 2 : index + // CHECK: %[[VAL_3:.*]] = arith.constant 1 : index + // CHECK: %[[VAL_4:.*]] = arith.constant 0 : index + // CHECK: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[VAL_6:.*]] = memref.get_global @constant_2xf32_2 : memref<2xf32> + // CHECK: %[[VAL_7:.*]] = memref.get_global @__constant_2xf32 : memref<2xf32> + // CHECK: %[[VAL_8:.*]] = memref.get_global @constant_2xf32_1 : memref<2xf32> + // CHECK: %[[VAL_9:.*]] = memref.get_global @__constant_1xf32 : memref<1xf32> + // CHECK: %[[VAL_10:.*]] = arith.constant 0 : index + // CHECK: %[[VAL_11:.*]] = arith.shli %[[VAL_3]], %[[VAL_10]] : index + // CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_4]] : index + // CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<2xf32> + // CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<2xf32> + // CHECK: %[[VAL_15:.*]] = arith.constant 1 : index + // CHECK: %[[VAL_16:.*]] = arith.shli %[[VAL_4]], %[[VAL_15]] : index + // CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_3]] : index + // CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<2xf32> + // CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_13]], %[[VAL_14]] : f32 + // CHECK: %[[VAL_20:.*]] = arith.addf %[[VAL_18]], %[[VAL_19]] : f32 + // CHECK: memref.store %[[VAL_20]], %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<1xf32> + // CHECK: memref.copy %[[VAL_9]], %[[VAL_1]] : memref<1xf32> to memref<1xf32> + // CHECK: return + // CHECK: } + // CHECK: } + + func.func @main(%arg0: memref<2x1xf32>, %arg1: memref<1xf32>) { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = memref.get_global @__constant_2x1xf32 : memref<2x1xf32> + %1 = memref.get_global @__constant_2xf32 : memref<2xf32> + %2 = memref.get_global @__constant_1x2xf32 : memref<1x2xf32> + %3 = memref.get_global @__constant_1xf32 : memref<1xf32> + %4 = memref.load %0[%c1, %c0] : memref<2x1xf32> + %5 = memref.load %1[%c0] : memref<2xf32> + %6 = memref.load %2[%c0, %c1] : memref<1x2xf32> + %7 = arith.mulf %4, %5 : f32 + %8 = arith.addf %6, %7 : f32 + memref.store %8, %3[%c0] : memref<1xf32> + memref.copy %3, %arg1 : memref<1xf32> to memref<1xf32> + return + } +} +