From 6d130eccb3abf66df4f33a78843a4d5f152dae23 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 20 May 2024 10:19:59 +0000 Subject: [PATCH] Updated the annotations of Python bindings * pybind11 allows passing both str and bytes to std::string parameters. * pybind11 allows passing any Sequence to std::vector parameters. I only updated the signatures of *ArrayAttr.get methods. * Sliceable implicitly adds __getitem__ and __len__. * Dense*ElementsAttr now overload get. This ensures that e.g. DenseIntElementsAttr.get returns DenseIntElementsAttr, and not just DenseElementsAttr. * `# value = ` comments are unncessary. --- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 118 ++++++++++++++--------- 1 file changed, 72 insertions(+), 46 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 586bf7f8e93fba..1e1b2a8348b1d7 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -479,7 +479,7 @@ class AffineExpr: class Attribute: @staticmethod - def parse(asm: str, context: Optional[Context] = None) -> Attribute: + def parse(asm: str | bytes, context: Optional[Context] = None) -> Attribute: """ Parses an attribute from an assembly form. Raises an MLIRError on failure. """ @@ -520,7 +520,7 @@ class Attribute: class Type: @staticmethod - def parse(asm: str, context: Optional[Context] = None) -> Type: + def parse(asm: str | bytes, context: Optional[Context] = None) -> Type: """ Parses the assembly form of a type. @@ -741,7 +741,7 @@ class AffineMap: def results(self) -> "AffineMapExprList": ... class AffineMapAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(affine_map: AffineMap) -> AffineMapAttr: """ @@ -779,7 +779,7 @@ class AffineSymbolExpr(AffineExpr): def position(self) -> int: ... class ArrayAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(attributes: List, context: Optional[Context] = None) -> ArrayAttr: """ @@ -823,7 +823,7 @@ class AttrBuilder: """ class BF16Type(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> BF16Type: """ @@ -909,6 +909,11 @@ class BlockArgument(Value): def owner(self) -> Block: ... class BlockArgumentList: + @overload + def __getitem__(self, arg0: int) -> BlockArgument: ... + @overload + def __getitem__(self, arg0: slice) -> BlockArgumentList: ... + def __len__(self) -> int: ... def __add__(self, arg0: BlockArgumentList) -> List[BlockArgument]: ... @property def types(self) -> List[Type]: ... @@ -955,7 +960,7 @@ class BoolAttr(Attribute): """ class ComplexType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(arg0: Type) -> ComplexType: """ @@ -1016,7 +1021,7 @@ class Context: class DenseBoolArrayAttr(Attribute): @staticmethod def get( - values: List[bool], context: Optional[Context] = None + values: Sequence[bool], context: Optional[Context] = None ) -> DenseBoolArrayAttr: """ Gets a uniqued dense array attribute @@ -1113,7 +1118,7 @@ class DenseElementsAttr(Attribute): class DenseF32ArrayAttr(Attribute): @staticmethod def get( - values: List[float], context: Optional[Context] = None + values: Sequence[float], context: Optional[Context] = None ) -> DenseF32ArrayAttr: """ Gets a uniqued dense array attribute @@ -1141,7 +1146,7 @@ class DenseF32ArrayIterator: class DenseF64ArrayAttr(Attribute): @staticmethod def get( - values: List[float], context: Optional[Context] = None + values: Sequence[float], context: Optional[Context] = None ) -> DenseF64ArrayAttr: """ Gets a uniqued dense array attribute @@ -1167,6 +1172,14 @@ class DenseF64ArrayIterator: def __next__(self) -> float: ... class DenseFPElementsAttr(DenseElementsAttr): + @staticmethod + def get( + array: Buffer, + signless: bool = True, + type: Optional[Type] = None, + shape: Optional[List[int]] = None, + context: Optional[Context] = None, + ) -> DenseFPElementsAttr: ... @staticmethod def isinstance(other: Attribute) -> bool: ... def __getitem__(self, arg0: int) -> float: ... @@ -1180,7 +1193,7 @@ class DenseFPElementsAttr(DenseElementsAttr): class DenseI16ArrayAttr(Attribute): @staticmethod - def get(values: List[int], context: Optional[Context] = None) -> DenseI16ArrayAttr: + def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI16ArrayAttr: """ Gets a uniqued dense array attribute """ @@ -1206,7 +1219,7 @@ class DenseI16ArrayIterator: class DenseI32ArrayAttr(Attribute): @staticmethod - def get(values: List[int], context: Optional[Context] = None) -> DenseI32ArrayAttr: + def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI32ArrayAttr: """ Gets a uniqued dense array attribute """ @@ -1232,7 +1245,7 @@ class DenseI32ArrayIterator: class DenseI64ArrayAttr(Attribute): @staticmethod - def get(values: List[int], context: Optional[Context] = None) -> DenseI64ArrayAttr: + def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI64ArrayAttr: """ Gets a uniqued dense array attribute """ @@ -1258,7 +1271,7 @@ class DenseI64ArrayIterator: class DenseI8ArrayAttr(Attribute): @staticmethod - def get(values: List[int], context: Optional[Context] = None) -> DenseI8ArrayAttr: + def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI8ArrayAttr: """ Gets a uniqued dense array attribute """ @@ -1283,6 +1296,14 @@ class DenseI8ArrayIterator: def __next__(self) -> int: ... class DenseIntElementsAttr(DenseElementsAttr): + @staticmethod + def get( + array: Buffer, + signless: bool = True, + type: Optional[Type] = None, + shape: Optional[List[int]] = None, + context: Optional[Context] = None, + ) -> DenseIntElementsAttr: ... @staticmethod def isinstance(other: Attribute) -> bool: ... def __getitem__(self, arg0: int) -> int: ... @@ -1422,7 +1443,7 @@ class Dialects: def __getitem__(self, arg0: str) -> Dialect: ... class DictAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(value: Dict = {}, context: Optional[Context] = None) -> DictAttr: """ @@ -1453,7 +1474,7 @@ class FloatType(Type): """ class F16Type(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> F16Type: """ @@ -1466,7 +1487,7 @@ class F16Type(FloatType): def typeid(self) -> TypeID: ... class F32Type(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> F32Type: """ @@ -1479,7 +1500,7 @@ class F32Type(FloatType): def typeid(self) -> TypeID: ... class F64Type(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> F64Type: """ @@ -1513,7 +1534,7 @@ class FlatSymbolRefAttr(Attribute): """ class Float8E4M3B11FNUZType(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> Float8E4M3B11FNUZType: """ @@ -1526,7 +1547,7 @@ class Float8E4M3B11FNUZType(FloatType): def typeid(self) -> TypeID: ... class Float8E4M3FNType(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> Float8E4M3FNType: """ @@ -1539,7 +1560,7 @@ class Float8E4M3FNType(FloatType): def typeid(self) -> TypeID: ... class Float8E4M3FNUZType(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> Float8E4M3FNUZType: """ @@ -1552,7 +1573,7 @@ class Float8E4M3FNUZType(FloatType): def typeid(self) -> TypeID: ... class Float8E5M2FNUZType(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> Float8E5M2FNUZType: """ @@ -1565,7 +1586,7 @@ class Float8E5M2FNUZType(FloatType): def typeid(self) -> TypeID: ... class Float8E5M2Type(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> Float8E5M2Type: """ @@ -1578,7 +1599,7 @@ class Float8E5M2Type(FloatType): def typeid(self) -> TypeID: ... class FloatAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(type: Type, value: float, loc: Optional[Location] = None) -> FloatAttr: """ @@ -1612,7 +1633,7 @@ class FloatAttr(Attribute): """ class FloatTF32Type(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> FloatTF32Type: """ @@ -1625,7 +1646,7 @@ class FloatTF32Type(FloatType): def typeid(self) -> TypeID: ... class FunctionType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( inputs: List[Type], results: List[Type], context: Optional[Context] = None @@ -1650,7 +1671,7 @@ class FunctionType(Type): def typeid(self) -> TypeID: ... class IndexType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> IndexType: """ @@ -1766,7 +1787,7 @@ class InsertionPoint: """ class IntegerAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(type: Type, value: int) -> IntegerAttr: """ @@ -1855,7 +1876,7 @@ class IntegerSetConstraintList: def __len__(self) -> int: ... class IntegerType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get_signed(width: int, context: Optional[Context] = None) -> IntegerType: """ @@ -1967,7 +1988,7 @@ class Location: """ class MemRefType(ShapedType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( shape: List[int], @@ -2007,7 +2028,7 @@ class Module: Creates an empty module """ @staticmethod - def parse(asm: str, context: Optional[Context] = None) -> Module: + def parse(asm: str | bytes, context: Optional[Context] = None) -> Module: """ Parses a module's assembly format from a string. @@ -2064,7 +2085,7 @@ class NamedAttribute: """ class NoneType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> NoneType: """ @@ -2130,7 +2151,12 @@ class OpResultList: class OpSuccessors: def __add__(self, arg0: OpSuccessors) -> List[Block]: ... + @overload + def __getitem__(self, arg0: int) -> Block: ... + @overload + def __getitem__(self, arg0: slice) -> OpSuccessors: ... def __setitem__(self, arg0: int, arg1: Block) -> None: ... + def __len__(self) -> int: ... class OpView(_OperationBase): _ODS_OPERAND_SEGMENTS: ClassVar[None] = ... @@ -2154,7 +2180,7 @@ class OpView(_OperationBase): @classmethod def parse( cls: _Type[_TOperation], - source: str, + source: str | bytes, *, source_name: str = "", context: Optional[Context] = None, @@ -2174,7 +2200,7 @@ class OpView(_OperationBase): """ class OpaqueAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( dialect_namespace: str, @@ -2204,7 +2230,7 @@ class OpaqueAttr(Attribute): def typeid(self) -> TypeID: ... class OpaqueType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( dialect_namespace: str, buffer: str, context: Optional[Context] = None @@ -2262,7 +2288,7 @@ class Operation(_OperationBase): """ @staticmethod def parse( - source: str, *, source_name: str = "", context: Optional[Context] = None + source: str | bytes, *, source_name: str = "", context: Optional[Context] = None ) -> Operation: """ Parses an operation. Supports both text assembly format and binary bytecode format. @@ -2290,7 +2316,7 @@ class OperationList: def __len__(self) -> int: ... class RankedTensorType(ShapedType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( shape: List[int], @@ -2443,7 +2469,7 @@ class ShapedTypeComponents: """ class StridedLayoutAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( offset: int, strides: List[int], context: Optional[Context] = None @@ -2477,9 +2503,9 @@ class StridedLayoutAttr(Attribute): def typeid(self) -> TypeID: ... class StringAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod - def get(value: str, context: Optional[Context] = None) -> StringAttr: + def get(value: str | bytes, context: Optional[Context] = None) -> StringAttr: """ Gets a uniqued string attribute """ @@ -2554,9 +2580,9 @@ class SymbolTable: def insert(self, operation: _OperationBase) -> Attribute: ... class TupleType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod - def get_Tuple(elements: List[Type], context: Optional[Context] = None) -> TupleType: + def get_tuple(elements: List[Type], context: Optional[Context] = None) -> TupleType: """ Create a Tuple type """ @@ -2576,7 +2602,7 @@ class TupleType(Type): def typeid(self) -> TypeID: ... class TypeAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(value: Type, context: Optional[Context] = None) -> TypeAttr: """ @@ -2603,7 +2629,7 @@ class TypeID: def _CAPIPtr(self) -> object: ... class UnitAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> UnitAttr: """ @@ -2618,7 +2644,7 @@ class UnitAttr(Attribute): def typeid(self) -> TypeID: ... class UnrankedMemRefType(ShapedType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( element_type: Type, memory_space: Attribute, loc: Optional[Location] = None @@ -2638,7 +2664,7 @@ class UnrankedMemRefType(ShapedType): def typeid(self) -> TypeID: ... class UnrankedTensorType(ShapedType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(element_type: Type, loc: Optional[Location] = None) -> UnrankedTensorType: """ @@ -2651,7 +2677,7 @@ class UnrankedTensorType(ShapedType): def typeid(self) -> TypeID: ... class VectorType(ShapedType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( shape: List[int],