Skip to content

Commit

Permalink
Introduce index2ptr and ptr2index helper functions (#6949)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Mar 18, 2023
1 parent f18f647 commit b5ecfd9
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 24 deletions.
20 changes: 7 additions & 13 deletions torch_geometric/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch import Tensor

from torch_geometric.typing import EdgeTensorType, EdgeType, OptTensor
from torch_geometric.utils import index_sort
from torch_geometric.utils.mixin import CastMixin
from torch_geometric.utils.sparse import index2ptr, ptr2index

# The output of converting between two types in the GraphStore is a Tuple of
# dictionaries: row, col, and perm. The dictionaries are keyed by the edge
Expand Down Expand Up @@ -262,40 +262,34 @@ def _edge_to_layout(
store: bool = False,
) -> Tuple[Tensor, Tensor, OptTensor]:

ind2ptr = torch._convert_indices_from_coo_to_csr

def ptr2ind(ptr: Tensor) -> Tensor:
ind = torch.arange(ptr.numel() - 1, device=ptr.device)
return ind.repeat_interleave(ptr[1:] - ptr[:-1])

(row, col), perm = self.get_edge_index(attr), None

if layout == EdgeLayout.COO: # COO output requested:
if attr.layout == EdgeLayout.CSR: # CSR->COO
row = ptr2ind(row)
row = ptr2index(row)
elif attr.layout == EdgeLayout.CSC: # CSC->COO
col = ptr2ind(col)
col = ptr2index(col)

elif layout == EdgeLayout.CSR: # CSR output requested:
if attr.layout == EdgeLayout.CSC: # CSC->COO
col = ptr2ind(col)
col = ptr2index(col)

if attr.layout != EdgeLayout.CSR: # COO->CSR
num_rows = attr.size[0] if attr.size else int(row.max()) + 1
row, perm = index_sort(row, max_value=num_rows)
col = col[perm]
row = ind2ptr(row, num_rows)
row = index2ptr(row, num_rows)

else: # CSC output requested:
if attr.layout == EdgeLayout.CSR: # CSR->COO
row = ptr2ind(row)
row = ptr2index(row)

if attr.layout != EdgeLayout.CSC: # COO->CSC
num_cols = attr.size[1] if attr.size else int(col.max()) + 1
if not attr.is_sorted: # Not sorted by destination.
col, perm = index_sort(col, max_value=num_cols)
row = row[perm]
col = ind2ptr(col, num_cols)
col = index2ptr(col, num_cols)

if attr.layout != layout and store:
attr = copy.copy(attr)
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/conv/cugraph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import Tensor

from torch_geometric.utils import index_sort
from torch_geometric.utils.sparse import index2ptr

try: # pragma: no cover
from pylibcugraphops import (
Expand Down Expand Up @@ -52,7 +53,7 @@ def to_csc(
col, perm = index_sort(col, max_value=num_target_nodes)
row = row[perm]

colptr = torch._convert_indices_from_coo_to_csr(col, num_target_nodes)
colptr = index2ptr(col, num_target_nodes)

if edge_attr is not None:
return (row, colptr), edge_attr[perm]
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
is_torch_sparse_tensor,
to_edge_index,
)
from torch_geometric.utils.sparse import ptr2index

from .utils.inspector import Inspector, func_body_repr, func_header_repr
from .utils.jit import class_from_module_repr
Expand Down Expand Up @@ -251,10 +252,10 @@ def __lift__(self, src, edge_index, dim):
if dim == 0:
index = edge_index.col_indices()
else:
index = ptr2ind(edge_index.crow_indices())
index = ptr2index(edge_index.crow_indices())
elif edge_index.layout == torch.sparse_csc:
if dim == 0:
index = ptr2ind(edge_index.ccol_indices())
index = ptr2index(edge_index.ccol_indices())
else:
index = edge_index.row_indices()
else:
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/conv/rgcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
torch_sparse,
)
from torch_geometric.utils import index_sort, scatter, spmm
from torch_geometric.utils.sparse import index2ptr

from ..inits import glorot, zeros

Expand Down Expand Up @@ -238,8 +239,7 @@ def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
edge_type, perm = index_sort(
edge_type, max_value=self.num_relations)
edge_index = edge_index[:, perm]
edge_type_ptr = torch._convert_indices_from_coo_to_csr(
edge_type, self.num_relations)
edge_type_ptr = index2ptr(edge_type, self.num_relations)
out = self.propagate(edge_index, x=x_l,
edge_type_ptr=edge_type_ptr, size=size)
else:
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/dense/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch_geometric.nn import inits
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import index_sort
from torch_geometric.utils.sparse import index2ptr


def is_uninitialized_parameter(x: Any) -> bool:
Expand Down Expand Up @@ -259,8 +260,7 @@ def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
type_vec, perm = index_sort(type_vec, self.num_types)
x = x[perm]

type_vec_ptr = torch._convert_indices_from_coo_to_csr(
type_vec, self.num_types)
type_vec_ptr = index2ptr(type_vec, self.num_types)
out = pyg_lib.ops.segment_matmul(x, type_vec_ptr, self.weight)
if self.bias is not None:
out += self.bias[type_vec]
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/sampler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch_geometric.data.storage import EdgeStorage
from torch_geometric.typing import NodeType, OptTensor
from torch_geometric.utils import index_sort
from torch_geometric.utils.sparse import index2ptr

# Edge Layout Conversion ######################################################

Expand Down Expand Up @@ -70,8 +71,7 @@ def to_csc(
if not is_sorted:
row, col, perm = sort_csc(row, col, src_node_time)

colptr = torch._convert_indices_from_coo_to_csr(
col, data.size(1), out_int32=col.dtype == torch.int32)
colptr = index2ptr(col, data.size(1))

else:
row = torch.empty(0, dtype=torch.long, device=device)
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/transforms/gdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
scatter,
to_dense_adj,
)
from torch_geometric.utils.sparse import index2ptr


@functional_transform('gdc')
Expand Down Expand Up @@ -305,8 +306,7 @@ def diffusion_matrix_approx(
edge_index_np = edge_index.cpu().numpy()

# Assumes sorted and coalesced edge indices:
indptr = torch._convert_indices_from_coo_to_csr(
edge_index[0], num_nodes).cpu().numpy()
indptr = index2ptr(edge_index[0], num_nodes).cpu().numpy()
out_degree = indptr[1:] - indptr[:-1]

neighbors, neighbor_weights = self.__calc_ppr__(
Expand Down
13 changes: 13 additions & 0 deletions torch_geometric/utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,16 @@ def to_edge_index(adj: Union[Tensor, SparseTensor]) -> Tuple[Tensor, Tensor]:
return adj.indices(), adj.values()

return adj._indices(), adj._values()


# Helper functions ############################################################


def ptr2index(ptr: Tensor) -> Tensor:
ind = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device)
return ind.repeat_interleave(ptr[1:] - ptr[:-1])


def index2ptr(index: Tensor, size: int) -> Tensor:
return torch._convert_indices_from_coo_to_csr(
index, size, out_int32=index.dtype == torch.int32)

0 comments on commit b5ecfd9

Please sign in to comment.