From d81fd1c96e3321f18237f0f08e6d776eb0d814cd Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Thu, 17 Aug 2023 12:21:48 +0200 Subject: [PATCH] Add support for XPU device in PrefetchLoader --- torch_geometric/loader/prefetch.py | 77 ++++++++++++++++++++++-------- torch_geometric/typing.py | 10 ++++ 2 files changed, 68 insertions(+), 19 deletions(-) diff --git a/torch_geometric/loader/prefetch.py b/torch_geometric/loader/prefetch.py index 3bbfd69c69785..005898141015b 100644 --- a/torch_geometric/loader/prefetch.py +++ b/torch_geometric/loader/prefetch.py @@ -4,6 +4,7 @@ import torch from torch.utils.data import DataLoader +from torch_geometric.typing import WITH_IPEX class PrefetchLoader: @@ -15,42 +16,81 @@ class PrefetchLoader: device (torch.device, optional): The device to load the data to. (default: :obj:`None`) """ + class PrefetchLoaderDevice: + def __init__(self, device: Optional[torch.device] = None): + cuda_present = torch.cuda.is_available() + xpu_present = torch.xpu.is_available() if WITH_IPEX else False + + if device is None: + if cuda_present: + device = 'cuda' + elif xpu_present: + device = 'xpu' + else: + device = 'cpu' + + self.device = torch.device(device) + + if ((self.device.type == 'cuda' and not cuda_present) or + (self.device.type == 'xpu' and not xpu_present)): + print(f'Requested device[{self.device.type}] is not available ' + '- fallback to CPU') + self.device = torch.device('cpu') + + self.is_gpu = self.device.type in ['cuda', 'xpu'] + self.stream = None + self.stream_context = nullcontext + + if self.is_gpu: + gpu_module = torch.cuda if self.device.type == 'cuda' else torch.xpu + else: + gpu_module = None + + self.gpu_module = gpu_module + + def maybe_init_stream(self) -> None: + if self.is_gpu: + self.stream = self.gpu_module.Stream() + self.stream_context = partial(self.gpu_module.stream, stream=self.stream) + + def maybe_wait_stream(self) -> None: + if self.stream is not None: + self.gpu_module.current_stream().wait_stream(self.stream) + + def get_device(self) -> torch.device: + return self.device + + def get_stream_context(self) -> Any: + return self.stream_context() + + def __init__( self, loader: DataLoader, device: Optional[torch.device] = None, ): - if device is None: - device = 'cuda' if torch.cuda.is_available() else 'cpu' - self.loader = loader - self.device = torch.device(device) - - self.is_cuda = torch.cuda.is_available() and self.device.type == 'cuda' + self.device_mgr = self.PrefetchLoaderDevice(device) def non_blocking_transfer(self, batch: Any) -> Any: - if not self.is_cuda: + if not self.device_mgr.is_gpu: return batch if isinstance(batch, (list, tuple)): return [self.non_blocking_transfer(v) for v in batch] if isinstance(batch, dict): return {k: self.non_blocking_transfer(v) for k, v in batch.items()} - batch = batch.pin_memory() - return batch.to(self.device, non_blocking=True) + device = self.device_mgr.get_device() + batch = batch.pin_memory(device) + return batch.to(device, non_blocking=True) def __iter__(self) -> Any: first = True - if self.is_cuda: - stream = torch.cuda.Stream() - stream_context = partial(torch.cuda.stream, stream=stream) - else: - stream = None - stream_context = nullcontext + self.device_mgr.maybe_init_stream() for next_batch in self.loader: - with stream_context(): + with self.device_mgr.get_stream_context(): next_batch = self.non_blocking_transfer(next_batch) if not first: @@ -58,8 +98,7 @@ def __iter__(self) -> Any: else: first = False - if stream is not None: - torch.cuda.current_stream().wait_stream(stream) + self.device_mgr.maybe_wait_stream() batch = next_batch @@ -69,4 +108,4 @@ def __len__(self) -> int: return len(self.loader) def __repr__(self) -> str: - return f'{self.__class__.__name__}({self.loader})' + return f'{self.__class__.__name__}({self.loader})' \ No newline at end of file diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index 1a6f1a61dbf01..b1714e7529811 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -188,6 +188,16 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor, layout: Optional[str] = None) -> SparseTensor: raise ImportError("'masked_select_nnz' requires 'torch-sparse'") +try: + import intel_extension_for_pytorch # noqa + WITH_IPEX = True +except (ImportError, OSError) as e: + if isinstance(e, OSError): + warnings.warn("An issue occurred while importing" + "'intel-extension-for-pytorch'. " + f"Disabling its usage. Stacktrace: {e}") + WITH_IPEX = False + class MockTorchCSCTensor: def __init__(