Skip to content

Commit

Permalink
Fix inference benchmark (#5341)
Browse files Browse the repository at this point in the history
* update

* changelog

* update

* update
  • Loading branch information
rusty1s authored Sep 2, 2022
1 parent 3c02604 commit 6f0be26
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 44 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Enabled `bf16` support in benchmark scripts ([#5293](https://github.com/pyg-team/pytorch_geometric/pull/5293))
- Enabled `bf16` support in benchmark scripts ([#5293](https://github.com/pyg-team/pytorch_geometric/pull/5293), [#5341](https://github.com/pyg-team/pytorch_geometric/pull/5341))
- Added `Aggregation.set_validate_args` option to skip validation of `dim_size` ([#5290](https://github.com/pyg-team/pytorch_geometric/pull/5290))
- Added `SparseTensor` support to inference benchmark suite ([#5242](https://github.com/pyg-team/pytorch_geometric/pull/5242), [#5258](https://github.com/pyg-team/pytorch_geometric/pull/5258))
- Added experimental mode in inference benchmarks ([#5254](https://github.com/pyg-team/pytorch_geometric/pull/5254))
Expand Down
27 changes: 8 additions & 19 deletions benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from utils import get_dataset, get_model

import torch_geometric
from torch_geometric import set_experimental_mode
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import PNAConv
from torch_geometric.profile import rename_profile_file, timeit, torch_profile
Expand All @@ -25,7 +25,7 @@ def run(args: argparse.ArgumentParser) -> None:
), f"Dataset {dataset_name} isn't supported."
print(f'Dataset: {dataset_name}')
dataset, num_classes = get_dataset(dataset_name, args.root,
args.use_sparse_tensor)
args.use_sparse_tensor, args.bf16)
data = dataset.to(device)
hetero = True if dataset_name == 'ogbn-mag' else False
mask = ('paper', None) if dataset_name == 'ogbn-mag' else None
Expand All @@ -34,9 +34,6 @@ def run(args: argparse.ArgumentParser) -> None:
amp = torch.cuda.amp.autocast(enabled=False)
else:
amp = torch.cpu.amp.autocast(enabled=args.bf16)
dtype = torch.float
if args.bf16:
dtype = torch.bfloat16

inputs_channels = data[
'paper'].num_features if dataset_name == 'ogbn-mag' \
Expand Down Expand Up @@ -101,26 +98,18 @@ def run(args: argparse.ArgumentParser) -> None:
model.eval()

with amp:
for _ in range(args.warmup):
model.inference(subgraph_loader, device,
progress_bar=True, dtype=dtype)
if args.experimental_mode:
with torch_geometric.experimental_mode():
with timeit():
model.inference(
subgraph_loader, device,
progress_bar=True, dtype=dtype)
else:
with set_experimental_mode(args.experimental_mode):
for _ in range(args.warmup):
model.inference(subgraph_loader, device,
progress_bar=True)
with timeit():
model.inference(subgraph_loader, device,
progress_bar=True,
dtype=dtype)
progress_bar=True)

if args.profile:
with torch_profile():
model.inference(subgraph_loader, device,
progress_bar=True,
dtype=dtype)
progress_bar=True)
rename_profile_file(
model_name, dataset_name, str(batch_size),
str(layers), str(hidden_channels),
Expand Down
10 changes: 8 additions & 2 deletions benchmark/inference/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os.path as osp

import torch
from hetero_gat import HeteroGAT
from hetero_sage import HeteroGraphSAGE
from ogb.nodeproppred import PygNodePropPredDataset
Expand All @@ -19,7 +20,7 @@
}


def get_dataset(name, root, use_sparse_tensor):
def get_dataset(name, root, use_sparse_tensor=False, bf16=False):
path = osp.join(osp.dirname(osp.realpath(__file__)), root, name)
transform = T.ToSparseTensor() if use_sparse_tensor else None
if name == 'ogbn-mag':
Expand All @@ -35,7 +36,12 @@ def get_dataset(name, root, use_sparse_tensor):
elif name == 'Reddit':
dataset = Reddit(root=path, transform=transform)

return dataset[0], dataset.num_classes
data = dataset[0]

if bf16:
data.x = data.x.to(torch.bfloat16)

return data, dataset.num_classes


def get_model(name, params, metadata=None):
Expand Down
19 changes: 18 additions & 1 deletion test/test_experimental.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import pytest

from torch_geometric import experimental_mode, is_experimental_mode_enabled
from torch_geometric import (
experimental_mode,
is_experimental_mode_enabled,
set_experimental_mode,
)


@pytest.mark.parametrize('options', [None, 'scatter_reduce'])
Expand All @@ -9,3 +13,16 @@ def test_experimental_mode(options):
with experimental_mode(options):
assert is_experimental_mode_enabled(options) is True
assert is_experimental_mode_enabled(options) is False

with set_experimental_mode(True, options):
assert is_experimental_mode_enabled(options) is True
assert is_experimental_mode_enabled(options) is False

with set_experimental_mode(False, options):
assert is_experimental_mode_enabled(options) is False
assert is_experimental_mode_enabled(options) is False

set_experimental_mode(True, options)
assert is_experimental_mode_enabled(options) is True
set_experimental_mode(False, options)
assert is_experimental_mode_enabled(options) is False
6 changes: 4 additions & 2 deletions torch_geometric/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from .seed import seed_everything
from .home import get_home_dir, set_home_dir
from .debug import is_debug_enabled, debug, set_debug
from .experimental import is_experimental_mode_enabled, experimental_mode
from .experimental import (is_experimental_mode_enabled, experimental_mode,
set_experimental_mode)


# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/
Expand Down Expand Up @@ -49,8 +50,9 @@ def __dir__(self):
'is_debug_enabled',
'debug',
'set_debug',
'experimental_mode',
'is_experimental_mode_enabled',
'experimental_mode',
'set_experimental_mode',
'torch_geometric',
'__version__',
]
6 changes: 2 additions & 4 deletions torch_geometric/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def is_debug_enabled():
return __debug_flag__['enabled']


def set_debug_enabled(mode):
def set_debug_enabled(mode: bool):
__debug_flag__['enabled'] = mode


Expand All @@ -27,7 +27,6 @@ def __enter__(self):

def __exit__(self, *args):
set_debug_enabled(self.prev)
return False


class set_debug:
Expand All @@ -39,7 +38,7 @@ class set_debug:
See :class:`debug` above for more details.
"""
def __init__(self, mode):
def __init__(self, mode: bool):
self.prev = is_debug_enabled()
set_debug_enabled(mode)

Expand All @@ -48,4 +47,3 @@ def __enter__(self):

def __exit__(self, *args):
set_debug_enabled(self.prev)
return False
58 changes: 45 additions & 13 deletions torch_geometric/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,30 @@

__experimental_flag__ = {'scatter_reduce': False}

Options = Optional[Union[str, List[str]]]

def is_experimental_mode_enabled(
options: Optional[Union[str, List[str]]] = None) -> bool:
r"""Returns :obj:`True` if the experimental mode is enabled. See
:class:`torch_geometric.experimental_mode` for a list of (optional)
options."""

def get_options(options: Options) -> List[str]:
if options is None:
options = list(__experimental_flag__.keys())
if isinstance(options, str):
options = [options]
return options


def is_experimental_mode_enabled(options: Options = None) -> bool:
r"""Returns :obj:`True` if the experimental mode is enabled. See
:class:`torch_geometric.experimental_mode` for a list of (optional)
options."""
options = get_options(options)
return all([__experimental_flag__[option] for option in options])


def set_experimental_mode_enabled(mode: bool, options: Options = None):
for option in get_options(options):
__experimental_flag__[option] = mode


class experimental_mode:
r"""Context-manager that enables the experimental mode to test new but
potentially unstable features.
Expand All @@ -31,20 +42,41 @@ class experimental_mode:
:meth:`torch.scatter_reduce` instead of
:meth:`torch_scatter.scatter`. Requires :obj:`torch>=1.12`.
"""
def __init__(self, options: Optional[Union[str, List[str]]] = None):
if options is None:
options = list(__experimental_flag__.keys())
if isinstance(options, str):
options = [options]
def __init__(self, options: Options = None):
self.options = get_options(options)
self.previous_state = {
option: __experimental_flag__[option]
for option in options
for option in self.options
}

def __enter__(self) -> None:
for option in self.previous_state.keys():
__experimental_flag__[option] = True
set_experimental_mode_enabled(True, self.options)

def __exit__(self, *args) -> bool:
for option, value in self.previous_state.items():
__experimental_flag__[option] = value


class set_experimental_mode:
r"""Context-manager that sets the experimental mode on or off.
:class:`set_experimental_mode` will enable or disable the experimental mode
based on its argument :attr:`mode`.
It can be used as a context-manager or as a function.
See :class:`experimental_mode` above for more details.
"""
def __init__(self, mode: bool, options: Options = None):
self.options = get_options(options)
self.previous_state = {
option: __experimental_flag__[option]
for option in self.options
}
set_experimental_mode_enabled(mode, self.options)

def __enter__(self):
pass

def __exit__(self, *args):
for option, value in self.previous_state.items():
__experimental_flag__[option] = value
3 changes: 1 addition & 2 deletions torch_geometric/nn/models/basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def forward(
@torch.no_grad()
def inference(self, loader: NeighborLoader,
device: Optional[torch.device] = None,
progress_bar: bool = False, dtype=torch.float) -> Tensor:
progress_bar: bool = False) -> Tensor:
r"""Performs layer-wise inference on large-graphs using
:class:`~torch_geometric.loader.NeighborLoader`.
:class:`~torch_geometric.loader.NeighborLoader` should sample the the
Expand All @@ -216,7 +216,6 @@ def inference(self, loader: NeighborLoader,
pbar.set_description('Inference')

x_all = loader.data.x.cpu()
x_all = x_all.to(dtype)
loader.data.n_id = torch.arange(x_all.size(0))

for i in range(self.num_layers):
Expand Down

0 comments on commit 6f0be26

Please sign in to comment.