diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 36a5b30855..338b90b26d 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Callable, Tuple +from typing import Any, Tuple import torch @@ -16,25 +16,20 @@ def export_program( - model: Callable, + model: torch.nn.Module, inputs: Any, - pt2_quant: bool = False, ) -> ExportedProgram: - # we don't support training mode. Make it eval - if hasattr(model, "eval"): - if pt2_quant: - # pyre-fixme[6]: Incompatible parameter type. - torch.ao.quantization.move_exported_model_to_eval(model) - else: - # pyre-fixme[16]: Anonymous callable has no attribute `eval`. - model.eval() - - # if it's already an ExportedProgram, just return it - if isinstance(model, ExportedProgram): - return model - assert isinstance(model, torch.nn.Module), "model should be an nn.Module" + # If the model is already a GraphModule (most likely from quantization), call the + # suggested torch.ao.quantization API instead, which only does dropout and batchnorm. + if isinstance(model, torch.fx.GraphModule): + torch.ao.quantization.move_exported_model_to_eval(model) + else: + # We don't support training mode. Make it eval + if hasattr(model, "eval"): + model.eval() + # Prevent mkldnn decompositions torch._C._set_mkldnn_enabled(False) @@ -44,13 +39,12 @@ def export_program( # Export the model and lower it it edge IR. def export_to_edge( - model: Callable, + model: torch.nn.Module, inputs: Any, - pt2_quant: bool = False, dump_graphs: bool = False, ) -> Tuple[EdgeProgramManager, ExportedProgram]: # Export the model into an ExportedProgram. - expo_program = export_program(model, inputs, pt2_quant) + expo_program = export_program(model, inputs) if dump_graphs: logging.info(f"Exported graph:\n{expo_program.graph_module.graph}") diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index bf96de2afd..4eab801e7c 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -69,9 +69,7 @@ def export_model( QuantFusion(patterns)(converted_model) # Get edge program (note: the name will change to export_to_cadence in future PRs) - edge_prog_manager, expo_prog = export_to_edge( - converted_model, example_inputs, pt2_quant=True - ) + edge_prog_manager, expo_prog = export_to_edge(converted_model, example_inputs) # Run a couple required passes for quant/dequant ops cadence_prog_manager = edge_prog_manager.transform(