Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu committed Jun 28, 2024
1 parent 329efd7 commit 12912a9
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions compiler/lib/Conversion/ToLinalg/HloToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1297,7 +1297,7 @@ inline Value mapFPToSIConvertOpToStdScalarOp(Location loc,
dyn_cast<FloatType>(convertedSourceType).getFloatSemantics())));
Value isInf = b->create<mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::OEQ,
args.front(), infValue);
Value isNan = b->create<mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::UNO,
Value isNan = b->create<mlir::arith::CmpFOp>(loc, arith::CmpFPredicate::UNE,
args.front(), args.front());
Value maxIntval = b->create<arith::ConstantOp>(
loc,
Expand All @@ -1324,25 +1324,20 @@ class FPToSIConvertOpConverter : public OpConversionPattern<mhlo::ConvertOp> {
matchAndRewrite(mhlo::ConvertOp op, typename mhlo::ConvertOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
auto loc = op.getLoc();
RankedTensorType type = dyn_cast<RankedTensorType>(op.getType());
if (!type || !type.hasStaticShape()) {
RankedTensorType inputType =
dyn_cast<RankedTensorType>(op.getOperand().getType());
RankedTensorType outType = dyn_cast<RankedTensorType>(op.getType());
if (!inputType || !outType) {
return failure();
}
// Apply only if convert type is FPToInt32
if (!mlir::arith::FPToSIOp::areCastCompatible(op.getOperand().getType(),
op.getType())) {
return failure();
}
auto targetType = op.getType().getElementType();
if (isa<IntegerType>(targetType) &&
(cast<IntegerType>(targetType).getWidth() != 32 ||
cast<IntegerType>(targetType).isUnsigned())) {
if (!inputType.getElementType().isF32() ||
!outType.getElementType().isSignlessInteger(32)) {
return failure();
}
// Find input/output values and types.
std::optional<ShapedType> resultTy =
this->typeConverter->convertType(op->getResultTypes().front())
.template dyn_cast<ShapedType>();
dyn_cast<ShapedType>(this->typeConverter->convertType(op.getType()));
Value emptyTensor =
getEmptyTensorFor(rewriter, loc, *resultTy, op, adaptor.getOperands());
// Mapped inputs are cast to the same shape as the init tensor.
Expand Down

0 comments on commit 12912a9

Please sign in to comment.