From 628b280af5ae9b85c037a35075f0829ec3ddb998 Mon Sep 17 00:00:00 2001 From: helunwencser Date: Tue, 23 Jul 2024 13:03:55 -0700 Subject: [PATCH] update phi-3-mini readme doc (#4377) Summary: Update phi-3-mini readme to reflect latest changes. It also fallbacks to use `torch.export._trace._export` due to issue https://github.com/pytorch/pytorch/issues/128394 Pull Request resolved: https://github.com/pytorch/executorch/pull/4377 Reviewed By: kirklandsign Differential Revision: D60130144 Pulled By: helunwencser fbshipit-source-id: 1d8d9a3791b877f43ad9312b4e704d5ca6d7f69e --- examples/models/phi-3-mini/README.md | 4 +- .../models/phi-3-mini/export_phi-3-mini.py | 75 ++++++++++++------- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/examples/models/phi-3-mini/README.md b/examples/models/phi-3-mini/README.md index af7ae912ea..a1ae78941b 100644 --- a/examples/models/phi-3-mini/README.md +++ b/examples/models/phi-3-mini/README.md @@ -4,8 +4,6 @@ This example demonstrates how to run a [Phi-3-mini](https://huggingface.co/micro # Instructions ## Step 1: Setup 1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. For installation run `./install_requirements.sh --pybind xnnpack` -2. Phi-3 Mini-128K-Instruct has been integrated in the development version (4.41.0.dev0) of transformers. Make sure that you install transformers with version at least 4.41.0: `pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers` - ## Step 2: Prepare and run the model 1. Download the `tokenizer.model` from HuggingFace. @@ -15,7 +13,7 @@ wget -O tokenizer.model https://huggingface.co/microsoft/Phi-3-mini-128k-instruc ``` 2. Export the model. This step will take a few minutes to finish. ``` -python export_model.py +python3 export_phi-3-mini.py ``` 3. Build and run the runner. ``` diff --git a/examples/models/phi-3-mini/export_phi-3-mini.py b/examples/models/phi-3-mini/export_phi-3-mini.py index 02d818368b..cb20a36510 100644 --- a/examples/models/phi-3-mini/export_phi-3-mini.py +++ b/examples/models/phi-3-mini/export_phi-3-mini.py @@ -5,47 +5,66 @@ # LICENSE file in the root directory of this source tree. import torch -from executorch.extension.llm.export.builder import DType, LLMEdgeManager -from executorch.extension.llm.export.partitioner_lib import get_xnnpack_partitioner -from executorch.extension.llm.export.quantizer_lib import ( - DynamicQuantLinearOptions, - get_pt2e_quantizers, - PT2EQuantOptions, +from executorch.backends.transforms.duplicate_dynamic_quant_chain import ( + DuplicateDynamicQuantChainPass, +) +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner +from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config +from executorch.exir import to_edge +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, ) from transformers import Phi3ForCausalLM def main() -> None: - torch.manual_seed(42) + torch.manual_seed(0) # pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `Phi3ForCausalLM` model = Phi3ForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct") - modelname = "phi-3-mini" - - ( - LLMEdgeManager( - model=model, - modelname=modelname, - max_seq_len=128, - dtype=DType.fp32, - use_kv_cache=False, - example_inputs=(torch.randint(0, 100, (1, 100), dtype=torch.long),), - enable_dynamic_shape=True, - verbose=True, + example_inputs = (torch.randint(0, 100, (1, 100), dtype=torch.long),) + dynamic_shape = {"input_ids": {1: torch.export.Dim("sequence_length", max=128)}} + + xnnpack_quant_config = get_symmetric_quantization_config( + is_per_channel=True, is_dynamic=True + ) + xnnpack_quantizer = XNNPACKQuantizer() + xnnpack_quantizer.set_global(xnnpack_quant_config) + + with torch.nn.attention.sdpa_kernel( + [torch.nn.attention.SDPBackend.MATH] + ), torch.no_grad(): + model = capture_pre_autograd_graph( + model, example_inputs, dynamic_shapes=dynamic_shape ) - .set_output_dir(".") - .capture_pre_autograd_graph() - .pt2e_quantize( - get_pt2e_quantizers(PT2EQuantOptions(None, DynamicQuantLinearOptions())) + model = prepare_pt2e(model, xnnpack_quantizer) + model(*example_inputs) + model = convert_pt2e(model, fold_quantize=False) + DuplicateDynamicQuantChainPass()(model) + # TODO(lunwenh): update it to use export once + # https://github.com/pytorch/pytorch/issues/128394 is resolved. + model = torch.export._trace._export( + model, + example_inputs, + dynamic_shapes=dynamic_shape, + strict=False, + pre_dispatch=False, ) - .export_to_edge() - .to_backend([get_xnnpack_partitioner()]) - .to_executorch() - .save_to_pte(f"{modelname}.pte") - ) + + edge_config = get_xnnpack_edge_compile_config() + edge_manager = to_edge(model, compile_config=edge_config) + edge_manager = edge_manager.to_backend(XnnpackPartitioner(has_dynamic_shapes=True)) + et_program = edge_manager.to_executorch() + + with open("phi-3-mini.pte", "wb") as file: + file.write(et_program.buffer) if __name__ == "__main__":