Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][IR] Auto-generate element type verification for VectorType #102449

Merged

Conversation

matthias-springer
Copy link
Member

#102326 enables verification of type parameters that are type constraints. The element type verification for VectorType (and maybe other builtin types in the future) can now be auto-generated.

Also remove redundant error checking in the vector type parser: element type and dimensions are already checked by the verifier (which is called from getChecked).

Depends on #102326.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Aug 8, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 8, 2024

@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

#102326 enables verification of type parameters that are type constraints. The element type verification for VectorType (and maybe other builtin types in the future) can now be auto-generated.

Also remove redundant error checking in the vector type parser: element type and dimensions are already checked by the verifier (which is called from getChecked).

Depends on #102326.


Full diff: https://github.com/llvm/llvm-project/pull/102449.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+3-1)
  • (modified) mlir/lib/AsmParser/TypeParser.cpp (+3-10)
  • (modified) mlir/test/IR/invalid-builtin-types.mlir (+3-3)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 365edcf68d8b94..4b3add2035263c 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -17,6 +17,7 @@
 include "mlir/IR/AttrTypeBase.td"
 include "mlir/IR/BuiltinDialect.td"
 include "mlir/IR/BuiltinTypeInterfaces.td"
+include "mlir/IR/CommonTypeConstraints.td"
 
 // TODO: Currently the types defined in this file are prefixed with `Builtin_`.
 // This is to differentiate the types here with the ones in OpBase.td. We should
@@ -1146,7 +1147,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
   }];
   let parameters = (ins
     ArrayRefParameter<"int64_t">:$shape,
-    "Type":$elementType,
+    AnyTypeOf<[AnyInteger, Index, AnyFloat]>:$elementType,
     ArrayRefParameter<"bool">:$scalableDims
   );
   let builders = [
@@ -1173,6 +1174,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
     /// type. In particular, vectors can consist of integer, index, or float
     /// primitives.
     static bool isValidElementType(Type t) {
+      // TODO: Auto-generate this function from $elementType.
       return ::llvm::isa<IntegerType, IndexType, FloatType>(t);
     }
 
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 542eaeefe57f12..f070c072c43296 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -458,31 +458,24 @@ Type Parser::parseTupleType() {
 /// static-dim-list ::= decimal-literal (`x` decimal-literal)*
 ///
 VectorType Parser::parseVectorType() {
+  SMLoc loc = getToken().getLoc();
   consumeToken(Token::kw_vector);
 
   if (parseToken(Token::less, "expected '<' in vector type"))
     return nullptr;
 
+  // Parse the dimensions.
   SmallVector<int64_t, 4> dimensions;
   SmallVector<bool, 4> scalableDims;
   if (parseVectorDimensionList(dimensions, scalableDims))
     return nullptr;
-  if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
-    return emitError(getToken().getLoc(),
-                     "vector types must have positive constant sizes"),
-           nullptr;
 
   // Parse the element type.
-  auto typeLoc = getToken().getLoc();
   auto elementType = parseType();
   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
     return nullptr;
 
-  if (!VectorType::isValidElementType(elementType))
-    return emitError(typeLoc, "vector elements must be int/index/float type"),
-           nullptr;
-
-  return VectorType::get(dimensions, elementType, scalableDims);
+  return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
 }
 
 /// Parse a dimension list in a vector type. This populates the dimension list.
diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 9884212e916c1f..07854a25000feb 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -120,17 +120,17 @@ func.func @illegaltype(i21312312323120) // expected-error {{invalid integer widt
 // -----
 
 // Test no nested vector.
-// expected-error@+1 {{vector elements must be int/index/float type}}
+// expected-error@+1 {{failed to verify 'elementType': integer or index or floating-point}}
 func.func @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>)
 
 // -----
 
-// expected-error @+1 {{vector types must have positive constant sizes}}
+// expected-error @+1 {{vector types must have positive constant sizes but got 0}}
 func.func @zero_vector_type() -> vector<0xi32>
 
 // -----
 
-// expected-error @+1 {{vector types must have positive constant sizes}}
+// expected-error @+1 {{vector types must have positive constant sizes but got 1, 0}}
 func.func @zero_in_vector_type() -> vector<1x0xi32>
 
 // -----

@matthias-springer matthias-springer force-pushed the users/matthias-springer/vector_type_verification branch from 73980b0 to 1e3c58a Compare August 8, 2024 11:44
@matthias-springer matthias-springer force-pushed the users/matthias-springer/ods_type_verifier branch from ec0d3fb to 27f3ffa Compare August 9, 2024 18:52
@matthias-springer matthias-springer force-pushed the users/matthias-springer/vector_type_verification branch from 1e3c58a to c399a1c Compare August 9, 2024 18:54
@matthias-springer matthias-springer force-pushed the users/matthias-springer/ods_type_verifier branch from 27f3ffa to a82692f Compare August 9, 2024 19:54
Base automatically changed from users/matthias-springer/ods_type_verifier to main August 9, 2024 20:04
@matthias-springer matthias-springer force-pushed the users/matthias-springer/vector_type_verification branch from c399a1c to dc1af74 Compare August 9, 2024 20:05
@matthias-springer matthias-springer merged commit 7d4aa1f into main Aug 12, 2024
8 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/vector_type_verification branch August 12, 2024 06:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:ods mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants