Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ODS] Verify type constraints in Types and Attributes #102326

Merged
merged 1 commit into from
Aug 9, 2024

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Aug 7, 2024

When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing.

Example:

def TestTypeVerification : Test_Type<"TestTypeVerification"> {
  let parameters = (ins AnyTypeOf<[I16, I32]>:$param);
  // ...
}

No verification code was generated to ensure that $param is I16 or I32.

When type constraints a present, a new method will generated for types and attributes: verifyInvariantsImpl. (The naming is similar to op verifiers.) The user-provided verifier is called verify (no change). There is now a new entry point to type/attribute verification: verifyInvariants. This function calls both verifyInvariantsImpl and verify. If neither of those two verifications are present, the verifyInvariants function is not generated.

When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the verifyInvariants function. (This function was previously called verify.)

Note for LLVM integration: If you have an attribute/type that is not defined in TableGen (i.e., just C++), you have to rename the verification function from verify to verifyInvariants. (Most attributes/types have no verification, in which case there is nothing to do.)

Depends on #102657.

@llvmbot
Copy link
Member

llvmbot commented Aug 7, 2024

@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing.

Example:

def TestTypeVerification : Test_Type&lt;"TestTypeVerification"&gt; {
  let parameters = (ins AnyTypeOf&lt;[I16, I32]&gt;:$param);
  // ...
}

No verification code was generated to ensure that $param is I16 or I32.

When type constraints a present, a new method will generated for types and attributes: verifyInvariantsImpl. (The naming is similar to op verifiers.) The user-provided verifier is called verify (no change). There is now a new entry point to type/attribute verification: verifyInvariants. This function calls both verifyInvariantsImpl and verify. If neither of those two verifications are present, the verifyInvariants function is not generated.

When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the verifyInvariants function. (This function was previously called verify.)


Patch is 33.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102326.diff

22 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h (+4-3)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h (+7-5)
  • (modified) mlir/include/mlir/Dialect/Quant/QuantTypes.h (+22-21)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h (+8-6)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+6-4)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+1)
  • (modified) mlir/include/mlir/IR/Constraints.td (+3)
  • (modified) mlir/include/mlir/IR/StorageUniquerSupport.h (+6-4)
  • (modified) mlir/include/mlir/IR/Types.h (+8-11)
  • (modified) mlir/include/mlir/TableGen/AttrOrTypeDef.h (+6)
  • (modified) mlir/include/mlir/TableGen/Class.h (+2)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+3-3)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+22-17)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp (+4-5)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+5-4)
  • (modified) mlir/lib/TableGen/AttrOrTypeDef.cpp (+13)
  • (added) mlir/test/IR/test-verifiers-type.mlir (+9)
  • (modified) mlir/test/lib/Dialect/Test/TestAttrDefs.td (-1)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+6)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (+92-10)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 96e1935bd0a84..57acd72610415 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -148,9 +148,10 @@ class MMAMatrixType
 
   /// Verify that shape and elementType are actually allowed for the
   /// MMAMatrixType.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              ArrayRef<int64_t> shape, Type elementType,
-                              StringRef operand);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<int64_t> shape, Type elementType,
+                   StringRef operand);
 
   /// Get number of dims.
   unsigned getNumDims() const;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 1befdfa74f67c..2ea589a7c4c3b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -180,11 +180,13 @@ class LLVMStructType
   ArrayRef<Type> getBody() const;
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              StringRef, bool);
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              ArrayRef<Type> types, bool);
-  using Base::verify;
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, StringRef,
+                   bool);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<Type> types, bool);
+  using Base::verifyInvariants;
 
   /// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
   /// DataLayout instance and query it instead.
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index de5aed0a91a20..57a2aa2983365 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -54,10 +54,10 @@ class QuantizedType : public Type {
   /// The maximum number of bits supported for storage types.
   static constexpr unsigned MaxStorageBits = 32;
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Support method to enable LLVM-style type casting.
   static bool classof(Type type);
@@ -214,10 +214,10 @@ class AnyQuantizedType
              int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 };
 
 /// Represents a family of uniform, quantized types.
@@ -276,11 +276,11 @@ class UniformQuantizedType
              int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, double scale,
-                              int64_t zeroPoint, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, double scale,
+                   int64_t zeroPoint, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Gets the scale term. The scale designates the difference between the real
   /// values corresponding to consecutive quantized values differing by 1.
@@ -338,12 +338,12 @@ class UniformQuantizedPerAxisType
              int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, ArrayRef<double> scales,
-                              ArrayRef<int64_t> zeroPoints,
-                              int32_t quantizedDimension,
-                              int64_t storageTypeMin, int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType,
+                   ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+                   int32_t quantizedDimension, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Gets the quantization scales. The scales designate the difference between
   /// the real values corresponding to consecutive quantized values differing
@@ -403,8 +403,9 @@ class CalibratedQuantizedType
              double min, double max);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type expressedType, double min, double max);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type expressedType, double min, double max);
   double getMin() const;
   double getMax() const;
 };
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 5ebfa9ca5ec25..2bdd7a5bf3dd8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -76,9 +76,10 @@ class InterfaceVarABIAttr
   /// Returns `spirv::StorageClass`.
   std::optional<StorageClass> getStorageClass();
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              IntegerAttr descriptorSet, IntegerAttr binding,
-                              IntegerAttr storageClass);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   IntegerAttr descriptorSet, IntegerAttr binding,
+                   IntegerAttr storageClass);
 
   static constexpr StringLiteral name = "spirv.interface_var_abi";
 };
@@ -128,9 +129,10 @@ class VerCapExtAttr
   /// Returns the capabilities as an integer array attribute.
   ArrayAttr getCapabilitiesAttr();
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              IntegerAttr version, ArrayAttr capabilities,
-                              ArrayAttr extensions);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   IntegerAttr version, ArrayAttr capabilities,
+                   ArrayAttr extensions);
 
   static constexpr StringLiteral name = "spirv.ver_cap_ext";
 };
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 55f0c787b4440..e2d04553d91b8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -258,8 +258,9 @@ class SampledImageType
   static SampledImageType
   getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type imageType);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type imageType);
 
   Type getImageType() const;
 
@@ -462,8 +463,9 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
   static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
                                Type columnType, uint32_t columnCount);
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type columnType, uint32_t columnCount);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type columnType, uint32_t columnCount);
 
   /// Returns true if the matrix elements are vectors of float elements.
   static bool isValidColumnType(Type columnType);
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 5b6ec167fa242..4d3e1428e6c40 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -180,6 +180,7 @@ class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
         summary),
     cppClassName> {
   list<Type> allowedTypes = allowedTypeList;
+  string cppType = cppClassName;
 }
 
 // A type that satisfies the constraints of all given types.
diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td
index a026d58ccffb8..242c850f38f30 100644
--- a/mlir/include/mlir/IR/Constraints.td
+++ b/mlir/include/mlir/IR/Constraints.td
@@ -153,6 +153,9 @@ class TypeConstraint<Pred predicate, string summary = "",
     Constraint<predicate, summary> {
   // The name of the C++ Type class if known, or Type if not.
   string cppClassName = cppClassNameParam;
+  // TODO: This field is sometimes called `cppClassName` and sometimes
+  // `cppType`. Use a single name consistently.
+  string cppType = cppClassNameParam;
 }
 
 // Subclass for constraints on an attribute.
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index fb64f15162df5..d6ccbbd857994 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -176,8 +176,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   template <typename... Args>
   static ConcreteT get(MLIRContext *ctx, Args &&...args) {
     // Ensure that the invariants are correct for construction.
-    assert(
-        succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
+    assert(succeeded(
+        ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
     return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
   }
 
@@ -198,7 +198,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
                               MLIRContext *ctx, Args... args) {
     // If the construction invariants fail then we return a null attribute.
-    if (failed(ConcreteT::verify(emitErrorFn, args...)))
+    if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
       return ConcreteT();
     return UniquerT::template get<ConcreteT>(ctx, args...);
   }
@@ -226,7 +226,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
 
   /// Default implementation that just returns success.
   template <typename... Args>
-  static LogicalResult verify(Args... args) {
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitErrorFn,
+                   Args... args) {
     return success();
   }
 
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 60dc8fee0f4a9..91b457deeba2f 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -34,7 +34,7 @@ class AsmState;
 /// Derived type classes are expected to implement several required
 /// implementation hooks:
 ///  * Optional:
-///    - static LogicalResult verify(
+///    - static LogicalResult verifyInvariants(
 ///                                function_ref<InFlightDiagnostic()> emitError,
 ///                                Args... args)
 ///      * This method is invoked when calling the 'TypeBase::get/getChecked'
@@ -97,20 +97,17 @@ class Type {
   bool operator!() const { return impl == nullptr; }
 
   template <typename... Tys>
-  [[deprecated("Use mlir::isa<U>() instead")]]
-  bool isa() const;
+  [[deprecated("Use mlir::isa<U>() instead")]] bool isa() const;
   template <typename... Tys>
-  [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]]
-  bool isa_and_nonnull() const;
+  [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]] bool
+  isa_and_nonnull() const;
   template <typename U>
-  [[deprecated("Use mlir::dyn_cast<U>() instead")]]
-  U dyn_cast() const;
+  [[deprecated("Use mlir::dyn_cast<U>() instead")]] U dyn_cast() const;
   template <typename U>
-  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
-  U dyn_cast_or_null() const;
+  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]] U
+  dyn_cast_or_null() const;
   template <typename U>
-  [[deprecated("Use mlir::cast<U>() instead")]]
-  U cast() const;
+  [[deprecated("Use mlir::cast<U>() instead")]] U cast() const;
 
   /// Return a unique identifier for the concrete type. This is used to support
   /// dynamic type casting.
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 19c3a9183ec2c..c7819ef7c0ffa 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -15,7 +15,9 @@
 #define MLIR_TABLEGEN_ATTRORTYPEDEF_H
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Argument.h"
 #include "mlir/TableGen/Builder.h"
+#include "mlir/TableGen/Predicate.h"
 #include "mlir/TableGen/Trait.h"
 
 namespace llvm {
@@ -85,6 +87,8 @@ class AttrOrTypeParameter {
   /// Get an optional C++ parameter parser.
   std::optional<StringRef> getParser() const;
 
+  std::optional<TypeConstraint> getTypeConstraint() const;
+
   /// Get an optional C++ parameter printer.
   std::optional<StringRef> getPrinter() const;
 
@@ -198,6 +202,8 @@ class AttrOrTypeDef {
   /// method.
   bool genVerifyDecl() const;
 
+  bool genVerifyInvariantsImpl() const;
+
   /// Returns the def's extra class declaration code.
   std::optional<StringRef> getExtraDecls() const;
 
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 855952d19492d..f750a34a3b2ba 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -67,6 +67,8 @@ class MethodParameter {
 
   /// Get the C++ type.
   StringRef getType() const { return type; }
+  /// Get the C++ parameter name.
+  StringRef getName() const { return name; }
   /// Returns true if the parameter has a default value.
   bool hasDefaultValue() const { return !defaultValue.empty(); }
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7bc2668310ddb..a1f87a637a614 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
 }
 
 LogicalResult
-MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
-                      ArrayRef<int64_t> shape, Type elementType,
-                      StringRef operand) {
+MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                ArrayRef<int64_t> shape, Type elementType,
+                                StringRef operand) {
   if (operand != "AOp" && operand != "BOp" && operand != "COp")
     return emitError() << "operand expected to be one of AOp, BOp or COp";
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index dc7aef8ef7f85..7f10a15ff31ff 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
 
 bool LLVMStructType::isValidElementType(Type type) {
   return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
-                    LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
-      type);
+                    LLVMFunctionType, LLVMTokenType>(type);
 }
 
 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
@@ -492,14 +491,15 @@ ArrayRef<Type> LLVMStructType::getBody() const {
                         : getImpl()->getTypeList();
 }
 
-LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
-                                     StringRef, bool) {
+LogicalResult
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
+                                 bool) {
   return success();
 }
 
 LogicalResult
-LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
-                       ArrayRef<Type> types, bool) {
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                 ArrayRef<Type> types, bool) {
   for (Type t : types)
     if (!isValidElementType(t))
       return emitError() << "invalid LLVM structure element type: " << t;
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 81e3b914755be..c2ba9c04e8771 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -29,9 +29,10 @@ bool QuantizedType::classof(Type type) {
 }
 
 LogicalResult
-QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
-                      unsigned flags, Type storageType, Type expressedType,
-                      int64_t storageTypeMin, int64_t storageTypeMax) {
+QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                unsigned flags, Type storageType,
+                                Type expressedType, int64_t storageTypeMin,
+                                int64_t storageTypeMax) {
   // Verify that the storage type is integral.
   // This restriction may be lifted at some point in favor of using bf16
   // or f16 as exact representations on hardware where that is advantageous.
@@ -233,11 +234,13 @@ AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
 }
 
 LogicalResult
-AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
-                         unsigned flags, Type storageType, Type expressedType,
-                         int64_t storageTypeMin, int64_t storageTypeMax) {
-  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
-                                   storageTypeMin, storageTypeMax))) {
+AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                   unsigned flags, Type storageType,
+                                   Type expressedType, int64_t storageTypeMin,
+                                   int64_t storageTypeMax) {
+  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+                                             expressedType, storageTypeMin,
+                                             storageTypeMax))) {
     return failure();
   }
 
@@ -268,12 +271,13 @@ UniformQuantizedType UniformQuantizedType::getChecked(
                           storageTypeMin, storageTypeMax);
 }
 
