diff --git a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h index 9cb43689d1ce64..9ef053f1e66361 100644 --- a/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h +++ b/mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h @@ -13,8 +13,8 @@ namespace mlir { class RewritePatternSet; class TypeConverter; -void populateArithToEmitCPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns); +void populateArithToEmitCPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter); } // namespace mlir #endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H diff --git a/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h index 5c7f87e470306a..ac6fe2da7d42ff 100644 --- a/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h +++ b/mlir/include/mlir/Conversion/FuncToEmitC/FuncToEmitC.h @@ -9,10 +9,12 @@ #ifndef MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H #define MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H +#include "mlir/Transforms/DialectConversion.h" namespace mlir { class RewritePatternSet; -void populateFuncToEmitCPatterns(RewritePatternSet &patterns); +void populateFuncToEmitCPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter); } // namespace mlir #endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H diff --git a/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h b/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h index 22df7f1c5dcf29..acc39e6acf726f 100644 --- a/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h +++ b/mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h @@ -9,6 +9,7 @@ #ifndef MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H #define MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H +#include "mlir/Transforms/DialectConversion.h" #include namespace mlir { @@ -19,7 +20,8 @@ class RewritePatternSet; #include "mlir/Conversion/Passes.h.inc" /// Collect a set of patterns to convert SCF operations to the EmitC dialect. -void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns); +void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter); } // namespace mlir #endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td index 0f080ac4433273..f6a8bd4ef59e10 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td @@ -76,6 +76,7 @@ def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> { static bool isValidElementType(Type type) { return type.isIntOrIndexOrFloat() || + emitc::isAnySizeTType(type) || llvm::isa(type); } }]; diff --git a/mlir/include/mlir/Dialect/EmitC/Transforms/TypeConversions.h b/mlir/include/mlir/Dialect/EmitC/Transforms/TypeConversions.h index da16b336b8bc37..1c05a927e948ae 100644 --- a/mlir/include/mlir/Dialect/EmitC/Transforms/TypeConversions.h +++ b/mlir/include/mlir/Dialect/EmitC/Transforms/TypeConversions.h @@ -6,8 +6,12 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Transforms/DialectConversion.h" +#ifndef MLIR_DIALECT_EMITC_TRANSFORMS_TYPECONVERSIONS_H +#define MLIR_DIALECT_EMITC_TRANSFORMS_TYPECONVERSIONS_H namespace mlir { -void populateEmitCSizeTypeConversionPatterns(mlir::TypeConverter &converter); +class TypeConverter; +void populateEmitCSizeTypeConversions(TypeConverter &converter); } // namespace mlir + +#endif // MLIR_DIALECT_EMITC_TRANSFORMS_TYPECONVERSIONS_H diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 62067c5e256449..04344b2c2ab3c1 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -563,11 +563,11 @@ class ItoFCastOpConversion : public OpConversionPattern { // Pattern population //===----------------------------------------------------------------------===// -void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter, - RewritePatternSet &patterns) { +void mlir::populateArithToEmitCPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter) { MLIRContext *ctx = patterns.getContext(); - mlir::populateEmitCSizeTypeConversionPatterns(typeConverter); + mlir::populateEmitCSizeTypeConversions(typeConverter); // clang-format off patterns.add< diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp index 76e7707ce7109e..7713506ef396ba 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -43,9 +44,11 @@ void ConvertArithToEmitC::runOnOperation() { RewritePatternSet patterns(&getContext()); TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { return type; }); - - populateArithToEmitCPatterns(typeConverter, patterns); + // Fallback converter + // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter + // Type converters are called most to least recently inserted + typeConverter.addConversion([](Type t) { return t; }); + populateArithToEmitCPatterns(patterns, typeConverter); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp index 6a8ecb7b00473a..29c64487a2bc09 100644 --- a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitC.cpp @@ -36,10 +36,17 @@ class CallOpConversion final : public OpConversionPattern { return rewriter.notifyMatchFailure( callOp, "only functions with zero or one result can be converted"); + // Convert the original function results. + Type resultTy = nullptr; + if (callOp.getNumResults()) { + resultTy = typeConverter->convertType(callOp.getResult(0).getType()); + if (!resultTy) + return rewriter.notifyMatchFailure( + callOp, "function return type conversion failed"); + } + rewriter.replaceOpWithNewOp( - callOp, - callOp.getNumResults() ? callOp.getResult(0).getType() : nullptr, - adaptor.getOperands(), callOp->getAttrs()); + callOp, resultTy, adaptor.getOperands(), callOp->getAttrs()); return success(); } @@ -53,13 +60,34 @@ class FuncOpConversion final : public OpConversionPattern { matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (funcOp.getFunctionType().getNumResults() > 1) + FunctionType type = funcOp.getFunctionType(); + if (!type) + return failure(); + + if (type.getNumResults() > 1) return rewriter.notifyMatchFailure( funcOp, "only functions with zero or one result can be converted"); + const TypeConverter *converter = getTypeConverter(); + + // Convert function signature + TypeConverter::SignatureConversion signatureConversion(type.getNumInputs()); + SmallVector convertedResults; + if (failed(converter->convertSignatureArgs(type.getInputs(), + signatureConversion)) || + failed(converter->convertTypes(type.getResults(), convertedResults)) || + failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), + *converter, &signatureConversion))) + return rewriter.notifyMatchFailure(funcOp, "signature conversion failed"); + + // Convert the function type + auto convertedFunctionType = FunctionType::get( + rewriter.getContext(), signatureConversion.getConvertedTypes(), + convertedResults); + // Create the converted `emitc.func` op. emitc::FuncOp newFuncOp = rewriter.create( - funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType()); + funcOp.getLoc(), funcOp.getName(), convertedFunctionType); // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp->getAttrs()) { @@ -113,8 +141,10 @@ class ReturnOpConversion final : public OpConversionPattern { // Pattern population //===----------------------------------------------------------------------===// -void mlir::populateFuncToEmitCPatterns(RewritePatternSet &patterns) { +void mlir::populateFuncToEmitCPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter) { MLIRContext *ctx = patterns.getContext(); - patterns.add(ctx); + patterns.add( + typeConverter, ctx); } diff --git a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp index 0b97f2641ad08d..65aa5bde96e5ad 100644 --- a/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp +++ b/mlir/lib/Conversion/FuncToEmitC/FuncToEmitCPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -33,13 +34,20 @@ struct ConvertFuncToEmitC } // namespace void ConvertFuncToEmitC::runOnOperation() { + TypeConverter typeConverter; + // Fallback converter + // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter + // Type converters are called most to least recently inserted + typeConverter.addConversion([](Type t) { return t; }); + populateEmitCSizeTypeConversions(typeConverter); + ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalOp(); RewritePatternSet patterns(&getContext()); - populateFuncToEmitCPatterns(patterns); + populateFuncToEmitCPatterns(patterns, typeConverter); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index 4e5d1912d15729..63ae47b32a5900 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -33,12 +34,13 @@ struct ConvertMemRefToEmitCPass // Fallback for other types. converter.addConversion([](Type type) -> std::optional { - if (isa(type)) - return {}; - return type; + if (emitc::isSupportedEmitCType(type)) + return type; + return {}; }); populateMemRefToEmitCTypeConversion(converter); + populateEmitCSizeTypeConversions(converter); RewritePatternSet patterns(&getContext()); populateMemRefToEmitCConversionPatterns(patterns, converter); diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 367142a5207427..55b46b346949c7 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -21,6 +22,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/OneToNTypeConversion.h" #include "mlir/Transforms/Passes.h" namespace mlir { @@ -39,21 +41,22 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase { // Lower scf::for to emitc::for, implementing result values using // emitc::variable's updated within the loop body. -struct ForLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct ForLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(ForOp forOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; // Create an uninitialized emitc::variable op for each result of the given op. template -static SmallVector createVariablesForResults(T op, - PatternRewriter &rewriter) { - SmallVector resultVariables; - +static LogicalResult +createVariablesForResults(T op, const TypeConverter *typeConverter, + ConversionPatternRewriter &rewriter, + SmallVector &resultVariables) { if (!op.getNumResults()) - return resultVariables; + return success(); Location loc = op->getLoc(); MLIRContext *context = op.getContext(); @@ -62,26 +65,29 @@ static SmallVector createVariablesForResults(T op, rewriter.setInsertionPoint(op); for (OpResult result : op.getResults()) { - Type resultType = result.getType(); + Type resultType = typeConverter->convertType(result.getType()); + if (!resultType) + return rewriter.notifyMatchFailure(op, "result type conversion failed"); emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, ""); emitc::VariableOp var = rewriter.create(loc, resultType, noInit); resultVariables.push_back(var); } - return resultVariables; + return success(); } // Create a series of assign ops assigning given values to given variables at // the current insertion point of given rewriter. static void assignValues(ValueRange values, SmallVector &variables, - PatternRewriter &rewriter, Location loc) { + ConversionPatternRewriter &rewriter, Location loc) { for (auto [value, var] : llvm::zip(values, variables)) rewriter.create(loc, var, value); } static void lowerYield(SmallVector &resultVariables, - PatternRewriter &rewriter, scf::YieldOp yield) { + ConversionPatternRewriter &rewriter, + scf::YieldOp yield) { Location loc = yield.getLoc(); ValueRange operands = yield.getOperands(); @@ -94,21 +100,28 @@ static void lowerYield(SmallVector &resultVariables, rewriter.eraseOp(yield); } -LogicalResult ForLowering::matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const { +LogicalResult +ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = forOp.getLoc(); // Create an emitc::variable op for each result. These variables will be // assigned to by emitc::assign ops within the loop body. - SmallVector resultVariables = - createVariablesForResults(forOp, rewriter); - SmallVector iterArgsVariables = - createVariablesForResults(forOp, rewriter); + SmallVector resultVariables; + if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter, + resultVariables))) + return rewriter.notifyMatchFailure(forOp, + "create variables for results failed"); + SmallVector iterArgsVariables; + if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter, + iterArgsVariables))) + return rewriter.notifyMatchFailure(forOp, + "create variables for iter args failed"); assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc); emitc::ForOp loweredFor = rewriter.create( - loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()); + loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep()); Block *loweredBody = loweredFor.getBody(); @@ -119,7 +132,8 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, replacingValues.push_back(loweredFor.getInductionVar()); replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end()); - rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues); + Block *adaptorBody = &(adaptor.getRegion().front()); + rewriter.mergeBlocks(adaptorBody, loweredBody, replacingValues); lowerYield(iterArgsVariables, rewriter, cast(loweredBody->getTerminator())); @@ -132,23 +146,28 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, // Lower scf::if to emitc::if, implementing result values as emitc::variable's // updated within the then and else regions. -struct IfLowering : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +struct IfLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(IfOp ifOp, - PatternRewriter &rewriter) const override; + LogicalResult + matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; }; } // namespace -LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, - PatternRewriter &rewriter) const { +LogicalResult +IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Location loc = ifOp.getLoc(); // Create an emitc::variable op for each result. These variables will be // assigned to by emitc::assign ops within the then & else regions. - SmallVector resultVariables = - createVariablesForResults(ifOp, rewriter); + SmallVector resultVariables; + if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter, + resultVariables))) + return rewriter.notifyMatchFailure(ifOp, + "create variables for results failed"); // Utility function to lower the contents of an scf::if region to an emitc::if // region. The contents of the scf::if regions is moved into the respective @@ -162,13 +181,13 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, lowerYield(resultVariables, rewriter, cast(terminator)); }; - Region &thenRegion = ifOp.getThenRegion(); - Region &elseRegion = ifOp.getElseRegion(); + Region &thenRegion = adaptor.getThenRegion(); + Region &elseRegion = adaptor.getElseRegion(); bool hasElseBlock = !elseRegion.empty(); auto loweredIf = - rewriter.create(loc, ifOp.getCondition(), false, false); + rewriter.create(loc, adaptor.getCondition(), false, false); Region &loweredThenRegion = loweredIf.getThenRegion(); lowerRegion(thenRegion, loweredThenRegion); @@ -182,14 +201,21 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, return success(); } -void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); - patterns.add(patterns.getContext()); +void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add(typeConverter, patterns.getContext()); + patterns.add(typeConverter, patterns.getContext()); } void SCFToEmitCPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - populateSCFToEmitCConversionPatterns(patterns); + TypeConverter typeConverter; + // Fallback converter + // See note https://mlir.llvm.org/docs/DialectConversion/#type-converter + // Type converters are called most to least recently inserted + typeConverter.addConversion([](Type t) { return t; }); + populateEmitCSizeTypeConversions(typeConverter); + populateSCFToEmitCConversionPatterns(patterns, typeConverter); // Configure conversion to lower out SCF operations. ConversionTarget target(getContext()); diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 71d354885a900f..3411a37dbae902 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -431,7 +431,8 @@ void ForOp::print(OpAsmPrinter &p) { << getUpperBound() << " step " << getStep(); p << ' '; - if (Type t = getInductionVar().getType(); !t.isIndex()) + if (Type t = getInductionVar().getType(); + !(t.isIndex() || emitc::isAnySizeTType(t))) p << " : " << t << ' '; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp index d896f95b0ab8f7..02987df56f0422 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp @@ -9,10 +9,31 @@ #include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Transforms/DialectConversion.h" +#include using namespace mlir; -void mlir::populateEmitCSizeTypeConversionPatterns(TypeConverter &converter) { +namespace { + +std::optional materializeAsUnrealizedCast(OpBuilder &builder, + Type resultType, + ValueRange inputs, + Location loc) { + if (inputs.size() != 1) + return std::nullopt; + + return builder.create(loc, resultType, inputs) + .getResult(0); +} + +} // namespace + +void mlir::populateEmitCSizeTypeConversions(TypeConverter &converter) { converter.addConversion( [](IndexType type) { return emitc::SizeTType::get(type.getContext()); }); + + converter.addSourceMaterialization(materializeAsUnrealizedCast); + converter.addTargetMaterialization(materializeAsUnrealizedCast); + converter.addArgumentMaterialization(materializeAsUnrealizedCast); } diff --git a/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir index 5c96cf1ce0d34c..854d8f3604f442 100644 --- a/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir +++ b/mlir/test/Conversion/FuncToEmitC/func-to-emitc.mlir @@ -58,3 +58,33 @@ func.func @call(%arg0: i32) -> i32 { // CHECK-LABEL: emitc.func private @return_i32(i32) -> i32 attributes {specifiers = ["extern"]} func.func private @return_i32(%arg0: i32) -> i32 + +// ----- + +// CHECK-LABEL: emitc.func @use_index +// CHECK-SAME: (%[[Arg0:.*]]: !emitc.size_t) -> !emitc.size_t +// CHECK: emitc.return %[[Arg0]] : !emitc.size_t +func.func @use_index(%arg0: index) -> index { + return %arg0 : index +} + +// ----- + +// CHECK-LABEL: emitc.func private @prototype_index(!emitc.size_t) -> !emitc.size_t attributes {specifiers = ["extern"]} +func.func private @prototype_index(%arg0: index) -> index + +// CHECK-LABEL: emitc.func @call(%arg0: !emitc.size_t) -> !emitc.size_t +// CHECK-NEXT: %0 = emitc.call @prototype_index(%arg0) : (!emitc.size_t) -> !emitc.size_t +// CHECK-NEXT: emitc.return %0 : !emitc.size_t +func.func @call(%arg0: index) -> index { + %0 = call @prototype_index(%arg0) : (index) -> (index) + return %0 : index +} + +// ----- + +// CHECK-LABEL: emitc.func @index_args_only(%arg0: !emitc.size_t) -> f32 +func.func @index_args_only(%i: index) -> f32 { + %0 = arith.constant 0.0 : f32 + return %0 : f32 +} diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index bc40ef48268eb0..ffb0e10d80893a 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -1,12 +1,15 @@ // RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s // CHECK-LABEL: memref_store -// CHECK-SAME: %[[v:.*]]: f32, %[[i:.*]]: index, %[[j:.*]]: index +// CHECK-SAME: %[[v:.*]]: f32, %[[argi:.*]]: index, %[[argj:.*]]: index func.func @memref_store(%v : f32, %i: index, %j: index) { + // CHECK: %[[i:.*]] = builtin.unrealized_conversion_cast %[[argi]] : index to !emitc.size_t + // CHECK: %[[j:.*]] = builtin.unrealized_conversion_cast %[[argj]] : index to !emitc.size_t + // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32> %0 = memref.alloca() : memref<4x8xf32> - // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32 + // CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, !emitc.size_t, !emitc.size_t) -> f32 // CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32 memref.store %v, %0[%i, %j] : memref<4x8xf32> return @@ -15,12 +18,15 @@ func.func @memref_store(%v : f32, %i: index, %j: index) { // ----- // CHECK-LABEL: memref_load -// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index +// CHECK-SAME: %[[argi:.*]]: index, %[[argj:.*]]: index func.func @memref_load(%i: index, %j: index) -> f32 { + // CHECK: %[[i:.*]] = builtin.unrealized_conversion_cast %[[argi]] : index to !emitc.size_t + // CHECK: %[[j:.*]] = builtin.unrealized_conversion_cast %[[argj]] : index to !emitc.size_t + // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32> %0 = memref.alloca() : memref<4x8xf32> - // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, index, index) -> f32 + // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8xf32>, !emitc.size_t, !emitc.size_t) -> f32 // CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32 // CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32 %1 = memref.load %0[%i, %j] : memref<4x8xf32> @@ -45,3 +51,25 @@ module @globals { return } } + +// ----- + +// CHECK-LABEL: memref_index_values +// CHECK-SAME: %[[argi:.*]]: index, %[[argj:.*]]: index +// CHECK-SAME: -> index +func.func @memref_index_values(%i: index, %j: index) -> index { + // CHECK: %[[i:.*]] = builtin.unrealized_conversion_cast %[[argi]] : index to !emitc.size_t + // CHECK: %[[j:.*]] = builtin.unrealized_conversion_cast %[[argj]] : index to !emitc.size_t + + // CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8x!emitc.size_t> + %0 = memref.alloca() : memref<4x8xindex> + + // CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : (!emitc.array<4x8x!emitc.size_t>, !emitc.size_t, !emitc.size_t) -> !emitc.size_t + // CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t + // CHECK: emitc.assign %[[LOAD]] : !emitc.size_t to %[[VAR]] : !emitc.size_t + %1 = memref.load %0[%i, %j] : memref<4x8xindex> + + // CHECK: %[[CAST_RET:.*]] = builtin.unrealized_conversion_cast %[[VAR]] : !emitc.size_t to index + // CHECK: return %[[CAST_RET]] : index + return %1 : index +} diff --git a/mlir/test/Conversion/SCFToEmitC/for.mlir b/mlir/test/Conversion/SCFToEmitC/for.mlir index 7f90310af21894..ca8a6fb1f59056 100644 --- a/mlir/test/Conversion/SCFToEmitC/for.mlir +++ b/mlir/test/Conversion/SCFToEmitC/for.mlir @@ -7,7 +7,10 @@ func.func @simple_std_for_loop(%arg0 : index, %arg1 : index, %arg2 : index) { return } // CHECK-LABEL: func.func @simple_std_for_loop( -// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) { +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) { +// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t // CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { // CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK-NEXT: } @@ -24,7 +27,10 @@ func.func @simple_std_2_for_loops(%arg0 : index, %arg1 : index, %arg2 : index) { return } // CHECK-LABEL: func.func @simple_std_2_for_loops( -// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) { +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) { +// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t // CHECK-NEXT: emitc.for %[[VAL_3:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { // CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK-NEXT: emitc.for %[[VAL_5:.*]] = %[[VAL_0]] to %[[VAL_1]] step %[[VAL_2]] { @@ -44,7 +50,10 @@ func.func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32) return %result#0, %result#1 : f32, f32 } // CHECK-LABEL: func.func @for_yield( -// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> (f32, f32) { +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> (f32, f32) { +// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t // CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-NEXT: %[[VAL_4:.*]] = arith.constant 1.000000e+00 : f32 // CHECK-NEXT: %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32 @@ -75,7 +84,10 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 return %r : f32 } // CHECK-LABEL: func.func @nested_for_yield( -// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index) -> f32 { +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> f32 { +// CHECK-NEXT: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK-NEXT: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t // CHECK-NEXT: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32 // CHECK-NEXT: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32 // CHECK-NEXT: %[[VAL_5:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32 diff --git a/mlir/test/Conversion/SCFToEmitC/nest-for-if.mlir b/mlir/test/Conversion/SCFToEmitC/nest-for-if.mlir new file mode 100644 index 00000000000000..572ec657483abc --- /dev/null +++ b/mlir/test/Conversion/SCFToEmitC/nest-for-if.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-emitc %s | FileCheck %s + +// CHECK-LABEL: func.func @nest_for_in_if +// CHECK-SAME: %[[ARG_0:.*]]: i1, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index, %[[ARG_3:.*]]: index, %[[ARG_4:.*]]: f32 +// CHECK-NEXT: %[[CAST_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK-NEXT: %[[CAST_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK-NEXT: %[[CAST_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_3]] : index to !emitc.size_t +// CHECK-NEXT: emitc.if %[[ARG_0]] { +// CHECK-NEXT: emitc.for %[[ARG_5:.*]] = %[[CAST_0]] to %[[CAST_1]] step %[[CAST_2]] { +// CHECK-NEXT: %[[CST_1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[CAST_3:.*]] = builtin.unrealized_conversion_cast %[[CST_1]] : index to !emitc.size_t +// CHECK-NEXT: emitc.for %[[ARG_6:.*]] = %[[CAST_0]] to %[[CAST_1]] step %[[CAST_3]] { +// CHECK-NEXT: %[[CST_2:.*]] = arith.constant 1 : index +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } else { +// CHECK-NEXT: %3 = emitc.call_opaque "func_false"(%[[ARG_4]]) : (f32) -> i32 +// CHECK-NEXT: } +// CHECK-NEXT: return +// CHECK-NEXT: } +func.func @nest_for_in_if(%arg0: i1, %arg1: index, %arg2: index, %arg3: index, %arg4: f32) { + scf.if %arg0 { + scf.for %i0 = %arg1 to %arg2 step %arg3 { + %c1 = arith.constant 1 : index + scf.for %i1 = %arg1 to %arg2 step %c1 { + %c1_0 = arith.constant 1 : index + } + } + } else { + %0 = emitc.call_opaque "func_false"(%arg4) : (f32) -> i32 + } + return +}