diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp index 2961df96b3cab..fbe79d0e45e5a 100644 --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -1541,21 +1541,44 @@ mlir::Value fir::factory::genMaxWithZero(fir::FirOpBuilder &builder, zero); } +static std::pair +genCPtrOrCFunptrFieldIndex(fir::FirOpBuilder &builder, mlir::Location loc, + mlir::Type cptrTy) { + auto recTy = mlir::cast(cptrTy); + assert(recTy.getTypeList().size() == 1); + auto addrFieldName = recTy.getTypeList()[0].first; + mlir::Type addrFieldTy = recTy.getTypeList()[0].second; + auto fieldIndexType = fir::FieldType::get(cptrTy.getContext()); + mlir::Value addrFieldIndex = builder.create( + loc, fieldIndexType, addrFieldName, recTy, + /*typeParams=*/mlir::ValueRange{}); + return {addrFieldIndex, addrFieldTy}; +} + mlir::Value fir::factory::genCPtrOrCFunptrAddr(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value cPtr, mlir::Type ty) { - assert(mlir::isa(ty)); - auto recTy = mlir::dyn_cast(ty); - assert(recTy.getTypeList().size() == 1); - auto fieldName = recTy.getTypeList()[0].first; - mlir::Type fieldTy = recTy.getTypeList()[0].second; - auto fieldIndexType = fir::FieldType::get(ty.getContext()); - mlir::Value field = - builder.create(loc, fieldIndexType, fieldName, recTy, - /*typeParams=*/mlir::ValueRange{}); - return builder.create(loc, builder.getRefType(fieldTy), - cPtr, field); + auto [addrFieldIndex, addrFieldTy] = + genCPtrOrCFunptrFieldIndex(builder, loc, ty); + return builder.create(loc, builder.getRefType(addrFieldTy), + cPtr, addrFieldIndex); +} + +mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder, + mlir::Location loc, + mlir::Value cPtr) { + mlir::Type cPtrTy = fir::unwrapRefType(cPtr.getType()); + if (fir::isa_ref_type(cPtr.getType())) { + mlir::Value cPtrAddr = + fir::factory::genCPtrOrCFunptrAddr(builder, loc, cPtr, cPtrTy); + return builder.create(loc, cPtrAddr); + } + auto [addrFieldIndex, addrFieldTy] = + genCPtrOrCFunptrFieldIndex(builder, loc, cPtrTy); + auto arrayAttr = + builder.getArrayAttr({builder.getIntegerAttr(builder.getIndexType(), 0)}); + return builder.create(loc, addrFieldTy, cPtr, arrayAttr); } fir::BoxValue fir::factory::createBoxValue(fir::FirOpBuilder &builder, @@ -1596,15 +1619,6 @@ fir::BoxValue fir::factory::createBoxValue(fir::FirOpBuilder &builder, return fir::BoxValue(box, lbounds, explicitTypeParams); } -mlir::Value fir::factory::genCPtrOrCFunptrValue(fir::FirOpBuilder &builder, - mlir::Location loc, - mlir::Value cPtr) { - mlir::Type cPtrTy = fir::unwrapRefType(cPtr.getType()); - mlir::Value cPtrAddr = - fir::factory::genCPtrOrCFunptrAddr(builder, loc, cPtr, cPtrTy); - return builder.create(loc, cPtrAddr); -} - mlir::Value fir::factory::createNullBoxProc(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Type boxType) { diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp index 3906aa553cb34..ff37310224e85 100644 --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -59,14 +59,16 @@ static mlir::FunctionType getNewFunctionType(mlir::FunctionType funcTy, /*resultTypes=*/{}); } +static mlir::Type getVoidPtrType(mlir::MLIRContext *context) { + return fir::ReferenceType::get(mlir::NoneType::get(context)); +} + /// This is for function result types that are of type C_PTR from ISO_C_BINDING. /// Follow the ABI for interoperability with C. static mlir::FunctionType getCPtrFunctionType(mlir::FunctionType funcTy) { - auto resultType = funcTy.getResult(0); - assert(fir::isa_builtin_cptr_type(resultType)); - llvm::SmallVector outputTypes; - auto recTy = mlir::dyn_cast(resultType); - outputTypes.emplace_back(recTy.getTypeList()[0].second); + assert(fir::isa_builtin_cptr_type(funcTy.getResult(0))); + llvm::SmallVector outputTypes{ + getVoidPtrType(funcTy.getContext())}; return mlir::FunctionType::get(funcTy.getContext(), funcTy.getInputs(), outputTypes); } @@ -109,15 +111,11 @@ class CallConversion : public mlir::OpRewritePattern { saveResult.getTypeparams()); llvm::SmallVector newResultTypes; - // TODO: This should be generalized for derived types, and it is - // architecture and OS dependent. bool isResultBuiltinCPtr = fir::isa_builtin_cptr_type(result.getType()); - Op newOp; - if (isResultBuiltinCPtr) { - auto recTy = mlir::dyn_cast(result.getType()); - newResultTypes.emplace_back(recTy.getTypeList()[0].second); - } + if (isResultBuiltinCPtr) + newResultTypes.emplace_back(getVoidPtrType(result.getContext())); + Op newOp; // fir::CallOp specific handling. if constexpr (std::is_same_v) { if (op.getCallee()) { @@ -175,7 +173,7 @@ class CallConversion : public mlir::OpRewritePattern { FirOpBuilder builder(rewriter, module); mlir::Value saveAddr = fir::factory::genCPtrOrCFunptrAddr( builder, loc, save, result.getType()); - rewriter.create(loc, newOp->getResult(0), saveAddr); + builder.createStoreWithConvert(loc, newOp->getResult(0), saveAddr); } op->dropAllReferences(); rewriter.eraseOp(op); @@ -210,42 +208,52 @@ class ReturnOpConversion : public mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { auto loc = ret.getLoc(); rewriter.setInsertionPoint(ret); - auto returnedValue = ret.getOperand(0); - bool replacedStorage = false; - if (auto *op = returnedValue.getDefiningOp()) - if (auto load = mlir::dyn_cast(op)) { - auto resultStorage = load.getMemref(); - // The result alloca may be behind a fir.declare, if any. - if (auto declare = mlir::dyn_cast_or_null( - resultStorage.getDefiningOp())) - resultStorage = declare.getMemref(); - // TODO: This should be generalized for derived types, and it is - // architecture and OS dependent. - if (fir::isa_builtin_cptr_type(returnedValue.getType())) { - rewriter.eraseOp(load); - auto module = ret->getParentOfType(); - FirOpBuilder builder(rewriter, module); - mlir::Value retAddr = fir::factory::genCPtrOrCFunptrAddr( - builder, loc, resultStorage, returnedValue.getType()); - mlir::Value retValue = rewriter.create( - loc, fir::unwrapRefType(retAddr.getType()), retAddr); - rewriter.replaceOpWithNewOp( - ret, mlir::ValueRange{retValue}); - return mlir::success(); - } - resultStorage.replaceAllUsesWith(newArg); - replacedStorage = true; - if (auto *alloc = resultStorage.getDefiningOp()) - if (alloc->use_empty()) - rewriter.eraseOp(alloc); + mlir::Value resultValue = ret.getOperand(0); + fir::LoadOp resultLoad; + mlir::Value resultStorage; + // Identify result local storage. + if (auto load = resultValue.getDefiningOp()) { + resultLoad = load; + resultStorage = load.getMemref(); + // The result alloca may be behind a fir.declare, if any. + if (auto declare = resultStorage.getDefiningOp()) + resultStorage = declare.getMemref(); + } + // Replace old local storage with new storage argument, unless + // the derived type is C_PTR/C_FUN_PTR, in which case the return + // type is updated to return void* (no new argument is passed). + if (fir::isa_builtin_cptr_type(resultValue.getType())) { + auto module = ret->getParentOfType(); + FirOpBuilder builder(rewriter, module); + mlir::Value cptr = resultValue; + if (resultLoad) { + // Replace whole derived type load by component load. + cptr = resultLoad.getMemref(); + rewriter.setInsertionPoint(resultLoad); } - // The result storage may have been optimized out by a memory to - // register pass, this is possible for fir.box results, or fir.record - // with no length parameters. Simply store the result in the result storage. - // at the return point. - if (!replacedStorage) - rewriter.create(loc, returnedValue, newArg); - rewriter.replaceOpWithNewOp(ret); + mlir::Value newResultValue = + fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); + newResultValue = builder.createConvert( + loc, getVoidPtrType(ret.getContext()), newResultValue); + rewriter.setInsertionPoint(ret); + rewriter.replaceOpWithNewOp( + ret, mlir::ValueRange{newResultValue}); + } else if (resultStorage) { + resultStorage.replaceAllUsesWith(newArg); + rewriter.replaceOpWithNewOp(ret); + } else { + // The result storage may have been optimized out by a memory to + // register pass, this is possible for fir.box results, or fir.record + // with no length parameters. Simply store the result in the result + // storage. at the return point. + rewriter.create(loc, resultValue, newArg); + rewriter.replaceOpWithNewOp(ret); + } + // Delete result old local storage if unused. + if (resultStorage) + if (auto alloc = resultStorage.getDefiningOp()) + if (alloc->use_empty()) + rewriter.eraseOp(alloc); return mlir::success(); } @@ -263,8 +271,6 @@ class AddrOfOpConversion : public mlir::OpRewritePattern { mlir::PatternRewriter &rewriter) const override { auto oldFuncTy = mlir::cast(addrOf.getType()); mlir::FunctionType newFuncTy; - // TODO: This should be generalized for derived types, and it is - // architecture and OS dependent. if (oldFuncTy.getNumResults() != 0 && fir::isa_builtin_cptr_type(oldFuncTy.getResult(0))) newFuncTy = getCPtrFunctionType(oldFuncTy); @@ -298,8 +304,6 @@ class AbstractResultOpt // Convert function type itself if it has an abstract result. auto funcTy = mlir::cast(func.getFunctionType()); if (hasAbstractResult(funcTy)) { - // TODO: This should be generalized for derived types, and it is - // architecture and OS dependent. if (fir::isa_builtin_cptr_type(funcTy.getResult(0))) { func.setType(getCPtrFunctionType(funcTy)); patterns.insert(context, mlir::Value{}); diff --git a/flang/test/Fir/abstract-results.fir b/flang/test/Fir/abstract-results.fir index 82f1cd33073fd..93e63dc657f0c 100644 --- a/flang/test/Fir/abstract-results.fir +++ b/flang/test/Fir/abstract-results.fir @@ -87,8 +87,8 @@ func.func @boxfunc_callee() -> !fir.box> { // FUNC-BOX: return } -// FUNC-REF-LABEL: func @retcptr() -> i64 -// FUNC-BOX-LABEL: func @retcptr() -> i64 +// FUNC-REF-LABEL: func @retcptr() -> !fir.ref +// FUNC-BOX-LABEL: func @retcptr() -> !fir.ref func.func @retcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> { %0 = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"} %1 = fir.load %0 : !fir.ref> @@ -98,12 +98,14 @@ func.func @retcptr() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__addres // FUNC-REF: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> // FUNC-REF: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref>, !fir.field) -> !fir.ref // FUNC-REF: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref - // FUNC-REF: return %[[VAL]] : i64 + // FUNC-REF: %[[CAST:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref + // FUNC-REF: return %[[CAST]] : !fir.ref // FUNC-BOX: %[[ALLOC:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = "rec", uniq_name = "_QFrecErec"} // FUNC-BOX: %[[FIELD:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> // FUNC-BOX: %[[ADDR:.*]] = fir.coordinate_of %[[ALLOC]], %[[FIELD]] : (!fir.ref>, !fir.field) -> !fir.ref // FUNC-BOX: %[[VAL:.*]] = fir.load %[[ADDR]] : !fir.ref - // FUNC-BOX: return %[[VAL]] : i64 + // FUNC-BOX: %[[CAST:.*]] = fir.convert %[[VAL]] : (i64) -> !fir.ref + // FUNC-BOX: return %[[CAST]] : !fir.ref } // FUNC-REF-LABEL: func private @arrayfunc_callee_declare( @@ -311,8 +313,8 @@ func.func @test_address_of() { } -// FUNC-REF-LABEL: func.func private @returns_null() -> i64 -// FUNC-BOX-LABEL: func.func private @returns_null() -> i64 +// FUNC-REF-LABEL: func.func private @returns_null() -> !fir.ref +// FUNC-BOX-LABEL: func.func private @returns_null() -> !fir.ref func.func private @returns_null() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> // FUNC-REF-LABEL: func @test_address_of_cptr @@ -323,12 +325,12 @@ func.func @test_address_of_cptr() { fir.call @_QMtest_c_func_modPsubr(%1) : (() -> ()) -> () return - // FUNC-REF: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64 - // FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) + // FUNC-REF: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> !fir.ref + // FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> !fir.ref) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) // FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ()) // FUNC-REF: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> () - // FUNC-BOX: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> i64 - // FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> i64) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) + // FUNC-BOX: %[[VAL_0:.*]] = fir.address_of(@returns_null) : () -> !fir.ref + // FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[VAL_0]] : (() -> !fir.ref) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) // FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> ()) // FUNC-BOX: fir.call @_QMtest_c_func_modPsubr(%[[VAL_2]]) : (() -> ()) -> () } @@ -380,18 +382,20 @@ func.func @test_indirect_calls_return_cptr(%arg0: () -> ()) { // FUNC-REF: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"} // FUNC-REF: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) - // FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64) - // FUNC-REF: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64 + // FUNC-REF: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> !fir.ref) + // FUNC-REF: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> !fir.ref // FUNC-REF: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> // FUNC-REF: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref>, !fir.field) -> !fir.ref - // FUNC-REF: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref + // FUNC-REF: %[[CAST:.*]] = fir.convert %[[VAL_3]] : (!fir.ref) -> i64 + // FUNC-REF: fir.store %[[CAST]] to %[[VAL_5]] : !fir.ref // FUNC-BOX: %[[VAL_0:.*]] = fir.alloca !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> {bindc_name = ".result"} // FUNC-BOX: %[[VAL_1:.*]] = fir.convert %[[ARG0]] : (() -> ()) -> (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) - // FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> i64) - // FUNC-BOX: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> i64 + // FUNC-BOX: %[[VAL_2:.*]] = fir.convert %[[VAL_1]] : (() -> !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}>) -> (() -> !fir.ref) + // FUNC-BOX: %[[VAL_3:.*]] = fir.call %[[VAL_2]]() : () -> !fir.ref // FUNC-BOX: %[[VAL_4:.*]] = fir.field_index __address, !fir.type<_QM__fortran_builtinsT__builtin_c_ptr{__address:i64}> // FUNC-BOX: %[[VAL_5:.*]] = fir.coordinate_of %[[VAL_0]], %[[VAL_4]] : (!fir.ref>, !fir.field) -> !fir.ref - // FUNC-BOX: fir.store %[[VAL_3]] to %[[VAL_5]] : !fir.ref + // FUNC-BOX: %[[CAST:.*]] = fir.convert %[[VAL_3]] : (!fir.ref) -> i64 + // FUNC-BOX: fir.store %[[CAST]] to %[[VAL_5]] : !fir.ref } // ----------------------- Test GlobalOp rewrite ------------------------