-LogicalResult UniformQuantizedType::verify(
+LogicalResult UniformQuantizedType::verifyInvariants(
     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
     Type storageType, Type expressedType, double scale, int64_t zeroPoint,
     int64_t storageTypeMin, int64_t storageTypeMax) {
-  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
-                                   storageTypeMin, storageTypeMax))) {
+  if (failed(QuantizedType::verifyInvariants(emit...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 7, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

Changes

When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing.

Example:

def TestTypeVerification : Test_Type&lt;"TestTypeVerification"&gt; {
  let parameters = (ins AnyTypeOf&lt;[I16, I32]&gt;:$param);
  // ...
}

No verification code was generated to ensure that $param is I16 or I32.

When type constraints a present, a new method will generated for types and attributes: verifyInvariantsImpl. (The naming is similar to op verifiers.) The user-provided verifier is called verify (no change). There is now a new entry point to type/attribute verification: verifyInvariants. This function calls both verifyInvariantsImpl and verify. If neither of those two verifications are present, the verifyInvariants function is not generated.

When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the verifyInvariants function. (This function was previously called verify.)


Patch is 33.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102326.diff

22 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h (+4-3)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h (+7-5)
  • (modified) mlir/include/mlir/Dialect/Quant/QuantTypes.h (+22-21)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h (+8-6)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+6-4)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+1)
  • (modified) mlir/include/mlir/IR/Constraints.td (+3)
  • (modified) mlir/include/mlir/IR/StorageUniquerSupport.h (+6-4)
  • (modified) mlir/include/mlir/IR/Types.h (+8-11)
  • (modified) mlir/include/mlir/TableGen/AttrOrTypeDef.h (+6)
  • (modified) mlir/include/mlir/TableGen/Class.h (+2)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+3-3)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+22-17)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp (+4-5)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+5-4)
  • (modified) mlir/lib/TableGen/AttrOrTypeDef.cpp (+13)
  • (added) mlir/test/IR/test-verifiers-type.mlir (+9)
  • (modified) mlir/test/lib/Dialect/Test/TestAttrDefs.td (-1)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+6)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (+92-10)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 96e1935bd0a84..57acd72610415 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -148,9 +148,10 @@ class MMAMatrixType
 
   /// Verify that shape and elementType are actually allowed for the
   /// MMAMatrixType.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              ArrayRef<int64_t> shape, Type elementType,
-                              StringRef operand);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<int64_t> shape, Type elementType,
+                   StringRef operand);
 
   /// Get number of dims.
   unsigned getNumDims() const;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 1befdfa74f67c..2ea589a7c4c3b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -180,11 +180,13 @@ class LLVMStructType
   ArrayRef<Type> getBody() const;
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              StringRef, bool);
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              ArrayRef<Type> types, bool);
-  using Base::verify;
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, StringRef,
+                   bool);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<Type> types, bool);
+  using Base::verifyInvariants;
 
   /// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
   /// DataLayout instance and query it instead.
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index de5aed0a91a20..57a2aa2983365 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -54,10 +54,10 @@ class QuantizedType : public Type {
   /// The maximum number of bits supported for storage types.
   static constexpr unsigned MaxStorageBits = 32;
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Support method to enable LLVM-style type casting.
   static bool classof(Type type);
@@ -214,10 +214,10 @@ class AnyQuantizedType
              int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 };
 
 /// Represents a family of uniform, quantized types.
@@ -276,11 +276,11 @@ class UniformQuantizedType
              int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, double scale,
-                              int64_t zeroPoint, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, double scale,
+                   int64_t zeroPoint, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Gets the scale term. The scale designates the difference between the real
   /// values corresponding to consecutive quantized values differing by 1.
@@ -338,12 +338,12 @@ class UniformQuantizedPerAxisType
              int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, ArrayRef<double> scales,
-                              ArrayRef<int64_t> zeroPoints,
-                              int32_t quantizedDimension,
-                              int64_t storageTypeMin, int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType,
+                   ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+                   int32_t quantizedDimension, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Gets the quantization scales. The scales designate the difference between
   /// the real values corresponding to consecutive quantized values differing
@@ -403,8 +403,9 @@ class CalibratedQuantizedType
              double min, double max);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type expressedType, double min, double max);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type expressedType, double min, double max);
   double getMin() const;
   double getMax() const;
 };
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 5ebfa9ca5ec25..2bdd7a5bf3dd8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -76,9 +76,10 @@ class InterfaceVarABIAttr
   /// Returns `spirv::StorageClass`.
   std::optional<StorageClass> getStorageClass();
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              IntegerAttr descriptorSet, IntegerAttr binding,
-                              IntegerAttr storageClass);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   IntegerAttr descriptorSet, IntegerAttr binding,
+                   IntegerAttr storageClass);
 
   static constexpr StringLiteral name = "spirv.interface_var_abi";
 };
@@ -128,9 +129,10 @@ class VerCapExtAttr
   /// Returns the capabilities as an integer array attribute.
   ArrayAttr getCapabilitiesAttr();
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              IntegerAttr version, ArrayAttr capabilities,
-                              ArrayAttr extensions);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   IntegerAttr version, ArrayAttr capabilities,
+                   ArrayAttr extensions);
 
   static constexpr StringLiteral name = "spirv.ver_cap_ext";
 };
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 55f0c787b4440..e2d04553d91b8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -258,8 +258,9 @@ class SampledImageType
   static SampledImageType
   getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type imageType);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type imageType);
 
   Type getImageType() const;
 
@@ -462,8 +463,9 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
   static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
                                Type columnType, uint32_t columnCount);
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type columnType, uint32_t columnCount);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type columnType, uint32_t columnCount);
 
   /// Returns true if the matrix elements are vectors of float elements.
   static bool isValidColumnType(Type columnType);
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 5b6ec167fa242..4d3e1428e6c40 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -180,6 +180,7 @@ class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
         summary),
     cppClassName> {
   list<Type> allowedTypes = allowedTypeList;
+  string cppType = cppClassName;
 }
 
 // A type that satisfies the constraints of all given types.
diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td
index a026d58ccffb8..242c850f38f30 100644
--- a/mlir/include/mlir/IR/Constraints.td
+++ b/mlir/include/mlir/IR/Constraints.td
@@ -153,6 +153,9 @@ class TypeConstraint<Pred predicate, string summary = "",
     Constraint<predicate, summary> {
   // The name of the C++ Type class if known, or Type if not.
   string cppClassName = cppClassNameParam;
+  // TODO: This field is sometimes called `cppClassName` and sometimes
+  // `cppType`. Use a single name consistently.
+  string cppType = cppClassNameParam;
 }
 
 // Subclass for constraints on an attribute.
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index fb64f15162df5..d6ccbbd857994 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -176,8 +176,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   template <typename... Args>
   static ConcreteT get(MLIRContext *ctx, Args &&...args) {
     // Ensure that the invariants are correct for construction.
-    assert(
-        succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
+    assert(succeeded(
+        ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
     return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
   }
 
@@ -198,7 +198,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
                               MLIRContext *ctx, Args... args) {
     // If the construction invariants fail then we return a null attribute.
-    if (failed(ConcreteT::verify(emitErrorFn, args...)))
+    if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
       return ConcreteT();
     return UniquerT::template get<ConcreteT>(ctx, args...);
   }
@@ -226,7 +226,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
 
   /// Default implementation that just returns success.
   template <typename... Args>
-  static LogicalResult verify(Args... args) {
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitErrorFn,
+                   Args... args) {
     return success();
   }
 
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 60dc8fee0f4a9..91b457deeba2f 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -34,7 +34,7 @@ class AsmState;
 /// Derived type classes are expected to implement several required
 /// implementation hooks:
 ///  * Optional:
-///    - static LogicalResult verify(
+///    - static LogicalResult verifyInvariants(
 ///                                function_ref<InFlightDiagnostic()> emitError,
 ///                                Args... args)
 ///      * This method is invoked when calling the 'TypeBase::get/getChecked'
@@ -97,20 +97,17 @@ class Type {
   bool operator!() const { return impl == nullptr; }
 
   template <typename... Tys>
-  [[deprecated("Use mlir::isa<U>() instead")]]
-  bool isa() const;
+  [[deprecated("Use mlir::isa<U>() instead")]] bool isa() const;
   template <typename... Tys>
-  [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]]
-  bool isa_and_nonnull() const;
+  [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]] bool
+  isa_and_nonnull() const;
   template <typename U>
-  [[deprecated("Use mlir::dyn_cast<U>() instead")]]
-  U dyn_cast() const;
+  [[deprecated("Use mlir::dyn_cast<U>() instead")]] U dyn_cast() const;
   template <typename U>
-  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
-  U dyn_cast_or_null() const;
+  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]] U
+  dyn_cast_or_null() const;
   template <typename U>
