-
Notifications
You must be signed in to change notification settings - Fork 443
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add XPU mixed precision plugin for lightning (#2714)
* Update anomaly XPU integration * Update strategy and accelerator * Cleanup in strategy * Fix mypy * remove XPU callback * Add XPU mixed precision lightning training * Fix linters * Handle default plugins value
- Loading branch information
Showing
3 changed files
with
125 additions
and
1 deletion.
There are no files selected for viewing
7 changes: 7 additions & 0 deletions
7
src/otx/algorithms/anomaly/adapters/anomalib/plugins/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""Plugin for mixed-precision training on XPU.""" | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from .xpu_precision import MixedPrecisionXPUPlugin | ||
|
||
__all__ = ["MixedPrecisionXPUPlugin"] |
109 changes: 109 additions & 0 deletions
109
src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
"""Plugin for mixed-precision training on XPU.""" | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from contextlib import contextmanager | ||
from typing import Any, Callable, Dict, Generator, Optional, Union | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
from lightning_fabric.utilities.types import Optimizable | ||
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin | ||
from pytorch_lightning.utilities import GradClipAlgorithmType | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from torch import Tensor | ||
from torch.optim import LBFGS, Optimizer | ||
|
||
|
||
class MixedPrecisionXPUPlugin(PrecisionPlugin): | ||
"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.xpu.autocast``. | ||
Args: | ||
scaler: An optional :class:`torch.cuda.amp.GradScaler` to use. | ||
""" | ||
|
||
def __init__(self, scaler: Optional[Any] = None) -> None: | ||
self.scaler = scaler | ||
|
||
def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor: | ||
"""Apply grad scaler before backward.""" | ||
if self.scaler is not None: | ||
tensor = self.scaler.scale(tensor) | ||
return super().pre_backward(tensor, module) | ||
|
||
def optimizer_step( # type: ignore[override] | ||
self, | ||
optimizer: Optimizable, | ||
model: "pl.LightningModule", | ||
optimizer_idx: int, | ||
closure: Callable[[], Any], | ||
**kwargs: Any, | ||
) -> Any: | ||
"""Make an optimizer step using scaler if it was passed.""" | ||
if self.scaler is None: | ||
# skip scaler logic, as bfloat16 does not require scaler | ||
return super().optimizer_step( | ||
optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs | ||
) | ||
if isinstance(optimizer, LBFGS): | ||
raise MisconfigurationException( | ||
f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." | ||
) | ||
closure_result = closure() | ||
|
||
if not _optimizer_handles_unscaling(optimizer): | ||
# Unscaling needs to be performed here in case we are going to apply gradient clipping. | ||
# Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam). | ||
# Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook. | ||
self.scaler.unscale_(optimizer) | ||
|
||
self._after_closure(model, optimizer, optimizer_idx) | ||
skipped_backward = closure_result is None | ||
# in manual optimization, the closure does not return a value | ||
if not model.automatic_optimization or not skipped_backward: | ||
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found | ||
step_output = self.scaler.step(optimizer, **kwargs) | ||
self.scaler.update() | ||
return step_output | ||
return closure_result | ||
|
||
def clip_gradients( | ||
self, | ||
optimizer: Optimizer, | ||
clip_val: Union[int, float] = 0.0, | ||
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, | ||
) -> None: | ||
"""Handle grad clipping with scaler.""" | ||
if clip_val > 0 and _optimizer_handles_unscaling(optimizer): | ||
raise RuntimeError( | ||
f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping" | ||
" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?" | ||
) | ||
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm) | ||
|
||
@contextmanager | ||
def forward_context(self) -> Generator[None, None, None]: | ||
"""Enable autocast context.""" | ||
with torch.xpu.autocast(True): | ||
yield | ||
|
||
def state_dict(self) -> Dict[str, Any]: | ||
"""Returns state dict of the plugin.""" | ||
if self.scaler is not None: | ||
return self.scaler.state_dict() | ||
return {} | ||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
"""Loads state dict to the plugin.""" | ||
if self.scaler is not None: | ||
self.scaler.load_state_dict(state_dict) | ||
|
||
|
||
def _optimizer_handles_unscaling(optimizer: Any) -> bool: | ||
"""Determines if a PyTorch optimizer handles unscaling gradients in the step method ratherthan through the scaler. | ||
Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return | ||
value will only be reliable for built-in PyTorch optimizers. | ||
""" | ||
return getattr(optimizer, "_step_supports_amp_scaling", False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters