Skip to content

Commit

Permalink
[clang][CodeGen] sret args should always point to the alloca AS, …
Browse files Browse the repository at this point in the history
…so use that (llvm#114062)

`sret` arguments are always going to reside in the stack/`alloca`
address space, which makes the current formulation where their AS is
derived from the pointee somewhat quaint. This patch ensures that `sret`
ends up pointing to the `alloca` AS in IR function signatures, and also
guards agains trying to pass a casted `alloca`d pointer to a `sret` arg,
which can happen for most languages, when compiled for targets that have
a non-zero `alloca` AS (e.g. AMDGCN) / map `LangAS::default` to a
non-zero value (SPIR-V). A target could still choose to do something
different here, by e.g. overriding `classifyReturnType` behaviour.

In a broader sense, this patch extends non-aliased indirect args to also
carry an AS, which leads to changing the `getIndirect()` interface. At
the moment we're only using this for (indirect) returns, but it allows
for future handling of indirect args themselves. We default to using the
AllocaAS as that matches what Clang is currently doing, however if, in
the future, a target would opt for e.g. placing indirect returns in some
other storage, with another AS, this will require revisiting.

---------

Co-authored-by: Matt Arsenault <arsenm2@gmail.com>
Co-authored-by: Matt Arsenault <Matthew.Arsenault@amd.com>
  • Loading branch information
3 people authored Feb 14, 2025
1 parent 2bcf62b commit 39ec9de
Show file tree
Hide file tree
Showing 35 changed files with 371 additions and 164 deletions.
11 changes: 6 additions & 5 deletions clang/include/clang/CodeGen/CGFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,16 @@ class ABIArgInfo {
static ABIArgInfo getIgnore() {
return ABIArgInfo(Ignore);
}
static ABIArgInfo getIndirect(CharUnits Alignment, bool ByVal = true,
bool Realign = false,
static ABIArgInfo getIndirect(CharUnits Alignment, unsigned AddrSpace,
bool ByVal = true, bool Realign = false,
llvm::Type *Padding = nullptr) {
auto AI = ABIArgInfo(Indirect);
AI.setIndirectAlign(Alignment);
AI.setIndirectByVal(ByVal);
AI.setIndirectRealign(Realign);
AI.setSRetAfterThis(false);
AI.setPaddingType(Padding);
AI.setIndirectAddrSpace(AddrSpace);
return AI;
}

Expand All @@ -232,7 +233,7 @@ class ABIArgInfo {

static ABIArgInfo getIndirectInReg(CharUnits Alignment, bool ByVal = true,
bool Realign = false) {
auto AI = getIndirect(Alignment, ByVal, Realign);
auto AI = getIndirect(Alignment, 0, ByVal, Realign);
AI.setInReg(true);
return AI;
}
Expand Down Expand Up @@ -422,12 +423,12 @@ class ABIArgInfo {
}

unsigned getIndirectAddrSpace() const {
assert(isIndirectAliased() && "Invalid kind!");
assert((isIndirect() || isIndirectAliased()) && "Invalid kind!");
return IndirectAttr.AddrSpace;
}

void setIndirectAddrSpace(unsigned AddrSpace) {
assert(isIndirectAliased() && "Invalid kind!");
assert((isIndirect() || isIndirectAliased()) && "Invalid kind!");
IndirectAttr.AddrSpace = AddrSpace;
}

Expand Down
8 changes: 4 additions & 4 deletions clang/lib/CodeGen/ABIInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,11 @@ bool ABIInfo::isPromotableIntegerTypeForABI(QualType Ty) const {
return false;
}

ABIArgInfo ABIInfo::getNaturalAlignIndirect(QualType Ty, bool ByVal,
bool Realign,
ABIArgInfo ABIInfo::getNaturalAlignIndirect(QualType Ty, unsigned AddrSpace,
bool ByVal, bool Realign,
llvm::Type *Padding) const {
return ABIArgInfo::getIndirect(getContext().getTypeAlignInChars(Ty), ByVal,
Realign, Padding);
return ABIArgInfo::getIndirect(getContext().getTypeAlignInChars(Ty),
AddrSpace, ByVal, Realign, Padding);
}

ABIArgInfo ABIInfo::getNaturalAlignIndirectInReg(QualType Ty,
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/CodeGen/ABIInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ class ABIInfo {
/// A convenience method to return an indirect ABIArgInfo with an
/// expected alignment equal to the ABI alignment of the given type.
CodeGen::ABIArgInfo
getNaturalAlignIndirect(QualType Ty, bool ByVal = true, bool Realign = false,
getNaturalAlignIndirect(QualType Ty, unsigned AddrSpace, bool ByVal = true,
bool Realign = false,
llvm::Type *Padding = nullptr) const;

CodeGen::ABIArgInfo getNaturalAlignIndirectInReg(QualType Ty,
Expand Down
15 changes: 9 additions & 6 deletions clang/lib/CodeGen/ABIInfoImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ ABIArgInfo DefaultABIInfo::classifyArgumentType(QualType Ty) const {
// Records with non-trivial destructors/copy-constructors should not be
// passed by value.
if (CGCXXABI::RecordArgABI RAA = getRecordArgABI(Ty, getCXXABI()))
return getNaturalAlignIndirect(Ty, RAA == CGCXXABI::RAA_DirectInMemory);
return getNaturalAlignIndirect(Ty, getDataLayout().getAllocaAddrSpace(),
RAA == CGCXXABI::RAA_DirectInMemory);

return getNaturalAlignIndirect(Ty);
return getNaturalAlignIndirect(Ty, getDataLayout().getAllocaAddrSpace());
}

// Treat an enum type as its underlying type.
Expand All @@ -36,7 +37,7 @@ ABIArgInfo DefaultABIInfo::classifyArgumentType(QualType Ty) const {
Context.getTypeSize(Context.getTargetInfo().hasInt128Type()
? Context.Int128Ty
: Context.LongLongTy))
return getNaturalAlignIndirect(Ty);
return getNaturalAlignIndirect(Ty, getDataLayout().getAllocaAddrSpace());

return (isPromotableIntegerTypeForABI(Ty)
? ABIArgInfo::getExtend(Ty, CGT.ConvertType(Ty))
Expand All @@ -48,7 +49,7 @@ ABIArgInfo DefaultABIInfo::classifyReturnType(QualType RetTy) const {
return ABIArgInfo::getIgnore();

if (isAggregateTypeForABI(RetTy))
return getNaturalAlignIndirect(RetTy);
return getNaturalAlignIndirect(RetTy, getDataLayout().getAllocaAddrSpace());

// Treat an enum type as its underlying type.
if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
Expand All @@ -59,7 +60,8 @@ ABIArgInfo DefaultABIInfo::classifyReturnType(QualType RetTy) const {
getContext().getTypeSize(getContext().getTargetInfo().hasInt128Type()
? getContext().Int128Ty
: getContext().LongLongTy))
return getNaturalAlignIndirect(RetTy);
return getNaturalAlignIndirect(RetTy,
getDataLayout().getAllocaAddrSpace());

return (isPromotableIntegerTypeForABI(RetTy) ? ABIArgInfo::getExtend(RetTy)
: ABIArgInfo::getDirect());
Expand Down Expand Up @@ -126,7 +128,8 @@ bool CodeGen::classifyReturnType(const CGCXXABI &CXXABI, CGFunctionInfo &FI,
if (const auto *RT = Ty->getAs<RecordType>())
if (!isa<CXXRecordDecl>(RT->getDecl()) &&
!RT->getDecl()->canPassInRegisters()) {
FI.getReturnInfo() = Info.getNaturalAlignIndirect(Ty);
FI.getReturnInfo() = Info.getNaturalAlignIndirect(
Ty, Info.getDataLayout().getAllocaAddrSpace());
return true;
}

Expand Down
32 changes: 20 additions & 12 deletions clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1671,10 +1671,8 @@ CodeGenTypes::GetFunctionType(const CGFunctionInfo &FI) {

// Add type for sret argument.
if (IRFunctionArgs.hasSRetArg()) {
QualType Ret = FI.getReturnType();
unsigned AddressSpace = CGM.getTypes().getTargetAddressSpace(Ret);
ArgTypes[IRFunctionArgs.getSRetArgNo()] =
llvm::PointerType::get(getLLVMContext(), AddressSpace);
ArgTypes[IRFunctionArgs.getSRetArgNo()] = llvm::PointerType::get(
getLLVMContext(), FI.getReturnInfo().getIndirectAddrSpace());
}

// Add type for inalloca argument.
Expand Down Expand Up @@ -5144,7 +5142,6 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
// If the call returns a temporary with struct return, create a temporary
// alloca to hold the result, unless one is given to us.
Address SRetPtr = Address::invalid();
RawAddress SRetAlloca = RawAddress::invalid();
llvm::Value *UnusedReturnSizePtr = nullptr;
if (RetAI.isIndirect() || RetAI.isInAlloca() || RetAI.isCoerceAndExpand()) {
// For virtual function pointer thunks and musttail calls, we must always
Expand All @@ -5158,11 +5155,11 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
} else if (!ReturnValue.isNull()) {
SRetPtr = ReturnValue.getAddress();
} else {
SRetPtr = CreateMemTemp(RetTy, "tmp", &SRetAlloca);
SRetPtr = CreateMemTempWithoutCast(RetTy, "tmp");
if (HaveInsertPoint() && ReturnValue.isUnused()) {
llvm::TypeSize size =
CGM.getDataLayout().getTypeAllocSize(ConvertTypeForMem(RetTy));
UnusedReturnSizePtr = EmitLifetimeStart(size, SRetAlloca.getPointer());
UnusedReturnSizePtr = EmitLifetimeStart(size, SRetPtr.getBasePointer());
}
}
if (IRFunctionArgs.hasSRetArg()) {
Expand Down Expand Up @@ -5397,11 +5394,22 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
V->getType()->isIntegerTy())
V = Builder.CreateZExt(V, ArgInfo.getCoerceToType());

// If the argument doesn't match, perform a bitcast to coerce it. This
// can happen due to trivial type mismatches.
// The only plausible mismatch here would be for pointer address spaces,
// which can happen e.g. when passing a sret arg that is in the AllocaAS
// to a function that takes a pointer to and argument in the DefaultAS.
// We assume that the target has a reasonable mapping for the DefaultAS
// (it can be casted to from incoming specific ASes), and insert an AS
// cast to address the mismatch.
if (FirstIRArg < IRFuncTy->getNumParams() &&
V->getType() != IRFuncTy->getParamType(FirstIRArg))
V = Builder.CreateBitCast(V, IRFuncTy->getParamType(FirstIRArg));
V->getType() != IRFuncTy->getParamType(FirstIRArg)) {
assert(V->getType()->isPointerTy() && "Only pointers can mismatch!");
auto FormalAS = CallInfo.arguments()[ArgNo]
.type.getQualifiers()
.getAddressSpace();
auto ActualAS = I->Ty.getAddressSpace();
V = getTargetHooks().performAddrSpaceCast(
*this, V, ActualAS, FormalAS, IRFuncTy->getParamType(FirstIRArg));
}

if (ArgHasMaybeUndefAttr)
V = Builder.CreateFreeze(V);
Expand Down Expand Up @@ -5737,7 +5745,7 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
// pop this cleanup later on. Being eager about this is OK, since this
// temporary is 'invisible' outside of the callee.
if (UnusedReturnSizePtr)
pushFullExprCleanup<CallLifetimeEnd>(NormalEHLifetimeMarker, SRetAlloca,
pushFullExprCleanup<CallLifetimeEnd>(NormalEHLifetimeMarker, SRetPtr,
UnusedReturnSizePtr);

llvm::BasicBlock *InvokeDest = CannotThrow ? nullptr : getInvokeDest();
Expand Down
19 changes: 13 additions & 6 deletions clang/lib/CodeGen/CGExprAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,18 +296,25 @@ void AggExprEmitter::withReturnValueSlot(
(RequiresDestruction && Dest.isIgnored());

Address RetAddr = Address::invalid();
RawAddress RetAllocaAddr = RawAddress::invalid();

EHScopeStack::stable_iterator LifetimeEndBlock;
llvm::Value *LifetimeSizePtr = nullptr;
llvm::IntrinsicInst *LifetimeStartInst = nullptr;
if (!UseTemp) {
RetAddr = Dest.getAddress();
// It is possible for the existing slot we are using directly to have been
// allocated in the correct AS for an indirect return, and then cast to
// the default AS (this is the behaviour of CreateMemTemp), however we know
// that the return address is expected to point to the uncasted AS, hence we
// strip possible pointer casts here.
if (Dest.getAddress().isValid())
RetAddr = Dest.getAddress().withPointer(
Dest.getAddress().getBasePointer()->stripPointerCasts(),
Dest.getAddress().isKnownNonNull());
} else {
RetAddr = CGF.CreateMemTemp(RetTy, "tmp", &RetAllocaAddr);
RetAddr = CGF.CreateMemTempWithoutCast(RetTy, "tmp");
llvm::TypeSize Size =
CGF.CGM.getDataLayout().getTypeAllocSize(CGF.ConvertTypeForMem(RetTy));
LifetimeSizePtr = CGF.EmitLifetimeStart(Size, RetAllocaAddr.getPointer());
LifetimeSizePtr = CGF.EmitLifetimeStart(Size, RetAddr.getBasePointer());
if (LifetimeSizePtr) {
LifetimeStartInst =
cast<llvm::IntrinsicInst>(std::prev(Builder.GetInsertPoint()));
Expand All @@ -316,7 +323,7 @@ void AggExprEmitter::withReturnValueSlot(
"Last insertion wasn't a lifetime.start?");

CGF.pushFullExprCleanup<CodeGenFunction::CallLifetimeEnd>(
NormalEHLifetimeMarker, RetAllocaAddr, LifetimeSizePtr);
NormalEHLifetimeMarker, RetAddr, LifetimeSizePtr);
LifetimeEndBlock = CGF.EHStack.stable_begin();
}
}
Expand All @@ -337,7 +344,7 @@ void AggExprEmitter::withReturnValueSlot(
// Since we're not guaranteed to be in an ExprWithCleanups, clean up
// eagerly.
CGF.DeactivateCleanupBlock(LifetimeEndBlock, LifetimeStartInst);
CGF.EmitLifetimeEnd(LifetimeSizePtr, RetAllocaAddr.getPointer());
CGF.EmitLifetimeEnd(LifetimeSizePtr, RetAddr.getBasePointer());
}
}

Expand Down
4 changes: 3 additions & 1 deletion clang/lib/CodeGen/ItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,9 @@ bool ItaniumCXXABI::classifyReturnType(CGFunctionInfo &FI) const {
// If C++ prohibits us from making a copy, return by address.
if (!RD->canPassInRegisters()) {
auto Align = CGM.getContext().getTypeAlignInChars(FI.getReturnType());
FI.getReturnInfo() = ABIArgInfo::getIndirect(Align, /*ByVal=*/false);
FI.getReturnInfo() = ABIArgInfo::getIndirect(
Align, /*AddrSpace=*/CGM.getDataLayout().getAllocaAddrSpace(),
/*ByVal=*/false);
return true;
}
return false;
Expand Down
4 changes: 3 additions & 1 deletion clang/lib/CodeGen/MicrosoftCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,9 @@ bool MicrosoftCXXABI::classifyReturnType(CGFunctionInfo &FI) const {

if (isIndirectReturn) {
CharUnits Align = CGM.getContext().getTypeAlignInChars(FI.getReturnType());
FI.getReturnInfo() = ABIArgInfo::getIndirect(Align, /*ByVal=*/false);
FI.getReturnInfo() = ABIArgInfo::getIndirect(
Align, /*AddrSpace=*/CGM.getDataLayout().getAllocaAddrSpace(),
/*ByVal=*/false);

// MSVC always passes `this` before the `sret` parameter.
FI.getReturnInfo().setSRetAfterThis(FI.isInstanceMethod());
Expand Down
16 changes: 11 additions & 5 deletions clang/lib/CodeGen/SwiftCallingConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -796,11 +796,14 @@ bool swiftcall::mustPassRecordIndirectly(CodeGenModule &CGM,

static ABIArgInfo classifyExpandedType(SwiftAggLowering &lowering,
bool forReturn,
CharUnits alignmentForIndirect) {
CharUnits alignmentForIndirect,
unsigned IndirectAS) {
if (lowering.empty()) {
return ABIArgInfo::getIgnore();
} else if (lowering.shouldPassIndirectly(forReturn)) {
return ABIArgInfo::getIndirect(alignmentForIndirect, /*byval*/ false);
return ABIArgInfo::getIndirect(alignmentForIndirect,
/*AddrSpace=*/IndirectAS,
/*byval=*/false);
} else {
auto types = lowering.getCoerceAndExpandTypes();
return ABIArgInfo::getCoerceAndExpand(types.first, types.second);
Expand All @@ -809,18 +812,21 @@ static ABIArgInfo classifyExpandedType(SwiftAggLowering &lowering,

static ABIArgInfo classifyType(CodeGenModule &CGM, CanQualType type,
bool forReturn) {
unsigned IndirectAS = CGM.getDataLayout().getAllocaAddrSpace();
if (auto recordType = dyn_cast<RecordType>(type)) {
auto record = recordType->getDecl();
auto &layout = CGM.getContext().getASTRecordLayout(record);

if (mustPassRecordIndirectly(CGM, record))
return ABIArgInfo::getIndirect(layout.getAlignment(), /*byval*/ false);
return ABIArgInfo::getIndirect(layout.getAlignment(),
/*AddrSpace=*/IndirectAS, /*byval=*/false);

SwiftAggLowering lowering(CGM);
lowering.addTypedData(recordType->getDecl(), CharUnits::Zero(), layout);
lowering.finish();

return classifyExpandedType(lowering, forReturn, layout.getAlignment());
return classifyExpandedType(lowering, forReturn, layout.getAlignment(),
IndirectAS);
}

// Just assume that all of our target ABIs can support returning at least
Expand All @@ -836,7 +842,7 @@ static ABIArgInfo classifyType(CodeGenModule &CGM, CanQualType type,
lowering.finish();

CharUnits alignment = CGM.getContext().getTypeAlignInChars(type);
return classifyExpandedType(lowering, forReturn, alignment);
return classifyExpandedType(lowering, forReturn, alignment, IndirectAS);
}

// Member pointer types need to be expanded, but it's a simple form of
Expand Down
24 changes: 15 additions & 9 deletions clang/lib/CodeGen/Targets/AArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,15 +327,17 @@ ABIArgInfo AArch64ABIInfo::coerceIllegalVector(QualType Ty, unsigned &NSRN,
return ABIArgInfo::getDirect(ResType);
}

return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
return getNaturalAlignIndirect(Ty, getDataLayout().getAllocaAddrSpace(),
/*ByVal=*/false);
}

ABIArgInfo AArch64ABIInfo::coerceAndExpandPureScalableAggregate(
QualType Ty, bool IsNamedArg, unsigned NVec, unsigned NPred,
const SmallVectorImpl<llvm::Type *> &UnpaddedCoerceToSeq, unsigned &NSRN,
unsigned &NPRN) const {
if (!IsNamedArg || NSRN + NVec > 8 || NPRN + NPred > 4)
return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
return getNaturalAlignIndirect(Ty, getDataLayout().getAllocaAddrSpace(),
/*ByVal=*/false);
NSRN += NVec;
NPRN += NPred;

Expand Down Expand Up @@ -375,7 +377,8 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,

if (const auto *EIT = Ty->getAs<BitIntType>())
if (EIT->getNumBits() > 128)
return getNaturalAlignIndirect(Ty, false);
return getNaturalAlignIndirect(Ty, getDataLayout().getAllocaAddrSpace(),
false);

if (Ty->isVectorType())
NSRN = std::min(NSRN + 1, 8u);
Expand Down Expand Up @@ -411,8 +414,9 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
// Structures with either a non-trivial destructor or a non-trivial
// copy constructor are always indirect.
if (CGCXXABI::RecordArgABI RAA = getRecordArgABI(Ty, getCXXABI())) {
return getNaturalAlignIndirect(Ty, /*ByVal=*/RAA ==
CGCXXABI::RAA_DirectInMemory);
return getNaturalAlignIndirect(
Ty, /*AddrSpace=*/getDataLayout().getAllocaAddrSpace(),
/*ByVal=*/RAA == CGCXXABI::RAA_DirectInMemory);
}

// Empty records:
Expand Down Expand Up @@ -489,7 +493,8 @@ ABIArgInfo AArch64ABIInfo::classifyArgumentType(QualType Ty, bool IsVariadicFn,
: llvm::ArrayType::get(BaseTy, Size / Alignment));
}

return getNaturalAlignIndirect(Ty, /*ByVal=*/false);
return getNaturalAlignIndirect(Ty, getDataLayout().getAllocaAddrSpace(),
/*ByVal=*/false);
}

ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
Expand All @@ -507,7 +512,7 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,

// Large vector types should be returned via memory.
if (RetTy->isVectorType() && getContext().getTypeSize(RetTy) > 128)
return getNaturalAlignIndirect(RetTy);
return getNaturalAlignIndirect(RetTy, getDataLayout().getAllocaAddrSpace());

if (!passAsAggregateType(RetTy)) {
// Treat an enum type as its underlying type.
Expand All @@ -516,7 +521,8 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,

if (const auto *EIT = RetTy->getAs<BitIntType>())
if (EIT->getNumBits() > 128)
return getNaturalAlignIndirect(RetTy);
return getNaturalAlignIndirect(RetTy,
getDataLayout().getAllocaAddrSpace());

return (isPromotableIntegerTypeForABI(RetTy) && isDarwinPCS()
? ABIArgInfo::getExtend(RetTy)
Expand Down Expand Up @@ -575,7 +581,7 @@ ABIArgInfo AArch64ABIInfo::classifyReturnType(QualType RetTy,
return ABIArgInfo::getDirect(llvm::IntegerType::get(getVMContext(), Size));
}

return getNaturalAlignIndirect(RetTy);
return getNaturalAlignIndirect(RetTy, getDataLayout().getAllocaAddrSpace());
}

/// isIllegalVectorType - check whether the vector type is legal for AArch64.
Expand Down
Loading

0 comments on commit 39ec9de

Please sign in to comment.