Skip to content

Commit

Permalink
Refine xpu callback
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Nov 30, 2023
1 parent 98ffd5d commit 6d84a5d
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions src/otx/algorithms/anomaly/adapters/anomalib/callbacks/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ def __init__(self, device_idx=0):

def on_fit_start(self, trainer, pl_module):
"""Applies IPEX optimization before training."""
if is_xpu_available():
pl_module.to(self.device)
model, optimizer = torch.xpu.optimize(trainer.model, optimizer=trainer.optimizers[0], dtype=torch.float32)
trainer.optimizers = [optimizer]
trainer.model = model
pl_module.to(self.device)
model, optimizer = torch.xpu.optimize(trainer.model, optimizer=trainer.optimizers[0])
trainer.optimizers = [optimizer]
trainer.model = model

def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
"""Moves train batch tensors to XPU."""
Expand Down

0 comments on commit 6d84a5d

Please sign in to comment.