Skip to content

Commit

Permalink
[fx importer] support fx importer with lower version torch (llvm#3486)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyunqu authored Jun 24, 2024
1 parent fc19709 commit 61f37ae
Showing 1 changed file with 30 additions and 12 deletions.
42 changes: 30 additions & 12 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,17 @@
torch.complex32: "complex<f16>",
torch.complex64: "complex<f32>",
torch.complex128: "complex<f64>",
torch.float8_e5m2: "f8E5M2",
torch.float8_e4m3fn: "f8E4M3FN",
torch.float8_e5m2fnuz: "f8E5M2FNUZ",
torch.float8_e4m3fnuz: "f8E4M3FNUZ",
}
# Type entries added only in torch with higher version
OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE_ASM = {
"float8_e5m2": "f8E5M2",
"float8_e4m3fn": "f8E4M3FN",
"float8_e5m2fnuz": "f8E5M2FNUZ",
"float8_e4m3fnuz": "f8E4M3FNUZ",
}
for dtype_str, dtype_asm in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE_ASM.items():
if hasattr(torch, dtype_str):
TORCH_DTYPE_TO_MLIR_TYPE_ASM[getattr(torch, dtype_str)] = dtype_asm

TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = {
torch.float16: lambda: F16Type.get(),
Expand All @@ -173,11 +179,17 @@
torch.complex32: lambda: ComplexType.get(F16Type.get()),
torch.complex64: lambda: ComplexType.get(F32Type.get()),
torch.complex128: lambda: ComplexType.get(F64Type.get()),
torch.float8_e5m2: lambda: Float8E5M2Type.get(),
torch.float8_e5m2fnuz: lambda: Float8E5M2FNUZType.get(),
torch.float8_e4m3fn: lambda: Float8E4M3FNType.get(),
torch.float8_e4m3fnuz: lambda: Float8E4M3FNUZType.get(),
}
# Type entries added only in torch with higher version
OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE = {
"float8_e5m2": lambda: Float8E5M2Type.get(),
"float8_e4m3fn": lambda: Float8E4M3FNType.get(),
"float8_e5m2fnuz": lambda: Float8E5M2FNUZType.get(),
"float8_e4m3fnuz": lambda: Float8E4M3FNUZType.get(),
}
for dtype_str, mlir_type in OPTIONAL_TORCH_DTYPE_TO_MLIR_TYPE.items():
if hasattr(torch, dtype_str):
TORCH_DTYPE_TO_MLIR_TYPE[getattr(torch, dtype_str)] = mlir_type

TORCH_DTYPE_TO_NPY_TYPE = {
# torch.qint8: None, # no equivalent np datatype
Expand Down Expand Up @@ -215,11 +227,17 @@
# torch.quint8: 13,
# torch.qint32 14
torch.bfloat16: 15,
torch.float8_e5m2: 23,
torch.float8_e4m3fn: 24,
torch.float8_e5m2fnuz: 25,
torch.float8_e4m3fnuz: 26,
}
# Type entries added only in torch with higher version
OPTIONAL_TORCH_DTYPE_TO_INT = {
"float8_e5m2": 23,
"float8_e4m3fn": 24,
"float8_e5m2fnuz": 25,
"float8_e4m3fnuz": 26,
}
for dtype_str, dtype_int in OPTIONAL_TORCH_DTYPE_TO_INT.items():
if hasattr(torch, dtype_str):
TORCH_DTYPE_TO_INT[getattr(torch, dtype_str)] = dtype_int

TORCH_MEMORY_FORMAT_TO_INT = {
torch.contiguous_format: 0,
Expand Down

0 comments on commit 61f37ae

Please sign in to comment.