Skip to content

Commit

Permalink
[mlir][IR] Auto-generate element type verification for VectorType
Browse files Browse the repository at this point in the history
  • Loading branch information
matthias-springer committed Aug 9, 2024
1 parent 7359a6b commit dc1af74
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 15 deletions.
4 changes: 3 additions & 1 deletion mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/CommonTypeConstraints.td"

// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
// This is to differentiate the types here with the ones in OpBase.td. We should
Expand Down Expand Up @@ -1146,7 +1147,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
AnyTypeOf<[AnyInteger, Index, AnyFloat]>:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
Expand All @@ -1173,6 +1174,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
/// type. In particular, vectors can consist of integer, index, or float
/// primitives.
static bool isValidElementType(Type t) {
// TODO: Auto-generate this function from $elementType.
return ::llvm::isa<IntegerType, IndexType, FloatType>(t);
}

Expand Down
13 changes: 3 additions & 10 deletions mlir/lib/AsmParser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,31 +458,24 @@ Type Parser::parseTupleType() {
/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
///
VectorType Parser::parseVectorType() {
SMLoc loc = getToken().getLoc();
consumeToken(Token::kw_vector);

if (parseToken(Token::less, "expected '<' in vector type"))
return nullptr;

// Parse the dimensions.
SmallVector<int64_t, 4> dimensions;
SmallVector<bool, 4> scalableDims;
if (parseVectorDimensionList(dimensions, scalableDims))
return nullptr;
if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
return emitError(getToken().getLoc(),
"vector types must have positive constant sizes"),
nullptr;

// Parse the element type.
auto typeLoc = getToken().getLoc();
auto elementType = parseType();
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;

if (!VectorType::isValidElementType(elementType))
return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;

return VectorType::get(dimensions, elementType, scalableDims);
return getChecked<VectorType>(loc, dimensions, elementType, scalableDims);
}

/// Parse a dimension list in a vector type. This populates the dimension list.
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/IR/invalid-builtin-types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,17 @@ func.func @illegaltype(i21312312323120) // expected-error {{invalid integer widt
// -----

// Test no nested vector.
// expected-error@+1 {{vector elements must be int/index/float type}}
// expected-error@+1 {{failed to verify 'elementType': integer or index or floating-point}}
func.func @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>)

// -----

// expected-error @+1 {{vector types must have positive constant sizes}}
// expected-error @+1 {{vector types must have positive constant sizes but got 0}}
func.func @zero_vector_type() -> vector<0xi32>

// -----

// expected-error @+1 {{vector types must have positive constant sizes}}
// expected-error @+1 {{vector types must have positive constant sizes but got 1, 0}}
func.func @zero_in_vector_type() -> vector<1x0xi32>

// -----
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/python/ir/builtin_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def testVectorType():
VectorType.get(shape, none)
except MLIRError as e:
# CHECK: Invalid type:
# CHECK: error: unknown: vector elements must be int/index/float type but got 'none'
# CHECK: error: unknown: failed to verify 'elementType': integer or index or floating-point
print(e)
else:
print("Exception not produced")
Expand Down

0 comments on commit dc1af74

Please sign in to comment.