From 6da38a30fad08ad301f2f58c079eedb8a82f07ed Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Wed, 13 Dec 2023 11:54:52 +0100 Subject: [PATCH] Add XPU mixed precision plugin for lightning (#2714) * Update anomaly XPU integration * Update strategy and accelerator * Cleanup in strategy * Fix mypy * remove XPU callback * Add XPU mixed precision lightning training * Fix linters * Handle default plugins value --- .../adapters/anomalib/plugins/__init__.py | 7 ++ .../anomalib/plugins/xpu_precision.py | 109 ++++++++++++++++++ src/otx/algorithms/anomaly/tasks/train.py | 10 +- 3 files changed, 125 insertions(+), 1 deletion(-) create mode 100644 src/otx/algorithms/anomaly/adapters/anomalib/plugins/__init__.py create mode 100644 src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py diff --git a/src/otx/algorithms/anomaly/adapters/anomalib/plugins/__init__.py b/src/otx/algorithms/anomaly/adapters/anomalib/plugins/__init__.py new file mode 100644 index 00000000000..df24d838d85 --- /dev/null +++ b/src/otx/algorithms/anomaly/adapters/anomalib/plugins/__init__.py @@ -0,0 +1,7 @@ +"""Plugin for mixed-precision training on XPU.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .xpu_precision import MixedPrecisionXPUPlugin + +__all__ = ["MixedPrecisionXPUPlugin"] diff --git a/src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py b/src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py new file mode 100644 index 00000000000..bfd9f5d3b93 --- /dev/null +++ b/src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py @@ -0,0 +1,109 @@ +"""Plugin for mixed-precision training on XPU.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + + +from contextlib import contextmanager +from typing import Any, Callable, Dict, Generator, Optional, Union + +import pytorch_lightning as pl +import torch +from lightning_fabric.utilities.types import Optimizable +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities import GradClipAlgorithmType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import Tensor +from torch.optim import LBFGS, Optimizer + + +class MixedPrecisionXPUPlugin(PrecisionPlugin): + """Plugin for Automatic Mixed Precision (AMP) training with ``torch.xpu.autocast``. + + Args: + scaler: An optional :class:`torch.cuda.amp.GradScaler` to use. + """ + + def __init__(self, scaler: Optional[Any] = None) -> None: + self.scaler = scaler + + def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: + """Apply grad scaler before backward.""" + if self.scaler is not None: + tensor = self.scaler.scale(tensor) + return super().pre_backward(tensor, module) + + def optimizer_step( # type: ignore[override] + self, + optimizer: Optimizable, + model: "pl.LightningModule", + optimizer_idx: int, + closure: Callable[[], Any], + **kwargs: Any, + ) -> Any: + """Make an optimizer step using scaler if it was passed.""" + if self.scaler is None: + # skip scaler logic, as bfloat16 does not require scaler + return super().optimizer_step( + optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs + ) + if isinstance(optimizer, LBFGS): + raise MisconfigurationException( + f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." + ) + closure_result = closure() + + if not _optimizer_handles_unscaling(optimizer): + # Unscaling needs to be performed here in case we are going to apply gradient clipping. + # Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam). + # Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook. + self.scaler.unscale_(optimizer) + + self._after_closure(model, optimizer, optimizer_idx) + skipped_backward = closure_result is None + # in manual optimization, the closure does not return a value + if not model.automatic_optimization or not skipped_backward: + # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found + step_output = self.scaler.step(optimizer, **kwargs) + self.scaler.update() + return step_output + return closure_result + + def clip_gradients( + self, + optimizer: Optimizer, + clip_val: Union[int, float] = 0.0, + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + ) -> None: + """Handle grad clipping with scaler.""" + if clip_val > 0 and _optimizer_handles_unscaling(optimizer): + raise RuntimeError( + f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping" + " because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?" + ) + super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm) + + @contextmanager + def forward_context(self) -> Generator[None, None, None]: + """Enable autocast context.""" + with torch.xpu.autocast(True): + yield + + def state_dict(self) -> Dict[str, Any]: + """Returns state dict of the plugin.""" + if self.scaler is not None: + return self.scaler.state_dict() + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Loads state dict to the plugin.""" + if self.scaler is not None: + self.scaler.load_state_dict(state_dict) + + +def _optimizer_handles_unscaling(optimizer: Any) -> bool: + """Determines if a PyTorch optimizer handles unscaling gradients in the step method ratherthan through the scaler. + + Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return + value will only be reliable for built-in PyTorch optimizers. + """ + return getattr(optimizer, "_step_supports_amp_scaling", False) diff --git a/src/otx/algorithms/anomaly/tasks/train.py b/src/otx/algorithms/anomaly/tasks/train.py index 34d5af57a34..67af58a944e 100644 --- a/src/otx/algorithms/anomaly/tasks/train.py +++ b/src/otx/algorithms/anomaly/tasks/train.py @@ -29,6 +29,7 @@ from otx.algorithms.anomaly.adapters.anomalib.callbacks import ProgressCallback from otx.algorithms.anomaly.adapters.anomalib.data import OTXAnomalyDataModule +from otx.algorithms.anomaly.adapters.anomalib.plugins.xpu_precision import MixedPrecisionXPUPlugin from otx.algorithms.common.utils.utils import is_xpu_available from otx.api.entities.datasets import DatasetEntity from otx.api.entities.model import ModelEntity @@ -89,11 +90,18 @@ def train( ), ] + plugins = [] + if config.trainer.plugins is not None: + plugins.extend(config.trainer.plugins) + config.trainer.pop("plugins") + if is_xpu_available(): config.trainer.strategy = "xpu_single" config.trainer.accelerator = "xpu" + if config.trainer.precision == 16: + plugins.append(MixedPrecisionXPUPlugin()) - self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks) + self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks, plugins=plugins) self.trainer.fit(model=self.model, datamodule=datamodule) self.save_model(output_model)