diff --git a/src/otx/algorithms/anomaly/adapters/anomalib/callbacks/__init__.py b/src/otx/algorithms/anomaly/adapters/anomalib/callbacks/__init__.py index 95822fd7712..85054363f31 100644 --- a/src/otx/algorithms/anomaly/adapters/anomalib/callbacks/__init__.py +++ b/src/otx/algorithms/anomaly/adapters/anomalib/callbacks/__init__.py @@ -16,5 +16,6 @@ from .inference import AnomalyInferenceCallback from .progress import ProgressCallback +from .xpu import XPUCallback -__all__ = ["AnomalyInferenceCallback", "ProgressCallback"] +__all__ = ["AnomalyInferenceCallback", "ProgressCallback", "XPUCallback"] diff --git a/src/otx/algorithms/anomaly/adapters/anomalib/callbacks/xpu.py b/src/otx/algorithms/anomaly/adapters/anomalib/callbacks/xpu.py new file mode 100644 index 00000000000..461696a1528 --- /dev/null +++ b/src/otx/algorithms/anomaly/adapters/anomalib/callbacks/xpu.py @@ -0,0 +1,36 @@ +"""Anomaly XPU device callback.""" +# Copyright (C) 2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +import torch +from pytorch_lightning import Callback + + +class XPUCallback(Callback): + """XPU device callback. + + Applies IPEX optimization before training, moves data to XPU. + """ + + def __init__(self, device_idx=0): + self.device = torch.device(f"xpu:{device_idx}") + + def on_fit_start(self, trainer, pl_module): + """Applies IPEX optimization before training.""" + pl_module.to(self.device) + model, optimizer = torch.xpu.optimize(trainer.model, optimizer=trainer.optimizers[0]) + trainer.optimizers = [optimizer] + trainer.model = model + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + """Moves train batch tensors to XPU.""" + for k in batch: + if not isinstance(batch[k], list): + batch[k] = batch[k].to(self.device) + + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + """Moves validation batch tensors to XPU.""" + for k in batch: + if not isinstance(batch[k], list): + batch[k] = batch[k].to(self.device) diff --git a/src/otx/algorithms/anomaly/tasks/train.py b/src/otx/algorithms/anomaly/tasks/train.py index 9e2f57f249f..8016157e2a6 100644 --- a/src/otx/algorithms/anomaly/tasks/train.py +++ b/src/otx/algorithms/anomaly/tasks/train.py @@ -28,7 +28,9 @@ from pytorch_lightning import Trainer, seed_everything from otx.algorithms.anomaly.adapters.anomalib.callbacks import ProgressCallback +from otx.algorithms.anomaly.adapters.anomalib.callbacks.xpu import XPUCallback from otx.algorithms.anomaly.adapters.anomalib.data import OTXAnomalyDataModule +from otx.algorithms.common.utils.utils import is_xpu_available from otx.api.entities.datasets import DatasetEntity from otx.api.entities.model import ModelEntity from otx.api.entities.train_parameters import TrainParameters @@ -88,6 +90,9 @@ def train( ), ] + if is_xpu_available(): + callbacks.append(XPUCallback()) + self.trainer = Trainer(**config.trainer, logger=False, callbacks=callbacks) self.trainer.fit(model=self.model, datamodule=datamodule)