diff --git a/examples/models/__init__.py b/examples/models/__init__.py index d82b015365..a64686b239 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -18,6 +18,7 @@ "llama2": ("llama2", "Llama2Model"), "mobilebert": ("mobilebert", "MobileBertModelExample"), "mv2": ("mobilenet_v2", "MV2Model"), + "mv2_untrained": ("mobilenet_v2", "MV2UntrainedModel"), "mv3": ("mobilenet_v3", "MV3Model"), "vit": ("torchvision_vit", "TorchVisionViTModel"), "w2l": ("wav2letter", "Wav2LetterModel"), diff --git a/examples/models/mobilenet_v2/__init__.py b/examples/models/mobilenet_v2/__init__.py index 5225511102..ee1235f81e 100644 --- a/examples/models/mobilenet_v2/__init__.py +++ b/examples/models/mobilenet_v2/__init__.py @@ -4,8 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .model import MV2Model +from .model import MV2Model, MV2UntrainedModel __all__ = [ MV2Model, + MV2UntrainedModel, ] diff --git a/examples/models/mobilenet_v2/model.py b/examples/models/mobilenet_v2/model.py index f55c148acc..feb3579151 100644 --- a/examples/models/mobilenet_v2/model.py +++ b/examples/models/mobilenet_v2/model.py @@ -29,3 +29,17 @@ def get_eager_model(self) -> torch.nn.Module: def get_example_inputs(self): tensor_size = (1, 3, 224, 224) return (torch.randn(tensor_size),) + + +class MV2UntrainedModel(EagerModelBase): + def __init__(self): + pass + + def get_eager_model(self) -> torch.nn.Module: + # pyre-ignore + mv2 = mobilenet_v2() + return mv2 + + def get_example_inputs(self): + tensor_size = (1, 3, 224, 224) + return (torch.randn(tensor_size),) diff --git a/examples/xnnpack/aot_compiler.py b/examples/xnnpack/aot_compiler.py index b22f90234b..3c7b4f273e 100644 --- a/examples/xnnpack/aot_compiler.py +++ b/examples/xnnpack/aot_compiler.py @@ -56,6 +56,7 @@ required=False, help="Generate and save an ETRecord to the given file location", ) + parser.add_argument("-o", "--output_dir", default=".", help="output directory") args = parser.parse_args() @@ -110,4 +111,4 @@ quant_tag = "q8" if args.quantize else "fp32" model_name = f"{args.model_name}_xnnpack_{quant_tag}" - save_pte_program(exec_prog.buffer, model_name) + save_pte_program(exec_prog.buffer, model_name, args.output_dir)