Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][cuda] Support data transfer from descriptor to a pointer #115023

Merged
merged 1 commit into from
Nov 5, 2024

Conversation

clementval
Copy link
Contributor

Data transfer from a variable with a descriptor to a pointer. We create a descriptor for the pointer so we can use the flang runtime to perform the transfer. The Assign function handles all corner cases. We add a new entry points CUFDataTransferDescDescNoRealloc to avoid reallocation since the variable on the LHS is not an allocatable.

@llvmbot llvmbot added flang:runtime flang Flang issues not falling into any other category flang:fir-hlfir labels Nov 5, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 5, 2024

@llvm/pr-subscribers-flang-runtime

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Data transfer from a variable with a descriptor to a pointer. We create a descriptor for the pointer so we can use the flang runtime to perform the transfer. The Assign function handles all corner cases. We add a new entry points CUFDataTransferDescDescNoRealloc to avoid reallocation since the variable on the LHS is not an allocatable.


Full diff: https://github.com/llvm/llvm-project/pull/115023.diff

4 Files Affected:

  • (modified) flang/include/flang/Runtime/CUDA/memory.h (+4)
  • (modified) flang/lib/Optimizer/Transforms/CUFOpConversion.cpp (+14-37)
  • (modified) flang/runtime/CUDA/memory.cpp (+18)
  • (modified) flang/test/Fir/CUDA/cuda-data-transfer.fir (+10-12)
diff --git a/flang/include/flang/Runtime/CUDA/memory.h b/flang/include/flang/Runtime/CUDA/memory.h
index 4ac2528c1aedbc..713bdf536aaf90 100644
--- a/flang/include/flang/Runtime/CUDA/memory.h
+++ b/flang/include/flang/Runtime/CUDA/memory.h
@@ -44,6 +44,10 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
 void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
     unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
 
+/// Data transfer from a descriptor to a descriptor.
+void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dst, Descriptor *src,
+    unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
+
 /// Data transfer from a descriptor to a global descriptor.
 void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dst, Descriptor *src,
     unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index 89d0af1fcd136f..6187ca03d2c411 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -581,50 +581,27 @@ struct CUFDataTransferOpConversion
       builder.create<fir::CallOp>(loc, func, args);
       rewriter.eraseOp(op);
     } else {
-      // Type used to compute the width.
-      mlir::Type computeType = dstTy;
-      auto seqTy = mlir::dyn_cast<fir::SequenceType>(dstTy);
-      if (mlir::isa<fir::BaseBoxType>(dstTy)) {
-        computeType = srcTy;
-        seqTy = mlir::dyn_cast<fir::SequenceType>(srcTy);
-      }
-      int width = computeWidth(loc, computeType, kindMap);
+      // Transfer from a descriptor.
 
-      mlir::Value nbElement;
-      mlir::Type idxTy = rewriter.getIndexType();
-      if (!op.getShape()) {
-        nbElement = rewriter.create<mlir::arith::ConstantOp>(
-            loc, idxTy,
-            rewriter.getIntegerAttr(idxTy, seqTy.getConstantArraySize()));
-      } else {
-        auto shapeOp =
-            mlir::dyn_cast<fir::ShapeOp>(op.getShape().getDefiningOp());
-        nbElement =
-            createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[0]);
-        for (unsigned i = 1; i < shapeOp.getExtents().size(); ++i) {
-          auto operand =
-              createConvertOp(rewriter, loc, idxTy, shapeOp.getExtents()[i]);
-          nbElement =
-              rewriter.create<mlir::arith::MulIOp>(loc, nbElement, operand);
-        }
-      }
+      mlir::Value addr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
+      mlir::Type boxTy = fir::BoxType::get(dstTy);
+      llvm::SmallVector<mlir::Value> lenParams;
+      mlir::Value box =
+          builder.createBox(loc, boxTy, addr, getShapeFromDecl(op.getDst()),
+                            /*slice=*/nullptr, lenParams,
+                            /*tdesc=*/nullptr);
+      mlir::Value memBox = builder.createTemporary(loc, box.getType());
+      builder.create<fir::StoreOp>(loc, box, memBox);
 
-      mlir::Value widthValue = rewriter.create<mlir::arith::ConstantOp>(
-          loc, idxTy, rewriter.getIntegerAttr(idxTy, width));
-      mlir::Value bytes =
-          rewriter.create<mlir::arith::MulIOp>(loc, nbElement, widthValue);
+      mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
+          CUFDataTransferDescDescNoRealloc)>(loc, builder);
 
-      mlir::func::FuncOp func =
-          fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferPtrDesc)>(
-              loc, builder);
       auto fTy = func.getFunctionType();
       mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
       mlir::Value sourceLine =
-          fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));
-      mlir::Value dst = op.getDst();
-      mlir::Value src = op.getSrc();
+          fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
       llvm::SmallVector<mlir::Value> args{
-          fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes,
+          fir::runtime::createArguments(builder, loc, fTy, memBox, op.getSrc(),
                                         modeValue, sourceFile, sourceLine)};
       builder.create<fir::CallOp>(loc, func, args);
       rewriter.eraseOp(op);
