From 395c5f631b86e7b977b15b381797f96f245f4541 Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 15 May 2024 14:07:13 +0200 Subject: [PATCH] Fix Export docstring in CLI (#2058) * fix docs Signed-off-by: Ashwin Vaidya * fix docs Signed-off-by: Ashwin Vaidya * Update src/anomalib/engine/engine.py Co-authored-by: Samet Akcay * lowercase docstring Signed-off-by: Ashwin Vaidya * fix docstring Signed-off-by: Ashwin Vaidya --------- Signed-off-by: Ashwin Vaidya Co-authored-by: Samet Akcay --- src/anomalib/cli/utils/help_formatter.py | 2 ++ src/anomalib/engine/engine.py | 20 ++++++++++++-------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/anomalib/cli/utils/help_formatter.py b/src/anomalib/cli/utils/help_formatter.py index ea4ef825b6..3b6c89b6a2 100644 --- a/src/anomalib/cli/utils/help_formatter.py +++ b/src/anomalib/cli/utils/help_formatter.py @@ -20,6 +20,7 @@ "validate": {"model", "model.help", "data", "data.help", "ckpt_path", "config"}, "test": {"model", "model.help", "data", "data.help", "ckpt_path", "config"}, "predict": {"model", "model.help", "data", "data.help", "ckpt_path", "config"}, + "export": {"model", "model.help", "export_type", "ckpt_path", "config"}, } try: @@ -31,6 +32,7 @@ "validate": Engine.validate, "test": Engine.test, "predict": Engine.predict, + "export": Engine.export, } except ImportError: print("To use other subcommand using `anomalib install`") diff --git a/src/anomalib/engine/engine.py b/src/anomalib/engine/engine.py index 8e7e679650..6dbaa15a10 100644 --- a/src/anomalib/engine/engine.py +++ b/src/anomalib/engine/engine.py @@ -865,14 +865,14 @@ def train( def export( self, model: AnomalyModule, - export_type: ExportType, + export_type: ExportType | str, export_root: str | Path | None = None, input_size: tuple[int, int] | None = None, transform: Transform | None = None, ov_args: dict[str, Any] | None = None, ckpt_path: str | Path | None = None, ) -> Path | None: - """Export the model in PyTorch, ONNX or OpenVINO format. + r"""Export the model in PyTorch, ONNX or OpenVINO format. Args: model (AnomalyModule): Trained model. @@ -882,7 +882,8 @@ def export( input_size (tuple[int, int] | None, optional): A statis input shape for the model, which is exported to ONNX and OpenVINO format. Defaults to None. transform (Transform | None, optional): Input transform to include in the exported model. If not provided, - the engine will try to use the transform from the datamodule or dataset. Defaults to None. + the engine will try to use the default transform from the model. + Defaults to ``None``. ov_args (dict[str, Any] | None, optional): This is optional and used only for OpenVINO's model optimizer. Defaults to None. ckpt_path (str | Path | None): Checkpoint path. If provided, the model will be loaded from this path. @@ -897,22 +898,25 @@ def export( CLI Usage: 1. To export as a torch ``.pt`` file you can run the following command. ```python - anomalib export --model Padim --export_mode TORCH --data MVTec + anomalib export --model Padim --export_mode torch --ckpt_path ``` 2. To export as an ONNX ``.onnx`` file you can run the following command. ```python - anomalib export --model Padim --export_mode ONNX --data Visa --input_size "[256,256]" + anomalib export --model Padim --export_mode onnx --ckpt_path \ + --input_size "[256,256]" ``` 3. To export as an OpenVINO ``.xml`` and ``.bin`` file you can run the following command. ```python - anomalib export --model Padim --export_mode OPENVINO --data Visa --input_size "[256,256]" + anomalib export --model Padim --export_mode openvino --ckpt_path \ + --input_size "[256,256]" ``` 4. You can also overrride OpenVINO model optimizer by adding the ``--ov_args.`` arguments. ```python - anomalib export --model Padim --export_mode OPENVINO --data Visa --input_size "[256,256]" \ - --ov_args.compress_to_fp16 False + anomalib export --model Padim --export_mode openvino --ckpt_path \ + --input_size "[256,256]" --ov_args.compress_to_fp16 False ``` """ + export_type = ExportType(export_type) self._setup_trainer(model) if ckpt_path: ckpt_path = Path(ckpt_path).resolve()