From f6a76e8df57a71b9ebac92b08f6591d9c656b1f7 Mon Sep 17 00:00:00 2001 From: Lukas Korencik Date: Tue, 14 Nov 2023 16:18:58 +0100 Subject: [PATCH] hl:lowertypes: Prune old code, use newer helpers that do type conversion. --- .../HighLevel/Transforms/HLLowerTypes.cpp | 157 +++--------------- 1 file changed, 20 insertions(+), 137 deletions(-) diff --git a/lib/vast/Dialect/HighLevel/Transforms/HLLowerTypes.cpp b/lib/vast/Dialect/HighLevel/Transforms/HLLowerTypes.cpp index b99345fa57..2bb4187136 100644 --- a/lib/vast/Dialect/HighLevel/Transforms/HLLowerTypes.cpp +++ b/lib/vast/Dialect/HighLevel/Transforms/HLLowerTypes.cpp @@ -27,7 +27,7 @@ VAST_UNRELAX_WARNINGS #include "vast/Conversion/TypeConverters/DataLayout.hpp" #include "vast/Conversion/TypeConverters/HLToStd.hpp" -#include "vast/Conversion/TypeConverters/TypeConverter.hpp" +#include "vast/Conversion/TypeConverters/TypeConvertingPattern.hpp" #include #include @@ -36,135 +36,26 @@ namespace vast::hl { using type_converter_t = conv::tc::HLToStd; - struct LowerHighLevelOpType : mlir::ConversionPattern - { - using Base = mlir::ConversionPattern; - using Base::Base; - - LowerHighLevelOpType(type_converter_t &tc, mcontext_t *mctx) - : Base(tc, mlir::Pattern::MatchAnyOpTypeTag{}, 1, mctx) - {} - - template< typename attrs_list > - maybe_attr_t high_level_typed_attr_conversion(mlir::Attribute attr) const { - using attr_t = typename attrs_list::head; - using rest_t = typename attrs_list::tail; - - if (auto typed = mlir::dyn_cast< attr_t >(attr)) { - if constexpr (std::same_as< attr_t, core::VoidAttr>) { - return Maybe(typed.getType()) - .and_then([&] (auto type) { - return getTypeConverter()->convertType(type); - }) - .and_then([&] (auto type) { - return core::VoidAttr::get(type.getContext(), type); - }) - .template take_wrapped< maybe_attr_t >(); - } else { - return Maybe(typed.getType()) - .and_then([&] (auto type) { - return getTypeConverter()->convertType(type); - }) - .and_then([&] (auto type) { - return attr_t::get(type, typed.getValue()); - }) - .template take_wrapped< maybe_attr_t >(); - } - } - - if constexpr (attrs_list::size != 1) { - return high_level_typed_attr_conversion< rest_t >(attr); - } else { - return std::nullopt; - } - } - - auto convert_high_level_typed_attr() const { - return [&] (mlir::Attribute attr) { - return high_level_typed_attr_conversion< core::typed_attrs >(attr); - }; - } + namespace pattern { - logical_result matchAndRewrite( - operation op, llvm::ArrayRef< mlir_value > ops, - conversion_rewriter &rewriter - ) const override { - if (mlir::isa< FuncOp >(op)) { - return mlir::failure(); - } + struct lower_type : conv::tc::type_converting_pattern< type_converter_t > + { + using parent = conv::tc::type_converting_pattern< type_converter_t >; - auto &tc = static_cast< type_converter_t & >(*getTypeConverter()); - - mlir::SmallVector< mlir_type > rty; - auto status = tc.convertTypes(op->getResultTypes(), rty); - // TODO(lukas): How to use `llvm::formatv` with `operation `? - VAST_CHECK(mlir::succeeded(status), "Was not able to type convert."); - - // We just change type, no need to copy everything - auto lower_op = [&]() { - for (std::size_t i = 0; i < rty.size(); ++i) { - op->getResult(i).setType(rty[i]); - } - - mlir::AttrTypeReplacer replacer; - replacer.addReplacement(conv::tc::convert_type_attr(tc)); - replacer.addReplacement(conv::tc::convert_data_layout_attrs(tc)); - replacer.addReplacement(convert_high_level_typed_attr()); - replacer.recursivelyReplaceElementsIn(op, true /* replace attrs */); - }; - // It has to be done in one "transaction". - rewriter.updateRootInPlace(op, lower_op); - - return mlir::success(); - } - }; + lower_type(type_converter_t &tc, mcontext_t *mctx) : parent(tc, *mctx) {} - struct LowerFuncOpType : mlir::OpConversionPattern< FuncOp > - { - using Base = mlir::OpConversionPattern< FuncOp >; - using Base::Base; - - using Base::getTypeConverter; - - // As the reference how to lower functions, the `StandardToLLVM` - // conversion is used. - // - // But basically we need to copy the function with the converted - // function type -> copy body -> fix arguments of the entry region. - logical_result matchAndRewrite( - FuncOp fn, OpAdaptor adaptor, conversion_rewriter &rewriter - ) const override { - auto fty = adaptor.getFunctionType(); - auto &tc = static_cast< type_converter_t & >(*getTypeConverter()); - - conv::tc::signature_conversion_t sigconvert(fty.getNumInputs()); - if (mlir::failed(tc.convertSignatureArgs(fty.getInputs(), sigconvert))) { - return mlir::failure(); + logical_result matchAndRewrite( + operation op, mlir::ArrayRef< mlir::Value > ops, + conversion_rewriter &rewriter + ) const override { + if (auto func_op = mlir::dyn_cast< hl::FuncOp >(op)) + return replace(func_op, ops, rewriter); + return replace(op, ops, rewriter); } - llvm::SmallVector< mlir_type, 1 > results; - if (mlir::failed(tc.convertTypes(fty.getResults(), results))) { - return mlir::failure(); - } + }; - auto params = sigconvert.getConvertedTypes(); - - auto new_type = core::FunctionType::get( - rewriter.getContext(), params, results, fty.isVarArg() - ); - - // TODO deal with function attribute types - - rewriter.updateRootInPlace(fn, [&] { - fn.setType(new_type); - for (auto [ty, param] : llvm::zip(params, fn.getBody().getArguments())) { - param.setType(ty); - } - }); - - return mlir::success(); - } - }; + } // namespace pattern struct HLLowerTypesPass : HLLowerTypesBase< HLLowerTypesPass > { @@ -172,24 +63,16 @@ namespace vast::hl auto op = this->getOperation(); auto &mctx = this->getContext(); + const auto &dl_analysis = this->getAnalysis< mlir::DataLayoutAnalysis >(); + type_converter_t type_converter(dl_analysis.getAtOrAbove(op), mctx); + mlir::ConversionTarget trg(mctx); - // We want to check *everything* for presence of hl type - // that can be lowered. - auto is_legal = [](operation op) - { - auto is_hl = [](mlir_type t) -> bool { return isHighLevelType(t); }; - - return !has_type_somewhere(op, is_hl); - }; + auto is_legal = type_converter.get_is_type_conversion_legal(); trg.markUnknownOpDynamicallyLegal(is_legal); mlir::RewritePatternSet patterns(&mctx); - const auto &dl_analysis = this->getAnalysis< mlir::DataLayoutAnalysis >(); - type_converter_t type_converter(dl_analysis.getAtOrAbove(op), mctx); - patterns.add< LowerHighLevelOpType, LowerFuncOpType >( - type_converter, patterns.getContext() - ); + patterns.add< pattern::lower_type >(type_converter, patterns.getContext()); if (mlir::failed(mlir::applyPartialConversion(op, trg, std::move(patterns)))) { return signalPassFailure();