diff --git a/compiler/src/iree/compiler/Tools/BUILD.bazel b/compiler/src/iree/compiler/Tools/BUILD.bazel index 95c620ca1e1b..e57ef6b004b5 100644 --- a/compiler/src/iree/compiler/Tools/BUILD.bazel +++ b/compiler/src/iree/compiler/Tools/BUILD.bazel @@ -88,6 +88,8 @@ iree_compiler_cc_library( "@llvm-project//mlir:AffineTransforms", "@llvm-project//mlir:ArmNeon2dToIntr", "@llvm-project//mlir:ArmNeonDialect", + "@llvm-project//mlir:ArmSVEDialect", + "@llvm-project//mlir:ArmSMEDialect", "@llvm-project//mlir:BufferizationDialect", "@llvm-project//mlir:ComplexDialect", "@llvm-project//mlir:ControlFlowDialect", diff --git a/compiler/src/iree/compiler/Tools/CMakeLists.txt b/compiler/src/iree/compiler/Tools/CMakeLists.txt index ee8d8201785a..0caa69dfd246 100644 --- a/compiler/src/iree/compiler/Tools/CMakeLists.txt +++ b/compiler/src/iree/compiler/Tools/CMakeLists.txt @@ -78,6 +78,8 @@ iree_cc_library( MLIRAffineDialect MLIRAffineTransforms MLIRArmNeonDialect + MLIRArmSVEDialect + MLIRArmSMEDialect MLIRArmNeon2dToIntr MLIRBufferizationDialect MLIRComplexDialect diff --git a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h index e259f83e604b..2148ff2e150a 100644 --- a/compiler/src/iree/compiler/Tools/init_mlir_dialects.h +++ b/compiler/src/iree/compiler/Tools/init_mlir_dialects.h @@ -16,6 +16,8 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -50,6 +52,7 @@ #include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" + #include "mlir/IR/Dialect.h" #ifdef IREE_HAVE_C_OUTPUT_FORMAT @@ -80,6 +83,8 @@ inline void registerMlirDialects(DialectRegistry ®istry) { quant::QuantizationDialect, spirv::SPIRVDialect, arm_neon::ArmNeonDialect, + arm_sve::ArmSVEDialect, + arm_sme::ArmSMEDialect, func::FuncDialect, mlir::arith::ArithDialect, vector::VectorDialect,