diff --git a/flang/runtime/CUDA/memory.cpp b/flang/runtime/CUDA/memory.cpp
index 2d499f93fbaece..7b40b837e7666e 100644
--- a/flang/runtime/CUDA/memory.cpp
+++ b/flang/runtime/CUDA/memory.cpp
@@ -120,6 +120,24 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
       *dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
 }
 
+void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dstDesc,
+    Descriptor *srcDesc, unsigned mode, const char *sourceFile,
+    int sourceLine) {
+  MemmoveFct memmoveFct;
+  Terminator terminator{sourceFile, sourceLine};
+  if (mode == kHostToDevice) {
+    memmoveFct = &MemmoveHostToDevice;
+  } else if (mode == kDeviceToHost) {
+    memmoveFct = &MemmoveDeviceToHost;
+  } else if (mode == kDeviceToDevice) {
+    memmoveFct = &MemmoveDeviceToDevice;
+  } else {
+    terminator.Crash("host to host copy not supported");
+  }
+  Fortran::runtime::Assign(
+      *dstDesc, *srcDesc, terminator, NoAssignFlags, memmoveFct);
+}
+
 void RTDECL(CUFDataTransferGlobalDescDesc)(Descriptor *dstDesc,
     Descriptor *srcDesc, unsigned mode, const char *sourceFile,
     int sourceLine) {
diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir
index 6a33190168024f..d9588942b21e81 100644
--- a/flang/test/Fir/CUDA/cuda-data-transfer.fir
+++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir
@@ -73,6 +73,7 @@ func.func @_QPsub4() {
   return
 }
 // CHECK-LABEL: func.func @_QPsub4()
+// CHECK: %[[TEMP_BOX1:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
 // CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<10xi32>>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub4Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>)
 // CHECK: %[[AHOST:.*]]:2 = hlfir.declare %{{.*}}(%[[AHOST_SHAPE:.*]]) {uniq_name = "_QFsub4Eahost"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>)
@@ -81,13 +82,11 @@ func.func @_QPsub4() {
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
-// CHECK: %[[NBELEM:.*]] = arith.constant 10 : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#0 : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#0(%[[AHOST_SHAPE]]) : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<10xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX1]] : !fir.ref<!fir.box<!fir.array<10xi32>>>
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX1]] : (!fir.ref<!fir.box<!fir.array<10xi32>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
   %0 = fir.dummy_scope : !fir.dscope
@@ -115,6 +114,7 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
 }
 
 // CHECK-LABEL: func.func @_QPsub5
+// CHECK: %[[TEMP_BOX1:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
 // CHECK: %[[TEMP_BOX:.*]] = fir.alloca !fir.box<!fir.array<?x?xi32>>
 // CHECK: %[[ADEV:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFsub5Eadev"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>)
 // CHECK: %[[SHAPE:.*]] = fir.shape %[[I1:.*]], %[[I2:.*]] : (index, index) -> !fir.shape<2>
@@ -124,13 +124,11 @@ func.func @_QPsub5(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}) {
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[AHOST_BOX]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
-// CHECK: %[[NBELEM:.*]] = arith.muli %[[I1]], %[[I2]] : index
-// CHECK: %[[WIDTH:.*]] = arith.constant 4 : index
-// CHECK: %[[BYTES:.*]] = arith.muli %[[NBELEM]], %[[WIDTH]] : index
-// CHECK: %[[AHOST_PTR:.*]] = fir.convert %[[AHOST]]#1 : (!fir.ref<!fir.array<?x?xi32>>) -> !fir.llvm_ptr<i8>
+// CHECK: %[[EMBOX:.*]] = fir.embox %[[AHOST]]#1(%[[SHAPE]]) : (!fir.ref<!fir.array<?x?xi32>>, !fir.shape<2>) -> !fir.box<!fir.array<?x?xi32>>
+// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX1]] : !fir.ref<!fir.box<!fir.array<?x?xi32>>>
+// CHECK: %[[AHOST_BOX:.*]] = fir.convert %[[TEMP_BOX1]] : (!fir.ref<!fir.box<!fir.array<?x?xi32>>>) -> !fir.ref<!fir.box<none>>
 // CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xi32>>>>) -> !fir.ref<!fir.box<none>>
-// CHECK: %[[BYTES_CONV:.*]] = fir.convert %[[BYTES]] : (index) -> i64
-// CHECK: fir.call @_FortranACUFDataTransferPtrDesc(%[[AHOST_PTR]], %[[ADEV_BOX]], %[[BYTES_CONV]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<!fir.box<none>>, i64, i32, !fir.ref<i8>, i32) -> none
+// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc(%[[AHOST_BOX]], %[[ADEV_BOX]], %c1{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
 
 func.func @_QPsub6() {
   %0 = cuf.alloc i32 {bindc_name = "idev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub6Eidev"} -> !fir.ref<i32>

@clementval clementval merged commit db69d69 into llvm:main Nov 5, 2024
10 of 11 checks passed
@clementval clementval deleted the cuf_data_transfer_ptr_desc branch November 5, 2024 19:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:runtime flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants