Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated the annotations of Python bindings #92733

Merged
merged 1 commit into from
Jun 11, 2024

Conversation

superbobry
Copy link
Contributor

@superbobry superbobry commented May 20, 2024

  • pybind11 allows passing both str and bytes to std::string parameters.
  • pybind11 allows passing any Sequence to std::vector parameters. I only updated the signatures of *ArrayAttr.get methods.
  • Sliceable implicitly adds __getitem__ and __len__.
  • Dense*ElementsAttr now overload get. This ensures that e.g. DenseIntElementsAttr.get returns DenseIntElementsAttr, and not just DenseElementsAttr.
  • # value = <mlir._mlir_libs._TypeID object> comments are unnecessary.

* pybind11 allows passing both str and bytes to std::string parameters.
* pybind11 allows passing any Sequence to std::vector parameters.
  I only updated the signatures of *ArrayAttr.get methods.
* Sliceable implicitly adds __getitem__ and __len__.
* Dense*ElementsAttr now overload get. This ensures that e.g.
  DenseIntElementsAttr.get returns DenseIntElementsAttr, and not just
  DenseElementsAttr.
* `# value = <mlir._mlir_libs._TypeID object>` comments are unncessary.
@llvmbot
Copy link
Collaborator

llvmbot commented May 20, 2024

@llvm/pr-subscribers-mlir

Author: Sergei Lebedev (superbobry)

Changes
  • pybind11 allows passing both str and bytes to std::string parameters.
  • pybind11 allows passing any Sequence to std::vector parameters. I only updated the signatures of *ArrayAttr.get methods.
  • Sliceable implicitly adds __getitem__ and __len__.
  • Dense*ElementsAttr now overload get. This ensures that e.g. DenseIntElementsAttr.get returns DenseIntElementsAttr, and not just DenseElementsAttr.
  • # value = &lt;mlir._mlir_libs._TypeID object&gt; comments are unncessary.

Full diff: https://github.com/llvm/llvm-project/pull/92733.diff

1 Files Affected:

  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+72-46)
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 586bf7f8e93fb..1e1b2a8348b1d 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -479,7 +479,7 @@ class AffineExpr:
 
 class Attribute:
     @staticmethod
-    def parse(asm: str, context: Optional[Context] = None) -> Attribute:
+    def parse(asm: str | bytes, context: Optional[Context] = None) -> Attribute:
         """
         Parses an attribute from an assembly form. Raises an MLIRError on failure.
         """
@@ -520,7 +520,7 @@ class Attribute:
 
 class Type:
     @staticmethod
-    def parse(asm: str, context: Optional[Context] = None) -> Type:
+    def parse(asm: str | bytes, context: Optional[Context] = None) -> Type:
         """
         Parses the assembly form of a type.
 
@@ -741,7 +741,7 @@ class AffineMap:
     def results(self) -> "AffineMapExprList": ...
 
 class AffineMapAttr(Attribute):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(affine_map: AffineMap) -> AffineMapAttr:
         """
@@ -779,7 +779,7 @@ class AffineSymbolExpr(AffineExpr):
     def position(self) -> int: ...
 
 class ArrayAttr(Attribute):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(attributes: List, context: Optional[Context] = None) -> ArrayAttr:
         """
@@ -823,7 +823,7 @@ class AttrBuilder:
         """
 
 class BF16Type(Type):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(context: Optional[Context] = None) -> BF16Type:
         """
@@ -909,6 +909,11 @@ class BlockArgument(Value):
     def owner(self) -> Block: ...
 
 class BlockArgumentList:
+    @overload
+    def __getitem__(self, arg0: int) -> BlockArgument: ...
+    @overload
+    def __getitem__(self, arg0: slice) -> BlockArgumentList: ...
+    def __len__(self) -> int: ...
     def __add__(self, arg0: BlockArgumentList) -> List[BlockArgument]: ...
     @property
     def types(self) -> List[Type]: ...
@@ -955,7 +960,7 @@ class BoolAttr(Attribute):
         """
 
 class ComplexType(Type):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(arg0: Type) -> ComplexType:
         """
@@ -1016,7 +1021,7 @@ class Context:
 class DenseBoolArrayAttr(Attribute):
     @staticmethod
     def get(
-        values: List[bool], context: Optional[Context] = None
+        values: Sequence[bool], context: Optional[Context] = None
     ) -> DenseBoolArrayAttr:
         """
         Gets a uniqued dense array attribute
