Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable index_sort #6554

Merged
merged 8 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added accelerated `index_sort` function from `pyg-lib` for faster sorting ([#6554](https://github.com/pyg-team/pytorch_geometric/pull/6554))
- Fix incorrect device in `EquilibriumAggregration` ([#6560](https://github.com/pyg-team/pytorch_geometric/pull/6560))
- Added bipartite graph support in `dense_to_sparse()` ([#6546](https://github.com/pyg-team/pytorch_geometric/pull/6546))
- Add CPU affinity support for more data loaders ([#6534](https://github.com/pyg-team/pytorch_geometric/pull/6534))
Expand Down
8 changes: 6 additions & 2 deletions benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ def run(args: argparse.ArgumentParser) -> None:
assert dataset_name in supported_sets.keys(
), f"Dataset {dataset_name} isn't supported."
print(f'Dataset: {dataset_name}')
dataset, num_classes = get_dataset(dataset_name, args.root,
args.use_sparse_tensor, args.bf16)
load_time = timeit() if args.measure_load_time else nullcontext()
with load_time:
dataset, num_classes = get_dataset(dataset_name, args.root,
args.use_sparse_tensor,
args.bf16)
data = dataset.to(device)
hetero = True if dataset_name == 'ogbn-mag' else False
mask = ('paper', None) if dataset_name == 'ogbn-mag' else None
Expand Down Expand Up @@ -166,4 +169,5 @@ def run(args: argparse.ArgumentParser) -> None:
help="Use DataLoader affinitzation.")
add('--loader-cores', nargs='+', default=[], type=int,
help="List of CPU core IDs to use for DataLoader workers.")
add('--measure-load-time', action='store_true')
run(argparser.parse_args())
7 changes: 5 additions & 2 deletions benchmark/training/training_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ def run(args: argparse.ArgumentParser) -> None:
assert dataset_name in supported_sets.keys(
), f"Dataset {dataset_name} isn't supported."
print(f'Dataset: {dataset_name}')
data, num_classes = get_dataset(dataset_name, args.root,
args.use_sparse_tensor, args.bf16)
load_time = timeit() if args.measure_load_time else nullcontext()
with load_time:
data, num_classes = get_dataset(dataset_name, args.root,
args.use_sparse_tensor, args.bf16)
hetero = True if dataset_name == 'ogbn-mag' else False
mask = ('paper', data['paper'].train_mask
) if dataset_name == 'ogbn-mag' else data.train_mask
Expand Down Expand Up @@ -219,6 +221,7 @@ def run(args: argparse.ArgumentParser) -> None:
help="Use DataLoader affinitzation.")
add('--loader-cores', nargs='+', default=[], type=int,
help="List of CPU core IDs to use for DataLoader workers.")
add('--measure-load-time', action='store_true')
args = argparser.parse_args()

run(args)
9 changes: 5 additions & 4 deletions torch_geometric/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch import Tensor

from torch_geometric.typing import EdgeTensorType, EdgeType, OptTensor
from torch_geometric.utils import index_sort
from torch_geometric.utils.mixin import CastMixin

# The output of converting between two types in the GraphStore is a Tuple of
Expand Down Expand Up @@ -276,20 +277,20 @@ def _edge_to_layout(
col = ptr2ind(col, row.numel())

if attr.layout != EdgeLayout.CSR: # COO->CSR
row, perm = row.sort() # Cannot be sorted by destination.
col = col[perm]
num_rows = attr.size[0] if attr.size else int(row.max()) + 1
row, perm = index_sort(row, max_value=num_rows)
col = col[perm]
row = ind2ptr(row, num_rows)

else: # CSC output requested:
if attr.layout == EdgeLayout.CSR: # CSR->COO
row = ptr2ind(row, col.numel())

if attr.layout != EdgeLayout.CSC: # COO->CSC
num_cols = attr.size[1] if attr.size else int(col.max()) + 1
if not attr.is_sorted: # Not sorted by destination.
col, perm = col.sort()
col, perm = index_sort(col, max_value=num_cols)
row = row[perm]
num_cols = attr.size[1] if attr.size else int(col.max()) + 1
col = ind2ptr(col, num_cols)

if attr.layout != layout and store:
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/datasets/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
download_url,
extract_tar,
)
from torch_geometric.utils import index_sort


class Entities(InMemoryDataset):
Expand Down Expand Up @@ -147,7 +148,7 @@ def process(self):
edges.append([dst, src, 2 * rel + 1])

edges = torch.tensor(edges, dtype=torch.long).t().contiguous()
perm = (N * R * edges[0] + R * edges[1] + edges[2]).argsort()
_, perm = index_sort(N * R * edges[0] + R * edges[1] + edges[2])
edges = edges[:, perm]

edge_index, edge_type = edges[:2], edges[2]
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/datasets/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from torch_geometric.data import Data, Dataset
from torch_geometric.utils import scatter
from torch_geometric.utils import index_sort, scatter


class TrackingData(Data):
Expand Down Expand Up @@ -85,7 +85,7 @@ def get(self, idx):
weight = torch.from_numpy(y['weight'].values).to(torch.float)

# Sort.
perm = (particle_id * hit_id.size(0) + hit_id).argsort()
_, perm = index_sort(particle_id * hit_id.size(0) + hit_id)
hit_id = hit_id[perm]
particle_id = particle_id[perm]
weight = weight[perm]
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/datasets/word_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from torch_geometric.data import Data, InMemoryDataset, download_url
from torch_geometric.utils import index_sort


class WordNet18(InMemoryDataset):
Expand Down Expand Up @@ -81,7 +82,7 @@ def process(self):
test_mask[srcs[0].size(0) + srcs[1].size(0):] = True

num_nodes = max(int(src.max()), int(dst.max())) + 1
perm = (num_nodes * src + dst).argsort()
_, perm = index_sort(num_nodes * src + dst)

edge_index = torch.stack([src[perm], dst[perm]], dim=0)
edge_type = edge_type[perm]
Expand Down Expand Up @@ -191,7 +192,7 @@ def process(self):
test_mask[srcs[0].size(0) + srcs[1].size(0):] = True

num_nodes = max(int(src.max()), int(dst.max())) + 1
perm = (num_nodes * src + dst).argsort()
_, perm = index_sort(num_nodes * src + dst)

edge_index = torch.stack([src[perm], dst[perm]], dim=0)
edge_type = edge_type[perm]
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/nn/conv/rgcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch_geometric.typing
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import Adj, OptTensor, pyg_lib
from torch_geometric.utils import scatter
from torch_geometric.utils import index_sort, scatter

from ..inits import glorot, zeros

Expand Down Expand Up @@ -230,7 +230,8 @@ def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
and isinstance(edge_index, Tensor)):
if not self.is_sorted:
if (edge_type[1:] < edge_type[:-1]).any():
edge_type, perm = edge_type.sort()
edge_type, perm = index_sort(
edge_type, max_value=self.num_relations)
edge_index = edge_index[:, perm]
edge_type_ptr = torch.ops.torch_sparse.ind2ptr(
edge_type, self.num_relations)
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/dense/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch_geometric.typing
from torch_geometric.nn import inits
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import index_sort


def is_uninitialized_parameter(x: Any) -> bool:
Expand Down Expand Up @@ -255,7 +256,7 @@ def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
perm: Optional[Tensor] = None
if not self.is_sorted:
if (type_vec[1:] < type_vec[:-1]).any():
type_vec, perm = type_vec.sort()
type_vec, perm = index_sort(type_vec, self.num_types)
x = x[perm]

type_vec_ptr = torch.ops.torch_sparse.ind2ptr(
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/sampler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch_geometric.data import Data, HeteroData
from torch_geometric.data.storage import EdgeStorage
from torch_geometric.typing import NodeType, OptTensor
from torch_geometric.utils import index_sort

# Edge Layout Conversion ######################################################

Expand All @@ -17,7 +18,7 @@ def sort_csc(
src_node_time: OptTensor = None,
) -> Tuple[Tensor, Tensor, Tensor]:
if src_node_time is None:
col, perm = col.sort()
col, perm = index_sort(col)
return row[perm], col, perm
else:
# We use `np.lexsort` to sort based on multiple keys.
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .scatter import scatter
from .sort import index_sort
from .degree import degree
from .softmax import softmax
from .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path
Expand Down Expand Up @@ -44,6 +45,7 @@

__all__ = [
'scatter',
'index_sort',
'degree',
'softmax',
'dropout_node',
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/utils/coalesce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import Tensor

from torch_geometric.utils import scatter
from torch_geometric.utils import index_sort, scatter

from .num_nodes import maybe_num_nodes

Expand Down Expand Up @@ -94,7 +94,7 @@ def coalesce(
idx[1:].mul_(num_nodes).add_(edge_index[int(sort_by_row)])

if not is_sorted:
idx[1:], perm = idx[1:].sort()
idx[1:], perm = index_sort(idx[1:], max_value=num_nodes * num_nodes)
edge_index = edge_index[:, perm]
if isinstance(edge_attr, Tensor):
edge_attr = edge_attr[perm]
Expand Down
28 changes: 28 additions & 0 deletions torch_geometric/utils/sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Optional, Tuple

import torch

from torch_geometric.typing import WITH_PYG_LIB, pyg_lib

WITH_INDEX_SORT = WITH_PYG_LIB and hasattr(torch.ops.pyg, 'index_sort')


def index_sort(
inputs: torch.Tensor,
max_value: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Sorts the elements of the :obj:`inputs` tensor in ascending order.
It is expected that :obj:`inputs` is one-dimensional and that it only
contains positive integer values. If :obj:`max_value` is given, it can
be used by the underlying algorithm for better performance.

Args:
inputs (torch.Tensor): A vector with positive integer values.
max_value (int, optional): The maximum value stored inside
:obj:`inputs`. This value can be an estimation, but needs to be
greater than or equal to the real maximum.
(default: :obj:`None`)
"""
if not WITH_INDEX_SORT:
return inputs.sort()
return pyg_lib.ops.index_sort(inputs, max_value=max_value)
4 changes: 3 additions & 1 deletion torch_geometric/utils/sort_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
from torch import Tensor

from torch_geometric.utils import index_sort

from .num_nodes import maybe_num_nodes


Expand Down Expand Up @@ -71,7 +73,7 @@ def sort_edge_index(
idx = edge_index[1 - int(sort_by_row)] * num_nodes
idx += edge_index[int(sort_by_row)]

perm = idx.argsort()
_, perm = index_sort(idx, max_value=num_nodes * num_nodes)

edge_index = edge_index[:, perm]

Expand Down