From 2c63244e76722ba3dad8daadaca7b7e326ac74d5 Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Fri, 3 Mar 2023 07:19:23 +0100 Subject: [PATCH 1/3] Add sparse tensor support in full batch mode --- benchmark/inference/inference_benchmark.py | 19 ++++++++++++++----- benchmark/utils/__init__.py | 3 ++- benchmark/utils/utils.py | 9 +++++++-- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/benchmark/inference/inference_benchmark.py b/benchmark/inference/inference_benchmark.py index 58cd56807964..49280a4208d4 100644 --- a/benchmark/inference/inference_benchmark.py +++ b/benchmark/inference/inference_benchmark.py @@ -3,7 +3,8 @@ import torch -from benchmark.utils import emit_itt, get_dataset, get_model, get_split_masks +from benchmark.utils import (emit_itt, get_dataset_with_transformation, + get_model, get_split_masks) from torch_geometric.loader import NeighborLoader from torch_geometric.nn import PNAConv from torch_geometric.profile import rename_profile_file, timeit, torch_profile @@ -18,7 +19,11 @@ @torch.no_grad() def full_batch_inference(model, data): model.eval() - return model(data.x, data.edge_index) + if hasattr(data, 'adj_t'): + edge_index = data.adj_t + else: + edge_index = data.edge_index + return model(data.x, edge_index) def test(y, loader): @@ -41,9 +46,10 @@ def run(args: argparse.ArgumentParser): print(f'Dataset: {dataset_name}') load_time = timeit() if args.measure_load_time else nullcontext() with load_time: - dataset, num_classes = get_dataset(dataset_name, args.root, - args.use_sparse_tensor, - args.bf16) + result = get_dataset_with_transformation(dataset_name, args.root, + args.use_sparse_tensor, + args.bf16) + dataset, num_classes, transformation = result data = dataset.to(device) hetero = True if dataset_name == 'ogbn-mag' else False mask = ('paper', None) if dataset_name == 'ogbn-mag' else None @@ -162,6 +168,9 @@ def run(args: argparse.ArgumentParser): itt = emit_itt( ) if args.vtune_profile else nullcontext() + if args.full_batch and args.use_sparse_tensor: + data = transformation(data) + with cpu_affinity, amp, timeit() as time: for _ in range(args.warmup): if args.full_batch: diff --git a/benchmark/utils/__init__.py b/benchmark/utils/__init__.py index 77cad79c3389..c0e324ca190a 100644 --- a/benchmark/utils/__init__.py +++ b/benchmark/utils/__init__.py @@ -1,11 +1,12 @@ from .utils import emit_itt -from .utils import get_dataset +from .utils import get_dataset, get_dataset_with_transformation from .utils import get_model from .utils import get_split_masks __all__ = [ 'emit_itt', 'get_dataset', + 'get_dataset_with_transformation', 'get_model', 'get_split_masks', ] diff --git a/benchmark/utils/utils.py b/benchmark/utils/utils.py index 163e75efa996..4038ffb0968a 100644 --- a/benchmark/utils/utils.py +++ b/benchmark/utils/utils.py @@ -32,7 +32,7 @@ def emit_itt(*args, **kwargs): } -def get_dataset(name, root, use_sparse_tensor=False, bf16=False): +def get_dataset_with_transformation(name, root, use_sparse_tensor=False, bf16=False): path = osp.join(osp.dirname(osp.realpath(__file__)), root, name) transform = T.ToSparseTensor( remove_edge_index=False) if use_sparse_tensor else None @@ -68,7 +68,12 @@ def get_dataset(name, root, use_sparse_tensor=False, bf16=False): if bf16: data.x = data.x.to(torch.bfloat16) - return data, dataset.num_classes + return data, dataset.num_classes, transform + + +def get_dataset(name, root, use_sparse_tensor=False, bf16=False): + data, num_classes, _ = get_dataset_with_transformation(name, root, use_sparse_tensor, bf16) + return data, num_classes def get_model(name, params, metadata=None): From a4dd36296164a629d5ffe3c93be8af10c5d92fa1 Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Fri, 3 Mar 2023 07:26:30 +0100 Subject: [PATCH 2/3] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 91b6fc261be4..1475c789bf70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for sparse tensor in full batch mode inference benchmark ([#6843](https://github.com/pyg-team/pytorch_geometric/pull/6843)) - Enabled `NeighborLoader` to return number of sampled nodes and edges per hop ([#6834](https://github.com/pyg-team/pytorch_geometric/pull/6834)) - Added `ZipLoader` to execute multiple `NodeLoader` or `LinkLoader` instances ([#6829](https://github.com/pyg-team/pytorch_geometric/issues/6829)) - Added common `utils.select` and `utils.narrow` functionality to support filtering of both tensors and lists ([#6162](https://github.com/pyg-team/pytorch_geometric/issues/6162)) From 4e5a621e9c51fec27463a3b1b488826166fbac23 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Mar 2023 07:10:55 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- benchmark/inference/inference_benchmark.py | 8 ++++++-- benchmark/utils/utils.py | 6 ++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/benchmark/inference/inference_benchmark.py b/benchmark/inference/inference_benchmark.py index 49280a4208d4..b9fda243f989 100644 --- a/benchmark/inference/inference_benchmark.py +++ b/benchmark/inference/inference_benchmark.py @@ -3,8 +3,12 @@ import torch -from benchmark.utils import (emit_itt, get_dataset_with_transformation, - get_model, get_split_masks) +from benchmark.utils import ( + emit_itt, + get_dataset_with_transformation, + get_model, + get_split_masks, +) from torch_geometric.loader import NeighborLoader from torch_geometric.nn import PNAConv from torch_geometric.profile import rename_profile_file, timeit, torch_profile diff --git a/benchmark/utils/utils.py b/benchmark/utils/utils.py index 4038ffb0968a..cfd9393f5946 100644 --- a/benchmark/utils/utils.py +++ b/benchmark/utils/utils.py @@ -32,7 +32,8 @@ def emit_itt(*args, **kwargs): } -def get_dataset_with_transformation(name, root, use_sparse_tensor=False, bf16=False): +def get_dataset_with_transformation(name, root, use_sparse_tensor=False, + bf16=False): path = osp.join(osp.dirname(osp.realpath(__file__)), root, name) transform = T.ToSparseTensor( remove_edge_index=False) if use_sparse_tensor else None @@ -72,7 +73,8 @@ def get_dataset_with_transformation(name, root, use_sparse_tensor=False, bf16=Fa def get_dataset(name, root, use_sparse_tensor=False, bf16=False): - data, num_classes, _ = get_dataset_with_transformation(name, root, use_sparse_tensor, bf16) + data, num_classes, _ = get_dataset_with_transformation( + name, root, use_sparse_tensor, bf16) return data, num_classes