-  [[deprecated("Use mlir::cast<U>() instead")]]
-  U cast() const;
+  [[deprecated("Use mlir::cast<U>() instead")]] U cast() const;
 
   /// Return a unique identifier for the concrete type. This is used to support
   /// dynamic type casting.
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 19c3a9183ec2c..c7819ef7c0ffa 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -15,7 +15,9 @@
 #define MLIR_TABLEGEN_ATTRORTYPEDEF_H
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Argument.h"
 #include "mlir/TableGen/Builder.h"
+#include "mlir/TableGen/Predicate.h"
 #include "mlir/TableGen/Trait.h"
 
 namespace llvm {
@@ -85,6 +87,8 @@ class AttrOrTypeParameter {
   /// Get an optional C++ parameter parser.
   std::optional<StringRef> getParser() const;
 
+  std::optional<TypeConstraint> getTypeConstraint() const;
+
   /// Get an optional C++ parameter printer.
   std::optional<StringRef> getPrinter() const;
 
@@ -198,6 +202,8 @@ class AttrOrTypeDef {
   /// method.
   bool genVerifyDecl() const;
 
+  bool genVerifyInvariantsImpl() const;
+
   /// Returns the def's extra class declaration code.
   std::optional<StringRef> getExtraDecls() const;
 
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 855952d19492d..f750a34a3b2ba 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -67,6 +67,8 @@ class MethodParameter {
 
   /// Get the C++ type.
   StringRef getType() const { return type; }
+  /// Get the C++ parameter name.
+  StringRef getName() const { return name; }
   /// Returns true if the parameter has a default value.
   bool hasDefaultValue() const { return !defaultValue.empty(); }
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7bc2668310ddb..a1f87a637a614 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
 }
 
 LogicalResult
-MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
-                      ArrayRef<int64_t> shape, Type elementType,
-                      StringRef operand) {
+MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                ArrayRef<int64_t> shape, Type elementType,
+                                StringRef operand) {
   if (operand != "AOp" && operand != "BOp" && operand != "COp")
     return emitError() << "operand expected to be one of AOp, BOp or COp";
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index dc7aef8ef7f85..7f10a15ff31ff 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
 
 bool LLVMStructType::isValidElementType(Type type) {
   return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
-                    LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
-      type);
+                    LLVMFunctionType, LLVMTokenType>(type);
 }
 
 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
@@ -492,14 +491,15 @@ ArrayRef<Type> LLVMStructType::getBody() const {
                         : getImpl()->getTypeList();
 }
 
-LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
-                                     StringRef, bool) {
+LogicalResult
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
+                                 bool) {
   return success();
 }
 
 LogicalResult
-LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
-                       ArrayRef<Type> types, bool) {
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                 ArrayRef<Type> types, bool) {
   for (Type t : types)
     if (!isValidElementType(t))
       return emitError() << "invalid LLVM structure element type: " << t;
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 81e3b914755be..c2ba9c04e8771 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -29,9 +29,10 @@ bool QuantizedType::classof(Type type) {
 }
 
 LogicalResult
-QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
-                      unsigned flags, Type storageType, Type expressedType,
-                      int64_t storageTypeMin, int64_t storageTypeMax) {
+QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                unsigned flags, Type storageType,
+                                Type expressedType, int64_t storageTypeMin,
+                                int64_t storageTypeMax) {
   // Verify that the storage type is integral.
   // This restriction may be lifted at some point in favor of using bf16
   // or f16 as exact representations on hardware where that is advantageous.
@@ -233,11 +234,13 @@ AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
 }
 
 LogicalResult
-AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
-                         unsigned flags, Type storageType, Type expressedType,
-                         int64_t storageTypeMin, int64_t storageTypeMax) {
-  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
-                                   storageTypeMin, storageTypeMax))) {
+AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                   unsigned flags, Type storageType,
+                                   Type expressedType, int64_t storageTypeMin,
+                                   int64_t storageTypeMax) {
+  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+                                             expressedType, storageTypeMin,
+                                             storageTypeMax))) {
     return failure();
   }
 
@@ -268,12 +271,13 @@ UniformQuantizedType UniformQuantizedType::getChecked(
                           storageTypeMin, storageTypeMax);
 }
 
-LogicalResult UniformQuantizedType::verify(
+LogicalResult UniformQuantizedType::verifyInvariants(
     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
     Type storageType, Type expressedType, double scale, int64_t zeroPoint,
     int64_t storageTypeMin, int64_t storageTypeMax) {
-  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
-                                   storageTypeMin, storageTypeMax))) {
+  if (failed(QuantizedType::verifyInvariants(emit...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 7, 2024

@llvm/pr-subscribers-mlir-quant

Author: Matthias Springer (matthias-springer)

Changes

When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing.

Example:

def TestTypeVerification : Test_Type&lt;"TestTypeVerification"&gt; {
  let parameters = (ins AnyTypeOf&lt;[I16, I32]&gt;:$param);
  // ...
}

No verification code was generated to ensure that $param is I16 or I32.

When type constraints a present, a new method will generated for types and attributes: verifyInvariantsImpl. (The naming is similar to op verifiers.) The user-provided verifier is called verify (no change). There is now a new entry point to type/attribute verification: verifyInvariants. This function calls both verifyInvariantsImpl and verify. If neither of those two verifications are present, the verifyInvariants function is not generated.

When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the verifyInvariants function. (This function was previously called verify.)


Patch is 33.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102326.diff

22 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h (+4-3)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h (+7-5)
  • (modified) mlir/include/mlir/Dialect/Quant/QuantTypes.h (+22-21)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h (+8-6)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+6-4)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+1)
  • (modified) mlir/include/mlir/IR/Constraints.td (+3)
  • (modified) mlir/include/mlir/IR/StorageUniquerSupport.h (+6-4)
  • (modified) mlir/include/mlir/IR/Types.h (+8-11)
  • (modified) mlir/include/mlir/TableGen/AttrOrTypeDef.h (+6)
  • (modified) mlir/include/mlir/TableGen/Class.h (+2)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+3-3)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp (+6-6)
  • (modified) mlir/lib/Dialect/Quant/IR/QuantTypes.cpp (+22-17)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp (+4-5)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+5-4)
  • (modified) mlir/lib/TableGen/AttrOrTypeDef.cpp (+13)
  • (added) mlir/test/IR/test-verifiers-type.mlir (+9)
  • (modified) mlir/test/lib/Dialect/Test/TestAttrDefs.td (-1)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+6)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp (+92-10)
  • (modified) mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 96e1935bd0a84..57acd72610415 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -148,9 +148,10 @@ class MMAMatrixType
 
   /// Verify that shape and elementType are actually allowed for the
   /// MMAMatrixType.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              ArrayRef<int64_t> shape, Type elementType,
