Skip to content

Commit

Permalink
Add XPU mixed precision plugin for lightning (#2714)
Browse files Browse the repository at this point in the history
* Update anomaly XPU integration

* Update strategy and accelerator

* Cleanup in strategy

* Fix mypy

* remove XPU callback

* Add XPU mixed precision lightning training

* Fix linters

* Handle default plugins value
  • Loading branch information
sovrasov authored Dec 13, 2023
1 parent 0062179 commit 6da38a3
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Plugin for mixed-precision training on XPU."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .xpu_precision import MixedPrecisionXPUPlugin

__all__ = ["MixedPrecisionXPUPlugin"]
109 changes: 109 additions & 0 deletions src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Plugin for mixed-precision training on XPU."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, Optional, Union

import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.types import Optimizable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor
from torch.optim import LBFGS, Optimizer


class MixedPrecisionXPUPlugin(PrecisionPlugin):
"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.xpu.autocast``.
Args:
scaler: An optional :class:`torch.cuda.amp.GradScaler` to use.
"""

def __init__(self, scaler: Optional[Any] = None) -> None:
self.scaler = scaler

def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor:
"""Apply grad scaler before backward."""
if self.scaler is not None:
tensor = self.scaler.scale(tensor)
return super().pre_backward(tensor, module)

def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
model: "pl.LightningModule",
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> Any:
"""Make an optimizer step using scaler if it was passed."""
if self.scaler is None:
# skip scaler logic, as bfloat16 does not require scaler
return super().optimizer_step(
optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs
)
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
)
closure_result = closure()

if not _optimizer_handles_unscaling(optimizer):
# Unscaling needs to be performed here in case we are going to apply gradient clipping.
# Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam).
# Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook.
self.scaler.unscale_(optimizer)

self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
step_output = self.scaler.step(optimizer, **kwargs)
self.scaler.update()
return step_output
return closure_result

def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""Handle grad clipping with scaler."""
if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
raise RuntimeError(
f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
)
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""Enable autocast context."""
with torch.xpu.autocast(True):
yield

def state_dict(self) -> Dict[str, Any]:
"""Returns state dict of the plugin."""
if self.scaler is not None:
return self.scaler.state_dict()
return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Loads state dict to the plugin."""
if self.scaler is not None:
self.scaler.load_state_dict(state_dict)


def _optimizer_handles_unscaling(optimizer: Any) -> bool:
"""Determines if a PyTorch optimizer handles unscaling gradients in the step method ratherthan through the scaler.
Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return
value will only be reliable for built-in PyTorch optimizers.
"""
return getattr(optimizer, "_step_supports_amp_scaling", False)
10 changes: 9 additions & 1 deletion src/otx/algorithms/anomaly/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from otx.algorithms.anomaly.adapters.anomalib.callbacks import ProgressCallback
from otx.algorithms.anomaly.adapters.anomalib.data import OTXAnomalyDataModule
from otx.algorithms.anomaly.adapters.anomalib.plugins.xpu_precision import MixedPrecisionXPUPlugin
from otx.algorithms.common.utils.utils import is_xpu_available
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.model import ModelEntity
Expand Down Expand Up @@ -89,11 +90,18 @@ def train(
),
]

plugins = []
if config.trainer.plugins is not None:
plugins.extend(config.trainer.plugins)
config.trainer.pop("plugins")

if is_xpu_available():
config.trainer.strategy = "xpu_single"
config.trainer.accelerator = "xpu"
if config.trainer.precision == 16:
plugins.append(MixedPrecisionXPUPlugin())

self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks)
self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks, plugins=plugins)
self.trainer.fit(model=self.model, datamodule=datamodule)

self.save_model(output_model)
Expand Down

0 comments on commit 6da38a3

Please sign in to comment.