Skip to content

Commit

Permalink
[flang] Lower REDUCE intrinsic with DIM argument (#94771)
Browse files Browse the repository at this point in the history
This is a follow up patch to #94652 and handles the lowering of the
reduce intrinsic with DIM argument and non scalar result.
  • Loading branch information
clementval authored Jun 11, 2024
1 parent 1934208 commit 6ffdcfa
Show file tree
Hide file tree
Showing 4 changed files with 443 additions and 1 deletion.
7 changes: 7 additions & 0 deletions flang/include/flang/Optimizer/Builder/Runtime/Reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,13 @@ mlir::Value genReduce(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value maskBox, mlir::Value identity,
mlir::Value ordered);

/// Generate call to `Reduce` intrinsic runtime routine. This is the version
/// that takes arrays of any rank with a dim argument specified.
void genReduceDim(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value arrayBox, mlir::Value operation, mlir::Value dim,
mlir::Value maskBox, mlir::Value identity,
mlir::Value ordered, mlir::Value resultBox);

} // namespace fir::runtime

#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_REDUCTION_H
12 changes: 11 additions & 1 deletion flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5790,7 +5790,17 @@ IntrinsicLibrary::genReduce(mlir::Type resultType,
return fir::runtime::genReduce(builder, loc, array, operation, mask,
identity, ordered);
}
TODO(loc, "reduce with array result");
// Handle cases that have an array result.
// Create mutable fir.box to be passed to the runtime for the result.
mlir::Type resultArrayType = builder.getVarLenSeqTy(resultType, rank - 1);
fir::MutableBoxValue resultMutableBox =
fir::factory::createTempMutableBox(builder, loc, resultArrayType);
mlir::Value resultIrBox =
fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
mlir::Value dim = fir::getBase(args[2]);
fir::runtime::genReduceDim(builder, loc, array, operation, dim, mask,
identity, ordered, resultIrBox);
return readAndAddCleanUp(resultMutableBox, resultType, "REDUCE");
}

// REPEAT
Expand Down
204 changes: 204 additions & 0 deletions flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,50 @@ struct ForcedReduceReal16 {
}
};

/// Placeholder for DIM real*10 version of Reduce Intrinsic
struct ForcedReduceReal10Dim {
static constexpr const char *name =
ExpandAndQuoteKey(RTNAME(ReduceReal10Dim));
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
return [](mlir::MLIRContext *ctx) {
auto ty = mlir::FloatType::getF80(ctx);
auto boxTy =
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
auto refTy = fir::ReferenceType::get(ty);
auto refBoxTy = fir::ReferenceType::get(boxTy);
auto i1Ty = mlir::IntegerType::get(ctx, 1);
return mlir::FunctionType::get(
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
{});
};
}
};

/// Placeholder for DIM real*16 version of Reduce Intrinsic
struct ForcedReduceReal16Dim {
static constexpr const char *name =
ExpandAndQuoteKey(RTNAME(ReduceReal16Dim));
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
return [](mlir::MLIRContext *ctx) {
auto ty = mlir::FloatType::getF128(ctx);
auto boxTy =
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
auto refTy = fir::ReferenceType::get(ty);
auto refBoxTy = fir::ReferenceType::get(boxTy);
auto i1Ty = mlir::IntegerType::get(ctx, 1);
return mlir::FunctionType::get(
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
{});
};
}
};

/// Placeholder for integer*16 version of Reduce Intrinsic
struct ForcedReduceInteger16 {
static constexpr const char *name =
Expand All @@ -525,6 +569,28 @@ struct ForcedReduceInteger16 {
}
};

/// Placeholder for DIM integer*16 version of Reduce Intrinsic
struct ForcedReduceInteger16Dim {
static constexpr const char *name =
ExpandAndQuoteKey(RTNAME(ReduceInteger16Dim));
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
return [](mlir::MLIRContext *ctx) {
auto ty = mlir::IntegerType::get(ctx, 128);
auto boxTy =
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
auto refTy = fir::ReferenceType::get(ty);
auto refBoxTy = fir::ReferenceType::get(boxTy);
auto i1Ty = mlir::IntegerType::get(ctx, 1);
return mlir::FunctionType::get(
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
{});
};
}
};

/// Placeholder for complex(10) version of Reduce Intrinsic
struct ForcedReduceComplex10 {
static constexpr const char *name =
Expand All @@ -546,6 +612,28 @@ struct ForcedReduceComplex10 {
}
};

/// Placeholder for Dim complex(10) version of Reduce Intrinsic
struct ForcedReduceComplex10Dim {
static constexpr const char *name =
ExpandAndQuoteKey(RTNAME(CppReduceComplex10Dim));
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
return [](mlir::MLIRContext *ctx) {
auto ty = mlir::ComplexType::get(mlir::FloatType::getF80(ctx));
auto boxTy =
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
auto refTy = fir::ReferenceType::get(ty);
auto refBoxTy = fir::ReferenceType::get(boxTy);
auto i1Ty = mlir::IntegerType::get(ctx, 1);
return mlir::FunctionType::get(
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
{});
};
}
};

