Skip to content

Commit

Permalink
Flatten memref Global and its corresponding GetGlobal operations (#7758)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiahanxie353 authored Oct 31, 2024
1 parent e1e10ae commit d22a695
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 3 deletions.
87 changes: 84 additions & 3 deletions lib/Transforms/FlattenMemRefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"

namespace circt {
Expand All @@ -46,6 +47,21 @@ struct FunctionRewrite {
FunctionType type;
};

static std::atomic<unsigned> globalCounter(0);
static DenseMap<StringAttr, StringAttr> globalNameMap;

static MemRefType getFlattenedMemRefType(MemRefType type) {
return MemRefType::get(SmallVector<int64_t>{type.getNumElements()},
type.getElementType());
}

static std::string getFlattenedMemRefName(StringAttr baseName,
MemRefType type) {
unsigned uniqueID = globalCounter++;
return llvm::formatv("{0}_{1}x{2}_{3}", baseName, type.getNumElements(),
type.getElementType(), uniqueID);
}

// Flatten indices by generating the product of the i'th index and the [0:i-1]
// shapes, for each index, and then summing these.
static Value flattenIndices(ConversionPatternRewriter &rewriter, Operation *op,
Expand Down Expand Up @@ -154,13 +170,74 @@ struct AllocOpConversion : public OpConversionPattern<memref::AllocOp> {
MemRefType type = op.getType();
if (isUniDimensional(type) || !type.hasStaticShape())
return failure();
MemRefType newType = MemRefType::get(
SmallVector<int64_t>{type.getNumElements()}, type.getElementType());
MemRefType newType = getFlattenedMemRefType(type);
rewriter.replaceOpWithNewOp<memref::AllocOp>(op, newType);
return success();
}
};

struct GlobalOpConversion : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::GlobalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType type = op.getType();
if (isUniDimensional(type) || !type.hasStaticShape())
return failure();
MemRefType newType = getFlattenedMemRefType(type);

auto cstAttr =
llvm::dyn_cast_or_null<DenseElementsAttr>(op.getConstantInitValue());

SmallVector<Attribute> flattenedVals;
for (auto attr : cstAttr.getValues<Attribute>())
flattenedVals.push_back(attr);

auto newTypeAttr = TypeAttr::get(newType);
auto newNameStr = getFlattenedMemRefName(op.getConstantAttrName(), type);
auto newName = rewriter.getStringAttr(newNameStr);
globalNameMap[op.getSymNameAttr()] = newName;

RankedTensorType tensorType = RankedTensorType::get(
{static_cast<int64_t>(flattenedVals.size())}, type.getElementType());
auto newInitValue = DenseElementsAttr::get(tensorType, flattenedVals);

rewriter.replaceOpWithNewOp<memref::GlobalOp>(
op, newName, op.getSymVisibilityAttr(), newTypeAttr, newInitValue,
op.getConstantAttr(), op.getAlignmentAttr());

return success();
}
};

struct GetGlobalOpConversion : public OpConversionPattern<memref::GetGlobalOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::GetGlobalOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *symbolTableOp = op->getParentWithTrait<mlir::OpTrait::SymbolTable>();
auto globalOp = dyn_cast_or_null<memref::GlobalOp>(
SymbolTable::lookupSymbolIn(symbolTableOp, op.getNameAttr()));

MemRefType type = globalOp.getType();
if (isUniDimensional(type) || !type.hasStaticShape())
return failure();

MemRefType newType = getFlattenedMemRefType(type);
auto originalName = globalOp.getSymNameAttr();
auto newNameIt = globalNameMap.find(originalName);
if (newNameIt == globalNameMap.end())
return failure();
auto newName = newNameIt->second;

rewriter.replaceOpWithNewOp<memref::GetGlobalOp>(op, newType, newName);

return success();
}
};

// A generic pattern which will replace an op with a new op of the same type
// but using the adaptor (type converted) operands.
template <typename TOp>
Expand Down Expand Up @@ -256,7 +333,10 @@ static void populateFlattenMemRefsLegality(ConversionTarget &target) {
[](memref::StoreOp op) { return op.getIndices().size() == 1; });
target.addDynamicallyLegalOp<memref::LoadOp>(
[](memref::LoadOp op) { return op.getIndices().size() == 1; });

target.addDynamicallyLegalOp<memref::GlobalOp>(
[](memref::GlobalOp op) { return isUniDimensional(op.getType()); });
target.addDynamicallyLegalOp<memref::GetGlobalOp>(
[](memref::GetGlobalOp op) { return isUniDimensional(op.getType()); });
addGenericLegalityConstraint<mlir::cf::CondBranchOp, mlir::cf::BranchOp,
func::CallOp, func::ReturnOp, memref::DeallocOp,
memref::CopyOp>(target);
Expand Down Expand Up @@ -323,6 +403,7 @@ struct FlattenMemRefPass
RewritePatternSet patterns(ctx);
SetVector<StringRef> rewrittenCallees;
patterns.add<LoadOpConversion, StoreOpConversion, AllocOpConversion,
GlobalOpConversion, GetGlobalOpConversion,
OperandConversionPattern<func::ReturnOp>,
OperandConversionPattern<memref::DeallocOp>,
CondBranchOpConversion,
Expand Down
86 changes: 86 additions & 0 deletions test/Transforms/flatten_memref.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,89 @@ func.func @dealloc_copy(%arg : memref<4x4xi32>) -> memref<4x4xi32> {
memref.dealloc %0 : memref<4x4xi32>
return %0 : memref<4x4xi32>
}

