From 3556065288a7f87b590c2229e4fb676f6b6dc1cb Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 7 Feb 2024 16:55:42 +0000 Subject: [PATCH 1/4] build: manually update PyTorch version Set PyTorch and TorchVision version to nightly release 2024-02-07. Signed-Off By: Vivek Khandelwal --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- test/python/fx_importer/basic_test.py | 18 ++++++------------ torchvision-requirements.txt | 2 +- 4 files changed, 9 insertions(+), 15 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 16be42d6c147..d78b0d0694d4 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -72fcb9ad662bb941a266e3d747835382634c2be6 +3cbc8e89fd09b0ffb4914187b438f15c121e2302 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 1de47ff9a195..540e78dccd49 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.3.0.dev20240122 +torch==2.3.0.dev20240207 diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 36c554862506..2b11ff30c605 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -13,6 +13,7 @@ from torch_mlir import fx +torch.manual_seed(0) def run(f): print(f"{f.__name__}") @@ -23,20 +24,13 @@ def run(f): @run # CHECK-LABEL: test_import_frozen_exported_program -# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> -# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32> -# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32> -# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> -# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]] -# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]] -# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]] +# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,4],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,1],f32>, %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG2]] +# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[ARG0]] +# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[ARG1]] +# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<0.568431258> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> # CHECK-DAG: %[[mul_p:.+]] = torch.aten.mul.Tensor %[[mul_b]], %[[p]] # CHECK: return %[[mul_p]] -# -# Validate dialect resources exist. -# CHECK: dialect_resources: -# CHECK-DAG: torch_tensor_1_4_torch.float32 -# CHECK-DAG: torch_tensor_3_1_torch.float32 def test_import_frozen_exported_program(): # Tests the basic structural premises of import_frozen_exported_program, # namely that free tensors (buffers) and parameters are treated as diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index fad713123493..4f775c549c6c 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.18.0.dev20240122 +torchvision==0.18.0.dev20240207 From a45390701093d34d38cdbfa2e2dc14064fb45c37 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Tue, 13 Feb 2024 05:50:12 -0800 Subject: [PATCH 2/4] bring basic_test.py up to current HEAD --- test/python/fx_importer/basic_test.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 2b11ff30c605..36c554862506 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -13,7 +13,6 @@ from torch_mlir import fx -torch.manual_seed(0) def run(f): print(f"{f.__name__}") @@ -24,13 +23,20 @@ def run(f): @run # CHECK-LABEL: test_import_frozen_exported_program -# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[1,4],f32>, %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,1],f32>, %[[ARG2:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> -# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG2]] -# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[ARG0]] -# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[ARG1]] -# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<0.568431258> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32> +# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32> +# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]] +# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]] +# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]] # CHECK-DAG: %[[mul_p:.+]] = torch.aten.mul.Tensor %[[mul_b]], %[[p]] # CHECK: return %[[mul_p]] +# +# Validate dialect resources exist. +# CHECK: dialect_resources: +# CHECK-DAG: torch_tensor_1_4_torch.float32 +# CHECK-DAG: torch_tensor_3_1_torch.float32 def test_import_frozen_exported_program(): # Tests the basic structural premises of import_frozen_exported_program, # namely that free tensors (buffers) and parameters are treated as From f9d63a790fc52b276655ec0a81e0cc1f1d8441b9 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Tue, 13 Feb 2024 11:43:19 -0800 Subject: [PATCH 3/4] fixes to adapt to https://github.com/pytorch/pytorch/pull/118969 --- python/torch_mlir/extras/fx_importer.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 5328e8730cc3..67e5cc40322c 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -341,6 +341,7 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram): """ sig = prog.graph_signature state_dict = prog.state_dict + constants = prog.constants arg_replacements: dict[str, Any] = {} # Lift buffers. for input_name, state_name in sig.inputs_to_buffers.items(): @@ -350,6 +351,14 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram): raise AssertionError("Could not find state mapping for buffer") from e arg_replacements[input_name] = state_value + # Lift tensor constants. + for input_name, state_name in sig.inputs_to_lifted_tensor_constants.items(): + try: + state_value = constants[state_name] + except KeyError as e: + raise AssertionError("Could not find state mapping for tensor constants") from e + arg_replacements[input_name] = state_value + # Lift parameters. for input_name, state_name in sig.inputs_to_parameters.items(): try: From 06a893e5e4a528ba0b345da8f4f099cec0cf14ad Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Tue, 13 Feb 2024 12:10:40 -0800 Subject: [PATCH 4/4] make the constants vs state_dict for buffers conditional --- python/torch_mlir/extras/fx_importer.py | 34 ++++++++++++++----------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 67e5cc40322c..3bd7fcddca6c 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -341,23 +341,27 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram): """ sig = prog.graph_signature state_dict = prog.state_dict - constants = prog.constants arg_replacements: dict[str, Any] = {} - # Lift buffers. - for input_name, state_name in sig.inputs_to_buffers.items(): - try: - state_value = state_dict[state_name] - except KeyError as e: - raise AssertionError("Could not find state mapping for buffer") from e - arg_replacements[input_name] = state_value - # Lift tensor constants. - for input_name, state_name in sig.inputs_to_lifted_tensor_constants.items(): - try: - state_value = constants[state_name] - except KeyError as e: - raise AssertionError("Could not find state mapping for tensor constants") from e - arg_replacements[input_name] = state_value + # If there is no "constants" attribute, consult the "state_dict". Otherwise, only look + # at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969 + if hasattr(prog, "constants"): + constants = prog.constants + # Lift tensor constants. + for input_name, state_name in sig.inputs_to_lifted_tensor_constants.items(): + try: + state_value = constants[state_name] + except KeyError as e: + raise AssertionError("Could not find state mapping for tensor constants") from e + arg_replacements[input_name] = state_value + else: + # Lift buffers. + for input_name, state_name in sig.inputs_to_buffers.items(): + try: + state_value = state_dict[state_name] + except KeyError as e: + raise AssertionError("Could not find state mapping for buffer") from e + arg_replacements[input_name] = state_value # Lift parameters. for input_name, state_name in sig.inputs_to_parameters.items():