-                              StringRef operand);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<int64_t> shape, Type elementType,
+                   StringRef operand);
 
   /// Get number of dims.
   unsigned getNumDims() const;
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
index 1befdfa74f67c..2ea589a7c4c3b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
@@ -180,11 +180,13 @@ class LLVMStructType
   ArrayRef<Type> getBody() const;
 
   /// Verifies that the type about to be constructed is well-formed.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              StringRef, bool);
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              ArrayRef<Type> types, bool);
-  using Base::verify;
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, StringRef,
+                   bool);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   ArrayRef<Type> types, bool);
+  using Base::verifyInvariants;
 
   /// Hooks for DataLayoutTypeInterface. Should not be called directly. Obtain a
   /// DataLayout instance and query it instead.
diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
index de5aed0a91a20..57a2aa2983365 100644
--- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h
+++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h
@@ -54,10 +54,10 @@ class QuantizedType : public Type {
   /// The maximum number of bits supported for storage types.
   static constexpr unsigned MaxStorageBits = 32;
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Support method to enable LLVM-style type casting.
   static bool classof(Type type);
@@ -214,10 +214,10 @@ class AnyQuantizedType
              int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 };
 
 /// Represents a family of uniform, quantized types.
@@ -276,11 +276,11 @@ class UniformQuantizedType
              int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, double scale,
-                              int64_t zeroPoint, int64_t storageTypeMin,
-                              int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType, double scale,
+                   int64_t zeroPoint, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Gets the scale term. The scale designates the difference between the real
   /// values corresponding to consecutive quantized values differing by 1.
@@ -338,12 +338,12 @@ class UniformQuantizedPerAxisType
              int64_t storageTypeMin, int64_t storageTypeMax);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              unsigned flags, Type storageType,
-                              Type expressedType, ArrayRef<double> scales,
-                              ArrayRef<int64_t> zeroPoints,
-                              int32_t quantizedDimension,
-                              int64_t storageTypeMin, int64_t storageTypeMax);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError, unsigned flags,
+                   Type storageType, Type expressedType,
+                   ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
+                   int32_t quantizedDimension, int64_t storageTypeMin,
+                   int64_t storageTypeMax);
 
   /// Gets the quantization scales. The scales designate the difference between
   /// the real values corresponding to consecutive quantized values differing
@@ -403,8 +403,9 @@ class CalibratedQuantizedType
              double min, double max);
 
   /// Verifies construction invariants and issues errors/warnings.
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type expressedType, double min, double max);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type expressedType, double min, double max);
   double getMin() const;
   double getMax() const;
 };
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
index 5ebfa9ca5ec25..2bdd7a5bf3dd8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h
@@ -76,9 +76,10 @@ class InterfaceVarABIAttr
   /// Returns `spirv::StorageClass`.
   std::optional<StorageClass> getStorageClass();
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              IntegerAttr descriptorSet, IntegerAttr binding,
-                              IntegerAttr storageClass);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   IntegerAttr descriptorSet, IntegerAttr binding,
+                   IntegerAttr storageClass);
 
   static constexpr StringLiteral name = "spirv.interface_var_abi";
 };
@@ -128,9 +129,10 @@ class VerCapExtAttr
   /// Returns the capabilities as an integer array attribute.
   ArrayAttr getCapabilitiesAttr();
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              IntegerAttr version, ArrayAttr capabilities,
-                              ArrayAttr extensions);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   IntegerAttr version, ArrayAttr capabilities,
+                   ArrayAttr extensions);
 
   static constexpr StringLiteral name = "spirv.ver_cap_ext";
 };
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 55f0c787b4440..e2d04553d91b8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -258,8 +258,9 @@ class SampledImageType
   static SampledImageType
   getChecked(function_ref<InFlightDiagnostic()> emitError, Type imageType);
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type imageType);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type imageType);
 
   Type getImageType() const;
 
@@ -462,8 +463,9 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
   static MatrixType getChecked(function_ref<InFlightDiagnostic()> emitError,
                                Type columnType, uint32_t columnCount);
 
-  static LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
-                              Type columnType, uint32_t columnCount);
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                   Type columnType, uint32_t columnCount);
 
   /// Returns true if the matrix elements are vectors of float elements.
   static bool isValidColumnType(Type columnType);
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 5b6ec167fa242..4d3e1428e6c40 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -180,6 +180,7 @@ class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
         summary),
     cppClassName> {
   list<Type> allowedTypes = allowedTypeList;
+  string cppType = cppClassName;
 }
 
 // A type that satisfies the constraints of all given types.
diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td
index a026d58ccffb8..242c850f38f30 100644
--- a/mlir/include/mlir/IR/Constraints.td
+++ b/mlir/include/mlir/IR/Constraints.td
@@ -153,6 +153,9 @@ class TypeConstraint<Pred predicate, string summary = "",
     Constraint<predicate, summary> {
   // The name of the C++ Type class if known, or Type if not.
   string cppClassName = cppClassNameParam;
+  // TODO: This field is sometimes called `cppClassName` and sometimes
+  // `cppType`. Use a single name consistently.
+  string cppType = cppClassNameParam;
 }
 
 // Subclass for constraints on an attribute.
diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h
index fb64f15162df5..d6ccbbd857994 100644
--- a/mlir/include/mlir/IR/StorageUniquerSupport.h
+++ b/mlir/include/mlir/IR/StorageUniquerSupport.h
@@ -176,8 +176,8 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   template <typename... Args>
   static ConcreteT get(MLIRContext *ctx, Args &&...args) {
     // Ensure that the invariants are correct for construction.
-    assert(
-        succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...)));
+    assert(succeeded(
+        ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...)));
     return UniquerT::template get<ConcreteT>(ctx, std::forward<Args>(args)...);
   }
 
@@ -198,7 +198,7 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
   static ConcreteT getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
                               MLIRContext *ctx, Args... args) {
     // If the construction invariants fail then we return a null attribute.
-    if (failed(ConcreteT::verify(emitErrorFn, args...)))
+    if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...)))
       return ConcreteT();
     return UniquerT::template get<ConcreteT>(ctx, args...);
   }
@@ -226,7 +226,9 @@ class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
 
   /// Default implementation that just returns success.
   template <typename... Args>
