diff --git a/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp b/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp index 400c41cbb0d4..acbbcd2c5c8b 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp @@ -169,3 +169,23 @@ void CIRGenCUDARuntime::emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn, else emitDeviceStubBodyLegacy(cgf, fn, args); } + +RValue CIRGenCUDARuntime::emitCUDAKernelCallExpr(CIRGenFunction &cgf, + const CUDAKernelCallExpr *expr, + ReturnValueSlot retValue) { + auto builder = cgm.getBuilder(); + mlir::Location loc = + cgf.currSrcLoc ? cgf.currSrcLoc.value() : builder.getUnknownLoc(); + + cgf.emitIfOnBoolExpr( + expr->getConfig(), + [&](mlir::OpBuilder &b, mlir::Location l) { + CIRGenCallee callee = cgf.emitCallee(expr->getCallee()); + cgf.emitCall(expr->getCallee()->getType(), callee, expr, retValue); + b.create(loc); + }, + loc, [](mlir::OpBuilder &b, mlir::Location l) {}, + std::optional()); + + return RValue::get(nullptr); +} diff --git a/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h b/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h index a3145a0baeb3..634f4891b85d 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h +++ b/clang/lib/CIR/CodeGen/CIRGenCUDARuntime.h @@ -23,6 +23,8 @@ namespace clang::CIRGen { class CIRGenFunction; class CIRGenModule; class FunctionArgList; +class RValue; +class ReturnValueSlot; class CIRGenCUDARuntime { protected: @@ -40,6 +42,10 @@ class CIRGenCUDARuntime { virtual void emitDeviceStub(CIRGenFunction &cgf, cir::FuncOp fn, FunctionArgList &args); + + virtual RValue emitCUDAKernelCallExpr(CIRGenFunction &cgf, + const CUDAKernelCallExpr *expr, + ReturnValueSlot retValue); }; } // namespace clang::CIRGen diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp index 38a880548202..4d4dd663e0dc 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp @@ -530,7 +530,10 @@ static CIRGenCallee emitDirectCallee(CIRGenModule &CGM, GlobalDecl GD) { auto CalleePtr = emitFunctionDeclPointer(CGM, GD); - assert(!CGM.getLangOpts().CUDA && "NYI"); + // For HIP, the device stub should be converted to handle. + if (CGM.getLangOpts().HIP && !CGM.getLangOpts().CUDAIsDevice && + FD->hasAttr()) + llvm_unreachable("NYI"); return CIRGenCallee::forDirect(CalleePtr, GD); } @@ -1405,7 +1408,9 @@ RValue CIRGenFunction::emitCallExpr(const clang::CallExpr *E, if (const auto *CE = dyn_cast(E)) return emitCXXMemberCallExpr(CE, ReturnValue); - assert(!dyn_cast(E) && "CUDA NYI"); + if (const auto *CE = dyn_cast(E)) + return CGM.getCUDARuntime().emitCUDAKernelCallExpr(*this, CE, ReturnValue); + if (const auto *CE = dyn_cast(E)) if (const CXXMethodDecl *MD = dyn_cast_or_null(CE->getCalleeDecl())) diff --git a/clang/test/CIR/CodeGen/CUDA/simple.cu b/clang/test/CIR/CodeGen/CUDA/simple.cu index 9675de3fe61a..51a1d3bb2f4b 100644 --- a/clang/test/CIR/CodeGen/CUDA/simple.cu +++ b/clang/test/CIR/CodeGen/CUDA/simple.cu @@ -31,3 +31,18 @@ __global__ void global_fn(int a) {} // CIR-HOST: cir.call @__cudaPopCallConfiguration // CIR-HOST: cir.get_global @_Z24__device_stub__global_fni // CIR-HOST: cir.call @cudaLaunchKernel + +int main() { + global_fn<<<1, 1>>>(1); +} +// CIR-DEVICE-NOT: cir.func @main() + +// CIR-HOST: cir.func @main() +// CIR-HOST: cir.call @_ZN4dim3C1Ejjj +// CIR-HOST: cir.call @_ZN4dim3C1Ejjj +// CIR-HOST: [[Push:%[0-9]+]] = cir.call @__cudaPushCallConfiguration +// CIR-HOST: [[ConfigOK:%[0-9]+]] = cir.cast(int_to_bool, [[Push]] +// CIR-HOST: cir.if [[ConfigOK]] { +// CIR-HOST: [[Arg:%[0-9]+]] = cir.const #cir.int<1> +// CIR-HOST: cir.call @_Z24__device_stub__global_fni([[Arg]]) +// CIR-HOST: }