Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] add getElementType() to fir::SquenceType and fir::VectorType #112770

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/Builder/PPCIntrinsicCall.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ struct VecTypeInfo {
static inline VecTypeInfo getVecTypeFromFirType(mlir::Type firTy) {
assert(mlir::isa<fir::VectorType>(firTy));
VecTypeInfo vecTyInfo;
vecTyInfo.eleTy = mlir::dyn_cast<fir::VectorType>(firTy).getEleTy();
vecTyInfo.eleTy = mlir::dyn_cast<fir::VectorType>(firTy).getElementType();
vecTyInfo.len = mlir::dyn_cast<fir::VectorType>(firTy).getLen();
return vecTyInfo;
}
Expand Down
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,8 @@ def fir_SequenceType : FIR_Type<"Sequence", "array"> {
size = size * static_cast<std::uint64_t>(extent);
return size;
}

mlir::Type getElementType() const { return getEleTy(); }
}];
}

Expand Down Expand Up @@ -519,6 +521,8 @@ def fir_VectorType : FIR_Type<"Vector", "vector"> {

let extraClassDeclaration = [{
static bool isValidElementType(mlir::Type t);

mlir::Type getElementType() const { return getEleTy(); }
}];

let skipDefaultBuilders = 1;
Expand Down
7 changes: 4 additions & 3 deletions flang/lib/Lower/ConvertConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,8 @@ genInlinedArrayLit(Fortran::lower::AbstractConverter &converter,
} while (con.IncrementSubscripts(subscripts));
} else if constexpr (T::category == Fortran::common::TypeCategory::Derived) {
do {
mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrayTy).getEleTy();
mlir::Type eleTy =
mlir::cast<fir::SequenceType>(arrayTy).getElementType();
mlir::Value elementVal =
genScalarLit(converter, loc, con.At(subscripts), eleTy,
/*outlineInReadOnlyMemory=*/false);
Expand All @@ -594,7 +595,7 @@ genInlinedArrayLit(Fortran::lower::AbstractConverter &converter,
} else {
llvm::SmallVector<mlir::Attribute> rangeStartIdx;
uint64_t rangeSize = 0;
mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrayTy).getEleTy();
mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrayTy).getElementType();
do {
auto getElementVal = [&]() {
return builder.createConvert(loc, eleTy,
Expand Down Expand Up @@ -643,7 +644,7 @@ genOutlineArrayLit(Fortran::lower::AbstractConverter &converter,
mlir::Location loc, mlir::Type arrayTy,
const Fortran::evaluate::Constant<T> &constant) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrayTy).getEleTy();
mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrayTy).getElementType();
llvm::StringRef globalName = converter.getUniqueLitName(
loc, std::make_unique<Fortran::lower::SomeExpr>(toEvExpr(constant)),
eleTy);
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Lower/ConvertExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,7 @@ class ScalarExprLowering {
mlir::Location loc = getLoc();
mlir::Value addr = fir::getBase(array);
mlir::Type arrTy = fir::dyn_cast_ptrEleTy(addr.getType());
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
mlir::Type seqTy = builder.getRefType(builder.getVarLenSeqTy(eleTy));
mlir::Type refTy = builder.getRefType(eleTy);
mlir::Value base = builder.createConvert(loc, seqTy, addr);
Expand Down Expand Up @@ -1659,7 +1659,7 @@ class ScalarExprLowering {
mlir::Location loc = getLoc();
mlir::Value addr = fir::getBase(exv);
mlir::Type arrTy = fir::dyn_cast_ptrOrBoxEleTy(addr.getType());
mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
mlir::Type refTy = builder.getRefType(eleTy);
mlir::IndexType idxTy = builder.getIndexType();
llvm::SmallVector<mlir::Value> arrayCoorArgs;
Expand Down Expand Up @@ -4145,7 +4145,7 @@ class ArrayExprLowering {
mlir::Location loc = getLoc();
return [=, builder = &converter.getFirOpBuilder()](IterSpace iters) {
mlir::Type arrTy = fir::dyn_cast_ptrOrBoxEleTy(tmp.getType());
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
mlir::Type eleRefTy = builder->getRefType(eleTy);
mlir::IntegerType i1Ty = builder->getI1Type();
// Adjust indices for any shift of the origin of the array.
Expand Down Expand Up @@ -5759,7 +5759,7 @@ class ArrayExprLowering {
return fir::BoxValue(embox, lbounds, nonDeferredLenParams);
};
}
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
if (isReferentiallyOpaque()) {
// Semantics are an opaque reference to an array.
// This case forwards a continuation that will generate the address
Expand Down
5 changes: 3 additions & 2 deletions flang/lib/Lower/ConvertExprToHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ class HlfirDesignatorBuilder {
return createVectorSubscriptElementAddrOp(partInfo, baseType,
resultExtents);

mlir::Type resultType = mlir::cast<fir::SequenceType>(baseType).getEleTy();
mlir::Type resultType =
mlir::cast<fir::SequenceType>(baseType).getElementType();
if (!resultTypeShape.empty()) {
// Ranked array section. The result shape comes from the array section
// subscripts.
Expand Down Expand Up @@ -811,7 +812,7 @@ class HlfirDesignatorBuilder {
}
}
builder.setInsertionPoint(elementalAddrOp);
return mlir::cast<fir::SequenceType>(baseType).getEleTy();
return mlir::cast<fir::SequenceType>(baseType).getElementType();
}

/// Yield the designator for the final part-ref inside the
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/ConvertVariable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ static fir::GlobalOp defineGlobal(Fortran::lower::AbstractConverter &converter,
// type does not support nested structures.
if (mlir::isa<fir::SequenceType>(symTy) &&
!Fortran::semantics::IsAllocatableOrPointer(sym)) {
mlir::Type eleTy = mlir::cast<fir::SequenceType>(symTy).getEleTy();
mlir::Type eleTy = mlir::cast<fir::SequenceType>(symTy).getElementType();
if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType,
fir::LogicalType>(eleTy)) {
const auto *details =
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3824,7 +3824,7 @@ IntrinsicLibrary::genReduction(FN func, FD funcDim, llvm::StringRef errMsg,
if (absentDim || rank == 1) {
mlir::Type ty = array.getType();
mlir::Type arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
if (fir::isa_complex(eleTy)) {
mlir::Value result = builder.createTemporary(loc, eleTy);
func(builder, loc, array, mask, result);
Expand Down Expand Up @@ -6137,7 +6137,7 @@ IntrinsicLibrary::genReduce(mlir::Type resultType,

mlir::Type ty = array.getType();
mlir::Type arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();

// Handle optional arguments
bool absentDim = isStaticallyAbsent(args[2]);
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/Builder/PPCIntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2797,7 +2797,7 @@ void PPCIntrinsicLibrary::genMmaIntr(llvm::ArrayRef<fir::ExtendedValue> args) {
if (vType != targetType) {
if (mlir::isa<mlir::VectorType>(targetType)) {
// Perform vector type conversion for arguments passed by value.
auto eleTy{mlir::dyn_cast<fir::VectorType>(vType).getEleTy()};
auto eleTy{mlir::dyn_cast<fir::VectorType>(vType).getElementType()};
auto len{mlir::dyn_cast<fir::VectorType>(vType).getLen()};
mlir::VectorType mlirType = mlir::VectorType::get(len, eleTy);
auto v0{builder.createConvert(loc, mlirType, v)};
Expand Down
24 changes: 12 additions & 12 deletions flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1157,7 +1157,7 @@ void fir::runtime::genMaxloc(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value back) {
auto ty = arrayBox.getType();
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
fir::factory::CharacterExprHelper charHelper{builder, loc};
auto [cat, kind] = fir::mlirTypeToCategoryKind(loc, eleTy);
mlir::func::FuncOp func;
Expand Down Expand Up @@ -1189,7 +1189,7 @@ mlir::Value fir::runtime::genMaxval(fir::FirOpBuilder &builder,
mlir::Value maskBox) {
auto ty = arrayBox.getType();
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
auto [cat, kind] = fir::mlirTypeToCategoryKind(loc, eleTy);
mlir::func::FuncOp func;
Expand Down Expand Up @@ -1241,7 +1241,7 @@ void fir::runtime::genMinloc(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value back) {
auto ty = arrayBox.getType();
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
auto [cat, kind] = fir::mlirTypeToCategoryKind(loc, eleTy);
mlir::func::FuncOp func;
REAL_INTRINSIC_INSTANCES(Minloc, )
Expand Down Expand Up @@ -1298,7 +1298,7 @@ mlir::Value fir::runtime::genMinval(fir::FirOpBuilder &builder,
mlir::Value maskBox) {
auto ty = arrayBox.getType();
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);
auto [cat, kind] = fir::mlirTypeToCategoryKind(loc, eleTy);

Expand Down Expand Up @@ -1326,7 +1326,7 @@ void fir::runtime::genNorm2Dim(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::func::FuncOp func;
auto ty = arrayBox.getType();
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
if (eleTy.isF128())
func = fir::runtime::getRuntimeFunc<ForcedNorm2DimReal16>(loc, builder);
else
Expand All @@ -1348,7 +1348,7 @@ mlir::Value fir::runtime::genNorm2(fir::FirOpBuilder &builder,
mlir::func::FuncOp func;
auto ty = arrayBox.getType();
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);

if (eleTy.isF32())
Expand Down Expand Up @@ -1398,7 +1398,7 @@ mlir::Value fir::runtime::genProduct(fir::FirOpBuilder &builder,
mlir::Value resultBox) {
auto ty = arrayBox.getType();
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);

auto [cat, kind] = fir::mlirTypeToCategoryKind(loc, eleTy);
Expand Down Expand Up @@ -1482,7 +1482,7 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value resultBox) {
auto ty = arrayBox.getType();
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0);

auto [cat, kind] = fir::mlirTypeToCategoryKind(loc, eleTy);
Expand Down Expand Up @@ -1521,7 +1521,7 @@ mlir::Value fir::runtime::genSum(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::func::FuncOp func; \
auto ty = arrayBox.getType(); \
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty); \
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy(); \
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType(); \
auto dim = builder.createIntegerConstant(loc, builder.getIndexType(), 0); \
\
if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1))) \
Expand Down Expand Up @@ -1596,7 +1596,7 @@ void fir::runtime::genReduce(fir::FirOpBuilder &builder, mlir::Location loc,
bool argByRef) {
auto ty = arrayBox.getType();
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
auto dim = builder.createIntegerConstant(loc, builder.getI32Type(), 1);

assert(resultBox && "expect non null value for the result");
Expand Down Expand Up @@ -1646,7 +1646,7 @@ mlir::Value fir::runtime::genReduce(fir::FirOpBuilder &builder,
bool argByRef) {
auto ty = arrayBox.getType();
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
auto dim = builder.createIntegerConstant(loc, builder.getI32Type(), 1);

assert((fir::isa_real(eleTy) || fir::isa_integer(eleTy) ||
Expand Down Expand Up @@ -1687,7 +1687,7 @@ void fir::runtime::genReduceDim(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value resultBox, bool argByRef) {
auto ty = arrayBox.getType();
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getElementType();
auto [cat, kind] = fir::mlirTypeToCategoryKind(loc, eleTy);

mlir::func::FuncOp func;
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,11 @@ void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::func::FuncOp func;
auto boxATy = matrixABox.getType();
auto arrATy = fir::dyn_cast_ptrOrBoxEleTy(boxATy);
auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy();
auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getElementType();
auto [aCat, aKind] = fir::mlirTypeToCategoryKind(loc, arrAEleTy);
auto boxBTy = matrixBBox.getType();
auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy(boxBTy);
auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy();
auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getElementType();
auto [bCat, bKind] = fir::mlirTypeToCategoryKind(loc, arrBEleTy);

#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
Expand Down Expand Up @@ -417,11 +417,11 @@ void fir::runtime::genMatmulTranspose(fir::FirOpBuilder &builder,
mlir::func::FuncOp func;
auto boxATy = matrixABox.getType();
auto arrATy = fir::dyn_cast_ptrOrBoxEleTy(boxATy);
auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy();
auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getElementType();
auto [aCat, aKind] = fir::mlirTypeToCategoryKind(loc, arrAEleTy);
auto boxBTy = matrixBBox.getType();
auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy(boxBTy);
auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy();
auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getElementType();
auto [bCat, bKind] = fir::mlirTypeToCategoryKind(loc, arrBEleTy);

#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2619,7 +2619,7 @@ struct CoordinateOpConversion
dims = dimsLeft - 1;
continue;
}
cpnTy = mlir::cast<fir::SequenceType>(cpnTy).getEleTy();
cpnTy = mlir::cast<fir::SequenceType>(cpnTy).getElementType();
// append array range in reverse (FIR arrays are column-major)
offs.append(arrIdx.rbegin(), arrIdx.rend());
arrIdx.clear();
Expand All @@ -2633,7 +2633,7 @@ struct CoordinateOpConversion
arrIdx.push_back(nxtOpnd);
continue;
}
cpnTy = mlir::cast<fir::SequenceType>(cpnTy).getEleTy();
cpnTy = mlir::cast<fir::SequenceType>(cpnTy).getElementType();
offs.push_back(nxtOpnd);
continue;
}
Expand Down
6 changes: 3 additions & 3 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,7 @@ bool fir::ConvertOp::isPointerCompatible(mlir::Type ty) {
static std::optional<mlir::Type> getVectorElementType(mlir::Type ty) {
mlir::Type elemTy;
if (mlir::isa<fir::VectorType>(ty))
elemTy = mlir::dyn_cast<fir::VectorType>(ty).getEleTy();
elemTy = mlir::dyn_cast<fir::VectorType>(ty).getElementType();
else if (mlir::isa<mlir::VectorType>(ty))
elemTy = mlir::dyn_cast<mlir::VectorType>(ty).getElementType();
else
Expand Down Expand Up @@ -1533,7 +1533,7 @@ llvm::LogicalResult fir::CoordinateOp::verify() {
}
if (dimension) {
if (--dimension == 0)
eleTy = mlir::cast<fir::SequenceType>(eleTy).getEleTy();
eleTy = mlir::cast<fir::SequenceType>(eleTy).getElementType();
} else {
if (auto t = mlir::dyn_cast<mlir::TupleType>(eleTy)) {
// FIXME: Generally, we don't know which field of the tuple is being
Expand Down Expand Up @@ -3817,7 +3817,7 @@ void fir::StoreOp::build(mlir::OpBuilder &builder, mlir::OperationState &result,
//===----------------------------------------------------------------------===//

inline fir::CharacterType::KindTy stringLitOpGetKind(fir::StringLitOp op) {
auto eleTy = mlir::cast<fir::SequenceType>(op.getType()).getEleTy();
auto eleTy = mlir::cast<fir::SequenceType>(op.getType()).getElementType();
return mlir::cast<fir::CharacterType>(eleTy).getFKind();
}

Expand Down
Loading