Skip to content

Commit

Permalink
Update anomaly XPU integration (#2697)
Browse files Browse the repository at this point in the history
* Update anomaly XPU integration

* Update strategy and accelerator

* Cleanup in strategy

* Fix mypy

* remove XPU callback
  • Loading branch information
sovrasov authored Dec 8, 2023
1 parent ae090ba commit 0062179
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 40 deletions.
4 changes: 4 additions & 0 deletions src/otx/algorithms/anomaly/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.


from .anomalib.accelerators.xpu import XPUAccelerator # noqa: F401
from .anomalib.strategies import SingleXPUStrategy # noqa: F401
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Lightning accelerator for XPU device."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from .xpu import XPUAccelerator

__all__ = ["XPUAccelerator"]
60 changes: 60 additions & 0 deletions src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Lightning accelerator for XPU device."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from typing import Any, Dict, Union

import torch
from pytorch_lightning.accelerators import AcceleratorRegistry
from pytorch_lightning.accelerators.accelerator import Accelerator

from otx.algorithms.common.utils.utils import is_xpu_available


class XPUAccelerator(Accelerator):
"""Support for a XPU, optimized for large-scale machine learning."""

accelerator_name = "xpu"

def setup_device(self, device: torch.device) -> None:
"""Sets up the specified device."""
if device.type != "xpu":
raise RuntimeError(f"Device should be xpu, got {device} instead")

torch.xpu.set_device(device)

@staticmethod
def parse_devices(devices: Any) -> Any:
"""Parses devices for multi-GPU training."""
if isinstance(devices, list):
return devices
return [devices]

@staticmethod
def get_parallel_devices(devices: Any) -> Any:
"""Generates a list of parrallel devices."""
return [torch.device("xpu", idx) for idx in devices]

@staticmethod
def auto_device_count() -> int:
"""Returns number of XPU devices available."""
return torch.xpu.device_count()

@staticmethod
def is_available() -> bool:
"""Checks if XPU available."""
return is_xpu_available()

def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
"""Returns XPU devices stats."""
return {}

def teardown(self) -> None:
"""Cleans-up XPU-related resources."""
pass


AcceleratorRegistry.register(
XPUAccelerator.accelerator_name, XPUAccelerator, description="Accelerator supports XPU devices"
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,5 @@

from .inference import AnomalyInferenceCallback
from .progress import ProgressCallback
from .xpu import XPUCallback

__all__ = ["AnomalyInferenceCallback", "ProgressCallback", "XPUCallback"]
__all__ = ["AnomalyInferenceCallback", "ProgressCallback"]
36 changes: 0 additions & 36 deletions src/otx/algorithms/anomaly/adapters/anomalib/callbacks/xpu.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Lightning strategy for single XPU device."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from .xpu_single import SingleXPUStrategy

__all__ = ["SingleXPUStrategy"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Lightning strategy for single XPU devic."""
# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from typing import Optional

import pytorch_lightning as pl
import torch
from lightning_fabric.plugins import CheckpointIO
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies import StrategyRegistry
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from otx.algorithms.common.utils.utils import is_xpu_available


class SingleXPUStrategy(SingleDeviceStrategy):
"""Strategy for training on single XPU device."""

strategy_name = "xpu_single"

def __init__(
self,
device: _DEVICE = "xpu:0",
accelerator: Optional["pl.accelerators.Accelerator"] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
):

if not is_xpu_available():
raise MisconfigurationException("`SingleXPUStrategy` requires XPU devices to run")

super().__init__(
accelerator=accelerator,
device=device,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)

@property
def is_distributed(self) -> bool:
"""Returns true if the strategy supports distributed training."""
return False

def setup_optimizers(self, trainer: "pl.Trainer") -> None:
"""Sets up optimizers."""
super().setup_optimizers(trainer)
if len(self.optimizers) != 1: # type: ignore
raise RuntimeError("XPU strategy doesn't support multiple optimizers")
model, optimizer = torch.xpu.optimize(trainer.model, optimizer=self.optimizers[0]) # type: ignore
self.optimizers = [optimizer]
trainer.model = model


StrategyRegistry.register(
SingleXPUStrategy.strategy_name, SingleXPUStrategy, description="Strategy that enables training on single XPU"
)
4 changes: 2 additions & 2 deletions src/otx/algorithms/anomaly/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from pytorch_lightning import Trainer, seed_everything

from otx.algorithms.anomaly.adapters.anomalib.callbacks import ProgressCallback
from otx.algorithms.anomaly.adapters.anomalib.callbacks.xpu import XPUCallback
from otx.algorithms.anomaly.adapters.anomalib.data import OTXAnomalyDataModule
from otx.algorithms.common.utils.utils import is_xpu_available
from otx.api.entities.datasets import DatasetEntity
Expand Down Expand Up @@ -91,7 +90,8 @@ def train(
]

if is_xpu_available():
callbacks.append(XPUCallback())
config.trainer.strategy = "xpu_single"
config.trainer.accelerator = "xpu"

self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks)
self.trainer.fit(model=self.model, datamodule=datamodule)
Expand Down

0 comments on commit 0062179

Please sign in to comment.