Skip to content

Commit

Permalink
[Feature] Add custom dataset and remove dependency on PyG (#174)
Browse files Browse the repository at this point in the history
* Reformat code

* Add custom dataset and remove dependency on PyG of stpgnn

* Fix a bug in experiment
  • Loading branch information
THINK2TRY authored Jan 24, 2021
1 parent 6453531 commit 282ac1f
Show file tree
Hide file tree
Showing 28 changed files with 481 additions and 199 deletions.
9 changes: 2 additions & 7 deletions cogdl/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
from .data import Data
from .batch import Batch
from .dataset import Dataset
from .dataset import Dataset, MultiGraphDataset
from .dataloader import DataLoader

__all__ = [
"Data",
"Batch",
"Dataset",
"DataLoader",
]
__all__ = ["Data", "Batch", "Dataset", "DataLoader", "MultiGraphDataset"]
3 changes: 3 additions & 0 deletions cogdl/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def __inc__(self, key, value):
# creating batches.
return self.num_nodes if bool(re.search("(index|face)", key)) else 0

def __cat_dim__(self, key, value):
return self.cat_dim(key, value)

@property
def num_edges(self):
r"""Returns the number of edges in the graph."""
Expand Down
21 changes: 19 additions & 2 deletions cogdl/data/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch.utils.data
from torch.utils.data.dataloader import default_collate

from cogdl.data import Batch
from cogdl.data import Batch, Data


class DataLoader(torch.utils.data.DataLoader):
Expand All @@ -18,5 +18,22 @@ class DataLoader(torch.utils.data.DataLoader):

def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
super(DataLoader, self).__init__(
dataset, batch_size, shuffle, collate_fn=lambda data_list: Batch.from_data_list(data_list), **kwargs
# dataset, batch_size, shuffle, collate_fn=lambda data_list: Batch.from_data_list(data_list), **kwargs
dataset,
batch_size,
shuffle,
collate_fn=self.collate_fn,
**kwargs,
)

@staticmethod
def collate_fn(batch):
item = batch[0]
if isinstance(item, Data):
return Batch.from_data_list(batch)
elif isinstance(item, torch.Tensor):
return default_collate(batch)
elif isinstance(item, float):
return torch.tensor(batch, dtype=torch.float)

raise TypeError("DataLoader found invalid type: {}".format(type(item)))
95 changes: 95 additions & 0 deletions cogdl/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import collections
import copy
import os.path as osp
from itertools import repeat, product

import torch.utils.data

Expand Down Expand Up @@ -130,5 +132,98 @@ def __getitem__(self, idx): # pragma: no cover
data = data if self.transform is None else self.transform(data)
return data

@property
def num_classes(self):
r"""The number of classes in the dataset."""
y = self.data.y
return y.max().item() + 1 if y.dim() == 1 else y.size(1)

def __repr__(self): # pragma: no cover
return "{}({})".format(self.__class__.__name__, len(self))


class MultiGraphDataset(Dataset):
def __init__(self, root=None, transform=None, pre_transform=None, pre_filter=None):
super(MultiGraphDataset, self).__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = None, None

@property
def num_classes(self):
r"""The number of classes in the dataset."""
y = self.data.y
return y.max().item() + 1 if y.dim() == 1 else y.size(1)

def len(self):
for item in self.slices.values():
return len(item) - 1
return 0

def _get(self, idx):
data = self.data.__class__()

if hasattr(self.data, "__num_nodes__"):
data.num_nodes = self.data.__num_nodes__[idx]

for key in self.data.keys:
item, slices = self.data[key], self.slices[key]
start, end = slices[idx].item(), slices[idx + 1].item()
if torch.is_tensor(item):
s = list(repeat(slice(None), item.dim()))
s[self.data.__cat_dim__(key, item)] = slice(start, end)
elif start + 1 == end:
s = slices[start]
else:
s = slice(start, end)
data[key] = item[s]
return data

def get(self, idx):
if isinstance(idx, int) or (len(idx) == 0):
return self._get(idx)
elif len(idx) > 1:
data_list = [self._get(i) for i in idx]
data, slices = self.from_data_list(data_list)
dataset = copy.copy(self)
dataset.data = data
dataset.slices = slices
return dataset

@staticmethod
def from_data_list(data_list):
r""" Borrowed from PyG"""

keys = data_list[0].keys
data = data_list[0].__class__()

for key in keys:
data[key] = []
slices = {key: [0] for key in keys}

for item, key in product(data_list, keys):
data[key].append(item[key])
if torch.is_tensor(item[key]):
s = slices[key][-1] + item[key].size(item.__cat_dim__(key, item[key]))
else:
s = slices[key][-1] + 1
slices[key].append(s)

if hasattr(data_list[0], "__num_nodes__"):
data.__num_nodes__ = []
for item in data_list:
data.__num_nodes__.append(item.num_nodes)

for key in keys:
item = data_list[0][key]
if torch.is_tensor(item):
data[key] = torch.cat(data[key], dim=data.__cat_dim__(key, item))
elif isinstance(item, int) or isinstance(item, float):
data[key] = torch.tensor(data[key])

slices[key] = torch.tensor(slices[key], dtype=torch.long)

return data, slices

def __len__(self):
for item in self.slices.values():
return len(item) - 1
return 0
24 changes: 19 additions & 5 deletions cogdl/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import importlib

from cogdl.data.dataset import Dataset
from .customizezd_data import CustomizedGraphClassificationDataset, CustomizedNodeClassificationDataset, BaseDataset

try:
import torch_geometric
Expand Down Expand Up @@ -51,8 +52,12 @@ def try_import_dataset(dataset):

def build_dataset(args):
if not try_import_dataset(args.dataset):
assert hasattr(args, "task")
dataset = build_dataset_from_path(args.dataset, args.task)
if dataset is not None:
return dataset
exit(1)
return DATASET_REGISTRY[args.dataset](args=args)
return DATASET_REGISTRY[args.dataset]()


def build_dataset_from_name(dataset):
Expand All @@ -61,6 +66,15 @@ def build_dataset_from_name(dataset):
return DATASET_REGISTRY[dataset]()


def build_dataset_from_path(data_path, task):
if "node_classification" in task:
return CustomizedNodeClassificationDataset(data_path)
elif "graph_classification" in task:
return CustomizedGraphClassificationDataset(data_path)
else:
return None


SUPPORTED_DATASETS = {
"kdd_icdm": "cogdl.datasets.gcc_data",
"sigir_cikm": "cogdl.datasets.gcc_data",
Expand Down Expand Up @@ -117,8 +131,8 @@ def build_dataset_from_name(dataset):
"reddit": "cogdl.datasets.saint_data",
"test_bio": "cogdl.datasets.pyg_strategies_data",
"test_chem": "cogdl.datasets.pyg_strategies_data",
"bio": "cogdl.datasets.pyg_strategies_data",
"chem": "cogdl.datasets.pyg_strategies_data",
"bace": "cogdl.datasets.pyg_strategies_data",
"bbbp": "cogdl.datasets.pyg_strategies_data",
"bio": "cogdl.datasets.strategies_data",
"chem": "cogdl.datasets.strategies_data",
"bace": "cogdl.datasets.strategies_data",
"bbbp": "cogdl.datasets.strategies_data",
}
102 changes: 102 additions & 0 deletions cogdl/datasets/customizezd_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch

from cogdl.data import Dataset, MultiGraphDataset, Batch
from cogdl.utils import accuracy_evaluator, download_url, multiclass_evaluator, multilabel_evaluator


def _get_evaluator(metric):
if metric == "accuracy":
return accuracy_evaluator()
elif metric == "multilabel_f1":
return multilabel_evaluator()
elif metric == "multiclass_f1":
return multiclass_evaluator()
else:
raise NotImplementedError


class BaseDataset(Dataset):
def __init__(self):
super(BaseDataset, self).__init__("custom")

def process(self):
pass

def _download(self):
pass

def _process(self):
pass

def get(self, idx):
return self.data


class CustomizedNodeClassificationDataset(Dataset):
"""
data_path : path to load dataset. The dataset must be processed to specific format
metric: Accuracy, multi-label f1 or multi-class f1. Default: `accuracy`
"""

def __init__(self, data_path, metric="accuracy"):
super(CustomizedNodeClassificationDataset, self).__init__(root=data_path)
try:
self.data = torch.load(data_path)
except Exception as e:
print(e)
exit(1)
self.metric = metric

def download(self):
for name in self.raw_file_names:
download_url("{}{}&dl=1".format(self.url, name), self.raw_dir, name=name)

def process(self):
pass

def get(self, idx):
assert idx == 0
return self.data

def get_evaluator(self):
return _get_evaluator(self.metric)

def __repr__(self):
return "{}()".format(self.name)

def _download(self):
pass

def _process(self):
pass


class CustomizedGraphClassificationDataset(MultiGraphDataset):
def __init__(self, data_path):
super(CustomizedGraphClassificationDataset, self).__init__(root=data_path)
try:
data = torch.load(data_path)
if isinstance(data, list):
batch = Batch.from_data_list(data)
self.data = batch
self.slices = batch.__slices__
del self.data.batch
else:
assert len(data) == 0
self.data = data[0]
self.slices = data[1]
except Exception as e:
print(e)
exit(1)

def get_evaluator(self):
return _get_evaluator(self.metric)

def __repr__(self):
return "{}()".format(self.name)

def _download(self):
pass

def _process(self):
pass
6 changes: 3 additions & 3 deletions cogdl/datasets/gatne.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,23 @@ def __repr__(self):

@register_dataset("amazon")
class AmazonDataset(GatneDataset):
def __init__(self, args=None):
def __init__(self):
dataset = "amazon"
path = osp.join("data", dataset)
super(AmazonDataset, self).__init__(path, dataset)


@register_dataset("twitter")
class TwitterDataset(GatneDataset):
def __init__(self, args=None):
def __init__(self):
dataset = "twitter"
path = osp.join("data", dataset)
super(TwitterDataset, self).__init__(path, dataset)


@register_dataset("youtube")
class YouTubeDataset(GatneDataset):
def __init__(self, args=None):
def __init__(self):
dataset = "youtube"
path = osp.join("data", dataset)
super(YouTubeDataset, self).__init__(path, dataset)
8 changes: 4 additions & 4 deletions cogdl/datasets/gcc_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,31 +161,31 @@ def process(self):

@register_dataset("kdd_icdm")
class KDD_ICDM_GCCDataset(GCCDataset):
def __init__(self, args=None):
def __init__(self):
dataset = "kdd_icdm"
path = osp.join("data", dataset)
super(KDD_ICDM_GCCDataset, self).__init__(path, dataset)


@register_dataset("sigir_cikm")
class SIGIR_CIKM_GCCDataset(GCCDataset):
def __init__(self, args=None):
def __init__(self):
dataset = "sigir_cikm"
path = osp.join("data", dataset)
super(SIGIR_CIKM_GCCDataset, self).__init__(path, dataset)


@register_dataset("sigmod_icde")
class SIGMOD_ICDE_GCCDataset(GCCDataset):
def __init__(self, args=None):
def __init__(self):
dataset = "sigmod_icde"
path = osp.join("data", dataset)
super(SIGMOD_ICDE_GCCDataset, self).__init__(path, dataset)


@register_dataset("usa-airport")
class USAAirportDataset(Edgelist):
def __init__(self, args=None):
def __init__(self):
dataset = "usa-airport"
path = osp.join("data", dataset)
super(USAAirportDataset, self).__init__(path, dataset)
Loading

0 comments on commit 282ac1f

Please sign in to comment.