diff --git a/examples/models/llama2/builder.py b/examples/models/llama2/builder.py index 98fc9bee239..b6087c1d3af 100644 --- a/examples/models/llama2/builder.py +++ b/examples/models/llama2/builder.py @@ -142,6 +142,7 @@ def __init__( verbose: bool = False, ): self.model = model + self.pre_autograd_graph_module: Optional[torch.fx.GraphModule] = None self.modelname = modelname self.weight_type = weight_type self.dtype = dtype @@ -251,25 +252,27 @@ def _get_metadata(self): self.metadata = metadata return self.metadata - def export_to_edge( + def pt2e_quantize( self, quantizers: Optional[List[Quantizer]] ) -> "LlamaEdgeManager": """ - Export the model to Edge dialect and retrieve a EdgeManager. + Quantize the model via pt2e flow and retrieve LlamaEdgeManager including the quantized model. Args: quantizers (Optional[List[Quantizer]]): A list of quantizers. """ + assert ( + self.edge_manager is None + ), "export_to_edge is already called, please call pt2e_quantize before export_to_edge" + logging.info(f"Using pt2e {quantizers} to quantizing the model...") dynamic_shape = self._get_dynamic_shape() - edge_config = self._get_edge_config() - metadata = self._get_metadata() # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) - with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): - m = capture_pre_autograd_graph( - self.model, self.example_inputs, dynamic_shapes=dynamic_shape - ) - if quantizers: + if quantizers: + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + m = capture_pre_autograd_graph( + self.model, self.example_inputs, dynamic_shapes=dynamic_shape + ) if self.verbose: logging.info(f"Applied quantizers: {quantizers}") composed_quantizer = ComposableQuantizer(quantizers) @@ -278,8 +281,29 @@ def export_to_edge( m(*self.example_inputs) m = convert_pt2e(m) DuplicateDynamicQuantChainPass()(m) + self.pre_autograd_graph_module = m + return self + else: + logging.info("No quantizer provided, passing...") + return self + + def export_to_edge(self) -> "LlamaEdgeManager": + """ + Export the model to Edge dialect and retrieve a LlamaEdgeManager. + """ + dynamic_shape = self._get_dynamic_shape() + edge_config = self._get_edge_config() + metadata = self._get_metadata() + + # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing + # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up) + with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad(): + if self.pre_autograd_graph_module is None: + self.pre_autograd_graph_module = capture_pre_autograd_graph( + self.model, self.example_inputs, dynamic_shapes=dynamic_shape + ) self.edge_manager = export_to_edge( - m, + self.pre_autograd_graph_module, self.example_inputs, dynamic_shapes=dynamic_shape, edge_constant_methods=metadata, diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index f4e8043af7f..964a78a9e85 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -378,9 +378,11 @@ def _export_llama(modelname, args) -> str: # noqa: C901 qnn_quantizer, quant_dtype = get_qnn_quantizer(args) quantizers.append(qnn_quantizer) - builder_exported_to_edge = _prepare_for_llama_export( - modelname, args - ).export_to_edge(quantizers) + builder_exported_to_edge = ( + _prepare_for_llama_export(modelname, args) + .pt2e_quantize(quantizers) + .export_to_edge() + ) modelname = builder_exported_to_edge.modelname