Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions clang/lib/CIR/CodeGen/CIRGenCUDARuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,16 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
loc, cir::PointerType::get(voidPtrArrayTy), voidPtrArrayTy, "kernel_args",
CharUnits::fromQuantity(16));

mlir::Value kernelArgsDecayed =
builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
cir::PointerType::get(cgm.VoidPtrTy));

// Store arguments into kernelArgs
for (auto [i, arg] : llvm::enumerate(args)) {
mlir::Value index =
builder.getConstInt(loc, llvm::APInt(/*numBits=*/32, i));
mlir::Value storePos = builder.createPtrStride(loc, kernelArgs, index);
mlir::Value storePos =
builder.createPtrStride(loc, kernelArgsDecayed, index);
builder.CIRBaseBuilderTy::createStore(
loc, cgf.GetAddrOfLocalVar(arg).getPointer(), storePos);
}
Expand Down Expand Up @@ -166,10 +171,6 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
// mlir::Value func = builder.createBitcast(kernel, cgm.VoidPtrTy);
CallArgList launchArgs;

mlir::Value kernelArgsDecayed =
builder.createCast(cir::CastKind::array_to_ptrdecay, kernelArgs,
cir::PointerType::get(cgm.VoidPtrTy));

launchArgs.add(RValue::get(kernel), launchFD->getParamDecl(0)->getType());
launchArgs.add(
RValue::getAggregate(Address(gridDim, CharUnits::fromQuantity(8))),
Expand All @@ -182,7 +183,8 @@ void CIRGenCUDARuntime::emitDeviceStubBodyNew(CIRGenFunction &cgf,
launchArgs.add(
RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, sharedMem)),
launchFD->getParamDecl(4)->getType());
launchArgs.add(RValue::get(stream), launchFD->getParamDecl(5)->getType());
launchArgs.add(RValue::get(builder.CIRBaseBuilderTy::createLoad(loc, stream)),
launchFD->getParamDecl(5)->getType());

mlir::Type launchTy = cgm.getTypes().convertType(launchFD->getType());
mlir::Operation *launchFn =
Expand Down Expand Up @@ -219,13 +221,16 @@ RValue CIRGenCUDARuntime::emitCUDAKernelCallExpr(CIRGenFunction &cgf,

cgf.emitIfOnBoolExpr(
expr->getConfig(),
[&](mlir::OpBuilder &b, mlir::Location l) {
b.create<cir::YieldOp>(loc);
},
loc,
[&](mlir::OpBuilder &b, mlir::Location l) {
CIRGenCallee callee = cgf.emitCallee(expr->getCallee());
cgf.emitCall(expr->getCallee()->getType(), callee, expr, retValue);
b.create<cir::YieldOp>(loc);
},
loc, [](mlir::OpBuilder &b, mlir::Location l) {},
std::optional<mlir::Location>());
loc);

return RValue::get(nullptr);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ void X86_64ABIInfo::computeInfo(LowerFunctionInfo &FI) const {
if (cir::MissingFeatures::vectorType())
cir_cconv_unreachable("NYI");
} else {
cir_cconv_unreachable("Indirect results are NYI");
it->info = getIndirectResult(it->type, FreeIntRegs);
}
}
}
Expand Down
17 changes: 11 additions & 6 deletions clang/test/CIR/CodeGen/CUDA/simple.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@ __global__ void global_fn(int a) {}
// Check for device stub emission.

// CIR-HOST: @_Z24__device_stub__global_fni{{.*}}extra([[Kernel]])
// CIR-HOST: cir.alloca {{.*}}"kernel_args"
// CIR-HOST: %[[#CIRKernelArgs:]] = cir.alloca {{.*}}"kernel_args"
// CIR-HOST: %[[#Decayed:]] = cir.cast(array_to_ptrdecay, %[[#CIRKernelArgs]]
// CIR-HOST: cir.call @__cudaPopCallConfiguration
// CIR-HOST: cir.get_global @_Z24__device_stub__global_fni
// CIR-HOST: cir.call @cudaLaunchKernel

// LLVM-HOST: void @_Z24__device_stub__global_fni
// LLVM-HOST: %[[#KernelArgs:]] = alloca [1 x ptr], i64 1, align 16
// LLVM-HOST: %[[#GEP1:]] = getelementptr ptr, ptr %[[#KernelArgs]], i32 0
// LLVM-HOST: %[[#GEP2:]] = getelementptr ptr, ptr %[[#GEP1]], i64 0
// LLVM-HOST: call i32 @__cudaPopCallConfiguration
// LLVM-HOST: call i32 @cudaLaunchKernel(ptr @_Z24__device_stub__global_fni

Expand All @@ -48,6 +52,7 @@ int main() {
// CIR-HOST: [[Push:%[0-9]+]] = cir.call @__cudaPushCallConfiguration
// CIR-HOST: [[ConfigOK:%[0-9]+]] = cir.cast(int_to_bool, [[Push]]
// CIR-HOST: cir.if [[ConfigOK]] {
// CIR-HOST: } else {
// CIR-HOST: [[Arg:%[0-9]+]] = cir.const #cir.int<1>
// CIR-HOST: cir.call @_Z24__device_stub__global_fni([[Arg]])
// CIR-HOST: }
Expand All @@ -58,9 +63,9 @@ int main() {
// LLVM-HOST: call void @_ZN4dim3C1Ejjj
// LLVM-HOST: call void @_ZN4dim3C1Ejjj
// LLVM-HOST: [[LLVMConfigOK:%[0-9]+]] = call i32 @__cudaPushCallConfiguration
// LLVM-HOST: br [[LLVMConfigOK]], label %[[Good:[0-9]+]], label [[Bad:[0-9]+]]
// LLVM-HOST: [[Good]]:
// LLVM-HOST: br [[LLVMConfigOK]], label %[[#Good:]], label [[#Bad:]]
// LLVM-HOST: [[#Good]]:
// LLVM-HOST: br label [[#End:]]
// LLVM-HOST: [[#Bad]]:
// LLVM-HOST: call void @_Z24__device_stub__global_fni
// LLVM-HOST: br label [[Bad]]
// LLVM-HOST: [[Bad]]:
// LLVM-HOST: ret i32
// LLVM-HOST: br label [[#End]]
Loading