Skip to content

Commit

Permalink
Remove pt2_quant flag
Browse files Browse the repository at this point in the history
Summary:
It's been on the to-do list for a while to clean that up.

It's only used in `export_program` to properly put the model in eval mode. Now that we only allow `nn.Module`, there are only two cases: `nn.Module`, which will have `eval()`, and `GraphModule`, which can use `torch.ao.quantization.move_exported_model_to_eval`, which we already called before with the `pt2_quant` flag.

Now that the flag is not needed, remove it everywhere!

We also promote the `quantize_and_export_program` function to `__init__.py` as a compiler API, because it can be quite useful.

Differential Revision: D57491621
  • Loading branch information
mcremon-meta authored and facebook-github-bot committed May 20, 2024
1 parent 0f21c66 commit 553b669
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
14 changes: 7 additions & 7 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@


def export_program(
model: Callable,
model: torch.nn.Module,
inputs: Any,
pt2_quant: bool = False,
) -> ExportedProgram:
# we don't support training mode. Make it eval
# We don't support training mode. Make it eval
if hasattr(model, "eval"):
if pt2_quant:
# pyre-fixme[6]: Incompatible parameter type.
# If the model is already a GraphModule (most likely from quantization),
# it can't call eval. Call the suggested torch.ao.quantization API instead,
# which only does dropout and batchnorm.
if isinstance(model, torch.fx.GraphModule):
torch.ao.quantization.move_exported_model_to_eval(model)
else:
# pyre-fixme[16]: Anonymous callable has no attribute `eval`.
model.eval()

# if it's already an ExportedProgram, just return it
Expand All @@ -46,11 +47,10 @@ def export_program(
def export_to_edge(
model: Callable,
inputs: Any,
pt2_quant: bool = False,
dump_graphs: bool = False,
) -> Tuple[EdgeProgramManager, ExportedProgram]:
# Export the model into an ExportedProgram.
expo_program = export_program(model, inputs, pt2_quant)
expo_program = export_program(model, inputs)

if dump_graphs:
logging.info(f"Exported graph:\n{expo_program.graph_module.graph}")
Expand Down
4 changes: 1 addition & 3 deletions backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def export_model(
QuantFusion(patterns)(converted_model)

# Get edge program (note: the name will change to export_to_cadence in future PRs)
edge_prog_manager, expo_prog = export_to_edge(
converted_model, example_inputs, pt2_quant=True
)
edge_prog_manager, expo_prog = export_to_edge(converted_model, example_inputs)

# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
Expand Down

0 comments on commit 553b669

Please sign in to comment.