Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XPU mixed precision plugin for lightning #2714

Merged
merged 9 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Plugin for mixed-precision training on XPU."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .xpu_precision import MixedPrecisionXPUPlugin

__all__ = ["MixedPrecisionXPUPlugin"]
109 changes: 109 additions & 0 deletions src/otx/algorithms/anomaly/adapters/anomalib/plugins/xpu_precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Plugin for mixed-precision training on XPU."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, Optional, Union

import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.types import Optimizable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor
from torch.optim import LBFGS, Optimizer


class MixedPrecisionXPUPlugin(PrecisionPlugin):
eunwoosh marked this conversation as resolved.
Show resolved Hide resolved
"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.xpu.autocast``.

Args:
scaler: An optional :class:`torch.cuda.amp.GradScaler` to use.
"""

def __init__(self, scaler: Optional[Any] = None) -> None:
self.scaler = scaler

def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor:
"""Apply grad scaler before backward."""
if self.scaler is not None:
tensor = self.scaler.scale(tensor)
return super().pre_backward(tensor, module)

def optimizer_step( # type: ignore[override]
self,
optimizer: Optimizable,
model: "pl.LightningModule",
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> Any:
"""Make an optimizer step using scaler if it was passed."""
if self.scaler is None:
# skip scaler logic, as bfloat16 does not require scaler
return super().optimizer_step(
optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs
)
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
)
closure_result = closure()

if not _optimizer_handles_unscaling(optimizer):
# Unscaling needs to be performed here in case we are going to apply gradient clipping.
# Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam).
# Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook.
self.scaler.unscale_(optimizer)

self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
step_output = self.scaler.step(optimizer, **kwargs)
self.scaler.update()
return step_output
return closure_result

def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""Handle grad clipping with scaler."""
if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
raise RuntimeError(
f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
)
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""Enable autocast context."""
with torch.xpu.autocast(True):
yield

def state_dict(self) -> Dict[str, Any]:
"""Returns state dict of the plugin."""
if self.scaler is not None:
return self.scaler.state_dict()
return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Loads state dict to the plugin."""
if self.scaler is not None:
self.scaler.load_state_dict(state_dict)


def _optimizer_handles_unscaling(optimizer: Any) -> bool:
"""Determines if a PyTorch optimizer handles unscaling gradients in the step method ratherthan through the scaler.

Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return
value will only be reliable for built-in PyTorch optimizers.
"""
return getattr(optimizer, "_step_supports_amp_scaling", False)
10 changes: 9 additions & 1 deletion src/otx/algorithms/anomaly/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from otx.algorithms.anomaly.adapters.anomalib.callbacks import ProgressCallback
from otx.algorithms.anomaly.adapters.anomalib.data import OTXAnomalyDataModule
from otx.algorithms.anomaly.adapters.anomalib.plugins.xpu_precision import MixedPrecisionXPUPlugin
from otx.algorithms.common.utils.utils import is_xpu_available
from otx.api.entities.datasets import DatasetEntity
from otx.api.entities.model import ModelEntity
Expand Down Expand Up @@ -89,11 +90,18 @@ def train(
),
]

plugins = []
if config.trainer.plugins is not None:
plugins.extend(config.trainer.plugins)
config.trainer.pop("plugins")

if is_xpu_available():
config.trainer.strategy = "xpu_single"
config.trainer.accelerator = "xpu"
if config.trainer.precision == 16:
plugins.append(MixedPrecisionXPUPlugin())

self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks)
self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks, plugins=plugins)
self.trainer.fit(model=self.model, datamodule=datamodule)

self.save_model(output_model)
Expand Down