Skip to content

Commit

Permalink
[flang][cuda] Call CUFGetDeviceAddress to get global device address f…
Browse files Browse the repository at this point in the history
…rom host address (#112989)
  • Loading branch information
Renaud-K authored Oct 19, 2024
1 parent f7b6dc8 commit 864902e
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 15 deletions.
2 changes: 2 additions & 0 deletions flang/include/flang/Optimizer/Transforms/CufOpConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ class LLVMTypeConverter;

namespace mlir {
class DataLayout;
class SymbolTable;
}

namespace cuf {

void populateCUFToFIRConversionPatterns(const fir::LLVMTypeConverter &converter,
mlir::DataLayout &dl,
const mlir::SymbolTable &symtab,
mlir::RewritePatternSet &patterns);

} // namespace cuf
Expand Down
96 changes: 81 additions & 15 deletions flang/lib/Optimizer/Transforms/CufOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,69 @@ static bool hasDoubleDescriptors(OpTy op) {
return false;
}

static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Type toTy,
mlir::Value val) {
if (val.getType() != toTy)
return rewriter.create<fir::ConvertOp>(loc, toTy, val);
return val;
}

mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter,
mlir::OpOperand &operand,
const mlir::SymbolTable &symtab) {
mlir::Value v = operand.get();
auto declareOp = v.getDefiningOp<fir::DeclareOp>();
if (!declareOp)
return v;

auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
if (!addrOfOp)
return v;

auto globalOp = symtab.lookup<fir::GlobalOp>(
addrOfOp.getSymbol().getRootReference().getValue());

if (!globalOp)
return v;

bool isDevGlobal{false};
auto attr = globalOp.getDataAttrAttr();
if (attr) {
switch (attr.getValue()) {
case cuf::DataAttribute::Device:
case cuf::DataAttribute::Managed:
case cuf::DataAttribute::Pinned:
isDevGlobal = true;
break;
default:
break;
}
}
if (!isDevGlobal)
return v;
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(operand.getOwner());
auto loc = declareOp.getLoc();
auto mod = declareOp->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);

mlir::func::FuncOp callee =
fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
auto fTy = callee.getFunctionType();
auto toTy = fTy.getInput(0);
mlir::Value inputArg =
createConvertOp(rewriter, loc, toTy, declareOp.getResult());
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, inputArg, sourceFile, sourceLine)};
auto call = rewriter.create<fir::CallOp>(loc, callee, args);

return call->getResult(0);
}

template <typename OpTy>
static mlir::LogicalResult convertOpToCall(OpTy op,
mlir::PatternRewriter &rewriter,
Expand Down Expand Up @@ -363,18 +426,14 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
}
};

static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Type toTy,
mlir::Value val) {
if (val.getType() != toTy)
return rewriter.create<fir::ConvertOp>(loc, toTy, val);
return val;
}

struct CufDataTransferOpConversion
: public mlir::OpRewritePattern<cuf::DataTransferOp> {
using OpRewritePattern::OpRewritePattern;

CufDataTransferOpConversion(mlir::MLIRContext *context,
const mlir::SymbolTable &symtab)
: OpRewritePattern(context), symtab{symtab} {}

mlir::LogicalResult
matchAndRewrite(cuf::DataTransferOp op,
mlir::PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -445,9 +504,11 @@ struct CufDataTransferOpConversion
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));

llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, op.getDst(), op.getSrc(), bytes, modeValue,
sourceFile, sourceLine)};
mlir::Value dst = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
mlir::Value src = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
llvm::SmallVector<mlir::Value> args{
fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
modeValue, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
return mlir::success();
Expand Down Expand Up @@ -552,6 +613,9 @@ struct CufDataTransferOpConversion
}
return mlir::success();
}

private:
const mlir::SymbolTable &symtab;
};

class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
Expand All @@ -565,13 +629,15 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(op);
if (!module)
return signalPassFailure();
mlir::SymbolTable symtab(module);

std::optional<mlir::DataLayout> dl =
fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false);
fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false,
/*forceUnifiedTBAATree=*/false, *dl);
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect>();
cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, patterns);
cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
patterns);
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
mlir::emitError(mlir::UnknownLoc::get(ctx),
Expand All @@ -584,9 +650,9 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {

void cuf::populateCUFToFIRConversionPatterns(
const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl,
mlir::RewritePatternSet &patterns) {
const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) {
patterns.insert<CufAllocOpConversion>(patterns.getContext(), &dl, &converter);
patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,
CufFreeOpConversion, CufDataTransferOpConversion>(
patterns.getContext());
CufFreeOpConversion>(patterns.getContext());
patterns.insert<CufDataTransferOpConversion>(patterns.getContext(), symtab);
}
43 changes: 43 additions & 0 deletions flang/test/Fir/CUDA/cuda-data-transfer.fir
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,47 @@ func.func @_QPsub7() {
// CHECK: %[[SRC:.*]] = fir.convert %[[IHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none

fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda<device>} : !fir.array<5xi32>
func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
%c5 = arith.constant 5 : index
%0 = fir.alloca !fir.array<5xi32> {bindc_name = "m", uniq_name = "_QFEm"}
%1 = fir.shape %c5 : (index) -> !fir.shape<1>
%2 = fir.declare %0(%1) {uniq_name = "_QFEm"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
%3 = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
%4 = fir.declare %3(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmtestsEn"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
cuf.data_transfer %4 to %2 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>
return
}

// CHECK-LABEL: func.func @_QPsub8()
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[DST:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none


func.func @_QPsub9() {
%c5 = arith.constant 5 : index
%0 = fir.alloca !fir.array<5xi32> {bindc_name = "m", uniq_name = "_QFtest9Em"}
%1 = fir.shape %c5 : (index) -> !fir.shape<1>
%2 = fir.declare %0(%1) {uniq_name = "_QFtest9Em"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
%3 = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
%4 = fir.declare %3(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmtestsEn"} : (!fir.ref<!fir.array<5xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<5xi32>>
cuf.data_transfer %2 to %4 {transfer_kind = #cuf.cuda_transfer<host_device>} : !fir.ref<!fir.array<5xi32>>, !fir.ref<!fir.array<5xi32>>
return
}

// CHECK-LABEL: func.func @_QPsub9()
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none
} // end of module

0 comments on commit 864902e

Please sign in to comment.