// -----

module {
// CHECK-LABEL: memref.global "private" constant @constant_10xf32_0 : memref<10xf32> = dense<[0.433561265, 0.0884729773, -0.39487046, -0.190938368, 0.705071926, -0.648731529, -0.00710275536, -0.278010637, -0.573243499, 5.029220e-01]> {alignment = 64 : i64}
memref.global "private" constant @__constant_5x2xf32 : memref<5x2xf32> = dense<[[0.433561265, 0.0884729773], [-0.39487046, -0.190938368], [0.705071926, -0.648731529], [-0.00710275536, -0.278010637], [-0.573243499, 5.029220e-01]]> {alignment = 64 : i64}

// CHECK-LABEL: func.func @forward() -> f32 {
// CHECK: %[[VAL_0:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = memref.get_global @constant_10xf32_0 : memref<10xf32>
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_4:.*]] = arith.shli %[[VAL_0]], %[[VAL_3]] : index
// CHECK: %[[VAL_5:.*]] = arith.addi %[[VAL_4]], %[[VAL_1]] : index
// CHECK: %[[VAL_6:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_5]]] : memref<10xf32>
// CHECK: return %[[VAL_6]] : f32
// CHECK: }
// CHECK: }
func.func @forward() -> f32 {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%0 = memref.get_global @__constant_5x2xf32 : memref<5x2xf32>
%1 = memref.load %0[%c2, %c1] : memref<5x2xf32>
return %1 :f32
}
}

// GlobalOp/GetGlobalOp may result in name conflict after flattening

module {
// CHECK-LABEL: module {
// CHECK: memref.global "private" constant @__constant_1xf32 : memref<1xf32> = dense<-0.344258487> {alignment = 64 : i64}
// CHECK: memref.global "private" constant @constant_2xf32_1 : memref<2xf32> = dense<[-0.154929623, 0.142687559]> {alignment = 64 : i64}
// CHECK: memref.global "private" constant @__constant_2xf32 : memref<2xf32> = dense<[-0.23427248, 0.918611288]> {alignment = 64 : i64}
// CHECK: memref.global "private" constant @constant_2xf32_2 : memref<2xf32> = dense<[0.764538527, 0.83000791]> {alignment = 64 : i64}
memref.global "private" constant @__constant_1xf32 : memref<1xf32> = dense<-0.344258487> {alignment = 64 : i64}
memref.global "private" constant @__constant_1x2xf32 : memref<1x2xf32> = dense<[[-0.154929623, 0.142687559]]> {alignment = 64 : i64}
memref.global "private" constant @__constant_2xf32 : memref<2xf32> = dense<[-0.23427248, 0.918611288]> {alignment = 64 : i64}
memref.global "private" constant @__constant_2x1xf32 : memref<2x1xf32> = dense<[[0.764538527], [0.83000791]]> {alignment = 64 : i64}

// CHECK: func.func @main(%[[VAL_0:.*]]: memref<2xf32>, %[[VAL_1:.*]]: memref<1xf32>) {
// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_6:.*]] = memref.get_global @constant_2xf32_2 : memref<2xf32>
// CHECK: %[[VAL_7:.*]] = memref.get_global @__constant_2xf32 : memref<2xf32>
// CHECK: %[[VAL_8:.*]] = memref.get_global @constant_2xf32_1 : memref<2xf32>
// CHECK: %[[VAL_9:.*]] = memref.get_global @__constant_1xf32 : memref<1xf32>
// CHECK: %[[VAL_10:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_11:.*]] = arith.shli %[[VAL_3]], %[[VAL_10]] : index
// CHECK: %[[VAL_12:.*]] = arith.addi %[[VAL_11]], %[[VAL_4]] : index
// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref<2xf32>
// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref<2xf32>
// CHECK: %[[VAL_15:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_16:.*]] = arith.shli %[[VAL_4]], %[[VAL_15]] : index
// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_16]], %[[VAL_3]] : index
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_17]]] : memref<2xf32>
// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_13]], %[[VAL_14]] : f32
// CHECK: %[[VAL_20:.*]] = arith.addf %[[VAL_18]], %[[VAL_19]] : f32
// CHECK: memref.store %[[VAL_20]], %[[VAL_9]]{{\[}}%[[VAL_4]]] : memref<1xf32>
// CHECK: memref.copy %[[VAL_9]], %[[VAL_1]] : memref<1xf32> to memref<1xf32>
// CHECK: return
// CHECK: }
// CHECK: }

func.func @main(%arg0: memref<2x1xf32>, %arg1: memref<1xf32>) {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = memref.get_global @__constant_2x1xf32 : memref<2x1xf32>
%1 = memref.get_global @__constant_2xf32 : memref<2xf32>
%2 = memref.get_global @__constant_1x2xf32 : memref<1x2xf32>
%3 = memref.get_global @__constant_1xf32 : memref<1xf32>
%4 = memref.load %0[%c1, %c0] : memref<2x1xf32>
%5 = memref.load %1[%c0] : memref<2xf32>
%6 = memref.load %2[%c0, %c1] : memref<1x2xf32>
%7 = arith.mulf %4, %5 : f32
%8 = arith.addf %6, %7 : f32
memref.store %8, %3[%c0] : memref<1xf32>
memref.copy %3, %arg1 : memref<1xf32> to memref<1xf32>
return
}
}

0 comments on commit d22a695

Please sign in to comment.