Skip to content

Commit

Permalink
Support for input_time in NeighborLoader (pyg-team#5763)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored and JakubPietrakIntel committed Nov 25, 2022
1 parent 45d30ca commit 8076e90
Show file tree
Hide file tree
Showing 14 changed files with 178 additions and 193 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added support for `input_time` in `NeighborLoader` ([#5763](https://github.com/pyg-team/pytorch_geometric/pull/5763))
- Added `disjoint` mode for temporal `LinkNeighborLoader` ([#5717](https://github.com/pyg-team/pytorch_geometric/pull/5717))
- Added `HeteroData` support for `transforms.Constant` ([#5700](https://github.com/pyg-team/pytorch_geometric/pull/5700))
- Added `np.memmap` support in `NeighborLoader` ([#5696](https://github.com/pyg-team/pytorch_geometric/pull/5696))
Expand Down
3 changes: 2 additions & 1 deletion test/loader/test_hgt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ def test_hgt_loader():
assert set(batch.node_types) == {'paper', 'author'}
assert set(batch.edge_types) == set(data.edge_types)

assert len(batch['paper']) == 2
assert len(batch['paper']) == 3
assert batch['paper'].x.size() == (40, ) # 20 + 4 * 5
assert batch['paper'].input_nodes.numel() == batch_size
assert batch['paper'].batch_size == batch_size
assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100

Expand Down
14 changes: 8 additions & 6 deletions test/loader/test_link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def test_homogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
for batch in loader:
assert isinstance(batch, Data)

assert len(batch) == 5
assert len(batch) == 6
assert batch.x.size(0) <= 100
assert batch.x.min() >= 0 and batch.x.max() < 100
assert batch.input_links.numel() == 20
assert batch.edge_index.min() >= 0
assert batch.edge_index.max() < batch.num_nodes
assert batch.edge_attr.min() >= 0
Expand Down Expand Up @@ -110,7 +111,7 @@ def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio):

for batch in loader:
assert isinstance(batch, HeteroData)
assert len(batch) == 5
assert len(batch) == 6
if neg_sampling_ratio == 0.0:
# Assert only positive samples are present in the original graph:
assert batch['paper', 'author'].edge_label.sum() == 0
Expand All @@ -120,7 +121,6 @@ def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio):
assert len(edge_index | edge_label_index) == len(edge_index)

else:

assert batch['paper', 'author'].edge_label_index.size(1) == 40
assert torch.all(batch['paper', 'author'].edge_label[:20] == 1)
assert torch.all(batch['paper', 'author'].edge_label[20:] == 0)
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_temporal_heterogeneous_link_neighbor_loader():
data['paper', 'author'].edge_index = get_edge_index(100, 200, 1000)
data['author', 'paper'].edge_index = get_edge_index(200, 100, 1000)

with pytest.raises(ValueError, match=r"'edge_label_time' was not set.*"):
with pytest.raises(ValueError, match=r"'edge_label_time' is not set"):
loader = LinkNeighborLoader(
data,
num_neighbors=[-1] * 2,
Expand Down Expand Up @@ -312,7 +312,8 @@ def test_homogeneous_link_neighbor_loader_no_edges():

for batch in loader:
assert isinstance(batch, Data)
assert len(batch) == 3
assert len(batch) == 4
assert batch.input_links.numel() == 20
assert batch.num_nodes <= 40
assert batch.edge_label_index.size(1) == 20
assert batch.num_nodes == batch.edge_label_index.unique().numel()
Expand All @@ -328,8 +329,9 @@ def test_heterogeneous_link_neighbor_loader_no_edges():

for batch in loader:
assert isinstance(batch, HeteroData)
assert len(batch) == 3
assert len(batch) == 4
assert batch['paper'].num_nodes <= 40
assert batch['paper', 'paper'].input_links.numel() == 20
assert batch['paper', 'paper'].edge_label_index.size(1) == 20
assert batch['paper'].num_nodes == batch[
'paper', 'paper'].edge_label_index.unique().numel()
10 changes: 5 additions & 5 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,9 @@ def test_homogeneous_neighbor_loader(directed):

for batch in loader:
assert isinstance(batch, Data)

assert len(batch) == 4
assert len(batch) == 5
assert batch.x.size(0) <= 100
assert batch.batch_size == 20
assert batch.input_nodes.numel() == batch.batch_size == 20
assert batch.x.min() >= 0 and batch.x.max() < 100
assert batch.edge_index.min() >= 0
assert batch.edge_index.max() < batch.num_nodes
Expand Down Expand Up @@ -118,8 +117,9 @@ def test_heterogeneous_neighbor_loader(directed):
# Test node type selection:
assert set(batch.node_types) == {'paper', 'author'}

assert len(batch['paper']) == 2
assert len(batch['paper']) == 3
assert batch['paper'].x.size(0) <= 100
assert batch['paper'].input_nodes.numel() == batch_size
assert batch['paper'].batch_size == batch_size
assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100

Expand Down Expand Up @@ -498,7 +498,7 @@ def test_pyg_lib_heterogeneous_neighbor_loader():
'author__to__paper': [-1, -1],
}

sample = torch.ops.pyg.hetero_neighbor_sample_cpu
sample = torch.ops.pyg.hetero_neighbor_sample
out1 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict,
num_neighbors_dict, None, None, True, False, True, False,
"uniform", True)
Expand Down
12 changes: 3 additions & 9 deletions torch_geometric/data/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,10 @@
from torch_geometric.data import Data, Dataset, HeteroData
from torch_geometric.data.feature_store import FeatureStore
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.loader import LinkNeighborLoader, NeighborLoader
from torch_geometric.loader.dataloader import DataLoader
from torch_geometric.loader.link_neighbor_loader import (
LinkNeighborLoader,
get_edge_label_index,
)
from torch_geometric.loader.neighbor_loader import (
NeighborLoader,
NeighborSampler,
get_input_nodes,
)
from torch_geometric.loader.utils import get_edge_label_index, get_input_nodes
from torch_geometric.sampler import NeighborSampler
from torch_geometric.typing import InputEdges, InputNodes

try:
Expand Down
6 changes: 4 additions & 2 deletions torch_geometric/loader/hgt_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,18 @@ def __init__(
**kwargs,
):
node_type, _ = get_input_nodes(data, input_nodes)
node_sampler = HGTSampler(

hgt_sampler = HGTSampler(
data,
num_samples=num_samples,
input_type=node_type,
is_sorted=is_sorted,
share_memory=kwargs.get('num_workers', 0) > 0,
)

super().__init__(
data=data,
node_sampler=node_sampler,
node_sampler=hgt_sampler,
input_nodes=input_nodes,
transform=transform,
filter_per_worker=filter_per_worker,
Expand Down
125 changes: 40 additions & 85 deletions torch_geometric/loader/link_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Iterator, Tuple, Union
from typing import Any, Callable, Iterator, List, Tuple, Union

import torch

Expand All @@ -7,6 +7,7 @@
from torch_geometric.data.graph_store import GraphStore
from torch_geometric.loader.base import DataLoaderIterator
from torch_geometric.loader.utils import (
InputData,
filter_custom_store,
filter_data,
filter_hetero_data,
Expand Down Expand Up @@ -89,53 +90,57 @@ def __init__(
if 'collate_fn' in kwargs:
del kwargs['collate_fn']

self.data = data

# Initialize sampler with keyword arguments:
# NOTE sampler is an attribute of 'DataLoader', so we use link_sampler
# here:
self.link_sampler = link_sampler

# Store additional arguments:
self.edge_label = edge_label
self.edge_label_index = edge_label_index
self.edge_label_time = edge_label_time
self.transform = transform
self.filter_per_worker = filter_per_worker
self.neg_sampling_ratio = neg_sampling_ratio

# Get input type, or None for homogeneous graphs:
# Get edge type (or `None` for homogeneous graphs):
edge_type, edge_label_index = get_edge_label_index(
data, edge_label_index)
if edge_label is None:
edge_label = torch.zeros(edge_label_index.size(1),
device=edge_label_index.device)
self.input_type = edge_type

super().__init__(
Dataset(edge_label_index, edge_label, edge_label_time),
collate_fn=self.collate_fn,
**kwargs,
self.data = data
self.edge_type = edge_type
self.link_sampler = link_sampler
self.input_data = InputData(edge_label_index[0], edge_label_index[1],
edge_label, edge_label_time)
self.neg_sampling_ratio = neg_sampling_ratio
self.transform = transform
self.filter_per_worker = filter_per_worker

iterator = range(edge_label_index.size(1))
super().__init__(iterator, collate_fn=self.collate_fn, **kwargs)

def collate_fn(self, index: List[int]) -> Any:
r"""Samples a subgraph from a batch of input nodes."""
input_data: EdgeSamplerInput = self.input_data[index]
out = self.link_sampler.sample_from_edges(
input_data,
negative_sampling_ratio=self.neg_sampling_ratio,
)

if self.filter_per_worker: # Execute `filter_fn` in the worker process
out = self.filter_fn(out)

return out

def filter_fn(
self,
out: Union[SamplerOutput, HeteroSamplerOutput],
) -> Union[Data, HeteroData]:
r"""Joins the sampled nodes with their corresponding features,
returning the resulting (Data or HeteroData) object to be used
downstream."""
returning the resulting :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` object to be used downstream.
"""
if isinstance(out, SamplerOutput):
edge_label_index, edge_label, edge_label_time = out.metadata
data = filter_data(self.data, out.node, out.row, out.col, out.edge,
self.link_sampler.edge_permutation)

data.batch = out.batch
data.edge_label_index = edge_label_index
data.edge_label = edge_label
data.edge_label_time = edge_label_time
data.input_links = out.metadata[0]
data.edge_label_index = out.metadata[1]
data.edge_label = out.metadata[2]
data.edge_label_time = out.metadata[3]

elif isinstance(out, HeteroSamplerOutput):
edge_label_index, edge_label, edge_label_time = out.metadata
if isinstance(self.data, HeteroData):
data = filter_hetero_data(self.data, out.node, out.row,
out.col, out.edge,
Expand All @@ -144,75 +149,25 @@ def filter_fn(
data = filter_custom_store(*self.data, out.node, out.row,
out.col, out.edge)

edge_type = self.input_type
for key, batch in (out.batch or {}).items():
data[key].batch = batch
data[edge_type].edge_label_index = edge_label_index
data[edge_type].edge_label = edge_label
if edge_label_time is not None:
data[edge_type].edge_label_time = edge_label_time
data[self.edge_type].input_links = out.metadata[0]
data[self.edge_type].edge_label_index = out.metadata[1]
data[self.edge_type].edge_label = out.metadata[2]
data[self.edge_type].edge_label_time = out.metadata[3]

else:
raise TypeError(f"'{self.__class__.__name__}'' found invalid "
f"type: '{type(out)}'")

return data if self.transform is None else self.transform(data)

def collate_fn(self, index: EdgeSamplerInput) -> Any:
r"""Samples a subgraph from a batch of input nodes."""
out = self.link_sampler.sample_from_edges(
index,
negative_sampling_ratio=self.neg_sampling_ratio,
)
if self.filter_per_worker:
# We execute `filter_fn` in the worker process.
out = self.filter_fn(out)
return out

def _get_iterator(self) -> Iterator:
if self.filter_per_worker:
return super()._get_iterator()
# We execute `filter_fn` in the main process.

# Execute `filter_fn` in the main process:
return DataLoaderIterator(super()._get_iterator(), self.filter_fn)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'


###############################################################################


class Dataset(torch.utils.data.Dataset):
def __init__(
self,
edge_label_index: torch.Tensor,
edge_label: torch.Tensor,
edge_label_time: OptTensor = None,
):
# NOTE see documentation of LinkLoader for details on these three
# input parameters:
self.edge_label_index = edge_label_index
self.edge_label = edge_label
self.edge_label_time = edge_label_time

def __getitem__(
self,
idx: int,
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if self.edge_label_time is None:
return (
self.edge_label_index[0, idx],
self.edge_label_index[1, idx],
self.edge_label[idx],
)
else:
return (
self.edge_label_index[0, idx],
self.edge_label_index[1, idx],
self.edge_label[idx],
self.edge_label_time[idx],
)

def __len__(self) -> int:
return self.edge_label_index.size(1)
19 changes: 7 additions & 12 deletions torch_geometric/loader/link_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,16 @@ def __init__(
neighbor_sampler: Optional[NeighborSampler] = None,
**kwargs,
):
# Get input type:
# TODO(manan): this computation is required twice, once here and once
# in LinkLoader:
# TODO(manan): Avoid duplicated computation (here and in NodeLoader):
edge_type, _ = get_edge_label_index(data, edge_label_index)

has_time_attr = time_attr is not None
has_edge_label_time = edge_label_time is not None
if has_edge_label_time != has_time_attr:
if (edge_label_time is not None) != (time_attr is not None):
raise ValueError(
f"Received conflicting 'time_attr' and 'edge_label_time' "
f"arguments: 'time_attr' was "
f"{'set' if has_time_attr else 'not set'} and "
f"'edge_label_time' was "
f"{'set' if has_edge_label_time else 'not set'}. Please "
f"resolve these conflicting arguments.")
f"Received conflicting 'edge_label_time' and 'time_attr' "
f"arguments: 'edge_label_time' is "
f"{'set' if edge_label_time is not None else 'not set'} "
f"while 'input_time' is "
f"{'set' if time_attr is not None else 'not set'}.")

if neighbor_sampler is None:
neighbor_sampler = NeighborSampler(
Expand Down
Loading

0 comments on commit 8076e90

Please sign in to comment.