diff --git a/CHANGELOG.md b/CHANGELOG.md index ab2417fdb89f..0b036bb3819d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.2.0] - 2022-MM-DD ### Added +- Added support for `input_time` in `NeighborLoader` ([#5763](https://github.com/pyg-team/pytorch_geometric/pull/5763)) - Added `disjoint` mode for temporal `LinkNeighborLoader` ([#5717](https://github.com/pyg-team/pytorch_geometric/pull/5717)) - Added `HeteroData` support for `transforms.Constant` ([#5700](https://github.com/pyg-team/pytorch_geometric/pull/5700)) - Added `np.memmap` support in `NeighborLoader` ([#5696](https://github.com/pyg-team/pytorch_geometric/pull/5696)) diff --git a/test/loader/test_hgt_loader.py b/test/loader/test_hgt_loader.py index 50afe304a392..bf477eb81d0d 100644 --- a/test/loader/test_hgt_loader.py +++ b/test/loader/test_hgt_loader.py @@ -60,8 +60,9 @@ def test_hgt_loader(): assert set(batch.node_types) == {'paper', 'author'} assert set(batch.edge_types) == set(data.edge_types) - assert len(batch['paper']) == 2 + assert len(batch['paper']) == 3 assert batch['paper'].x.size() == (40, ) # 20 + 4 * 5 + assert batch['paper'].input_nodes.numel() == batch_size assert batch['paper'].batch_size == batch_size assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100 diff --git a/test/loader/test_link_neighbor_loader.py b/test/loader/test_link_neighbor_loader.py index 9399b3016bd8..60751e6a24d5 100644 --- a/test/loader/test_link_neighbor_loader.py +++ b/test/loader/test_link_neighbor_loader.py @@ -51,9 +51,10 @@ def test_homogeneous_link_neighbor_loader(directed, neg_sampling_ratio): for batch in loader: assert isinstance(batch, Data) - assert len(batch) == 5 + assert len(batch) == 6 assert batch.x.size(0) <= 100 assert batch.x.min() >= 0 and batch.x.max() < 100 + assert batch.input_links.numel() == 20 assert batch.edge_index.min() >= 0 assert batch.edge_index.max() < batch.num_nodes assert batch.edge_attr.min() >= 0 @@ -110,7 +111,7 @@ def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio): for batch in loader: assert isinstance(batch, HeteroData) - assert len(batch) == 5 + assert len(batch) == 6 if neg_sampling_ratio == 0.0: # Assert only positive samples are present in the original graph: assert batch['paper', 'author'].edge_label.sum() == 0 @@ -120,7 +121,6 @@ def test_heterogeneous_link_neighbor_loader(directed, neg_sampling_ratio): assert len(edge_index | edge_label_index) == len(edge_index) else: - assert batch['paper', 'author'].edge_label_index.size(1) == 40 assert torch.all(batch['paper', 'author'].edge_label[:20] == 1) assert torch.all(batch['paper', 'author'].edge_label[20:] == 0) @@ -195,7 +195,7 @@ def test_temporal_heterogeneous_link_neighbor_loader(): data['paper', 'author'].edge_index = get_edge_index(100, 200, 1000) data['author', 'paper'].edge_index = get_edge_index(200, 100, 1000) - with pytest.raises(ValueError, match=r"'edge_label_time' was not set.*"): + with pytest.raises(ValueError, match=r"'edge_label_time' is not set"): loader = LinkNeighborLoader( data, num_neighbors=[-1] * 2, @@ -312,7 +312,8 @@ def test_homogeneous_link_neighbor_loader_no_edges(): for batch in loader: assert isinstance(batch, Data) - assert len(batch) == 3 + assert len(batch) == 4 + assert batch.input_links.numel() == 20 assert batch.num_nodes <= 40 assert batch.edge_label_index.size(1) == 20 assert batch.num_nodes == batch.edge_label_index.unique().numel() @@ -328,8 +329,9 @@ def test_heterogeneous_link_neighbor_loader_no_edges(): for batch in loader: assert isinstance(batch, HeteroData) - assert len(batch) == 3 + assert len(batch) == 4 assert batch['paper'].num_nodes <= 40 + assert batch['paper', 'paper'].input_links.numel() == 20 assert batch['paper', 'paper'].edge_label_index.size(1) == 20 assert batch['paper'].num_nodes == batch[ 'paper', 'paper'].edge_label_index.unique().numel() diff --git a/test/loader/test_neighbor_loader.py b/test/loader/test_neighbor_loader.py index 3668a4ac9e1b..aff809e6188f 100644 --- a/test/loader/test_neighbor_loader.py +++ b/test/loader/test_neighbor_loader.py @@ -48,10 +48,9 @@ def test_homogeneous_neighbor_loader(directed): for batch in loader: assert isinstance(batch, Data) - - assert len(batch) == 4 + assert len(batch) == 5 assert batch.x.size(0) <= 100 - assert batch.batch_size == 20 + assert batch.input_nodes.numel() == batch.batch_size == 20 assert batch.x.min() >= 0 and batch.x.max() < 100 assert batch.edge_index.min() >= 0 assert batch.edge_index.max() < batch.num_nodes @@ -118,8 +117,9 @@ def test_heterogeneous_neighbor_loader(directed): # Test node type selection: assert set(batch.node_types) == {'paper', 'author'} - assert len(batch['paper']) == 2 + assert len(batch['paper']) == 3 assert batch['paper'].x.size(0) <= 100 + assert batch['paper'].input_nodes.numel() == batch_size assert batch['paper'].batch_size == batch_size assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100 @@ -498,7 +498,7 @@ def test_pyg_lib_heterogeneous_neighbor_loader(): 'author__to__paper': [-1, -1], } - sample = torch.ops.pyg.hetero_neighbor_sample_cpu + sample = torch.ops.pyg.hetero_neighbor_sample out1 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict, num_neighbors_dict, None, None, True, False, True, False, "uniform", True) diff --git a/torch_geometric/data/lightning_datamodule.py b/torch_geometric/data/lightning_datamodule.py index fce51019d0ef..ffdf369e948e 100644 --- a/torch_geometric/data/lightning_datamodule.py +++ b/torch_geometric/data/lightning_datamodule.py @@ -8,16 +8,10 @@ from torch_geometric.data import Data, Dataset, HeteroData from torch_geometric.data.feature_store import FeatureStore from torch_geometric.data.graph_store import GraphStore +from torch_geometric.loader import LinkNeighborLoader, NeighborLoader from torch_geometric.loader.dataloader import DataLoader -from torch_geometric.loader.link_neighbor_loader import ( - LinkNeighborLoader, - get_edge_label_index, -) -from torch_geometric.loader.neighbor_loader import ( - NeighborLoader, - NeighborSampler, - get_input_nodes, -) +from torch_geometric.loader.utils import get_edge_label_index, get_input_nodes +from torch_geometric.sampler import NeighborSampler from torch_geometric.typing import InputEdges, InputNodes try: diff --git a/torch_geometric/loader/hgt_loader.py b/torch_geometric/loader/hgt_loader.py index e41429e21b89..b078ae7c935f 100644 --- a/torch_geometric/loader/hgt_loader.py +++ b/torch_geometric/loader/hgt_loader.py @@ -104,16 +104,18 @@ def __init__( **kwargs, ): node_type, _ = get_input_nodes(data, input_nodes) - node_sampler = HGTSampler( + + hgt_sampler = HGTSampler( data, num_samples=num_samples, input_type=node_type, is_sorted=is_sorted, share_memory=kwargs.get('num_workers', 0) > 0, ) + super().__init__( data=data, - node_sampler=node_sampler, + node_sampler=hgt_sampler, input_nodes=input_nodes, transform=transform, filter_per_worker=filter_per_worker, diff --git a/torch_geometric/loader/link_loader.py b/torch_geometric/loader/link_loader.py index 157267b917e5..cf5088870627 100644 --- a/torch_geometric/loader/link_loader.py +++ b/torch_geometric/loader/link_loader.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Iterator, Tuple, Union +from typing import Any, Callable, Iterator, List, Tuple, Union import torch @@ -7,6 +7,7 @@ from torch_geometric.data.graph_store import GraphStore from torch_geometric.loader.base import DataLoaderIterator from torch_geometric.loader.utils import ( + InputData, filter_custom_store, filter_data, filter_hetero_data, @@ -89,53 +90,57 @@ def __init__( if 'collate_fn' in kwargs: del kwargs['collate_fn'] - self.data = data - - # Initialize sampler with keyword arguments: - # NOTE sampler is an attribute of 'DataLoader', so we use link_sampler - # here: - self.link_sampler = link_sampler - - # Store additional arguments: - self.edge_label = edge_label - self.edge_label_index = edge_label_index - self.edge_label_time = edge_label_time - self.transform = transform - self.filter_per_worker = filter_per_worker - self.neg_sampling_ratio = neg_sampling_ratio - - # Get input type, or None for homogeneous graphs: + # Get edge type (or `None` for homogeneous graphs): edge_type, edge_label_index = get_edge_label_index( data, edge_label_index) if edge_label is None: edge_label = torch.zeros(edge_label_index.size(1), device=edge_label_index.device) - self.input_type = edge_type - super().__init__( - Dataset(edge_label_index, edge_label, edge_label_time), - collate_fn=self.collate_fn, - **kwargs, + self.data = data + self.edge_type = edge_type + self.link_sampler = link_sampler + self.input_data = InputData(edge_label_index[0], edge_label_index[1], + edge_label, edge_label_time) + self.neg_sampling_ratio = neg_sampling_ratio + self.transform = transform + self.filter_per_worker = filter_per_worker + + iterator = range(edge_label_index.size(1)) + super().__init__(iterator, collate_fn=self.collate_fn, **kwargs) + + def collate_fn(self, index: List[int]) -> Any: + r"""Samples a subgraph from a batch of input nodes.""" + input_data: EdgeSamplerInput = self.input_data[index] + out = self.link_sampler.sample_from_edges( + input_data, + negative_sampling_ratio=self.neg_sampling_ratio, ) + if self.filter_per_worker: # Execute `filter_fn` in the worker process + out = self.filter_fn(out) + + return out + def filter_fn( self, out: Union[SamplerOutput, HeteroSamplerOutput], ) -> Union[Data, HeteroData]: r"""Joins the sampled nodes with their corresponding features, - returning the resulting (Data or HeteroData) object to be used - downstream.""" + returning the resulting :class:`~torch_geometric.data.Data` or + :class:`~torch_geometric.data.HeteroData` object to be used downstream. + """ if isinstance(out, SamplerOutput): - edge_label_index, edge_label, edge_label_time = out.metadata data = filter_data(self.data, out.node, out.row, out.col, out.edge, self.link_sampler.edge_permutation) + data.batch = out.batch - data.edge_label_index = edge_label_index - data.edge_label = edge_label - data.edge_label_time = edge_label_time + data.input_links = out.metadata[0] + data.edge_label_index = out.metadata[1] + data.edge_label = out.metadata[2] + data.edge_label_time = out.metadata[3] elif isinstance(out, HeteroSamplerOutput): - edge_label_index, edge_label, edge_label_time = out.metadata if isinstance(self.data, HeteroData): data = filter_hetero_data(self.data, out.node, out.row, out.col, out.edge, @@ -144,13 +149,12 @@ def filter_fn( data = filter_custom_store(*self.data, out.node, out.row, out.col, out.edge) - edge_type = self.input_type for key, batch in (out.batch or {}).items(): data[key].batch = batch - data[edge_type].edge_label_index = edge_label_index - data[edge_type].edge_label = edge_label - if edge_label_time is not None: - data[edge_type].edge_label_time = edge_label_time + data[self.edge_type].input_links = out.metadata[0] + data[self.edge_type].edge_label_index = out.metadata[1] + data[self.edge_type].edge_label = out.metadata[2] + data[self.edge_type].edge_label_time = out.metadata[3] else: raise TypeError(f"'{self.__class__.__name__}'' found invalid " @@ -158,61 +162,12 @@ def filter_fn( return data if self.transform is None else self.transform(data) - def collate_fn(self, index: EdgeSamplerInput) -> Any: - r"""Samples a subgraph from a batch of input nodes.""" - out = self.link_sampler.sample_from_edges( - index, - negative_sampling_ratio=self.neg_sampling_ratio, - ) - if self.filter_per_worker: - # We execute `filter_fn` in the worker process. - out = self.filter_fn(out) - return out - def _get_iterator(self) -> Iterator: if self.filter_per_worker: return super()._get_iterator() - # We execute `filter_fn` in the main process. + + # Execute `filter_fn` in the main process: return DataLoaderIterator(super()._get_iterator(), self.filter_fn) def __repr__(self) -> str: return f'{self.__class__.__name__}()' - - -############################################################################### - - -class Dataset(torch.utils.data.Dataset): - def __init__( - self, - edge_label_index: torch.Tensor, - edge_label: torch.Tensor, - edge_label_time: OptTensor = None, - ): - # NOTE see documentation of LinkLoader for details on these three - # input parameters: - self.edge_label_index = edge_label_index - self.edge_label = edge_label - self.edge_label_time = edge_label_time - - def __getitem__( - self, - idx: int, - ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: - if self.edge_label_time is None: - return ( - self.edge_label_index[0, idx], - self.edge_label_index[1, idx], - self.edge_label[idx], - ) - else: - return ( - self.edge_label_index[0, idx], - self.edge_label_index[1, idx], - self.edge_label[idx], - self.edge_label_time[idx], - ) - - def __len__(self) -> int: - return self.edge_label_index.size(1) diff --git a/torch_geometric/loader/link_neighbor_loader.py b/torch_geometric/loader/link_neighbor_loader.py index d37bf5353eeb..bb883ddd0de1 100644 --- a/torch_geometric/loader/link_neighbor_loader.py +++ b/torch_geometric/loader/link_neighbor_loader.py @@ -166,21 +166,16 @@ def __init__( neighbor_sampler: Optional[NeighborSampler] = None, **kwargs, ): - # Get input type: - # TODO(manan): this computation is required twice, once here and once - # in LinkLoader: + # TODO(manan): Avoid duplicated computation (here and in NodeLoader): edge_type, _ = get_edge_label_index(data, edge_label_index) - has_time_attr = time_attr is not None - has_edge_label_time = edge_label_time is not None - if has_edge_label_time != has_time_attr: + if (edge_label_time is not None) != (time_attr is not None): raise ValueError( - f"Received conflicting 'time_attr' and 'edge_label_time' " - f"arguments: 'time_attr' was " - f"{'set' if has_time_attr else 'not set'} and " - f"'edge_label_time' was " - f"{'set' if has_edge_label_time else 'not set'}. Please " - f"resolve these conflicting arguments.") + f"Received conflicting 'edge_label_time' and 'time_attr' " + f"arguments: 'edge_label_time' is " + f"{'set' if edge_label_time is not None else 'not set'} " + f"while 'input_time' is " + f"{'set' if time_attr is not None else 'not set'}.") if neighbor_sampler is None: neighbor_sampler = NeighborSampler( diff --git a/torch_geometric/loader/neighbor_loader.py b/torch_geometric/loader/neighbor_loader.py index 5d3b05457cda..94cd9b8a633c 100644 --- a/torch_geometric/loader/neighbor_loader.py +++ b/torch_geometric/loader/neighbor_loader.py @@ -6,7 +6,7 @@ from torch_geometric.loader.node_loader import NodeLoader from torch_geometric.loader.utils import get_input_nodes from torch_geometric.sampler import NeighborSampler -from torch_geometric.typing import InputNodes, NumNeighbors +from torch_geometric.typing import InputNodes, NumNeighbors, OptTensor class NeighborLoader(NodeLoader): @@ -122,6 +122,11 @@ class NeighborLoader(NodeLoader): If set to :obj:`None`, all nodes will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the node type and node indices. (default: :obj:`None`) + input_time (torch.Tensor, optional): Optional values to override the + timestamp for the input nodes given in :obj:`input_nodes`. If not + set, will use the timestamps in :obj:`time_attr` as default (if + present). The :obj:`time_attr` needs to be set for this to work. + (default: :obj:`None`) replace (bool, optional): If set to :obj:`True`, will sample with replacement. (default: :obj:`False`) directed (bool, optional): If set to :obj:`False`, will include all @@ -164,6 +169,7 @@ def __init__( data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], num_neighbors: NumNeighbors, input_nodes: InputNodes = None, + input_time: OptTensor = None, replace: bool = False, directed: bool = True, temporal_strategy: str = 'uniform', @@ -174,11 +180,14 @@ def __init__( neighbor_sampler: Optional[NeighborSampler] = None, **kwargs, ): - # Get input type: - # TODO(manan): this computation is repeated twice, once here and once - # in NodeLoader: + # TODO(manan): Avoid duplicated computation (here and in NodeLoader): node_type, _ = get_input_nodes(data, input_nodes) + if input_time is not None and time_attr is None: + raise ValueError("Received conflicting 'input_time' and " + "'time_attr' arguments: 'input_time' is set " + "while 'time_attr' is not set.") + if neighbor_sampler is None: neighbor_sampler = NeighborSampler( data, @@ -192,12 +201,11 @@ def __init__( share_memory=kwargs.get('num_workers', 0) > 0, ) - # A NeighborLoader is simply a NodeLoader that uses the NeighborSampler - # sampling implementation: super().__init__( data=data, node_sampler=neighbor_sampler, input_nodes=input_nodes, + input_time=input_time, transform=transform, filter_per_worker=filter_per_worker, **kwargs, diff --git a/torch_geometric/loader/node_loader.py b/torch_geometric/loader/node_loader.py index a0d94a000f97..e0b2916afa06 100644 --- a/torch_geometric/loader/node_loader.py +++ b/torch_geometric/loader/node_loader.py @@ -7,6 +7,7 @@ from torch_geometric.data.graph_store import GraphStore from torch_geometric.loader.base import DataLoaderIterator from torch_geometric.loader.utils import ( + InputData, filter_custom_store, filter_data, filter_hetero_data, @@ -18,7 +19,7 @@ NodeSamplerInput, SamplerOutput, ) -from torch_geometric.typing import InputNodes +from torch_geometric.typing import InputNodes, OptTensor class NodeLoader(torch.utils.data.DataLoader): @@ -43,6 +44,11 @@ class NodeLoader(torch.utils.data.DataLoader): If set to :obj:`None`, all nodes will be considered. In heterogeneous graphs, needs to be passed as a tuple that holds the node type and node indices. (default: :obj:`None`) + input_time (torch.Tensor, optional): Optional values to override the + timestamp for the input nodes given in :obj:`input_nodes`. If not + set, will use the timestamps in :obj:`time_attr` as default (if + present). The :obj:`time_attr` needs to be set for this to work. + (default: :obj:`None`) transform (Callable, optional): A function/transform that takes in a sampled mini-batch and returns a transformed version. (default: :obj:`None`) @@ -63,6 +69,7 @@ def __init__( data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]], node_sampler: BaseSampler, input_nodes: InputNodes = None, + input_time: OptTensor = None, transform: Callable = None, filter_per_worker: bool = False, **kwargs, @@ -73,35 +80,44 @@ def __init__( if 'collate_fn' in kwargs: del kwargs['collate_fn'] - self.data = data + # Get node type (or `None` for homogeneous graphs): + node_type, input_nodes = get_input_nodes(data, input_nodes) - # NOTE sampler is an attribute of 'DataLoader', so we use node_sampler - # here: + self.data = data + self.node_type = node_type self.node_sampler = node_sampler - - # Store additional arguments: - self.input_nodes = input_nodes + self.input_data = InputData(input_nodes, input_time) self.transform = transform self.filter_per_worker = filter_per_worker - # Get input type, or None for homogeneous graphs: - node_type, input_nodes = get_input_nodes(self.data, input_nodes) - self.input_type = node_type + iterator = range(input_nodes.size(0)) + super().__init__(iterator, collate_fn=self.collate_fn, **kwargs) + + def collate_fn(self, index: NodeSamplerInput) -> Any: + r"""Samples a subgraph from a batch of input nodes.""" + input_data: NodeSamplerInput = self.input_data[index] - super().__init__(input_nodes, collate_fn=self.collate_fn, **kwargs) + out = self.node_sampler.sample_from_nodes(input_data) + + if self.filter_per_worker: # Execute `filter_fn` in the worker process + out = self.filter_fn(out) + + return out def filter_fn( self, out: Union[SamplerOutput, HeteroSamplerOutput], ) -> Union[Data, HeteroData]: r"""Joins the sampled nodes with their corresponding features, - returning the resulting (Data or HeteroData) object to be used - downstream.""" + returning the resulting :class:`~torch_geometric.data.Data` or + :class:`~torch_geometric.data.HeteroData` object to be used downstream. + """ if isinstance(out, SamplerOutput): data = filter_data(self.data, out.node, out.row, out.col, out.edge, self.node_sampler.edge_permutation) data.batch = out.batch - data.batch_size = out.metadata + data.input_nodes = out.metadata + data.batch_size = out.metadata.size(0) elif isinstance(out, HeteroSamplerOutput): if isinstance(self.data, HeteroData): @@ -114,7 +130,8 @@ def filter_fn( for key, batch in (out.batch or {}).items(): data[key].batch = batch - data[self.input_type].batch_size = out.metadata + data[self.node_type].input_nodes = out.metadata + data[self.node_type].batch_size = out.metadata.size(0) else: raise TypeError(f"'{self.__class__.__name__}'' found invalid " @@ -122,21 +139,11 @@ def filter_fn( return data if self.transform is None else self.transform(data) - def collate_fn(self, index: NodeSamplerInput) -> Any: - r"""Samples a subgraph from a batch of input nodes.""" - if isinstance(index, (list, tuple)): - index = torch.tensor(index) - - out = self.node_sampler.sample_from_nodes(index) - if self.filter_per_worker: - # We execute `filter_fn` in the worker process. - out = self.filter_fn(out) - return out - def _get_iterator(self) -> Iterator: if self.filter_per_worker: return super()._get_iterator() - # We execute `filter_fn` in the main process. + + # Execute `filter_fn` in the main process: return DataLoaderIterator(super()._get_iterator(), self.filter_fn) def __repr__(self) -> str: diff --git a/torch_geometric/loader/utils.py b/torch_geometric/loader/utils.py index 7306ec7d507a..bc21424f4756 100644 --- a/torch_geometric/loader/utils.py +++ b/torch_geometric/loader/utils.py @@ -1,7 +1,7 @@ import copy import math from collections.abc import Sequence -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -20,6 +20,20 @@ ) +class InputData: + def __init__(self, *args): + self.args = args + + def __getitem__(self, index: Union[Tensor, List[int]]) -> Any: + if not isinstance(index, Tensor): + index = torch.tensor(index, dtype=torch.long) + + outs = [index] + for arg in self.args: + outs.append(arg[index] if arg is not None else None) + return tuple(outs) + + def index_select(value: FeatureTensorType, index: Tensor, dim: int = 0) -> Tensor: if isinstance(value, Tensor): @@ -192,18 +206,20 @@ def get_input_nodes( def to_index(tensor): if isinstance(tensor, Tensor) and tensor.dtype == torch.bool: return tensor.nonzero(as_tuple=False).view(-1) + if not isinstance(tensor, Tensor): + return torch.tensor(tensor, dtype=torch.long) return tensor if isinstance(data, Data): if input_nodes is None: - return None, range(data.num_nodes) + return None, torch.arange(data.num_nodes) return None, to_index(input_nodes) elif isinstance(data, HeteroData): assert input_nodes is not None if isinstance(input_nodes, str): - return input_nodes, range(data[input_nodes].num_nodes) + return input_nodes, torch.arange(data[input_nodes].num_nodes) assert isinstance(input_nodes, (list, tuple)) assert len(input_nodes) == 2 @@ -211,7 +227,7 @@ def to_index(tensor): node_type, input_nodes = input_nodes if input_nodes is None: - return node_type, range(data[node_type].num_nodes) + return node_type, torch.arange(data[node_type].num_nodes) return node_type, to_index(input_nodes) else: # Tuple[FeatureStore, GraphStore] @@ -222,7 +238,7 @@ def to_index(tensor): return None, to_index(input_nodes) if isinstance(input_nodes, str): - return input_nodes, range( + return input_nodes, torch.arange( remote_backend_utils.num_nodes(feature_store, graph_store, input_nodes)) @@ -232,7 +248,7 @@ def to_index(tensor): node_type, input_nodes = input_nodes if input_nodes is None: - return node_type, range( + return node_type, torch.arange( remote_backend_utils.num_nodes(feature_store, graph_store, input_nodes)) return node_type, to_index(input_nodes) diff --git a/torch_geometric/sampler/base.py b/torch_geometric/sampler/base.py index a2081423ea32..e5f1d02e9cc7 100644 --- a/torch_geometric/sampler/base.py +++ b/torch_geometric/sampler/base.py @@ -2,19 +2,23 @@ from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union -import torch +from torch import Tensor from torch_geometric.typing import EdgeType, NodeType, OptTensor -# An input to a node-based sampler is a tensor of node indices: -NodeSamplerInput = torch.Tensor +# An input to a node-based sampler consists of two tensors: +# * The example indices +# * The node indices +# * The timestamps of the given node indices (optional) +NodeSamplerInput = Tuple[Tensor, Tensor, OptTensor] # An input to an edge-based sampler consists of four tensors: +# * The example indices # * The row of the edge index in COO format # * The column of the edge index in COO format # * The labels of the edges -# * (Optionally) the time attribute corresponding to the edge label -EdgeSamplerInput = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, OptTensor] +# * The time attribute corresponding to the edge label (optional) +EdgeSamplerInput = Tuple[Tensor, Tensor, Tensor, Tensor, OptTensor] # A sampler output contains the following information. @@ -40,11 +44,11 @@ # There exist both homogeneous and heterogeneous versions. @dataclass class SamplerOutput: - node: torch.Tensor - row: torch.Tensor - col: torch.Tensor - edge: torch.Tensor - batch: Optional[torch.Tensor] = None + node: Tensor + row: Tensor + col: Tensor + edge: Tensor + batch: OptTensor = None # TODO(manan): refine this further; it does not currently define a proper # API for the expected output of a sampler. metadata: Optional[Any] = None @@ -52,11 +56,11 @@ class SamplerOutput: @dataclass class HeteroSamplerOutput: - node: Dict[NodeType, torch.Tensor] - row: Dict[EdgeType, torch.Tensor] - col: Dict[EdgeType, torch.Tensor] - edge: Dict[EdgeType, torch.Tensor] - batch: Optional[Dict[NodeType, torch.Tensor]] = None + node: Dict[NodeType, Tensor] + row: Dict[EdgeType, Tensor] + col: Dict[EdgeType, Tensor] + edge: Dict[EdgeType, Tensor] + batch: Optional[Dict[NodeType, Tensor]] = None # TODO(manan): refine this further; it does not currently define a proper # API for the expected output of a sampler. metadata: Optional[Any] = None diff --git a/torch_geometric/sampler/hgt_sampler.py b/torch_geometric/sampler/hgt_sampler.py index 661a2d517c3d..e72b669bc345 100644 --- a/torch_geometric/sampler/hgt_sampler.py +++ b/torch_geometric/sampler/hgt_sampler.py @@ -60,7 +60,8 @@ def sample_from_nodes( index: NodeSamplerInput, **kwargs, ) -> HeteroSamplerOutput: - input_node_dict = {self.input_type: torch.tensor(index)} + index, input_nodes, _ = index + input_node_dict = {self.input_type: input_nodes} sample_fn = torch.ops.torch_sparse.hgt_sample out = sample_fn( self.colptr_dict, @@ -76,7 +77,7 @@ def sample_from_nodes( col=remap_keys(col, self.to_edge_type), edge=remap_keys(edge, self.to_edge_type), batch=batch, - metadata=len(index), + metadata=index, ) def sample_from_edges( diff --git a/torch_geometric/sampler/neighbor_sampler.py b/torch_geometric/sampler/neighbor_sampler.py index 8056800bf3fe..9b904a8af336 100644 --- a/torch_geometric/sampler/neighbor_sampler.py +++ b/torch_geometric/sampler/neighbor_sampler.py @@ -220,9 +220,7 @@ def _sample( Note that the 'metadata' field of the output is not filled; it is the job of the caller to appropriately fill out this field for downstream loaders.""" - - # TODO(manan): remote backends only support heterogeneous graphs for - # now: + # TODO(manan): remote backends only support heterogeneous graphs: if self.data_cls == 'custom' or issubclass(self.data_cls, HeteroData): if _WITH_PYG_LIB: # TODO (matthias) Add `disjoint` option to `NeighborSampler` @@ -328,18 +326,19 @@ def sample_from_nodes( index: NodeSamplerInput, **kwargs, ) -> Union[SamplerOutput, HeteroSamplerOutput]: - if isinstance(index, (list, tuple)): - index = torch.tensor(index) + index, input_nodes, input_time = index - # Tuple[FeatureStore, GraphStore] currently only supports heterogeneous - # sampling: if self.data_cls == 'custom' or issubclass(self.data_cls, HeteroData): - output = self._sample(seed={self.input_type: index}) - output.metadata = index.numel() + seed_time_dict = None + if input_time is not None: + seed_time_dict = {self.input_type: input_time} + output = self._sample(seed={self.input_type: input_nodes}, + seed_time_dict=seed_time_dict) + output.metadata = index elif issubclass(self.data_cls, Data): - output = self._sample(seed=index) - output.metadata = index.numel() + output = self._sample(seed=input_nodes, seed_time=input_time) + output.metadata = index else: raise TypeError(f"'{self.__class__.__name__}'' found invalid " @@ -354,11 +353,9 @@ def sample_from_edges( index: EdgeSamplerInput, **kwargs, ) -> Union[SamplerOutput, HeteroSamplerOutput]: + index, row, col, edge_label, edge_label_time = index + edge_label_index = torch.stack([row, col], dim=0) negative_sampling_ratio = kwargs.get('negative_sampling_ratio', 0.0) - query = [torch.stack(s, dim=0) for s in zip(*index)] - edge_label_index = torch.stack(query[:2], dim=0) - edge_label = query[2] - edge_label_time = query[3] if len(query) == 4 else None out = add_negative_samples(edge_label_index, edge_label, edge_label_time, self.num_src_nodes, @@ -424,7 +421,8 @@ def sample_from_edges( for key, batch in output.batch.items(): output.batch[key] = batch % num_seed_edges - output.metadata = (edge_label_index, edge_label, edge_label_time) + output.metadata = (index, edge_label_index, edge_label, + edge_label_time) elif issubclass(self.data_cls, Data): if self.disjoint_sampling: @@ -444,7 +442,8 @@ def sample_from_edges( if self.disjoint_sampling: output.batch = output.batch % num_seed_edges - output.metadata = (edge_label_index, edge_label, edge_label_time) + output.metadata = (index, edge_label_index, edge_label, + edge_label_time) else: raise TypeError(f"'{self.__class__.__name__}'' found invalid "