-  static LogicalResult verify(Args... args) {
+  static LogicalResult
+  verifyInvariants(function_ref<InFlightDiagnostic()> emitErrorFn,
+                   Args... args) {
     return success();
   }
 
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 60dc8fee0f4a9..91b457deeba2f 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -34,7 +34,7 @@ class AsmState;
 /// Derived type classes are expected to implement several required
 /// implementation hooks:
 ///  * Optional:
-///    - static LogicalResult verify(
+///    - static LogicalResult verifyInvariants(
 ///                                function_ref<InFlightDiagnostic()> emitError,
 ///                                Args... args)
 ///      * This method is invoked when calling the 'TypeBase::get/getChecked'
@@ -97,20 +97,17 @@ class Type {
   bool operator!() const { return impl == nullptr; }
 
   template <typename... Tys>
-  [[deprecated("Use mlir::isa<U>() instead")]]
-  bool isa() const;
+  [[deprecated("Use mlir::isa<U>() instead")]] bool isa() const;
   template <typename... Tys>
-  [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]]
-  bool isa_and_nonnull() const;
+  [[deprecated("Use mlir::isa_and_nonnull<U>() instead")]] bool
+  isa_and_nonnull() const;
   template <typename U>
-  [[deprecated("Use mlir::dyn_cast<U>() instead")]]
-  U dyn_cast() const;
+  [[deprecated("Use mlir::dyn_cast<U>() instead")]] U dyn_cast() const;
   template <typename U>
-  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]]
-  U dyn_cast_or_null() const;
+  [[deprecated("Use mlir::dyn_cast_or_null<U>() instead")]] U
+  dyn_cast_or_null() const;
   template <typename U>
-  [[deprecated("Use mlir::cast<U>() instead")]]
-  U cast() const;
+  [[deprecated("Use mlir::cast<U>() instead")]] U cast() const;
 
   /// Return a unique identifier for the concrete type. This is used to support
   /// dynamic type casting.
diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
index 19c3a9183ec2c..c7819ef7c0ffa 100644
--- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h
+++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h
@@ -15,7 +15,9 @@
 #define MLIR_TABLEGEN_ATTRORTYPEDEF_H
 
 #include "mlir/Support/LLVM.h"
+#include "mlir/TableGen/Argument.h"
 #include "mlir/TableGen/Builder.h"