/// Placeholder for complex(16) version of Reduce Intrinsic
struct ForcedReduceComplex16 {
static constexpr const char *name =
Expand All @@ -567,6 +655,28 @@ struct ForcedReduceComplex16 {
}
};

/// Placeholder for Dim complex(16) version of Reduce Intrinsic
struct ForcedReduceComplex16Dim {
static constexpr const char *name =
ExpandAndQuoteKey(RTNAME(CppReduceComplex16Dim));
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
return [](mlir::MLIRContext *ctx) {
auto ty = mlir::ComplexType::get(mlir::FloatType::getF128(ctx));
auto boxTy =
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
auto refTy = fir::ReferenceType::get(ty);
auto refBoxTy = fir::ReferenceType::get(boxTy);
auto i1Ty = mlir::IntegerType::get(ctx, 1);
return mlir::FunctionType::get(
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
{});
};
}
};

/// Generate call to specialized runtime function that takes a mask and
/// dim argument. The All, Any, and Count intrinsics use this pattern.
template <typename FN>
Expand Down Expand Up @@ -1461,3 +1571,97 @@ mlir::Value fir::runtime::genReduce(fir::FirOpBuilder &builder,
maskBox, identity, ordered);
return builder.create<fir::CallOp>(loc, func, args).getResult(0);
}

void fir::runtime::genReduceDim(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value arrayBox, mlir::Value operation,
mlir::Value dim, mlir::Value maskBox,
mlir::Value identity, mlir::Value ordered,
mlir::Value resultBox) {
mlir::func::FuncOp func;
auto ty = arrayBox.getType();
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();

mlir::MLIRContext *ctx = builder.getContext();
fir::factory::CharacterExprHelper charHelper{builder, loc};

if (eleTy.isF16())
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal2Dim)>(loc, builder);
else if (eleTy.isBF16())
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal3Dim)>(loc, builder);
else if (eleTy.isF32())
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal4Dim)>(loc, builder);
else if (eleTy.isF64())
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal8Dim)>(loc, builder);
else if (eleTy.isF80())
func = fir::runtime::getRuntimeFunc<ForcedReduceReal10Dim>(loc, builder);
else if (eleTy.isF128())
func = fir::runtime::getRuntimeFunc<ForcedReduceReal16Dim>(loc, builder);
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
func =
fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger1Dim)>(loc, builder);
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
func =
fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger2Dim)>(loc, builder);
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
func =
fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger4Dim)>(loc, builder);
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
func =
fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger8Dim)>(loc, builder);
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
func = fir::runtime::getRuntimeFunc<ForcedReduceInteger16Dim>(loc, builder);
else if (eleTy == fir::ComplexType::get(ctx, 2))
func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex2Dim)>(loc,
builder);
else if (eleTy == fir::ComplexType::get(ctx, 3))
func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex3Dim)>(loc,
builder);
else if (eleTy == fir::ComplexType::get(ctx, 4))
func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex4Dim)>(loc,
builder);
else if (eleTy == fir::ComplexType::get(ctx, 8))
func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex8Dim)>(loc,
builder);
else if (eleTy == fir::ComplexType::get(ctx, 10))
func = fir::runtime::getRuntimeFunc<ForcedReduceComplex10Dim>(loc, builder);
else if (eleTy == fir::ComplexType::get(ctx, 16))
func = fir::runtime::getRuntimeFunc<ForcedReduceComplex16Dim>(loc, builder);
else if (eleTy == fir::LogicalType::get(ctx, 1))
func =
fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical1Dim)>(loc, builder);
else if (eleTy == fir::LogicalType::get(ctx, 2))
func =
fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical2Dim)>(loc, builder);
else if (eleTy == fir::LogicalType::get(ctx, 4))
func =
fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical4Dim)>(loc, builder);
else if (eleTy == fir::LogicalType::get(ctx, 8))
func =
fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical8Dim)>(loc, builder);
else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 1)
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceCharacter1Dim)>(loc,
builder);
else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 2)
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceCharacter2Dim)>(loc,
builder);
else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 4)
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceCharacter4Dim)>(loc,
builder);
else if (fir::isa_derived(eleTy))
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceDerivedTypeDim)>(loc,
builder);
else
fir::intrinsicTypeTODO(builder, eleTy, loc, "REDUCE");

auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);

auto sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
auto opAddr = builder.create<fir::BoxAddrOp>(loc, fTy.getInput(2), operation);
auto args = fir::runtime::createArguments(
builder, loc, fTy, resultBox, arrayBox, opAddr, sourceFile, sourceLine,
dim, maskBox, identity, ordered);
builder.create<fir::CallOp>(loc, func, args);
}
Loading

0 comments on commit 6ffdcfa

Please sign in to comment.