Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for XPU device in PrefetchLoader #3

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918))
- Added support for floating-point slicing in `Dataset`, *e.g.*, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915))
- Added nightly GPU tests ([#7895](https://github.com/pyg-team/pytorch_geometric/pull/7895))
- Added the `HalfHop` graph upsampling augmentation ([#7827](https://github.com/pyg-team/pytorch_geometric/pull/7827))
Expand Down
75 changes: 57 additions & 18 deletions torch_geometric/loader/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
from torch.utils.data import DataLoader
from torch_geometric.typing import WITH_IPEX


class PrefetchLoader:
Expand All @@ -15,51 +16,89 @@ class PrefetchLoader:
device (torch.device, optional): The device to load the data to.
(default: :obj:`None`)
"""
class PrefetchLoaderDevice:
def __init__(self, device: Optional[torch.device] = None):
cuda_present = torch.cuda.is_available()
xpu_present = torch.xpu.is_available() if WITH_IPEX else False

if device is None:
if cuda_present:
device = 'cuda'
elif xpu_present:
device = 'xpu'
else:
device = 'cpu'

self.device = torch.device(device)

if ((self.device.type == 'cuda' and not cuda_present) or
(self.device.type == 'xpu' and not xpu_present)):
print(f'Requested device[{self.device.type}] is not available '
'- fallback to CPU')
self.device = torch.device('cpu')

self.is_gpu = self.device.type in ['cuda', 'xpu']
self.stream = None
self.stream_context = nullcontext

if self.is_gpu:
gpu_module = torch.cuda if self.device.type == 'cuda' else torch.xpu
else:
gpu_module = None

self.gpu_module = gpu_module

def maybe_init_stream(self) -> None:
if self.is_gpu:
self.stream = self.gpu_module.Stream()
self.stream_context = partial(self.gpu_module.stream, stream=self.stream)

def maybe_wait_stream(self) -> None:
if self.stream is not None:
self.gpu_module.current_stream().wait_stream(self.stream)

def get_device(self) -> torch.device:
return self.device

def get_stream_context(self) -> Any:
return self.stream_context()


def __init__(
self,
loader: DataLoader,
device: Optional[torch.device] = None,
):
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

self.loader = loader
self.device = torch.device(device)

self.is_cuda = torch.cuda.is_available() and self.device.type == 'cuda'
self.device_mgr = self.PrefetchLoaderDevice(device)

def non_blocking_transfer(self, batch: Any) -> Any:
if not self.is_cuda:
if not self.device_mgr.is_gpu:
return batch
if isinstance(batch, (list, tuple)):
return [self.non_blocking_transfer(v) for v in batch]
if isinstance(batch, dict):
return {k: self.non_blocking_transfer(v) for k, v in batch.items()}

batch = batch.pin_memory()
return batch.to(self.device, non_blocking=True)
device = self.device_mgr.get_device()
batch = batch.pin_memory(device)
return batch.to(device, non_blocking=True)

def __iter__(self) -> Any:
first = True
if self.is_cuda:
stream = torch.cuda.Stream()
stream_context = partial(torch.cuda.stream, stream=stream)
else:
stream = None
stream_context = nullcontext
self.device_mgr.maybe_init_stream()

for next_batch in self.loader:

with stream_context():
with self.device_mgr.get_stream_context():
next_batch = self.non_blocking_transfer(next_batch)

if not first:
yield batch # noqa
else:
first = False

if stream is not None:
torch.cuda.current_stream().wait_stream(stream)
self.device_mgr.maybe_wait_stream()

batch = next_batch

Expand Down
10 changes: 10 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,16 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor,
layout: Optional[str] = None) -> SparseTensor:
raise ImportError("'masked_select_nnz' requires 'torch-sparse'")

try:
import intel_extension_for_pytorch # noqa
WITH_IPEX = True
except (ImportError, OSError) as e:
if isinstance(e, OSError):
warnings.warn("An issue occurred while importing"
"'intel-extension-for-pytorch'. "
f"Disabling its usage. Stacktrace: {e}")
WITH_IPEX = False


class MockTorchCSCTensor:
def __init__(
Expand Down
Loading