@@ -1113,7 +1118,7 @@ class DenseElementsAttr(Attribute):
 class DenseF32ArrayAttr(Attribute):
     @staticmethod
     def get(
-        values: List[float], context: Optional[Context] = None
+        values: Sequence[float], context: Optional[Context] = None
     ) -> DenseF32ArrayAttr:
         """
         Gets a uniqued dense array attribute
@@ -1141,7 +1146,7 @@ class DenseF32ArrayIterator:
 class DenseF64ArrayAttr(Attribute):
     @staticmethod
     def get(
-        values: List[float], context: Optional[Context] = None
+        values: Sequence[float], context: Optional[Context] = None
     ) -> DenseF64ArrayAttr:
         """
         Gets a uniqued dense array attribute
@@ -1167,6 +1172,14 @@ class DenseF64ArrayIterator:
     def __next__(self) -> float: ...
 
 class DenseFPElementsAttr(DenseElementsAttr):
+    @staticmethod
+    def get(
+        array: Buffer,
+        signless: bool = True,
+        type: Optional[Type] = None,
+        shape: Optional[List[int]] = None,
+        context: Optional[Context] = None,
+    ) -> DenseFPElementsAttr: ...
     @staticmethod
     def isinstance(other: Attribute) -> bool: ...
     def __getitem__(self, arg0: int) -> float: ...
@@ -1180,7 +1193,7 @@ class DenseFPElementsAttr(DenseElementsAttr):
 
 class DenseI16ArrayAttr(Attribute):
     @staticmethod
-    def get(values: List[int], context: Optional[Context] = None) -> DenseI16ArrayAttr:
+    def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI16ArrayAttr:
         """
         Gets a uniqued dense array attribute
         """
@@ -1206,7 +1219,7 @@ class DenseI16ArrayIterator:
 
 class DenseI32ArrayAttr(Attribute):
     @staticmethod
-    def get(values: List[int], context: Optional[Context] = None) -> DenseI32ArrayAttr:
+    def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI32ArrayAttr:
         """
         Gets a uniqued dense array attribute
         """
@@ -1232,7 +1245,7 @@ class DenseI32ArrayIterator:
 
 class DenseI64ArrayAttr(Attribute):
     @staticmethod
-    def get(values: List[int], context: Optional[Context] = None) -> DenseI64ArrayAttr:
+    def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI64ArrayAttr:
         """
         Gets a uniqued dense array attribute
         """
@@ -1258,7 +1271,7 @@ class DenseI64ArrayIterator:
 
 class DenseI8ArrayAttr(Attribute):
     @staticmethod
-    def get(values: List[int], context: Optional[Context] = None) -> DenseI8ArrayAttr:
+    def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI8ArrayAttr:
         """
         Gets a uniqued dense array attribute
         """
@@ -1283,6 +1296,14 @@ class DenseI8ArrayIterator:
     def __next__(self) -> int: ...
 
 class DenseIntElementsAttr(DenseElementsAttr):
+    @staticmethod
+    def get(
+        array: Buffer,
+        signless: bool = True,
+        type: Optional[Type] = None,
+        shape: Optional[List[int]] = None,
+        context: Optional[Context] = None,
+    ) -> DenseIntElementsAttr: ...
     @staticmethod
     def isinstance(other: Attribute) -> bool: ...
     def __getitem__(self, arg0: int) -> int: ...
