-
Notifications
You must be signed in to change notification settings - Fork 446
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update anomaly XPU integration (#2697)
* Update anomaly XPU integration * Update strategy and accelerator * Cleanup in strategy * Fix mypy * remove XPU callback
- Loading branch information
Showing
8 changed files
with
143 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 8 additions & 0 deletions
8
src/otx/algorithms/anomaly/adapters/anomalib/accelerators/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Lightning accelerator for XPU device.""" | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
|
||
from .xpu import XPUAccelerator | ||
|
||
__all__ = ["XPUAccelerator"] |
60 changes: 60 additions & 0 deletions
60
src/otx/algorithms/anomaly/adapters/anomalib/accelerators/xpu.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""Lightning accelerator for XPU device.""" | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
|
||
from typing import Any, Dict, Union | ||
|
||
import torch | ||
from pytorch_lightning.accelerators import AcceleratorRegistry | ||
from pytorch_lightning.accelerators.accelerator import Accelerator | ||
|
||
from otx.algorithms.common.utils.utils import is_xpu_available | ||
|
||
|
||
class XPUAccelerator(Accelerator): | ||
"""Support for a XPU, optimized for large-scale machine learning.""" | ||
|
||
accelerator_name = "xpu" | ||
|
||
def setup_device(self, device: torch.device) -> None: | ||
"""Sets up the specified device.""" | ||
if device.type != "xpu": | ||
raise RuntimeError(f"Device should be xpu, got {device} instead") | ||
|
||
torch.xpu.set_device(device) | ||
|
||
@staticmethod | ||
def parse_devices(devices: Any) -> Any: | ||
"""Parses devices for multi-GPU training.""" | ||
if isinstance(devices, list): | ||
return devices | ||
return [devices] | ||
|
||
@staticmethod | ||
def get_parallel_devices(devices: Any) -> Any: | ||
"""Generates a list of parrallel devices.""" | ||
return [torch.device("xpu", idx) for idx in devices] | ||
|
||
@staticmethod | ||
def auto_device_count() -> int: | ||
"""Returns number of XPU devices available.""" | ||
return torch.xpu.device_count() | ||
|
||
@staticmethod | ||
def is_available() -> bool: | ||
"""Checks if XPU available.""" | ||
return is_xpu_available() | ||
|
||
def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]: | ||
"""Returns XPU devices stats.""" | ||
return {} | ||
|
||
def teardown(self) -> None: | ||
"""Cleans-up XPU-related resources.""" | ||
pass | ||
|
||
|
||
AcceleratorRegistry.register( | ||
XPUAccelerator.accelerator_name, XPUAccelerator, description="Accelerator supports XPU devices" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 0 additions & 36 deletions
36
src/otx/algorithms/anomaly/adapters/anomalib/callbacks/xpu.py
This file was deleted.
Oops, something went wrong.
8 changes: 8 additions & 0 deletions
8
src/otx/algorithms/anomaly/adapters/anomalib/strategies/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Lightning strategy for single XPU device.""" | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
|
||
from .xpu_single import SingleXPUStrategy | ||
|
||
__all__ = ["SingleXPUStrategy"] |
60 changes: 60 additions & 0 deletions
60
src/otx/algorithms/anomaly/adapters/anomalib/strategies/xpu_single.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""Lightning strategy for single XPU devic.""" | ||
# Copyright (C) 2023 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
# | ||
|
||
from typing import Optional | ||
|
||
import pytorch_lightning as pl | ||
import torch | ||
from lightning_fabric.plugins import CheckpointIO | ||
from lightning_fabric.utilities.types import _DEVICE | ||
from pytorch_lightning.plugins.precision import PrecisionPlugin | ||
from pytorch_lightning.strategies import StrategyRegistry | ||
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
from otx.algorithms.common.utils.utils import is_xpu_available | ||
|
||
|
||
class SingleXPUStrategy(SingleDeviceStrategy): | ||
"""Strategy for training on single XPU device.""" | ||
|
||
strategy_name = "xpu_single" | ||
|
||
def __init__( | ||
self, | ||
device: _DEVICE = "xpu:0", | ||
accelerator: Optional["pl.accelerators.Accelerator"] = None, | ||
checkpoint_io: Optional[CheckpointIO] = None, | ||
precision_plugin: Optional[PrecisionPlugin] = None, | ||
): | ||
|
||
if not is_xpu_available(): | ||
raise MisconfigurationException("`SingleXPUStrategy` requires XPU devices to run") | ||
|
||
super().__init__( | ||
accelerator=accelerator, | ||
device=device, | ||
checkpoint_io=checkpoint_io, | ||
precision_plugin=precision_plugin, | ||
) | ||
|
||
@property | ||
def is_distributed(self) -> bool: | ||
"""Returns true if the strategy supports distributed training.""" | ||
return False | ||
|
||
def setup_optimizers(self, trainer: "pl.Trainer") -> None: | ||
"""Sets up optimizers.""" | ||
super().setup_optimizers(trainer) | ||
if len(self.optimizers) != 1: # type: ignore | ||
raise RuntimeError("XPU strategy doesn't support multiple optimizers") | ||
model, optimizer = torch.xpu.optimize(trainer.model, optimizer=self.optimizers[0]) # type: ignore | ||
self.optimizers = [optimizer] | ||
trainer.model = model | ||
|
||
|
||
StrategyRegistry.register( | ||
SingleXPUStrategy.strategy_name, SingleXPUStrategy, description="Strategy that enables training on single XPU" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters