-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][cuda] Support data transfer from descriptor to a pointer #115023
Conversation
@llvm/pr-subscribers-flang-runtime @llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) ChangesData transfer from a variable with a descriptor to a pointer. We create a descriptor for the pointer so we can use the flang runtime to perform the transfer. The Assign function handles all corner cases. We add a new entry points Full diff: https://github.com/llvm/llvm-project/pull/115023.diff 4 Files Affected:
diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index 4ac2528c1aedbc..713bdf536aaf90 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -44,6 +44,10 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
+/// Data transfer from a descriptor to a descriptor.
+void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dst, Descriptor *src,
+ unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
+
/// Data transfer from a descriptor to a global descriptor.
void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dst, Descriptor *src,
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 89d0af1fcd136f..6187ca03d2c411 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -581,50 +581,27 @@ struct CUFDataTransferOpConversion
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
} else {
- // Type used to compute the width.
- mlir::Type computeType = dstTy;
- auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
- if (mlir::isa<fir::BaseBoxType>(dstTy)) {
- computeType = srcTy;
- seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
- }
- int width = computeWidth(loc, computeType, kindMap);
+ // Transfer from a descriptor.
- mlir::Value nbElement;
- mlir::Type idxTy = rewriter.getIndexType();
- if (!op.getShape()) {
- nbElement = rewriter.create<mlir::arith::ConstantOp>(
- loc, idxTy,
- rewriter.getIntegerAttr(idxTy, seqTy.getConstantArraySize()));
- } else {
- auto shapeOp =
- mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp());
- nbElement =
- createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[0]);
- for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) {
- auto operand =
- createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[i]);
- nbElement =
- rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
- }
- }
+ mlir::Value addr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
+ mlir::Type boxTy = fir::BoxType::get(dstTy);
+ llvm::SmallVector<mlir::Value> lenParams;
+ mlir::Value box =
+ builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getDst()),
+ /*slice=*/nullptr, lenParams,
+ /*tdesc=*/nullptr);
+ mlir::Value memBox = builder.createTemporary(loc, box.getType());
+ builder.create<fir::StoreOp>(loc, box, memBox);
- mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
- loc, idxTy, rewriter.getIntegerAttr(idxTy, width));
- mlir::Value bytes =
- rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
+ mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
+ CUFDataTransferDescDescNoRealloc)>(loc, builder);
- mlir::func::FuncOp func =
- fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
- loc, builder);
auto fTy = func.getFunctionType();
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
- fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
- mlir::Value dst = op.getDst();
- mlir::Value src = op.getSrc();
+ fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
llvm::SmallVector<mlir::Value> args{
- fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
+ fir::runtime::createArguments(builder, loc, fTy, memBox, op.getSrc(),
modeValue, sourceFile, sourceLine)};
builder.create<fir::CallOp>(loc, func, args);
rewriter.eraseOp(op);
diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index 2d499f93fbaece..7b40b837e7666e 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -120,6 +120,24 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
}
+void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dstDesc,
+ Descriptor *srcDesc, unsigned mode, const char *sourceFile,
+ int sourceLine) {
+ MemmoveFct memmoveFct;
+ Terminator terminator{sourceFile, sourceLine};
+ if (mode == kHostToDevice) {
+ memmoveFct = &MemmoveHostToDevice;
+ } else if (mode == kDeviceToHost) {
+ memmoveFct = &MemmoveDeviceToHost;
+ } else if (mode == kDeviceToDevice) {
+ memmoveFct = &MemmoveDeviceToDevice;
+ } else {
+ terminator.Crash("host to host copy not supported");
+ }
+ Fortran::runtime::Assign(
+ *dstDesc, *srcDesc, terminator, NoAssignFlags, memmoveFct);
+}
+
void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dstDesc,
Descriptor *srcDesc, unsigned mode, const char *sourceFile,
int sourceLine) {
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index 6a33190168024f..d9588942b21e81 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -73,6 +73,7 @@ func.func @_QPsub4() {
return
}
// CHECK-LABEL: func.func @_QPsub4()
+// CHECK: %[[TEMP_BOX1:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
// CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[AHOST_SHAPE:.*]]) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
@@ -81,13 +82,11 @@ func.func @_QPsub4() {
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
-// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#0(%[[AHOST_SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX1]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX1]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
%0 = fir.dummy_scope : !fir.dscope
@@ -115,6 +114,7 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
}
// CHECK-LABEL: func.func @_QPsub5
+// CHECK: %[[TEMP_BOX1:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
// CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
// CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
// CHECK: %[[SHAPE:.*]] = fir.shape %[[I1:.*]], %[[I2:.*]] : (index, index) -> !fir.shape<2>
@@ -124,13 +124,11 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
-// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#1(%[[SHAPE]]) : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX1]] : !fir.ref<!fir.box<!fir.array<?x?xi32>>>
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX1]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
func.func @_QPsub6() {
%0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Eidev"} -> !fir.ref<i32>
|
Data transfer from a variable with a descriptor to a pointer. We create a descriptor for the pointer so we can use the flang runtime to perform the transfer. The Assign function handles all corner cases. We add a new entry points
CUFDataTransferDescDescNoRealloc
to avoid reallocation since the variable on the LHS is not an allocatable.