Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] fix C_PTR function result lowering #100082

Merged
merged 2 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 34 additions & 20 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1541,21 +1541,44 @@ mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder,
zero);
}

static std::pair<mlir::Value, mlir::Type>
genCPtrOrCFunptrFieldIndex(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Type cptrTy) {
auto recTy = mlir::cast<fir::RecordType>(cptrTy);
assert(recTy.getTypeList().size() == 1);
auto addrFieldName = recTy.getTypeList()[0].first;
mlir::Type addrFieldTy = recTy.getTypeList()[0].second;
auto fieldIndexType = fir::FieldType::get(cptrTy.getContext());
mlir::Value addrFieldIndex = builder.create<fir::FieldIndexOp>(
loc, fieldIndexType, addrFieldName, recTy,
/*typeParams=*/mlir::ValueRange{});
return {addrFieldIndex, addrFieldTy};
}

mlir::Value fir::factory::genCPtrOrCFunptrAddr(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value cPtr,
mlir::Type ty) {
assert(mlir::isa<fir::RecordType>(ty));
auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
assert(recTy.getTypeList().size() == 1);
auto fieldName = recTy.getTypeList()[0].first;
mlir::Type fieldTy = recTy.getTypeList()[0].second;
auto fieldIndexType = fir::FieldType::get(ty.getContext());
mlir::Value field =
builder.create<fir::FieldIndexOp>(loc, fieldIndexType, fieldName, recTy,
/*typeParams=*/mlir::ValueRange{});
return builder.create<fir::CoordinateOp>(loc, builder.getRefType(fieldTy),
cPtr, field);
auto [addrFieldIndex, addrFieldTy] =
genCPtrOrCFunptrFieldIndex(builder, loc, ty);
return builder.create<fir::CoordinateOp>(loc, builder.getRefType(addrFieldTy),
cPtr, addrFieldIndex);
}

mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value cPtr) {
mlir::Type cPtrTy = fir::unwrapRefType(cPtr.getType());
if (fir::isa_ref_type(cPtr.getType())) {
mlir::Value cPtrAddr =
fir::factory::genCPtrOrCFunptrAddr(builder, loc, cPtr, cPtrTy);
return builder.create<fir::LoadOp>(loc, cPtrAddr);
}
auto [addrFieldIndex, addrFieldTy] =
genCPtrOrCFunptrFieldIndex(builder, loc, cPtrTy);
auto arrayAttr =
builder.getArrayAttr({builder.getIntegerAttr(builder.getIndexType(), 0)});
return builder.create<fir::ExtractValueOp>(loc, addrFieldTy, cPtr, arrayAttr);
}

fir::BoxValue fir::factory::createBoxValue(fir::FirOpBuilder &builder,
Expand Down Expand Up @@ -1596,15 +1619,6 @@ fir::BoxValue fir::factory::createBoxValue(fir::FirOpBuilder &builder,
return fir::BoxValue(box, lbounds, explicitTypeParams);
}

mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value cPtr) {
mlir::Type cPtrTy = fir::unwrapRefType(cPtr.getType());
mlir::Value cPtrAddr =
fir::factory::genCPtrOrCFunptrAddr(builder, loc, cPtr, cPtrTy);
return builder.create<fir::LoadOp>(loc, cPtrAddr);
}

mlir::Value fir::factory::createNullBoxProc(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Type boxType) {
Expand Down
108 changes: 56 additions & 52 deletions flang/lib/Optimizer/Transforms/AbstractResult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,16 @@ static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy,
/*resultTypes=*/{});
}

static mlir::Type getVoidPtrType(mlir::MLIRContext *context) {
return fir::ReferenceType::get(mlir::NoneType::get(context));
}

/// This is for function result types that are of type C_PTR from ISO_C_BINDING.
/// Follow the ABI for interoperability with C.
static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) {
auto resultType = funcTy.getResult(0);
assert(fir::isa_builtin_cptr_type(resultType));
llvm::SmallVector<mlir::Type> outputTypes;
auto recTy = mlir::dyn_cast<fir::RecordType>(resultType);
outputTypes.emplace_back(recTy.getTypeList()[0].second);
assert(fir::isa_builtin_cptr_type(funcTy.getResult(0)));
llvm::SmallVector<mlir::Type> outputTypes{
getVoidPtrType(funcTy.getContext())};
return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(),
outputTypes);
}
Expand Down Expand Up @@ -109,15 +111,11 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
saveResult.getTypeparams());