@@ -1422,7 +1443,7 @@ class Dialects:
     def __getitem__(self, arg0: str) -> Dialect: ...
 
 class DictAttr(Attribute):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(value: Dict = {}, context: Optional[Context] = None) -> DictAttr:
         """
@@ -1453,7 +1474,7 @@ class FloatType(Type):
         """
 
 class F16Type(FloatType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(context: Optional[Context] = None) -> F16Type:
         """
@@ -1466,7 +1487,7 @@ class F16Type(FloatType):
     def typeid(self) -> TypeID: ...
 
 class F32Type(FloatType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(context: Optional[Context] = None) -> F32Type:
         """
@@ -1479,7 +1500,7 @@ class F32Type(FloatType):
     def typeid(self) -> TypeID: ...
 
 class F64Type(FloatType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(context: Optional[Context] = None) -> F64Type:
         """
@@ -1513,7 +1534,7 @@ class FlatSymbolRefAttr(Attribute):
         """
 
 class Float8E4M3B11FNUZType(FloatType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E4M3B11FNUZType:
         """
@@ -1526,7 +1547,7 @@ class Float8E4M3B11FNUZType(FloatType):
     def typeid(self) -> TypeID: ...
 
 class Float8E4M3FNType(FloatType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E4M3FNType:
         """
@@ -1539,7 +1560,7 @@ class Float8E4M3FNType(FloatType):
     def typeid(self) -> TypeID: ...
 
 class Float8E4M3FNUZType(FloatType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E4M3FNUZType:
         """
@@ -1552,7 +1573,7 @@ class Float8E4M3FNUZType(FloatType):
     def typeid(self) -> TypeID: ...
 
 class Float8E5M2FNUZType(FloatType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E5M2FNUZType:
         """
@@ -1565,7 +1586,7 @@ class Float8E5M2FNUZType(FloatType):
     def typeid(self) -> TypeID: ...
 
 class Float8E5M2Type(FloatType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(context: Optional[Context] = None) -> Float8E5M2Type:
         """
@@ -1578,7 +1599,7 @@ class Float8E5M2Type(FloatType):
     def typeid(self) -> TypeID: ...
 
 class FloatAttr(Attribute):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(type: Type, value: float, loc: Optional[Location] = None) -> FloatAttr:
         """
@@ -1612,7 +1633,7 @@ class FloatAttr(Attribute):
         """
 
 class FloatTF32Type(FloatType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(context: Optional[Context] = None) -> FloatTF32Type:
         """
@@ -1625,7 +1646,7 @@ class FloatTF32Type(FloatType):
     def typeid(self) -> TypeID: ...
 
 class FunctionType(Type):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(
         inputs: List[Type], results: List[Type], context: Optional[Context] = None
@@ -1650,7 +1671,7 @@ class FunctionType(Type):
     def typeid(self) -> TypeID: ...
 
 class IndexType(Type):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(context: Optional[Context] = None) -> IndexType:
         """
@@ -1766,7 +1787,7 @@ class InsertionPoint:
         """
 
 class IntegerAttr(Attribute):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(type: Type, value: int) -> IntegerAttr:
         """
@@ -1855,7 +1876,7 @@ class IntegerSetConstraintList:
     def __len__(self) -> int: ...
 
 class IntegerType(Type):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get_signed(width: int, context: Optional[Context] = None) -> IntegerType:
         """
@@ -1967,7 +1988,7 @@ class Location:
         """
 
 class MemRefType(ShapedType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(
         shape: List[int],
@@ -2007,7 +2028,7 @@ class Module:
         Creates an empty module
         """
     @staticmethod
-    def parse(asm: str, context: Optional[Context] = None) -> Module:
+    def parse(asm: str | bytes, context: Optional[Context] = None) -> Module:
         """
         Parses a module's assembly format from a string.
 
@@ -2064,7 +2085,7 @@ class NamedAttribute:
         """
 
 class NoneType(Type):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(context: Optional[Context] = None) -> NoneType:
         """
@@ -2130,7 +2151,12 @@ class OpResultList:
 
 class OpSuccessors:
     def __add__(self, arg0: OpSuccessors) -> List[Block]: ...
+    @overload
+    def __getitem__(self, arg0: int) -> Block: ...
+    @overload
+    def __getitem__(self, arg0: slice) -> OpSuccessors: ...
     def __setitem__(self, arg0: int, arg1: Block) -> None: ...
+    def __len__(self) -> int: ...
 
 class OpView(_OperationBase):
     _ODS_OPERAND_SEGMENTS: ClassVar[None] = ...
@@ -2154,7 +2180,7 @@ class OpView(_OperationBase):
     @classmethod
     def parse(
         cls: _Type[_TOperation],
-        source: str,
+        source: str | bytes,
         *,
         source_name: str = "",
         context: Optional[Context] = None,
@@ -2174,7 +2200,7 @@ class OpView(_OperationBase):
         """
 
 class OpaqueAttr(Attribute):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(
         dialect_namespace: str,
@@ -2204,7 +2230,7 @@ class OpaqueAttr(Attribute):
     def typeid(self) -> TypeID: ...
 
 class OpaqueType(Type):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(
         dialect_namespace: str, buffer: str, context: Optional[Context] = None
@@ -2262,7 +2288,7 @@ class Operation(_OperationBase):
         """
     @staticmethod
     def parse(
-        source: str, *, source_name: str = "", context: Optional[Context] = None
+        source: str | bytes, *, source_name: str = "", context: Optional[Context] = None
     ) -> Operation:
         """
         Parses an operation. Supports both text assembly format and binary bytecode format.
@@ -2290,7 +2316,7 @@ class OperationList:
     def __len__(self) -> int: ...
 
 class RankedTensorType(ShapedType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(
         shape: List[int],
@@ -2443,7 +2469,7 @@ class ShapedTypeComponents:
         """
 
 class StridedLayoutAttr(Attribute):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(
         offset: int, strides: List[int], context: Optional[Context] = None
@@ -2477,9 +2503,9 @@ class StridedLayoutAttr(Attribute):
     def typeid(self) -> TypeID: ...
 
 class StringAttr(Attribute):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
-    def get(value: str, context: Optional[Context] = None) -> StringAttr:
+    def get(value: str | bytes, context: Optional[Context] = None) -> StringAttr:
         """
         Gets a uniqued string attribute
         """
@@ -2554,9 +2580,9 @@ class SymbolTable:
     def insert(self, operation: _OperationBase) -> Attribute: ...
 
 class TupleType(Type):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
-    def get_Tuple(elements: List[Type], context: Optional[Context] = None) -> TupleType:
+    def get_tuple(elements: List[Type], context: Optional[Context] = None) -> TupleType:
         """
         Create a Tuple type
         """
@@ -2576,7 +2602,7 @@ class TupleType(Type):
     def typeid(self) -> TypeID: ...
 
 class TypeAttr(Attribute):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(value: Type, context: Optional[Context] = None) -> TypeAttr:
         """
@@ -2603,7 +2629,7 @@ class TypeID:
     def _CAPIPtr(self) -> object: ...
 
 class UnitAttr(Attribute):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(context: Optional[Context] = None) -> UnitAttr:
         """
@@ -2618,7 +2644,7 @@ class UnitAttr(Attribute):
     def typeid(self) -> TypeID: ...
 
 class UnrankedMemRefType(ShapedType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(
         element_type: Type, memory_space: Attribute, loc: Optional[Location] = None
@@ -2638,7 +2664,7 @@ class UnrankedMemRefType(ShapedType):
     def typeid(self) -> TypeID: ...
 
 class UnrankedTensorType(ShapedType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(element_type: Type, loc: Optional[Location] = None) -> UnrankedTensorType:
         """
@@ -2651,7 +2677,7 @@ class UnrankedTensorType(ShapedType):
     def typeid(self) -> TypeID: ...
 
 class VectorType(ShapedType):
-    static_typeid: ClassVar[TypeID]  # value = <mlir._mlir_libs._TypeID object>
+    static_typeid: ClassVar[TypeID]
     @staticmethod
     def get(
         shape: List[int],

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@superbobry
Copy link
Contributor Author

Are there any blockers for merging this?

@stellaraccident can you review when you have a moment, please?

@makslevental
Copy link
Contributor

makslevental commented Jun 11, 2024

@superbobry

ALZHSj

ie I think you can merge.

@superbobry
Copy link
Contributor Author

I don't think I can, unfortunately:

Only those with write access to this repository can merge pull requests.

@makslevental makslevental merged commit bc5ced5 into llvm:main Jun 11, 2024
7 checks passed
@makslevental
Copy link
Contributor

I don't think I can, unfortunately:

Done (didn't realize you didn't have merge access)

@HerrCai0907 HerrCai0907 mentioned this pull request Jun 13, 2024
@superbobry superbobry deleted the piper_export_cl_635399316 branch October 1, 2024 10:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants