diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 365edcf68d8b9..4b3add2035263 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -17,6 +17,7 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinDialect.td" include "mlir/IR/BuiltinTypeInterfaces.td" +include "mlir/IR/CommonTypeConstraints.td" // TODO: Currently the types defined in this file are prefixed with `Builtin_`. // This is to differentiate the types here with the ones in OpBase.td. We should @@ -1146,7 +1147,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", }]; let parameters = (ins ArrayRefParameter<"int64_t">:$shape, - "Type":$elementType, + AnyTypeOf<[AnyInteger, Index, AnyFloat]>:$elementType, ArrayRefParameter<"bool">:$scalableDims ); let builders = [ @@ -1173,6 +1174,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", /// type. In particular, vectors can consist of integer, index, or float /// primitives. static bool isValidElementType(Type t) { + // TODO: Auto-generate this function from $elementType. return ::llvm::isa(t); } diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp index 542eaeefe57f1..f070c072c4329 100644 --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -458,31 +458,24 @@ Type Parser::parseTupleType() { /// static-dim-list ::= decimal-literal (`x` decimal-literal)* /// VectorType Parser::parseVectorType() { + SMLoc loc = getToken().getLoc(); consumeToken(Token::kw_vector); if (parseToken(Token::less, "expected '<' in vector type")) return nullptr; + // Parse the dimensions. SmallVector dimensions; SmallVector scalableDims; if (parseVectorDimensionList(dimensions, scalableDims)) return nullptr; - if (any_of(dimensions, [](int64_t i) { return i <= 0; })) - return emitError(getToken().getLoc(), - "vector types must have positive constant sizes"), - nullptr; // Parse the element type. - auto typeLoc = getToken().getLoc(); auto elementType = parseType(); if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) return nullptr; - if (!VectorType::isValidElementType(elementType)) - return emitError(typeLoc, "vector elements must be int/index/float type"), - nullptr; - - return VectorType::get(dimensions, elementType, scalableDims); + return getChecked(loc, dimensions, elementType, scalableDims); } /// Parse a dimension list in a vector type. This populates the dimension list. diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir index 9884212e916c1..07854a25000fe 100644 --- a/mlir/test/IR/invalid-builtin-types.mlir +++ b/mlir/test/IR/invalid-builtin-types.mlir @@ -120,17 +120,17 @@ func.func @illegaltype(i21312312323120) // expected-error {{invalid integer widt // ----- // Test no nested vector. -// expected-error@+1 {{vector elements must be int/index/float type}} +// expected-error@+1 {{failed to verify 'elementType': integer or index or floating-point}} func.func @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>) // ----- -// expected-error @+1 {{vector types must have positive constant sizes}} +// expected-error @+1 {{vector types must have positive constant sizes but got 0}} func.func @zero_vector_type() -> vector<0xi32> // ----- -// expected-error @+1 {{vector types must have positive constant sizes}} +// expected-error @+1 {{vector types must have positive constant sizes but got 1, 0}} func.func @zero_in_vector_type() -> vector<1x0xi32> // ----- diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py index 2161f110ac31e..f95cccc54105e 100644 --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -345,7 +345,7 @@ def testVectorType(): VectorType.get(shape, none) except MLIRError as e: # CHECK: Invalid type: - # CHECK: error: unknown: vector elements must be int/index/float type but got 'none' + # CHECK: error: unknown: failed to verify 'elementType': integer or index or floating-point print(e) else: print("Exception not produced")