llvm::SmallVector<mlir::Type> newResultTypes;
// TODO: This should be generalized for derived types, and it is
// architecture and OS dependent.
bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType());
Op newOp;
if (isResultBuiltinCPtr) {
auto recTy = mlir::dyn_cast<fir::RecordType>(result.getType());
newResultTypes.emplace_back(recTy.getTypeList()[0].second);
}
if (isResultBuiltinCPtr)
newResultTypes.emplace_back(getVoidPtrType(result.getContext()));

Op newOp;
// fir::CallOp specific handling.
if constexpr (std::is_same_v<Op, fir::CallOp>) {
if (op.getCallee()) {
Expand Down Expand Up @@ -175,7 +173,7 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
FirOpBuilder builder(rewriter, module);
mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr(
builder, loc, save, result.getType());
rewriter.create<fir::StoreOp>(loc, newOp->getResult(0), saveAddr);
builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr);
}
op->dropAllReferences();
rewriter.eraseOp(op);
Expand Down Expand Up @@ -210,42 +208,52 @@ class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
mlir::PatternRewriter &rewriter) const override {
auto loc = ret.getLoc();
rewriter.setInsertionPoint(ret);
auto returnedValue = ret.getOperand(0);
bool replacedStorage = false;
if (auto *op = returnedValue.getDefiningOp())
if (auto load = mlir::dyn_cast<fir::LoadOp>(op)) {
auto resultStorage = load.getMemref();
// The result alloca may be behind a fir.declare, if any.
if (auto declare = mlir::dyn_cast_or_null<fir::DeclareOp>(
resultStorage.getDefiningOp()))
resultStorage = declare.getMemref();
// TODO: This should be generalized for derived types, and it is
// architecture and OS dependent.
if (fir::isa_builtin_cptr_type(returnedValue.getType())) {
rewriter.eraseOp(load);
auto module = ret->getParentOfType<mlir::ModuleOp>();
FirOpBuilder builder(rewriter, module);
mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr(
builder, loc, resultStorage, returnedValue.getType());
mlir::Value retValue = rewriter.create<fir::LoadOp>(
loc, fir::unwrapRefType(retAddr.getType()), retAddr);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
ret, mlir::ValueRange{retValue});
return mlir::success();
}
resultStorage.replaceAllUsesWith(newArg);
replacedStorage = true;
if (auto *alloc = resultStorage.getDefiningOp())
if (alloc->use_empty())
rewriter.eraseOp(alloc);
mlir::Value resultValue = ret.getOperand(0);
fir::LoadOp resultLoad;
mlir::Value resultStorage;
// Identify result local storage.
if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
resultLoad = load;
resultStorage = load.getMemref();
// The result alloca may be behind a fir.declare, if any.
if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
resultStorage = declare.getMemref();
}
// Replace old local storage with new storage argument, unless
// the derived type is C_PTR/C_FUN_PTR, in which case the return
// type is updated to return void* (no new argument is passed).
if (fir::isa_builtin_cptr_type(resultValue.getType())) {
auto module = ret->getParentOfType<mlir::ModuleOp>();
FirOpBuilder builder(rewriter, module);
mlir::Value cptr = resultValue;
if (resultLoad) {
// Replace whole derived type load by component load.
cptr = resultLoad.getMemref();
rewriter.setInsertionPoint(resultLoad);
}
// The result storage may have been optimized out by a memory to
// register pass, this is possible for fir.box results, or fir.record
// with no length parameters. Simply store the result in the result storage.
// at the return point.
if (!replacedStorage)
rewriter.create<fir::StoreOp>(loc, returnedValue, newArg);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
mlir::Value newResultValue =
fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
newResultValue = builder.createConvert(
loc, getVoidPtrType(ret.getContext()), newResultValue);
rewriter.setInsertionPoint(ret);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
ret, mlir::ValueRange{newResultValue});
} else if (resultStorage) {
resultStorage.replaceAllUsesWith(newArg);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
} else {
// The result storage may have been optimized out by a memory to
// register pass, this is possible for fir.box results, or fir.record
// with no length parameters. Simply store the result in the result
// storage. at the return point.
rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
}
// Delete result old local storage if unused.
if (resultStorage)
if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
if (alloc->use_empty())
rewriter.eraseOp(alloc);
return mlir::success();
}

Expand All @@ -263,8 +271,6 @@ class AddrOfOpConversion : public mlir::OpRewritePattern<fir::AddrOfOp> {
mlir::PatternRewriter &rewriter) const override {
auto oldFuncTy = mlir::cast<mlir::FunctionType>(addrOf.getType());
mlir::FunctionType newFuncTy;
// TODO: This should be generalized for derived types, and it is
// architecture and OS dependent.
if (oldFuncTy.getNumResults() != 0 &&
fir::isa_builtin_cptr_type(oldFuncTy.getResult(0)))
newFuncTy = getCPtrFunctionType(oldFuncTy);
Expand Down Expand Up @@ -298,8 +304,6 @@ class AbstractResultOpt
// Convert function type itself if it has an abstract result.
auto funcTy = mlir::cast<mlir::FunctionType>(func.getFunctionType());
if (hasAbstractResult(funcTy)) {
// TODO: This should be generalized for derived types, and it is
// architecture and OS dependent.
if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) {
func.setType(getCPtrFunctionType(funcTy));
patterns.insert<ReturnOpConversion>(context, mlir::Value{});
Expand Down
36 changes: 20 additions & 16 deletions flang/test/Fir/abstract-results.fir
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ func.func @boxfunc_callee() -> !fir.box<!fir.heap<f64>> {
// FUNC-BOX: return
}

// FUNC-REF-LABEL: func @retcptr() -> i64
// FUNC-BOX-LABEL: func @retcptr() -> i64
// FUNC-REF-LABEL: func @retcptr() -> !fir.ref<none>
// FUNC-BOX-LABEL: func @retcptr() -> !fir.ref<none>
func.func @retcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {
%0 = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"}
%1 = fir.load %0 : !fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>
Expand All @@ -98,12 +98,14 @@ func.func @retcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__addres
// FUNC-REF: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-REF: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
// FUNC-REF: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64>
// FUNC-REF: return %[[VAL]] : i64
// FUNC-REF: %[[CAST:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref<none>
// FUNC-REF: return %[[CAST]] : !fir.ref<none>
// FUNC-BOX: %[[ALLOC:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"}
// FUNC-BOX: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-BOX: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
// FUNC-BOX: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref<i64>
// FUNC-BOX: return %[[VAL]] : i64
// FUNC-BOX: %[[CAST:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref<none>
// FUNC-BOX: return %[[CAST]] : !fir.ref<none>
}

// FUNC-REF-LABEL: func private @arrayfunc_callee_declare(
Expand Down Expand Up @@ -311,8 +313,8 @@ func.func @test_address_of() {

}

// FUNC-REF-LABEL: func.func private @returns_null() -> i64
// FUNC-BOX-LABEL: func.func private @returns_null() -> i64
// FUNC-REF-LABEL: func.func private @returns_null() -> !fir.ref<none>
// FUNC-BOX-LABEL: func.func private @returns_null() -> !fir.ref<none>
func.func private @returns_null() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>

// FUNC-REF-LABEL: func @test_address_of_cptr
Expand All @@ -323,12 +325,12 @@ func.func @test_address_of_cptr() {
fir.call @_QMtest_c_func_modPsubr(%1) : (() -> ()) -> ()
return

// FUNC-REF: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64
// FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-REF: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> !fir.ref<none>
// FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> !fir.ref<none>) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ())
// FUNC-REF: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> ()
// FUNC-BOX: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64
// FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-BOX: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> !fir.ref<none>
// FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> !fir.ref<none>) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ())
// FUNC-BOX: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> ()
}
Expand Down Expand Up @@ -380,18 +382,20 @@ func.func @test_indirect_calls_return_cptr(%arg0: () -> ()) {

// FUNC-REF: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
// FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64)
// FUNC-REF: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64
// FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> !fir.ref<none>)
// FUNC-REF: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> !fir.ref<none>
// FUNC-REF: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-REF: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
// FUNC-REF: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i64>
// FUNC-REF: %[[CAST:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<none>) -> i64
// FUNC-REF: fir.store %[[CAST]] to %[[VAL_5]] : !fir.ref<i64>
// FUNC-BOX: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"}
// FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>)
// FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64)
// FUNC-BOX: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64
// FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> !fir.ref<none>)
// FUNC-BOX: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> !fir.ref<none>
// FUNC-BOX: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>
// FUNC-BOX: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref<!fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>>, !fir.field) -> !fir.ref<i64>
// FUNC-BOX: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref<i64>
// FUNC-BOX: %[[CAST:.*]] = fir.convert %[[VAL_3]] : (!fir.ref<none>) -> i64
// FUNC-BOX: fir.store %[[CAST]] to %[[VAL_5]] : !fir.ref<i64>
}

// ----------------------- Test GlobalOp rewrite ------------------------
Expand Down
Loading