Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor NeighborSampler to be input-type agnostic #6173

Merged
merged 13 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added `summary` for PyG/PyTorch models ([#5859](https://github.com/pyg-team/pytorch_geometric/pull/5859), [#6161](https://github.com/pyg-team/pytorch_geometric/pull/6161))
- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033))
- Add inputs_channels back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
- Add `inputs_channels` back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
### Changed
- Refactoed `NeighborSampler` to be input-type agnostic ([#7247](https://github.com/pyg-team/pytorch_geometric/pull/7247))
- Infer correct CUDA device ID in `profileit` decorator ([#6164](https://github.com/pyg-team/pytorch_geometric/pull/6164))
- Correctly use edge weights in `GDC` example ([#6159](https://github.com/pyg-team/pytorch_geometric/pull/6159))
- [Breaking Change] Moved PyTorch Lightning data modules to `torch_geometric.data.lightning` ([#6140](https://github.com/pyg-team/pytorch_geometric/pull/6140))
Expand Down
5 changes: 0 additions & 5 deletions torch_geometric/data/lightning/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
NeighborLoader,
NodeLoader,
)
from torch_geometric.loader.utils import get_edge_label_index, get_input_nodes
from torch_geometric.sampler import BaseSampler, NeighborSampler
from torch_geometric.typing import InputEdges, InputNodes

Expand Down Expand Up @@ -318,7 +317,6 @@ def __init__(
if loader == 'neighbor':
sampler_args = dict(inspect.signature(NeighborSampler).parameters)
sampler_args.pop('data')
sampler_args.pop('input_type')
sampler_args.pop('share_memory')
sampler_kwargs = {
key: kwargs.get(key, param.default)
Expand All @@ -327,7 +325,6 @@ def __init__(

self.neighbor_sampler = NeighborSampler(
data=data,
input_type=get_input_nodes(data, input_train_nodes)[0],
share_memory=num_workers > 0,
**sampler_kwargs,
)
Expand Down Expand Up @@ -555,15 +552,13 @@ def __init__(
if loader in ['neighbor', 'link_neighbor']:
sampler_args = dict(inspect.signature(NeighborSampler).parameters)
sampler_args.pop('data')
sampler_args.pop('input_type')
sampler_args.pop('share_memory')
sampler_kwargs = {
key: kwargs.get(key, param.default)
for key, param in sampler_args.items()
}
self.neighbor_sampler = NeighborSampler(
data=data,
input_type=get_edge_label_index(data, input_train_edges)[0],
share_memory=num_workers > 0,
**sampler_kwargs,
)
Expand Down
4 changes: 0 additions & 4 deletions torch_geometric/loader/hgt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from torch_geometric.data import FeatureStore, GraphStore, HeteroData
from torch_geometric.loader import NodeLoader
from torch_geometric.loader.utils import get_input_nodes
from torch_geometric.sampler import HGTSampler
from torch_geometric.typing import NodeType

Expand Down Expand Up @@ -106,12 +105,9 @@ def __init__(
filter_per_worker: bool = False,
**kwargs,
):
node_type, _ = get_input_nodes(data, input_nodes)

hgt_sampler = HGTSampler(
data,
num_samples=num_samples,
input_type=node_type,
is_sorted=is_sorted,
share_memory=kwargs.get('num_workers', 0) > 0,
)
Expand Down
41 changes: 20 additions & 21 deletions torch_geometric/loader/link_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,18 @@
from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
from torch_geometric.loader.base import DataLoaderIterator
from torch_geometric.loader.utils import (
InputData,
filter_custom_store,
filter_data,
filter_hetero_data,
get_edge_label_index,
)
from torch_geometric.sampler import (
BaseSampler,
EdgeSamplerInput,
HeteroSamplerOutput,
SamplerOutput,
)
from torch_geometric.sampler.base import (
EdgeSamplerInput,
NegativeSamplingConfig,
)
from torch_geometric.sampler.base import NegativeSamplingConfig
from torch_geometric.typing import InputEdges, OptTensor


Expand Down Expand Up @@ -134,11 +131,10 @@ def __init__(
neg_sampling = NegativeSamplingConfig("binary", neg_sampling_ratio)

# Get edge type (or `None` for homogeneous graphs):
edge_type, edge_label_index = get_edge_label_index(
input_type, edge_label_index = get_edge_label_index(
data, edge_label_index)

self.data = data
self.edge_type = edge_type
self.link_sampler = link_sampler
self.neg_sampling = NegativeSamplingConfig.cast(neg_sampling)
self.transform = transform
Expand All @@ -158,11 +154,13 @@ def __init__(
"instead to differentiate between positive and "
"negative samples.")

self.input_data = InputData(
edge_label_index[0].clone(),
edge_label_index[1].clone(),
edge_label,
edge_label_time,
self.input_data = EdgeSamplerInput(
input_id=None,
row=edge_label_index[0].clone(),
col=edge_label_index[1].clone(),
label=edge_label,
time=edge_label_time,
input_type=input_type,
)

iterator = range(edge_label_index.size(1))
Expand Down Expand Up @@ -217,18 +215,19 @@ def filter_fn(
for key, batch in (out.batch or {}).items():
data[key].batch = batch

data[self.edge_type].input_id = out.metadata[0]
input_type = self.input_data.input_type
data[input_type].input_id = out.metadata[0]

if self.neg_sampling is None or self.neg_sampling.is_binary():
data[self.edge_type].edge_label_index = out.metadata[1]
data[self.edge_type].edge_label = out.metadata[2]
data[self.edge_type].edge_label_time = out.metadata[3]
data[input_type].edge_label_index = out.metadata[1]
data[input_type].edge_label = out.metadata[2]
data[input_type].edge_label_time = out.metadata[3]
elif self.neg_sampling.is_triplet():
data[self.edge_type[0]].src_index = out.metadata[1]
data[self.edge_type[-1]].dst_pos_index = out.metadata[2]
data[self.edge_type[-1]].dst_neg_index = out.metadata[3]
data[self.edge_type[0]].seed_time = out.metadata[4]
data[self.edge_type[-1]].seed_time = out.metadata[4]
data[input_type[0]].src_index = out.metadata[1]
data[input_type[-1]].dst_pos_index = out.metadata[2]
data[input_type[-1]].dst_neg_index = out.metadata[3]
data[input_type[0]].seed_time = out.metadata[4]
data[input_type[-1]].seed_time = out.metadata[4]

else:
raise TypeError(f"'{self.__class__.__name__}'' found invalid "
Expand Down
5 changes: 0 additions & 5 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
from torch_geometric.loader.link_loader import LinkLoader
from torch_geometric.loader.utils import get_edge_label_index
from torch_geometric.sampler import NeighborSampler
from torch_geometric.sampler.base import NegativeSamplingConfig
from torch_geometric.typing import InputEdges, NumNeighbors, OptTensor
Expand Down Expand Up @@ -191,9 +190,6 @@ def __init__(
neighbor_sampler: Optional[NeighborSampler] = None,
**kwargs,
):
# TODO(manan): Avoid duplicated computation (here and in NodeLoader):
edge_type, _ = get_edge_label_index(data, edge_label_index)

if (edge_label_time is not None) != (time_attr is not None):
raise ValueError(
f"Received conflicting 'edge_label_time' and 'time_attr' "
Expand All @@ -210,7 +206,6 @@ def __init__(
directed=directed,
disjoint=disjoint,
temporal_strategy=temporal_strategy,
input_type=edge_type,
time_attr=time_attr,
is_sorted=is_sorted,
share_memory=kwargs.get('num_workers', 0) > 0,
Expand Down
5 changes: 0 additions & 5 deletions torch_geometric/loader/neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
from torch_geometric.loader.node_loader import NodeLoader
from torch_geometric.loader.utils import get_input_nodes
from torch_geometric.sampler import NeighborSampler
from torch_geometric.typing import InputNodes, NumNeighbors, OptTensor

Expand Down Expand Up @@ -187,9 +186,6 @@ def __init__(
neighbor_sampler: Optional[NeighborSampler] = None,
**kwargs,
):
# TODO(manan): Avoid duplicated computation (here and in NodeLoader):
node_type, _ = get_input_nodes(data, input_nodes)

if input_time is not None and time_attr is None:
raise ValueError("Received conflicting 'input_time' and "
"'time_attr' arguments: 'input_time' is set "
Expand All @@ -203,7 +199,6 @@ def __init__(
directed=directed,
disjoint=disjoint,
temporal_strategy=temporal_strategy,
input_type=node_type,
time_attr=time_attr,
is_sorted=is_sorted,
share_memory=kwargs.get('num_workers', 0) > 0,
Expand Down
19 changes: 12 additions & 7 deletions torch_geometric/loader/node_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
from torch_geometric.loader.base import DataLoaderIterator, WorkerInitWrapper
from torch_geometric.loader.utils import (
InputData,
filter_custom_store,
filter_data,
filter_hetero_data,
Expand All @@ -18,9 +17,9 @@
from torch_geometric.sampler import (
BaseSampler,
HeteroSamplerOutput,
NodeSamplerInput,
SamplerOutput,
)
from torch_geometric.sampler.base import NodeSamplerInput
from torch_geometric.typing import InputNodes, OptTensor


Expand Down Expand Up @@ -87,15 +86,20 @@ def __init__(
del kwargs['collate_fn']

# Get node type (or `None` for homogeneous graphs):
node_type, input_nodes = get_input_nodes(data, input_nodes)
input_type, input_nodes = get_input_nodes(data, input_nodes)

self.data = data
self.node_type = node_type
self.node_sampler = node_sampler
self.input_data = InputData(input_nodes, input_time)
self.transform = transform
self.filter_per_worker = filter_per_worker

self.input_data = NodeSamplerInput(
input_id=None,
node=input_nodes,
time=input_time,
input_type=input_type,
)

# TODO: Unify DL affinitization in `BaseDataLoader` class
# CPU Affinitization for loader and compute cores
self.num_workers = kwargs.get('num_workers', 0)
Expand Down Expand Up @@ -144,8 +148,9 @@ def filter_fn(

for key, batch in (out.batch or {}).items():
data[key].batch = batch
data[self.node_type].input_id = out.metadata
data[self.node_type].batch_size = out.metadata.size(0)

data[self.input_data.input_type].input_id = out.metadata
data[self.input_data.input_type].batch_size = out.metadata.size(0)

else:
raise TypeError(f"'{self.__class__.__name__}'' found invalid "
Expand Down
16 changes: 1 addition & 15 deletions torch_geometric/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import os
from collections.abc import Sequence
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -27,20 +27,6 @@
)


class InputData:
def __init__(self, *args):
self.args = args

def __getitem__(self, index: Union[Tensor, List[int]]) -> Any:
if not isinstance(index, Tensor):
index = torch.tensor(index, dtype=torch.long)

outs = [index]
for arg in self.args:
outs.append(arg[index] if arg is not None else None)
return tuple(outs)


def index_select(value: FeatureTensorType, index: Tensor,
dim: int = 0) -> Tensor:

Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from .base import BaseSampler, SamplerOutput, HeteroSamplerOutput
from .base import (BaseSampler, NodeSamplerInput, EdgeSamplerInput,
SamplerOutput, HeteroSamplerOutput)
from .neighbor_sampler import NeighborSampler
from .hgt_sampler import HGTSampler

__all__ = classes = [
'BaseSampler',
'NodeSamplerInput',
'EdgeSamplerInput',
'SamplerOutput',
'HeteroSamplerOutput',
'NeighborSampler',
Expand Down
Loading