Skip to content

Commit

Permalink
seperate quantize and export_to_edge in builder
Browse files Browse the repository at this point in the history
Summary: Currently export_to_edge includes both applying quantizer and run to_edge, separate them so I can call quantize only in the eval_llama.py

Differential Revision: D57367832
  • Loading branch information
cccclai authored and facebook-github-bot committed May 15, 2024
1 parent aaa2f2e commit faa5d00
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 17 deletions.
37 changes: 23 additions & 14 deletions examples/models/llama2/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(
verbose: bool = False,
):
self.model = model
self.quantized_model = None
self.modelname = modelname
self.weight_type = weight_type
self.dtype = dtype
Expand Down Expand Up @@ -251,35 +252,43 @@ def _get_metadata(self):
self.metadata = metadata
return self.metadata

def export_to_edge(
def pt2e_quantize(
self, quantizers: Optional[List[Quantizer]]
) -> "LlamaEdgeManager":
"""
Export the model to Edge dialect and retrieve a EdgeManager.
Args:
quantizers (Optional[List[Quantizer]]): A list of quantizers.
"""
dynamic_shape = self._get_dynamic_shape()
edge_config = self._get_edge_config()
metadata = self._get_metadata()

# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
m = capture_pre_autograd_graph(
self.model = capture_pre_autograd_graph(
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
)
if quantizers:
if self.verbose:
logging.info(f"Applied quantizers: {quantizers}")
composed_quantizer = ComposableQuantizer(quantizers)
m = prepare_pt2e(m, composed_quantizer)
self.model = prepare_pt2e(self.model, composed_quantizer)
# Calibrate
m(*self.example_inputs)
m = convert_pt2e(m)
DuplicateDynamicQuantChainPass()(m)
self.model(*self.example_inputs)
self.model = convert_pt2e(self.model)
DuplicateDynamicQuantChainPass()(self.model)
return self

def export_to_edge(self) -> "LlamaEdgeManager":
"""
Export the model to Edge dialect and retrieve a EdgeManager.
Args:
quantizers (Optional[List[Quantizer]]): A list of quantizers.
"""
dynamic_shape = self._get_dynamic_shape()
edge_config = self._get_edge_config()
metadata = self._get_metadata()

# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
self.edge_manager = export_to_edge(
m,
self.model,
self.example_inputs,
dynamic_shapes=dynamic_shape,
edge_constant_methods=metadata,
Expand Down
8 changes: 5 additions & 3 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,11 @@ def _export_llama(modelname, args) -> str: # noqa: C901
qnn_quantizer, quant_dtype = get_qnn_quantizer(args)
quantizers.append(qnn_quantizer)

builder_exported_to_edge = _prepare_for_llama_export(
modelname, args
).export_to_edge(quantizers)
builder_quantized = _prepare_for_llama_export(modelname, args).pt2e_quantize(
quantizers
)

builder_exported_to_edge = builder_quantized.export_to_edge()

modelname = builder_exported_to_edge.modelname

Expand Down

0 comments on commit faa5d00

Please sign in to comment.