From 29e5bf6224f15f7e23f5701fa7b2c8294b3fe13d Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Fri, 21 Jun 2024 15:08:26 +0800 Subject: [PATCH 1/3] [fx importer] support fx importer with lower version torch --- python/torch_mlir/extras/fx_importer.py | 39 +++++++++++++++++-------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 2a73325c7d76..7f35903ebd18 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -151,11 +151,16 @@ torch.complex32: "complex", torch.complex64: "complex", torch.complex128: "complex", - torch.float8_e5m2: "f8E5M2", - torch.float8_e4m3fn: "f8E4M3FN", - torch.float8_e5m2fnuz: "f8E5M2FNUZ", - torch.float8_e4m3fnuz: "f8E4M3FNUZ", } +HIGH_VERSION_TORCH_DTYPE_TO_MLIR_TYPE_ASM = { + "float8_e5m2": "f8E5M2", + "float8_e4m3fn": "f8E4M3FN", + "float8_e5m2fnuz": "f8E5M2FNUZ", + "float8_e4m3fnuz": "f8E4M3FNUZ", +} +for type_str, asm in HIGH_VERSION_TORCH_DTYPE_TO_MLIR_TYPE_ASM.getitems(): + if hasattr(torch, type_str): + TORCH_DTYPE_TO_MLIR_TYPE_ASM[getattr(torch, type_str)] = asm TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = { torch.float16: lambda: F16Type.get(), @@ -173,11 +178,16 @@ torch.complex32: lambda: ComplexType.get(F16Type.get()), torch.complex64: lambda: ComplexType.get(F32Type.get()), torch.complex128: lambda: ComplexType.get(F64Type.get()), - torch.float8_e5m2: lambda: Float8E5M2Type.get(), - torch.float8_e5m2fnuz: lambda: Float8E5M2FNUZType.get(), - torch.float8_e4m3fn: lambda: Float8E4M3FNType.get(), - torch.float8_e4m3fnuz: lambda: Float8E4M3FNUZType.get(), } +HIGH_VERSION_TORCH_DTYPE_TO_MLIR_TYPE = { + "float8_e5m2": lambda: Float8E5M2Type.get(), + "float8_e4m3fn": lambda: Float8E4M3FNType.get(), + "float8_e5m2fnuz": lambda: Float8E5M2FNUZType.get(), + "float8_e4m3fnuz": lambda: Float8E4M3FNUZType.get(), +} +for type_str, type in HIGH_VERSION_TORCH_DTYPE_TO_MLIR_TYPE.getitems(): + if hasattr(torch, type_str): + TORCH_DTYPE_TO_MLIR_TYPE[getattr(torch, type_str)] = type TORCH_DTYPE_TO_NPY_TYPE = { # torch.qint8: None, # no equivalent np datatype @@ -215,11 +225,16 @@ # torch.quint8: 13, # torch.qint32 14 torch.bfloat16: 15, - torch.float8_e5m2: 23, - torch.float8_e4m3fn: 24, - torch.float8_e5m2fnuz: 25, - torch.float8_e4m3fnuz: 26, } +HIGH_VERSION_TORCH_DTYPE_TO_INT = { + "float8_e5m2": 23, + "float8_e4m3fn": 24, + "float8_e5m2fnuz": 25, + "float8_e4m3fnuz": 26, +} +for type_str, type_int in HIGH_VERSION_TORCH_DTYPE_TO_INT.getitems(): + if hasattr(torch, type_str): + TORCH_DTYPE_TO_INT[getattr(torch, type_str)] = type_int TORCH_MEMORY_FORMAT_TO_INT = { torch.contiguous_format: 0, From 2e66dabfdda4775ef89b858ec6e01d51de656034 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Fri, 21 Jun 2024 15:35:34 +0800 Subject: [PATCH 2/3] fix --- python/torch_mlir/extras/fx_importer.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 7f35903ebd18..303ae637c8f7 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -152,13 +152,14 @@ torch.complex64: "complex", torch.complex128: "complex", } -HIGH_VERSION_TORCH_DTYPE_TO_MLIR_TYPE_ASM = { +# Type entries added only in torch with higher version +OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE_ASM = { "float8_e5m2": "f8E5M2", "float8_e4m3fn": "f8E4M3FN", "float8_e5m2fnuz": "f8E5M2FNUZ", "float8_e4m3fnuz": "f8E4M3FNUZ", } -for type_str, asm in HIGH_VERSION_TORCH_DTYPE_TO_MLIR_TYPE_ASM.getitems(): +for type_str, asm in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE_ASM.items(): if hasattr(torch, type_str): TORCH_DTYPE_TO_MLIR_TYPE_ASM[getattr(torch, type_str)] = asm @@ -179,13 +180,14 @@ torch.complex64: lambda: ComplexType.get(F32Type.get()), torch.complex128: lambda: ComplexType.get(F64Type.get()), } -HIGH_VERSION_TORCH_DTYPE_TO_MLIR_TYPE = { +# Type entries added only in torch with higher version +OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE = { "float8_e5m2": lambda: Float8E5M2Type.get(), "float8_e4m3fn": lambda: Float8E4M3FNType.get(), "float8_e5m2fnuz": lambda: Float8E5M2FNUZType.get(), "float8_e4m3fnuz": lambda: Float8E4M3FNUZType.get(), } -for type_str, type in HIGH_VERSION_TORCH_DTYPE_TO_MLIR_TYPE.getitems(): +for type_str, type in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE.items(): if hasattr(torch, type_str): TORCH_DTYPE_TO_MLIR_TYPE[getattr(torch, type_str)] = type @@ -226,13 +228,14 @@ # torch.qint32 14 torch.bfloat16: 15, } -HIGH_VERSION_TORCH_DTYPE_TO_INT = { +# Type entries added only in torch with higher version +OPTIONAL_TORCH_DTYPE_TO_INT = { "float8_e5m2": 23, "float8_e4m3fn": 24, "float8_e5m2fnuz": 25, "float8_e4m3fnuz": 26, } -for type_str, type_int in HIGH_VERSION_TORCH_DTYPE_TO_INT.getitems(): +for type_str, type_int in OPTIONAL_TORCH_DTYPE_TO_INT.items(): if hasattr(torch, type_str): TORCH_DTYPE_TO_INT[getattr(torch, type_str)] = type_int From 86c12ea4c7c7f0c279baf52401f8dfbd7ae3b7f2 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Mon, 24 Jun 2024 15:28:50 +0800 Subject: [PATCH 3/3] fix --- python/torch_mlir/extras/fx_importer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 303ae637c8f7..cb86406c55fd 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -159,9 +159,9 @@ "float8_e5m2fnuz": "f8E5M2FNUZ", "float8_e4m3fnuz": "f8E4M3FNUZ", } -for type_str, asm in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE_ASM.items(): - if hasattr(torch, type_str): - TORCH_DTYPE_TO_MLIR_TYPE_ASM[getattr(torch, type_str)] = asm +for dtype_str, dtype_asm in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE_ASM.items(): + if hasattr(torch, dtype_str): + TORCH_DTYPE_TO_MLIR_TYPE_ASM[getattr(torch, dtype_str)] = dtype_asm TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = { torch.float16: lambda: F16Type.get(), @@ -187,9 +187,9 @@ "float8_e5m2fnuz": lambda: Float8E5M2FNUZType.get(), "float8_e4m3fnuz": lambda: Float8E4M3FNUZType.get(), } -for type_str, type in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE.items(): - if hasattr(torch, type_str): - TORCH_DTYPE_TO_MLIR_TYPE[getattr(torch, type_str)] = type +for dtype_str, mlir_type in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE.items(): + if hasattr(torch, dtype_str): + TORCH_DTYPE_TO_MLIR_TYPE[getattr(torch, dtype_str)] = mlir_type TORCH_DTYPE_TO_NPY_TYPE = { # torch.qint8: None, # no equivalent np datatype @@ -235,9 +235,9 @@ "float8_e5m2fnuz": 25, "float8_e4m3fnuz": 26, } -for type_str, type_int in OPTIONAL_TORCH_DTYPE_TO_INT.items(): - if hasattr(torch, type_str): - TORCH_DTYPE_TO_INT[getattr(torch, type_str)] = type_int +for dtype_str, dtype_int in OPTIONAL_TORCH_DTYPE_TO_INT.items(): + if hasattr(torch, dtype_str): + TORCH_DTYPE_TO_INT[getattr(torch, dtype_str)] = dtype_int TORCH_MEMORY_FORMAT_TO_INT = { torch.contiguous_format: 0,