Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse tensor support in full batch mode #6843

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for sparse tensor in full batch mode inference benchmark ([#6843](https://github.com/pyg-team/pytorch_geometric/pull/6843))
- Enabled `NeighborLoader` to return number of sampled nodes and edges per hop ([#6834](https://github.com/pyg-team/pytorch_geometric/pull/6834))
- Added `ZipLoader` to execute multiple `NodeLoader` or `LinkLoader` instances ([#6829](https://github.com/pyg-team/pytorch_geometric/issues/6829))
- Added common `utils.select` and `utils.narrow` functionality to support filtering of both tensors and lists ([#6162](https://github.com/pyg-team/pytorch_geometric/issues/6162))
Expand Down
23 changes: 18 additions & 5 deletions benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

import torch

from benchmark.utils import emit_itt, get_dataset, get_model, get_split_masks
from benchmark.utils import (
emit_itt,
get_dataset_with_transformation,
get_model,
get_split_masks,
)
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import PNAConv
from torch_geometric.profile import rename_profile_file, timeit, torch_profile
Expand All @@ -18,7 +23,11 @@
@torch.no_grad()
def full_batch_inference(model, data):
model.eval()
return model(data.x, data.edge_index)
if hasattr(data, 'adj_t'):
edge_index = data.adj_t
else:
edge_index = data.edge_index
return model(data.x, edge_index)


def test(y, loader):
Expand All @@ -41,9 +50,10 @@ def run(args: argparse.ArgumentParser):
print(f'Dataset: {dataset_name}')
load_time = timeit() if args.measure_load_time else nullcontext()
with load_time:
dataset, num_classes = get_dataset(dataset_name, args.root,
args.use_sparse_tensor,
args.bf16)
result = get_dataset_with_transformation(dataset_name, args.root,
args.use_sparse_tensor,
args.bf16)
dataset, num_classes, transformation = result
data = dataset.to(device)
hetero = True if dataset_name == 'ogbn-mag' else False
mask = ('paper', None) if dataset_name == 'ogbn-mag' else None
Expand Down Expand Up @@ -162,6 +172,9 @@ def run(args: argparse.ArgumentParser):
itt = emit_itt(
) if args.vtune_profile else nullcontext()

if args.full_batch and args.use_sparse_tensor:
data = transformation(data)

with cpu_affinity, amp, timeit() as time:
for _ in range(args.warmup):
if args.full_batch:
Expand Down
3 changes: 2 additions & 1 deletion benchmark/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .utils import emit_itt
from .utils import get_dataset
from .utils import get_dataset, get_dataset_with_transformation
from .utils import get_model
from .utils import get_split_masks

__all__ = [
'emit_itt',
'get_dataset',
'get_dataset_with_transformation',
'get_model',
'get_split_masks',
]
11 changes: 9 additions & 2 deletions benchmark/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def emit_itt(*args, **kwargs):
}


def get_dataset(name, root, use_sparse_tensor=False, bf16=False):
def get_dataset_with_transformation(name, root, use_sparse_tensor=False,
bf16=False):
path = osp.join(osp.dirname(osp.realpath(__file__)), root, name)
transform = T.ToSparseTensor(
remove_edge_index=False) if use_sparse_tensor else None
Expand Down Expand Up @@ -68,7 +69,13 @@ def get_dataset(name, root, use_sparse_tensor=False, bf16=False):
if bf16:
data.x = data.x.to(torch.bfloat16)

return data, dataset.num_classes
return data, dataset.num_classes, transform


def get_dataset(name, root, use_sparse_tensor=False, bf16=False):
data, num_classes, _ = get_dataset_with_transformation(
name, root, use_sparse_tensor, bf16)
return data, num_classes


def get_model(name, params, metadata=None):
Expand Down