+#include "mlir/TableGen/Predicate.h"
 #include "mlir/TableGen/Trait.h"
 
 namespace llvm {
@@ -85,6 +87,8 @@ class AttrOrTypeParameter {
   /// Get an optional C++ parameter parser.
   std::optional<StringRef> getParser() const;
 
+  std::optional<TypeConstraint> getTypeConstraint() const;
+
   /// Get an optional C++ parameter printer.
   std::optional<StringRef> getPrinter() const;
 
@@ -198,6 +202,8 @@ class AttrOrTypeDef {
   /// method.
   bool genVerifyDecl() const;
 
+  bool genVerifyInvariantsImpl() const;
+
   /// Returns the def's extra class declaration code.
   std::optional<StringRef> getExtraDecls() const;
 
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index 855952d19492d..f750a34a3b2ba 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -67,6 +67,8 @@ class MethodParameter {
 
   /// Get the C++ type.
   StringRef getType() const { return type; }
+  /// Get the C++ parameter name.
+  StringRef getName() const { return name; }
   /// Returns true if the parameter has a default value.
   bool hasDefaultValue() const { return !defaultValue.empty(); }
 
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 7bc2668310ddb..a1f87a637a614 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -148,9 +148,9 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
 }
 
 LogicalResult
-MMAMatrixType::verify(function_ref<InFlightDiagnostic()> emitError,
-                      ArrayRef<int64_t> shape, Type elementType,
-                      StringRef operand) {
+MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                ArrayRef<int64_t> shape, Type elementType,
+                                StringRef operand) {
   if (operand != "AOp" && operand != "BOp" && operand != "COp")
     return emitError() << "operand expected to be one of AOp, BOp or COp";
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
index dc7aef8ef7f85..7f10a15ff31ff 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
@@ -418,8 +418,7 @@ LogicalResult LLVMPointerType::verifyEntries(DataLayoutEntryListRef entries,
 
 bool LLVMStructType::isValidElementType(Type type) {
   return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType,
-                    LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(
-      type);
+                    LLVMFunctionType, LLVMTokenType>(type);
 }
 
 LLVMStructType LLVMStructType::getIdentified(MLIRContext *context,
@@ -492,14 +491,15 @@ ArrayRef<Type> LLVMStructType::getBody() const {
                         : getImpl()->getTypeList();
 }
 
-LogicalResult LLVMStructType::verify(function_ref<InFlightDiagnostic()>,
-                                     StringRef, bool) {
+LogicalResult
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()>, StringRef,
+                                 bool) {
   return success();
 }
 
 LogicalResult
-LLVMStructType::verify(function_ref<InFlightDiagnostic()> emitError,
-                       ArrayRef<Type> types, bool) {
+LLVMStructType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                 ArrayRef<Type> types, bool) {
   for (Type t : types)
     if (!isValidElementType(t))
       return emitError() << "invalid LLVM structure element type: " << t;
diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
index 81e3b914755be..c2ba9c04e8771 100644
--- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp
@@ -29,9 +29,10 @@ bool QuantizedType::classof(Type type) {
 }
 
 LogicalResult
-QuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
-                      unsigned flags, Type storageType, Type expressedType,
-                      int64_t storageTypeMin, int64_t storageTypeMax) {
+QuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                unsigned flags, Type storageType,
+                                Type expressedType, int64_t storageTypeMin,
+                                int64_t storageTypeMax) {
   // Verify that the storage type is integral.
   // This restriction may be lifted at some point in favor of using bf16
   // or f16 as exact representations on hardware where that is advantageous.
@@ -233,11 +234,13 @@ AnyQuantizedType::getChecked(function_ref<InFlightDiagnostic()> emitError,
 }
 
 LogicalResult
-AnyQuantizedType::verify(function_ref<InFlightDiagnostic()> emitError,
-                         unsigned flags, Type storageType, Type expressedType,
-                         int64_t storageTypeMin, int64_t storageTypeMax) {
-  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
-                                   storageTypeMin, storageTypeMax))) {
+AnyQuantizedType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+                                   unsigned flags, Type storageType,
+                                   Type expressedType, int64_t storageTypeMin,
+                                   int64_t storageTypeMax) {
+  if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType,
+                                             expressedType, storageTypeMin,
+                                             storageTypeMax))) {
     return failure();
   }
 
@@ -268,12 +271,13 @@ UniformQuantizedType UniformQuantizedType::getChecked(
                           storageTypeMin, storageTypeMax);
 }
 
-LogicalResult UniformQuantizedType::verify(
+LogicalResult UniformQuantizedType::verifyInvariants(
     function_ref<InFlightDiagnostic()> emitError, unsigned flags,
     Type storageType, Type expressedType, double scale, int64_t zeroPoint,
     int64_t storageTypeMin, int64_t storageTypeMax) {
-  if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType,
-                                   storageTypeMin, storageTypeMax))) {
+  if (failed(QuantizedType::verifyInvariants(emit...
[truncated]

@matthias-springer matthias-springer force-pushed the users/matthias-springer/ods_type_verifier branch from 8df1c42 to 2134ae9 Compare August 7, 2024 16:04
@matthias-springer matthias-springer force-pushed the users/matthias-springer/ods_type_verifier branch from 2134ae9 to bb21067 Compare August 7, 2024 16:16
Copy link
Contributor

@River707 River707 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the plan for attribute constraints? Are you intending to do that in a followup? Or can you just add it here? Would be good to have parity between the two.

@River707
Copy link
Contributor

River707 commented Aug 7, 2024

This looks quite nice though! It's good to have more parity between ops/non-ops. Can you also add some documentation to the attr/type definition page detailing that constraints can also be used in parameters? (not sure if there is a good spot for that)

@matthias-springer matthias-springer force-pushed the users/matthias-springer/ods_type_verifier branch from bb21067 to f339d44 Compare August 7, 2024 18:45
@matthias-springer
Copy link
Member Author

What is the plan for attribute constraints? Are you intending to do that in a followup? Or can you just add it here? Would be good to have parity between the two.

I'm going to take a look tmr. I think it's very similar and can be done in this PR.

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good! It's nice and compact :-) River made good point and good to have together as you proposed.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/ods_type_verifier branch from f339d44 to 3549817 Compare August 8, 2024 09:10
@llvmbot llvmbot added the mlir:sparse Sparse compiler in MLIR label Aug 8, 2024
@matthias-springer matthias-springer force-pushed the users/matthias-springer/ods_type_verifier branch 2 times, most recently from 36730cf to 33db3e6 Compare August 8, 2024 09:47
@matthias-springer
Copy link
Member Author

matthias-springer commented Aug 8, 2024

I took a look at attribute constraints. My implementation already generates the verifier, but a different part of the TableGen setup does not support it.

Example:

def TestAttrVerification : Test_Attr<"TestAttrVerification"> {
  let parameters = (ins I32ArrayAttr:$param);
  let mnemonic = "attr_verification";
  let assemblyFormat = "`<` $param `>`";
}
llvm-project/mlir/include/mlir/IR/CommonAttrConstraints.td:600:5: error: Missing `cppType` field in Attribute/Type parameter: I32ArrayAttr
def I32ArrayAttr : TypedArrayAttrBase<I32Attr,

The reason is that AttributeConstraint is slightly different from TypeConstraint: it does not have a cppType/cppClassName field. Adding that could be a larger change, so I'd rather do that in a separate commit.

I added another test to attr-or-type-format.td that shows that a verifier is generated for parameters that are attributes (AttrDef).

@matthias-springer matthias-springer force-pushed the users/matthias-springer/ods_type_verifier branch from ec0d3fb to 27f3ffa Compare August 9, 2024 18:52
@matthias-springer matthias-springer changed the base branch from main to users/matthias-springer/td_type_cpp_type August 9, 2024 18:52
Base automatically changed from users/matthias-springer/td_type_cpp_type to main August 9, 2024 19:53
When a type/attribute is defined in TableGen, a type constraint can be used for parameters, but the type constraint verification was missing.

Example:
```
def TestTypeVerification : Test_Type<"TestTypeVerification"> {
  let parameters = (ins AnyTypeOf<[I16, I32]>:$param);
  // ...
}
```

No verification code was generated to ensure that `$param` is I16 or I32.

When type constraints a present, a new method will generated for types and attributes: `verifyInvariantsImpl`. (The naming is similar to op verifiers.) The user-provided verifier is called `verify` (no change). There is now a new entry point to type/attribute verification: `verifyInvariants`. This function calls both `verifyInvariantsImpl` and `verify`. If neither of those two verifications are present, the `verifyInvariants` function is not generated.

When a type/attribute is not defined in TableGen, but a verifier is needed, users can implement the `verifyInvariants` function. (This function was previously called `verify`.)
@matthias-springer matthias-springer force-pushed the users/matthias-springer/ods_type_verifier branch from 27f3ffa to a82692f Compare August 9, 2024 19:54
@matthias-springer matthias-springer merged commit 7359a6b into main Aug 9, 2024
5 of 7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/ods_type_verifier branch August 9, 2024 20:04
kutemeikito added a commit to kutemeikito/llvm-project that referenced this pull request Aug 10, 2024
* 'main' of https://github.com/llvm/llvm-project: (700 commits)
  [SandboxIR][NFC] SingleLLVMInstructionImpl class (llvm#102687)
  [ThinLTO]Clean up 'import-assume-unique-local' flag. (llvm#102424)
  [nsan] Make #include more conventional
  [SandboxIR][NFC] Use Tracker.emplaceIfTracking()
  [libc]  Moved range_reduction_double ifdef statement (llvm#102659)
  [libc] Fix CFP long double and add tests (llvm#102660)
  [TargetLowering] Handle vector types in expandFixedPointMul (llvm#102635)
  [compiler-rt][NFC] Replace environment variable with %t (llvm#102197)
  [UnitTests] Convert a test to use opaque pointers (llvm#102668)
  [CodeGen][NFCI] Don't re-implement parts of ASTContext::getIntWidth (llvm#101765)
  [SandboxIR] Clean up tracking code with the help of emplaceIfTracking() (llvm#102406)
  [mlir][bazel] remove extra blanks in mlir-tblgen test
  [NVPTX][NFC] Update tests to use bfloat type (llvm#101493)
  [mlir] Add support for parsing nested PassPipelineOptions (llvm#101118)
  [mlir][bazel] add missing td dependency in mlir-tblgen test
  [flang][cuda] Fix lib dependency
  [libc] Clean up remaining use of *_WIDTH macros in printf (llvm#102679)
  [flang][cuda] Convert cuf.alloc for box to fir.alloca in device context (llvm#102662)
  [SandboxIR] Implement the InsertElementInst class (llvm#102404)
  [libc] Fix use of cpp::numeric_limits<...>::digits (llvm#102674)
  [mlir][ODS] Verify type constraints in Types and Attributes (llvm#102326)
  [LTO] enable `ObjCARCContractPass` only on optimized build  (llvm#101114)
  [mlir][ODS] Consistent `cppType` / `cppClassName` usage (llvm#102657)
  [lldb] Move definition of SBSaveCoreOptions dtor out of header (llvm#102539)
  [libc] Use cpp::numeric_limits in preference to C23 <limits.h> macros (llvm#102665)
  [clang] Implement -fptrauth-auth-traps. (llvm#102417)
  [LLVM][rtsan] rtsan transform to preserve CFGAnalyses (llvm#102651)
  Revert "[AMDGPU] Move `AMDGPUAttributorPass` to full LTO post link stage (llvm#102086)"
  [RISCV][GISel] Add missing tests for G_CTLZ/CTTZ instruction selection. NFC
  Return available function types for BindingDecls. (llvm#102196)
  [clang] Wire -fptrauth-returns to "ptrauth-returns" fn attribute. (llvm#102416)
  [RISCV] Remove riscv-experimental-rv64-legal-i32. (llvm#102509)
  [RISCV] Move PseudoVSET(I)VLI expansion to use PseudoInstExpansion. (llvm#102496)
  [NVPTX] support switch statement with brx.idx (reland) (llvm#102550)
  [libc][newhdrgen]sorted function names in yaml (llvm#102544)
  [GlobalIsel] Combine G_ADD and G_SUB with constants (llvm#97771)
  Suppress spurious warnings due to R_RISCV_SET_ULEB128
  [scudo] Separated committed and decommitted entries. (llvm#101409)
  [MIPS] Fix missing ANDI optimization (llvm#97689)
  [Clang] Add env var for nvptx-arch/amdgpu-arch timeout (llvm#102521)
  [asan] Switch allocator to dynamic base address (llvm#98511)
  [AMDGPU] Move `AMDGPUAttributorPass` to full LTO post link stage (llvm#102086)
  [libc][math][c23] Add fadd{l,f128} C23 math functions (llvm#102531)
  [mlir][bazel] revert bazel rule change for DLTITransformOps
  [msan] Support vst{2,3,4}_lane instructions (llvm#101215)
  Revert "[MLIR][DLTI][Transform] Introduce transform.dlti.query (llvm#101561)"
  [X86] pr57673.ll - generate MIR test checks
  [mlir][vector][test] Split tests from vector-transfer-flatten.mlir (llvm#102584)
  [mlir][bazel] add bazel rule for DLTITransformOps
  OpenMPOpt: Remove dead include
  [IR] Add method to GlobalVariable to change type of initializer. (llvm#102553)
  [flang][cuda] Force default allocator in device code (llvm#102238)
  [llvm] Construct SmallVector<SDValue> with ArrayRef (NFC) (llvm#102578)
  [MLIR][DLTI][Transform] Introduce transform.dlti.query (llvm#101561)
  [AMDGPU][AsmParser][NFC] Remove a misleading comment. (llvm#102604)
  [Arm][AArch64][Clang] Respect function's branch protection attributes. (llvm#101978)
  [mlir] Verifier: steal bit to track seen instead of set. (llvm#102626)
  [Clang] Fix Handling of Init Capture with Parameter Packs in LambdaScopeForCallOperatorInstantiationRAII (llvm#100766)
  [X86] Convert truncsat clamping patterns to use SDPatternMatch. NFC.
  [gn] Give two scripts argparse.RawDescriptionHelpFormatter
  [bazel] Add missing dep for the SPIRVToLLVM target
  [Clang] Simplify specifying passes via -Xoffload-linker (llvm#102483)
  [bazel] Port for d45de80
  [SelectionDAG] Use unaligned store/load to move AVX registers onto stack for `insertelement` (llvm#82130)
  [Clang][OMPX] Add the code generation for multi-dim `num_teams` (llvm#101407)
  [ARM] Regenerate big-endian-vmov.ll. NFC
  [AMDGPU][AsmParser][NFCI] All NamedIntOperands to be of the i32 type. (llvm#102616)
  [libc][math][c23] Add totalorderl function. (llvm#102564)
  [mlir][spirv] Support `memref` in `convert-to-spirv` pass (llvm#102534)
  [MLIR][GPU-LLVM] Convert `gpu.func` to `llvm.func` (llvm#101664)
  Fix a unit test input file (llvm#102567)
  [llvm-readobj][COFF] Dump hybrid objects for ARM64X files. (llvm#102245)
  AMDGPU/NewPM: Port SIFixSGPRCopies to new pass manager (llvm#102614)
  [MemoryBuiltins] Simplify getCalledFunction() helper (NFC)
  [AArch64] Add invalid 1 x vscale costs for reductions and reduction-operations. (llvm#102105)
  [MemoryBuiltins] Handle allocator attributes on call-site
  LSV/test/AArch64: add missing lit.local.cfg; fix build (llvm#102607)
  Revert "Enable logf128 constant folding for hosts with 128bit floats (llvm#96287)"
  [RISCV] Add Syntacore SCR5 RV32/64 processors definition (llvm#102285)
  [InstCombine] Remove unnecessary RUN line from test (NFC)
  [flang][OpenMP] Handle multiple ranges in `num_teams` clause (llvm#102535)
  [mlir][vector] Add tests for scalable vectors in one-shot-bufferize.mlir (llvm#102361)
  [mlir][vector] Disable `vector.matrix_multiply` for scalable vectors (llvm#102573)
  [clang] Implement CWG2627 Bit-fields and narrowing conversions (llvm#78112)
  [NFC] Use references to avoid copying (llvm#99863)
  Revert "[mlir][ArmSME] Pattern to swap shape_cast(tranpose) with transpose(shape_cast) (llvm#100731)" (llvm#102457)
  [IRBuilder] Generate nuw GEPs for struct member accesses (llvm#99538)
  [bazel] Port for 9b06e25
  [CodeGen][NewPM] Improve start/stop pass error message CodeGenPassBuilder (llvm#102591)
  [AArch64] Implement TRBMPAM_EL1 system register (llvm#102485)
  [InstCombine] Fixing wrong select folding in vectors with undef elements (llvm#102244)
  [AArch64] Sink operands to fmuladd. (llvm#102297)
  LSV: document hang reported in llvm#37865 (llvm#102479)
  Enable logf128 constant folding for hosts with 128bit floats (llvm#96287)
  [RISCV][clang] Remove bfloat base type in non-zvfbfmin vcreate (llvm#102146)
  [RISCV][clang] Add missing `zvfbfmin` to `vget_v` intrinsic (llvm#102149)
  [mlir][vector] Add mask elimination transform (llvm#99314)
  [Clang][Interp] Fix display of syntactically-invalid note for member function calls (llvm#102170)
  [bazel] Port for 3fffa6d
  [DebugInfo][RemoveDIs] Use iterator-inserters in clang (llvm#102006)
  ...

Signed-off-by: Edwiin Kusuma Jaya <kutemeikito0905@gmail.com>
matthias-springer added a commit that referenced this pull request Aug 12, 2024
…02449)

#102326 enables verification of type parameters that are type
constraints. The element type verification for `VectorType` (and maybe
other builtin types in the future) can now be auto-generated.

Also remove redundant error checking in the vector type parser: element
type and dimensions are already checked by the verifier (which is called
from `getChecked`).

Depends on #102326.
bchetioui added a commit to bchetioui/llvm-project that referenced this pull request Aug 13, 2024
…on of verifyInvariants.

PR llvm#102326 changed the prototype of the default implementation of
verify to include emitErrorFn.

This breaks automatic derivation in consumer attributes, such as
https://github.com/tensorflow/runtime/blob/60277ba976739502e45ad26585e071568fa44af1/include/tfrt/core_runtime/opdefs/attributes.h#L53.

This PR simply restores the signature to what it was prior to PR llvm#102326.
bchetioui added a commit to bchetioui/llvm-project that referenced this pull request Aug 13, 2024
…on of verifyInvariants.

PR llvm#102326 changed the prototype of the default implementation of
verify to include emitErrorFn.

This breaks automatic derivation in consumer attributes, such as
https://github.com/tensorflow/runtime/blob/60277ba976739502e45ad26585e071568fa44af1/include/tfrt/core_runtime/opdefs/attributes.h#L53.

This PR simply restores the signature to what it was prior to PR llvm#102326.
bchetioui added a commit that referenced this pull request Aug 13, 2024
…n of verifyInvariants. (#103023)

PR #102326 changed the prototype of the default implementation of verify
to include emitErrorFn.

This breaks automatic derivation in consumer attributes, such as
https://github.com/tensorflow/runtime/blob/60277ba976739502e45ad26585e071568fa44af1/include/tfrt/core_runtime/opdefs/attributes.h#L53.

This PR simply restores the signature to what it was prior to PR
#102326.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants