From 282ac1ff6b1dd0d95e50300e41b07dd762cba730 Mon Sep 17 00:00:00 2001 From: THINK2TRY <48679723+THINK2TRY@users.noreply.github.com> Date: Sun, 24 Jan 2021 18:20:50 +0800 Subject: [PATCH] [Feature] Add custom dataset and remove dependency on PyG (#174) * Reformat code * Add custom dataset and remove dependency on PyG of stpgnn * Fix a bug in experiment --- cogdl/data/__init__.py | 9 +- cogdl/data/data.py | 3 + cogdl/data/dataloader.py | 21 +++- cogdl/data/dataset.py | 95 ++++++++++++++++ cogdl/datasets/__init__.py | 24 ++++- cogdl/datasets/customizezd_data.py | 102 ++++++++++++++++++ cogdl/datasets/gatne.py | 6 +- cogdl/datasets/gcc_data.py | 8 +- cogdl/datasets/gtn_data.py | 11 +- cogdl/datasets/han_data.py | 11 +- cogdl/datasets/kg_data.py | 12 +-- cogdl/datasets/matlab_matrix.py | 8 +- cogdl/datasets/planetoid_data.py | 6 +- cogdl/datasets/pyg_ogb.py | 20 ++-- cogdl/datasets/saint_data.py | 8 +- ..._strategies_data.py => strategies_data.py} | 83 +++++--------- cogdl/datasets/test_data.py | 37 +++++-- cogdl/datasets/tu_data.py | 24 ++--- cogdl/experiments.py | 7 +- cogdl/layers/strategies_layers.py | 6 +- cogdl/models/__init__.py | 2 +- cogdl/models/emb/dgk.py | 1 - cogdl/models/nn/__init__.py | 31 ++++++ cogdl/models/nn/{pyg_stpgnn.py => stpgnn.py} | 0 cogdl/utils/utils.py | 2 +- examples/custom_dataset.py | 79 +++++--------- tests/datasets/test_customized_data.py | 48 +++++++++ tests/tasks/test_pretrain.py | 16 +-- 28 files changed, 481 insertions(+), 199 deletions(-) create mode 100644 cogdl/datasets/customizezd_data.py rename cogdl/datasets/{pyg_strategies_data.py => strategies_data.py} (94%) rename cogdl/models/nn/{pyg_stpgnn.py => stpgnn.py} (100%) create mode 100644 tests/datasets/test_customized_data.py diff --git a/cogdl/data/__init__.py b/cogdl/data/__init__.py index 2419f173..5f2bdcae 100644 --- a/cogdl/data/__init__.py +++ b/cogdl/data/__init__.py @@ -1,11 +1,6 @@ from .data import Data from .batch import Batch -from .dataset import Dataset +from .dataset import Dataset, MultiGraphDataset from .dataloader import DataLoader -__all__ = [ - "Data", - "Batch", - "Dataset", - "DataLoader", -] +__all__ = ["Data", "Batch", "Dataset", "DataLoader", "MultiGraphDataset"] diff --git a/cogdl/data/data.py b/cogdl/data/data.py index 9ef3e9fb..e489064e 100644 --- a/cogdl/data/data.py +++ b/cogdl/data/data.py @@ -113,6 +113,9 @@ def __inc__(self, key, value): # creating batches. return self.num_nodes if bool(re.search("(index|face)", key)) else 0 + def __cat_dim__(self, key, value): + return self.cat_dim(key, value) + @property def num_edges(self): r"""Returns the number of edges in the graph.""" diff --git a/cogdl/data/dataloader.py b/cogdl/data/dataloader.py index 8e9a0451..795260ca 100644 --- a/cogdl/data/dataloader.py +++ b/cogdl/data/dataloader.py @@ -1,7 +1,7 @@ import torch.utils.data from torch.utils.data.dataloader import default_collate -from cogdl.data import Batch +from cogdl.data import Batch, Data class DataLoader(torch.utils.data.DataLoader): @@ -18,5 +18,22 @@ class DataLoader(torch.utils.data.DataLoader): def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): super(DataLoader, self).__init__( - dataset, batch_size, shuffle, collate_fn=lambda data_list: Batch.from_data_list(data_list), **kwargs + # dataset, batch_size, shuffle, collate_fn=lambda data_list: Batch.from_data_list(data_list), **kwargs + dataset, + batch_size, + shuffle, + collate_fn=self.collate_fn, + **kwargs, ) + + @staticmethod + def collate_fn(batch): + item = batch[0] + if isinstance(item, Data): + return Batch.from_data_list(batch) + elif isinstance(item, torch.Tensor): + return default_collate(batch) + elif isinstance(item, float): + return torch.tensor(batch, dtype=torch.float) + + raise TypeError("DataLoader found invalid type: {}".format(type(item))) diff --git a/cogdl/data/dataset.py b/cogdl/data/dataset.py index b2f51497..7b329649 100644 --- a/cogdl/data/dataset.py +++ b/cogdl/data/dataset.py @@ -1,5 +1,7 @@ import collections +import copy import os.path as osp +from itertools import repeat, product import torch.utils.data @@ -130,5 +132,98 @@ def __getitem__(self, idx): # pragma: no cover data = data if self.transform is None else self.transform(data) return data + @property + def num_classes(self): + r"""The number of classes in the dataset.""" + y = self.data.y + return y.max().item() + 1 if y.dim() == 1 else y.size(1) + def __repr__(self): # pragma: no cover return "{}({})".format(self.__class__.__name__, len(self)) + + +class MultiGraphDataset(Dataset): + def __init__(self, root=None, transform=None, pre_transform=None, pre_filter=None): + super(MultiGraphDataset, self).__init__(root, transform, pre_transform, pre_filter) + self.data, self.slices = None, None + + @property + def num_classes(self): + r"""The number of classes in the dataset.""" + y = self.data.y + return y.max().item() + 1 if y.dim() == 1 else y.size(1) + + def len(self): + for item in self.slices.values(): + return len(item) - 1 + return 0 + + def _get(self, idx): + data = self.data.__class__() + + if hasattr(self.data, "__num_nodes__"): + data.num_nodes = self.data.__num_nodes__[idx] + + for key in self.data.keys: + item, slices = self.data[key], self.slices[key] + start, end = slices[idx].item(), slices[idx + 1].item() + if torch.is_tensor(item): + s = list(repeat(slice(None), item.dim())) + s[self.data.__cat_dim__(key, item)] = slice(start, end) + elif start + 1 == end: + s = slices[start] + else: + s = slice(start, end) + data[key] = item[s] + return data + + def get(self, idx): + if isinstance(idx, int) or (len(idx) == 0): + return self._get(idx) + elif len(idx) > 1: + data_list = [self._get(i) for i in idx] + data, slices = self.from_data_list(data_list) + dataset = copy.copy(self) + dataset.data = data + dataset.slices = slices + return dataset + + @staticmethod + def from_data_list(data_list): + r""" Borrowed from PyG""" + + keys = data_list[0].keys + data = data_list[0].__class__() + + for key in keys: + data[key] = [] + slices = {key: [0] for key in keys} + + for item, key in product(data_list, keys): + data[key].append(item[key]) + if torch.is_tensor(item[key]): + s = slices[key][-1] + item[key].size(item.__cat_dim__(key, item[key])) + else: + s = slices[key][-1] + 1 + slices[key].append(s) + + if hasattr(data_list[0], "__num_nodes__"): + data.__num_nodes__ = [] + for item in data_list: + data.__num_nodes__.append(item.num_nodes) + + for key in keys: + item = data_list[0][key] + if torch.is_tensor(item): + data[key] = torch.cat(data[key], dim=data.__cat_dim__(key, item)) + elif isinstance(item, int) or isinstance(item, float): + data[key] = torch.tensor(data[key]) + + slices[key] = torch.tensor(slices[key], dtype=torch.long) + + return data, slices + + def __len__(self): + for item in self.slices.values(): + return len(item) - 1 + return 0 diff --git a/cogdl/datasets/__init__.py b/cogdl/datasets/__init__.py index 5e43a6d9..b060bfb2 100644 --- a/cogdl/datasets/__init__.py +++ b/cogdl/datasets/__init__.py @@ -1,6 +1,7 @@ import importlib from cogdl.data.dataset import Dataset +from .customizezd_data import CustomizedGraphClassificationDataset, CustomizedNodeClassificationDataset, BaseDataset try: import torch_geometric @@ -51,8 +52,12 @@ def try_import_dataset(dataset): def build_dataset(args): if not try_import_dataset(args.dataset): + assert hasattr(args, "task") + dataset = build_dataset_from_path(args.dataset, args.task) + if dataset is not None: + return dataset exit(1) - return DATASET_REGISTRY[args.dataset](args=args) + return DATASET_REGISTRY[args.dataset]() def build_dataset_from_name(dataset): @@ -61,6 +66,15 @@ def build_dataset_from_name(dataset): return DATASET_REGISTRY[dataset]() +def build_dataset_from_path(data_path, task): + if "node_classification" in task: + return CustomizedNodeClassificationDataset(data_path) + elif "graph_classification" in task: + return CustomizedGraphClassificationDataset(data_path) + else: + return None + + SUPPORTED_DATASETS = { "kdd_icdm": "cogdl.datasets.gcc_data", "sigir_cikm": "cogdl.datasets.gcc_data", @@ -117,8 +131,8 @@ def build_dataset_from_name(dataset): "reddit": "cogdl.datasets.saint_data", "test_bio": "cogdl.datasets.pyg_strategies_data", "test_chem": "cogdl.datasets.pyg_strategies_data", - "bio": "cogdl.datasets.pyg_strategies_data", - "chem": "cogdl.datasets.pyg_strategies_data", - "bace": "cogdl.datasets.pyg_strategies_data", - "bbbp": "cogdl.datasets.pyg_strategies_data", + "bio": "cogdl.datasets.strategies_data", + "chem": "cogdl.datasets.strategies_data", + "bace": "cogdl.datasets.strategies_data", + "bbbp": "cogdl.datasets.strategies_data", } diff --git a/cogdl/datasets/customizezd_data.py b/cogdl/datasets/customizezd_data.py new file mode 100644 index 00000000..89507e83 --- /dev/null +++ b/cogdl/datasets/customizezd_data.py @@ -0,0 +1,102 @@ +import torch + +from cogdl.data import Dataset, MultiGraphDataset, Batch +from cogdl.utils import accuracy_evaluator, download_url, multiclass_evaluator, multilabel_evaluator + + +def _get_evaluator(metric): + if metric == "accuracy": + return accuracy_evaluator() + elif metric == "multilabel_f1": + return multilabel_evaluator() + elif metric == "multiclass_f1": + return multiclass_evaluator() + else: + raise NotImplementedError + + +class BaseDataset(Dataset): + def __init__(self): + super(BaseDataset, self).__init__("custom") + + def process(self): + pass + + def _download(self): + pass + + def _process(self): + pass + + def get(self, idx): + return self.data + + +class CustomizedNodeClassificationDataset(Dataset): + """ + data_path : path to load dataset. The dataset must be processed to specific format + metric: Accuracy, multi-label f1 or multi-class f1. Default: `accuracy` + """ + + def __init__(self, data_path, metric="accuracy"): + super(CustomizedNodeClassificationDataset, self).__init__(root=data_path) + try: + self.data = torch.load(data_path) + except Exception as e: + print(e) + exit(1) + self.metric = metric + + def download(self): + for name in self.raw_file_names: + download_url("{}{}&dl=1".format(self.url, name), self.raw_dir, name=name) + + def process(self): + pass + + def get(self, idx): + assert idx == 0 + return self.data + + def get_evaluator(self): + return _get_evaluator(self.metric) + + def __repr__(self): + return "{}()".format(self.name) + + def _download(self): + pass + + def _process(self): + pass + + +class CustomizedGraphClassificationDataset(MultiGraphDataset): + def __init__(self, data_path): + super(CustomizedGraphClassificationDataset, self).__init__(root=data_path) + try: + data = torch.load(data_path) + if isinstance(data, list): + batch = Batch.from_data_list(data) + self.data = batch + self.slices = batch.__slices__ + del self.data.batch + else: + assert len(data) == 0 + self.data = data[0] + self.slices = data[1] + except Exception as e: + print(e) + exit(1) + + def get_evaluator(self): + return _get_evaluator(self.metric) + + def __repr__(self): + return "{}()".format(self.name) + + def _download(self): + pass + + def _process(self): + pass diff --git a/cogdl/datasets/gatne.py b/cogdl/datasets/gatne.py index 81b52f57..6acf734a 100644 --- a/cogdl/datasets/gatne.py +++ b/cogdl/datasets/gatne.py @@ -86,7 +86,7 @@ def __repr__(self): @register_dataset("amazon") class AmazonDataset(GatneDataset): - def __init__(self, args=None): + def __init__(self): dataset = "amazon" path = osp.join("data", dataset) super(AmazonDataset, self).__init__(path, dataset) @@ -94,7 +94,7 @@ def __init__(self, args=None): @register_dataset("twitter") class TwitterDataset(GatneDataset): - def __init__(self, args=None): + def __init__(self): dataset = "twitter" path = osp.join("data", dataset) super(TwitterDataset, self).__init__(path, dataset) @@ -102,7 +102,7 @@ def __init__(self, args=None): @register_dataset("youtube") class YouTubeDataset(GatneDataset): - def __init__(self, args=None): + def __init__(self): dataset = "youtube" path = osp.join("data", dataset) super(YouTubeDataset, self).__init__(path, dataset) diff --git a/cogdl/datasets/gcc_data.py b/cogdl/datasets/gcc_data.py index 6fa835c5..37a15d49 100644 --- a/cogdl/datasets/gcc_data.py +++ b/cogdl/datasets/gcc_data.py @@ -161,7 +161,7 @@ def process(self): @register_dataset("kdd_icdm") class KDD_ICDM_GCCDataset(GCCDataset): - def __init__(self, args=None): + def __init__(self): dataset = "kdd_icdm" path = osp.join("data", dataset) super(KDD_ICDM_GCCDataset, self).__init__(path, dataset) @@ -169,7 +169,7 @@ def __init__(self, args=None): @register_dataset("sigir_cikm") class SIGIR_CIKM_GCCDataset(GCCDataset): - def __init__(self, args=None): + def __init__(self): dataset = "sigir_cikm" path = osp.join("data", dataset) super(SIGIR_CIKM_GCCDataset, self).__init__(path, dataset) @@ -177,7 +177,7 @@ def __init__(self, args=None): @register_dataset("sigmod_icde") class SIGMOD_ICDE_GCCDataset(GCCDataset): - def __init__(self, args=None): + def __init__(self): dataset = "sigmod_icde" path = osp.join("data", dataset) super(SIGMOD_ICDE_GCCDataset, self).__init__(path, dataset) @@ -185,7 +185,7 @@ def __init__(self, args=None): @register_dataset("usa-airport") class USAAirportDataset(Edgelist): - def __init__(self, args=None): + def __init__(self): dataset = "usa-airport" path = osp.join("data", dataset) super(USAAirportDataset, self).__init__(path, dataset) diff --git a/cogdl/datasets/gtn_data.py b/cogdl/datasets/gtn_data.py index ea2a146f..0fd96c04 100644 --- a/cogdl/datasets/gtn_data.py +++ b/cogdl/datasets/gtn_data.py @@ -25,7 +25,6 @@ def __init__(self, root, name): self.url = f"https://github.com/cenyk1230/gtn-data/blob/master/{name}.zip?raw=true" super(GTNDataset, self).__init__(root) self.data = torch.load(self.processed_paths[0]) - self.num_classes = torch.max(self.data.train_target).item() + 1 self.num_edge = len(self.data.adj) self.num_nodes = self.data.x.shape[0] @@ -38,6 +37,10 @@ def raw_file_names(self): def processed_file_names(self): return ["data.pt"] + @property + def num_classes(self): + return torch.max(self.data.train_target).item() + 1 + def read_gtn_data(self, folder): edges = pickle.load(open(osp.join(folder, "edges.pkl"), "rb")) labels = pickle.load(open(osp.join(folder, "labels.pkl"), "rb")) @@ -128,7 +131,7 @@ def __repr__(self): @register_dataset("gtn-acm") class ACM_GTNDataset(GTNDataset): - def __init__(self, args=None): + def __init__(self): dataset = "gtn-acm" path = osp.join("data", dataset) super(ACM_GTNDataset, self).__init__(path, dataset) @@ -136,7 +139,7 @@ def __init__(self, args=None): @register_dataset("gtn-dblp") class DBLP_GTNDataset(GTNDataset): - def __init__(self, args=None): + def __init__(self): dataset = "gtn-dblp" path = osp.join("data", dataset) super(DBLP_GTNDataset, self).__init__(path, dataset) @@ -144,7 +147,7 @@ def __init__(self, args=None): @register_dataset("gtn-imdb") class IMDB_GTNDataset(GTNDataset): - def __init__(self, args=None): + def __init__(self): dataset = "gtn-imdb" path = osp.join("data", dataset) super(IMDB_GTNDataset, self).__init__(path, dataset) diff --git a/cogdl/datasets/han_data.py b/cogdl/datasets/han_data.py index 3dea403b..5aecc0d3 100644 --- a/cogdl/datasets/han_data.py +++ b/cogdl/datasets/han_data.py @@ -32,7 +32,6 @@ def __init__(self, root, name): self.url = f"https://github.com/cenyk1230/han-data/blob/master/{name}.zip?raw=true" super(HANDataset, self).__init__(root) self.data = torch.load(self.processed_paths[0]) - self.num_classes = torch.max(self.data.train_target).item() + 1 self.num_edge = len(self.data.adj) self.num_nodes = self.data.x.shape[0] @@ -45,6 +44,10 @@ def raw_file_names(self): def processed_file_names(self): return ["data.pt"] + @property + def num_classes(self): + return torch.max(self.data.train_target).item() + 1 + def read_gtn_data(self, folder): data = sio.loadmat(osp.join(folder, "data.mat")) if self.name == "han-acm" or self.name == "han-imdb": @@ -136,7 +139,7 @@ def __repr__(self): @register_dataset("han-acm") class ACM_HANDataset(HANDataset): - def __init__(self, args=None): + def __init__(self): dataset = "han-acm" path = osp.join("data", dataset) super(ACM_HANDataset, self).__init__(path, dataset) @@ -144,7 +147,7 @@ def __init__(self, args=None): @register_dataset("han-dblp") class DBLP_HANDataset(HANDataset): - def __init__(self, args=None): + def __init__(self): dataset = "han-dblp" path = osp.join("data", dataset) super(DBLP_HANDataset, self).__init__(path, dataset) @@ -152,7 +155,7 @@ def __init__(self, args=None): @register_dataset("han-imdb") class IMDB_HANDataset(HANDataset): - def __init__(self, args=None): + def __init__(self): dataset = "han-imdb" path = osp.join("data", dataset) super(IMDB_HANDataset, self).__init__(path, dataset) diff --git a/cogdl/datasets/kg_data.py b/cogdl/datasets/kg_data.py index c400424c..7dc7845e 100644 --- a/cogdl/datasets/kg_data.py +++ b/cogdl/datasets/kg_data.py @@ -308,7 +308,7 @@ def process(self): @register_dataset("fb13") class FB13Datset(KnowledgeGraphDataset): - def __init__(self, args=None): + def __init__(self): dataset = "FB13" path = osp.join("data", dataset) super(FB13Datset, self).__init__(path, dataset) @@ -316,7 +316,7 @@ def __init__(self, args=None): @register_dataset("fb15k") class FB15kDatset(KnowledgeGraphDataset): - def __init__(self, args=None): + def __init__(self): dataset = "FB15K" path = osp.join("data", dataset) super(FB15kDatset, self).__init__(path, dataset) @@ -324,7 +324,7 @@ def __init__(self, args=None): @register_dataset("fb15k237") class FB15k237Datset(KnowledgeGraphDataset): - def __init__(self, args=None): + def __init__(self): dataset = "FB15K237" path = osp.join("data", dataset) super(FB15k237Datset, self).__init__(path, dataset) @@ -332,7 +332,7 @@ def __init__(self, args=None): @register_dataset("wn18") class WN18Datset(KnowledgeGraphDataset): - def __init__(self, args=None): + def __init__(self): dataset = "WN18" path = osp.join("data", dataset) super(WN18Datset, self).__init__(path, dataset) @@ -340,7 +340,7 @@ def __init__(self, args=None): @register_dataset("wn18rr") class WN18RRDataset(KnowledgeGraphDataset): - def __init__(self, args=None): + def __init__(self): dataset = "WN18RR" path = osp.join("data", dataset) super(WN18RRDataset, self).__init__(path, dataset) @@ -350,7 +350,7 @@ def __init__(self, args=None): class FB13SDatset(KnowledgeGraphDataset): url = "https://raw.githubusercontent.com/cenyk1230/test-data/main" - def __init__(self, args=None): + def __init__(self): dataset = "FB13-S" path = osp.join("data", dataset) super(FB13SDatset, self).__init__(path, dataset) diff --git a/cogdl/datasets/matlab_matrix.py b/cogdl/datasets/matlab_matrix.py index 13350a45..4bcae52e 100644 --- a/cogdl/datasets/matlab_matrix.py +++ b/cogdl/datasets/matlab_matrix.py @@ -68,7 +68,7 @@ def process(self): @register_dataset("blogcatalog") class BlogcatalogDataset(MatlabMatrix): - def __init__(self, args=None): + def __init__(self): dataset, filename = "blogcatalog", "blogcatalog" url = "http://leitang.net/code/social-dimension/data/" path = osp.join("data", dataset) @@ -77,7 +77,7 @@ def __init__(self, args=None): @register_dataset("flickr-ne") class FlickrDataset(MatlabMatrix): - def __init__(self, args=None): + def __init__(self): dataset, filename = "flickr", "flickr" url = "http://leitang.net/code/social-dimension/data/" path = osp.join("data", dataset) @@ -86,7 +86,7 @@ def __init__(self, args=None): @register_dataset("wikipedia") class WikipediaDataset(MatlabMatrix): - def __init__(self, args=None): + def __init__(self): dataset, filename = "wikipedia", "POS" url = "http://snap.stanford.edu/node2vec/" path = osp.join("data", dataset) @@ -95,7 +95,7 @@ def __init__(self, args=None): @register_dataset("ppi") class PPIDataset(MatlabMatrix): - def __init__(self, args=None): + def __init__(self): dataset, filename = "ppi", "Homo_sapiens" url = "http://snap.stanford.edu/node2vec/" path = osp.join("data", dataset) diff --git a/cogdl/datasets/planetoid_data.py b/cogdl/datasets/planetoid_data.py index 7f91b0bf..3e0a260a 100644 --- a/cogdl/datasets/planetoid_data.py +++ b/cogdl/datasets/planetoid_data.py @@ -191,7 +191,7 @@ def normalize_feature(data): @register_dataset("cora") class CoraDataset(Planetoid): - def __init__(self, args=None): + def __init__(self): dataset = "Cora" path = osp.join("data", dataset) if not osp.exists(path): @@ -202,7 +202,7 @@ def __init__(self, args=None): @register_dataset("citeseer") class CiteSeerDataset(Planetoid): - def __init__(self, args=None): + def __init__(self): dataset = "CiteSeer" path = osp.join("data", dataset) if not osp.exists(path): @@ -213,7 +213,7 @@ def __init__(self, args=None): @register_dataset("pubmed") class PubMedDataset(Planetoid): - def __init__(self, args=None): + def __init__(self): dataset = "PubMed" path = osp.join("data", dataset) if not osp.exists(path): diff --git a/cogdl/datasets/pyg_ogb.py b/cogdl/datasets/pyg_ogb.py index 8aa9cd08..771d4d3c 100644 --- a/cogdl/datasets/pyg_ogb.py +++ b/cogdl/datasets/pyg_ogb.py @@ -34,7 +34,7 @@ def get(self, idx): @register_dataset("ogbn-arxiv") class OGBArxivDataset(OGBNDataset): - def __init__(self, args=None): + def __init__(self): dataset = "ogbn-arxiv" path = osp.join("data", dataset) if not osp.exists(path): @@ -49,7 +49,7 @@ def __init__(self, args=None): @register_dataset("ogbn-products") class OGBProductsDataset(OGBNDataset): - def __init__(self, args=None): + def __init__(self): dataset = "ogbn-products" path = osp.join("data", dataset) if not osp.exists(path): @@ -59,7 +59,7 @@ def __init__(self, args=None): @register_dataset("ogbn-proteins") class OGBProteinsDataset(OGBNDataset): - def __init__(self, args=None): + def __init__(self): dataset = "ogbn-proteins" path = osp.join("data", dataset) if not osp.exists(path): @@ -69,7 +69,7 @@ def __init__(self, args=None): @register_dataset("ogbn-mag") class OGBMAGDataset(OGBNDataset): - def __init__(self, args=None): + def __init__(self): dataset = "ogbn-mag" path = osp.join("data", dataset) if not osp.exists(path): @@ -79,7 +79,7 @@ def __init__(self, args=None): @register_dataset("ogbn-papers100M") class OGBPapers100MDataset(OGBNDataset): - def __init__(self, args=None): + def __init__(self): dataset = "ogbn-papers100M" path = osp.join("data", dataset) if not osp.exists(path): @@ -109,7 +109,7 @@ def get(self, idx): @register_dataset("ogbg-molbace") class OGBMolbaceDataset(OGBGDataset): - def __init__(self, args=None): + def __init__(self): dataset = "ogbg-molbace" path = osp.join("data", dataset) if not osp.exists(path): @@ -119,7 +119,7 @@ def __init__(self, args=None): @register_dataset("ogbg-molhiv") class OGBMolhivDataset(OGBGDataset): - def __init__(self, args=None): + def __init__(self): dataset = "ogbg-molhiv" path = osp.join("data", dataset) if not osp.exists(path): @@ -129,7 +129,7 @@ def __init__(self, args=None): @register_dataset("ogbg-molpcba") class OGBMolpcbaDataset(OGBGDataset): - def __init__(self, args=None): + def __init__(self): dataset = "ogbg-molpcba" path = osp.join("data", dataset) if not osp.exists(path): @@ -139,7 +139,7 @@ def __init__(self, args=None): @register_dataset("ogbg-ppa") class OGBPpaDataset(OGBGDataset): - def __init__(self, args=None): + def __init__(self): dataset = "ogbg-ppa" path = osp.join("data", dataset) if not osp.exists(path): @@ -149,7 +149,7 @@ def __init__(self, args=None): @register_dataset("ogbg-code") class OGBCodeDataset(OGBGDataset): - def __init__(self, args=None): + def __init__(self): dataset = "ogbg-code" path = osp.join("data", dataset) if not osp.exists(path): diff --git a/cogdl/datasets/saint_data.py b/cogdl/datasets/saint_data.py index 213332f6..1b92769a 100644 --- a/cogdl/datasets/saint_data.py +++ b/cogdl/datasets/saint_data.py @@ -117,7 +117,7 @@ def scale_feats(data): @register_dataset("yelp") class YelpDataset(SAINTDataset): - def __init__(self, args=None): + def __init__(self): dataset = "Yelp" url = "https://cloud.tsinghua.edu.cn/d/7218cc013c9a40159306/files/?p=%2F{}&dl=1" path = osp.join("data", dataset) @@ -129,7 +129,7 @@ def __init__(self, args=None): @register_dataset("amazon-s") class AmazonDataset(SAINTDataset): - def __init__(self, args=None): + def __init__(self): dataset = "AmazonSaint" url = "https://cloud.tsinghua.edu.cn/d/ae4b2c4f59bd41be9b0b/files/?p=%2F{}&dl=1" path = osp.join("data", dataset) @@ -141,7 +141,7 @@ def __init__(self, args=None): @register_dataset("flickr") class FlickrDatset(SAINTDataset): - def __init__(self, args=None): + def __init__(self): dataset = "Flickr" url = "https://cloud.tsinghua.edu.cn/d/d3ebcb5fa2da463b8213/files/?p=%2F{}&dl=1" path = osp.join("data", dataset) @@ -156,7 +156,7 @@ def get_evaluator(self): @register_dataset("reddit") class RedditDataset(SAINTDataset): - def __init__(self, args=None): + def __init__(self): dataset = "Reddit" url = "https://cloud.tsinghua.edu.cn/d/d087e7e766e747ce8073/files/?p=%2F{}&dl=1" path = osp.join("data", dataset) diff --git a/cogdl/datasets/pyg_strategies_data.py b/cogdl/datasets/strategies_data.py similarity index 94% rename from cogdl/datasets/pyg_strategies_data.py rename to cogdl/datasets/strategies_data.py index e7de8504..07757414 100644 --- a/cogdl/datasets/pyg_strategies_data.py +++ b/cogdl/datasets/strategies_data.py @@ -8,11 +8,12 @@ import numpy as np import torch -from torch_geometric.data import InMemoryDataset, Data from cogdl.utils import download_url import os.path as osp from itertools import repeat +from cogdl.data import Data, MultiGraphDataset + # ================ # Dataset utils # ================ @@ -557,7 +558,7 @@ def __repr__(self): class BatchFinetune(Data): def __init__(self, batch=None, **kwargs): - super(BatchMasking, self).__init__(**kwargs) + super(BatchFinetune, self).__init__(**kwargs) self.batch = batch @staticmethod @@ -841,10 +842,8 @@ def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): @register_dataset("test_bio") -class TestBioDataset(InMemoryDataset): - def __init__( - self, data_type="unsupervised", root=None, transform=None, pre_transform=None, pre_filter=None, args=None - ): +class TestBioDataset(MultiGraphDataset): + def __init__(self, data_type="unsupervised", root="testbio", transform=None, pre_transform=None, pre_filter=None): super(TestBioDataset, self).__init__(root, transform, pre_transform, pre_filter) num_nodes = 20 num_edges = 20 @@ -891,12 +890,16 @@ def cycle_index(num, shift): self.slices["go_target_pretrain"] = torch.arange(0, (num_graphs + 1) * pretrain_tasks) self.slices["go_target_downstream"] = torch.arange(0, (num_graphs + 1) * downstream_tasks) + def _download(self): + pass + + def _process(self): + pass + @register_dataset("test_chem") -class TestChemDataset(InMemoryDataset): - def __init__( - self, data_type="unsupervised", root=None, transform=None, pre_transform=None, pre_filter=None, args=None - ): +class TestChemDataset(MultiGraphDataset): + def __init__(self, data_type="unsupervised", root="testchem", transform=None, pre_transform=None, pre_filter=None): super(TestChemDataset, self).__init__(root, transform, pre_transform, pre_filter) num_nodes = 10 num_edges = 10 @@ -943,21 +946,16 @@ def cycle_index(num, shift): self.data.y = go_target_pretrain self.slices["y"] = torch.arange(0, (num_graphs + 1) * pretrain_tasks, pretrain_tasks) - def get(self, idx): - data = Data() - for key in self.data.keys: - item, slices = self.data[key], self.slices[key] - s = list(repeat(slice(None), item.dim())) - s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) - data[key] = item[s] - return data + def _download(self): + pass + + def _process(self): + pass @register_dataset("bio") -class BioDataset(InMemoryDataset): - def __init__( - self, data_type="unsupervised", empty=False, transform=None, pre_transform=None, pre_filter=None, args=None - ): +class BioDataset(MultiGraphDataset): + def __init__(self, data_type="unsupervised", empty=False, transform=None, pre_transform=None, pre_filter=None): self.data_type = data_type self.url = "https://cloud.tsinghua.edu.cn/f/c865b1d61348489e86ac/?dl=1" self.root = osp.join("data", "BIO") @@ -987,10 +985,8 @@ def process(self): @register_dataset("chem") -class MoleculeDataset(InMemoryDataset): - def __init__( - self, data_type="unsupervised", transform=None, pre_transform=None, pre_filter=None, empty=False, args=None - ): +class MoleculeDataset(MultiGraphDataset): + def __init__(self, data_type="unsupervised", transform=None, pre_transform=None, pre_filter=None, empty=False): self.data_type = data_type self.url = "https://cloud.tsinghua.edu.cn/f/2cac04ee904e4b54b4b2/?dl=1" self.root = osp.join("data", "CHEM") @@ -1004,15 +1000,6 @@ def __init__( else: self.data, self.slices = torch.load(self.processed_paths[0]) - def get(self, idx): - data = Data() - for key in self.data.keys: - item, slices = self.data[key], self.slices[key] - s = list(repeat(slice(None), item.dim())) - s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) - data[key] = item[s] - return data - @property def raw_file_names(self): return ["processed.zip"] @@ -1037,8 +1024,8 @@ def process(self): @register_dataset("bace") -class BACEDataset(InMemoryDataset): - def __init__(self, transform=None, pre_transform=None, pre_filter=None, empty=False, args=None): +class BACEDataset(MultiGraphDataset): + def __init__(self, transform=None, pre_transform=None, pre_filter=None, empty=False): self.url = "https://cloud.tsinghua.edu.cn/f/253270b278f4465380f1/?dl=1" self.root = osp.join("data", "BACE") @@ -1048,15 +1035,6 @@ def __init__(self, transform=None, pre_transform=None, pre_filter=None, empty=Fa if not empty: self.data, self.slices = torch.load(self.processed_paths[0]) - def get(self, idx): - data = Data() - for key in self.data.keys: - item, slices = self.data[key], self.slices[key] - s = list(repeat(slice(None), item.dim())) - s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) - data[key] = item[s] - return data - @property def raw_file_names(self): return ["processed.zip"] @@ -1076,8 +1054,8 @@ def process(self): @register_dataset("bbbp") -class BBBPDataset(InMemoryDataset): - def __init__(self, transform=None, pre_transform=None, pre_filter=None, empty=False, args=None): +class BBBPDataset(MultiGraphDataset): + def __init__(self, transform=None, pre_transform=None, pre_filter=None, empty=False): self.url = "https://cloud.tsinghua.edu.cn/f/ab8ff4d0a68c40a38956/?dl=1" self.root = osp.join("data", "BBBP") @@ -1087,15 +1065,6 @@ def __init__(self, transform=None, pre_transform=None, pre_filter=None, empty=Fa if not empty: self.data, self.slices = torch.load(self.processed_paths[0]) - def get(self, idx): - data = Data() - for key in self.data.keys: - item, slices = self.data[key], self.slices[key] - s = list(repeat(slice(None), item.dim())) - s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1]) - data[key] = item[s] - return data - @property def raw_file_names(self): return ["processed.zip"] diff --git a/cogdl/datasets/test_data.py b/cogdl/datasets/test_data.py index a3e6db21..da9fc801 100644 --- a/cogdl/datasets/test_data.py +++ b/cogdl/datasets/test_data.py @@ -3,21 +3,42 @@ from cogdl.datasets import register_dataset from cogdl.data import Dataset, Data + @register_dataset("test_small") class TestSmallDataset(Dataset): r"""small dataset for debug""" - def __init__(self, args=None): - x = torch.FloatTensor([[-2, -1], [-2, 1], [-1, 0], [0, 0], [0, 1], [1, 0], [2, 1], [3, 0], [2, -1], [4, 0], [4, 1], [5, 0]]) - edge_index = torch.LongTensor([[0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 7, 7, 7, 8, 8, 9, 9 , 9 , 10, 10, 11, 11], - [1, 2, 0, 2, 0, 1, 3, 2, 4, 5, 3, 3, 6, 7, 8, 5, 7, 5, 6, 8, 9, 5, 7, 7, 10, 11, 9 , 11, 9 , 10]]) + + def __init__(self): + super(TestSmallDataset, self).__init__("test") + x = torch.FloatTensor( + [[-2, -1], [-2, 1], [-1, 0], [0, 0], [0, 1], [1, 0], [2, 1], [3, 0], [2, -1], [4, 0], [4, 1], [5, 0]] + ) + edge_index = torch.LongTensor( + [ + [0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 7, 7, 7, 8, 8, 9, 9, 9, 10, 10, 11, 11], + [1, 2, 0, 2, 0, 1, 3, 2, 4, 5, 3, 3, 6, 7, 8, 5, 7, 5, 6, 8, 9, 5, 7, 7, 10, 11, 9, 11, 9, 10], + ] + ) y = torch.LongTensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3]) self.data = Data(x, edge_index, None, y, None) - self.data.train_mask = torch.tensor([True, False, False, True, False, True, False, False, False, True, False, False]) - self.data.val_mask = torch.tensor([False, True, False, False, False, False, True, False, False, False, False, True]) - self.data.test_mask = torch.tensor([False, False, True, False, True, False, False, True, True, False, True, False]) + self.data.train_mask = torch.tensor( + [True, False, False, True, False, True, False, False, False, True, False, False] + ) + self.data.val_mask = torch.tensor( + [False, True, False, False, False, False, True, False, False, False, False, True] + ) + self.data.test_mask = torch.tensor( + [False, False, True, False, True, False, False, True, True, False, True, False] + ) self.num_classes = 4 self.transform = None def get(self, idx): assert idx == 0 - return self.data \ No newline at end of file + return self.data + + def _download(self): + pass + + def _process(self): + pass diff --git a/cogdl/datasets/tu_data.py b/cogdl/datasets/tu_data.py index 1ebd2c61..4e7d096c 100644 --- a/cogdl/datasets/tu_data.py +++ b/cogdl/datasets/tu_data.py @@ -275,7 +275,7 @@ def get(self, idx): @register_dataset("mutag") class MUTAGDataset(TUDataset): - def __init__(self, args=None): + def __init__(self): dataset = "MUTAG" path = osp.join("data", dataset) if not osp.exists(path): @@ -285,7 +285,7 @@ def __init__(self, args=None): @register_dataset("imdb-b") class ImdbBinaryDataset(TUDataset): - def __init__(self, args=None): + def __init__(self): dataset = "IMDB-BINARY" path = osp.join("data", dataset) if not osp.exists(path): @@ -295,7 +295,7 @@ def __init__(self, args=None): @register_dataset("imdb-m") class ImdbMultiDataset(TUDataset): - def __init__(self, args=None): + def __init__(self): dataset = "IMDB-MULTI" path = osp.join("data", dataset) if not osp.exists(path): @@ -305,7 +305,7 @@ def __init__(self, args=None): @register_dataset("collab") class CollabDataset(TUDataset): - def __init__(self, args=None): + def __init__(self): dataset = "COLLAB" path = osp.join("data", dataset) if not osp.exists(path): @@ -315,7 +315,7 @@ def __init__(self, args=None): @register_dataset("proteins") class ProtainsDataset(TUDataset): - def __init__(self, args=None): + def __init__(self): dataset = "PROTEINS" path = osp.join("data", dataset) if not osp.exists(path): @@ -325,7 +325,7 @@ def __init__(self, args=None): @register_dataset("reddit-b") class RedditBinary(TUDataset): - def __init__(self, args=None): + def __init__(self): dataset = "REDDIT-BINARY" path = osp.join("data", dataset) if not osp.exists(path): @@ -335,7 +335,7 @@ def __init__(self, args=None): @register_dataset("reddit-multi-5k") class RedditMulti5K(TUDataset): - def __init__(self, args=None): + def __init__(self): dataset = "REDDIT-MULTI-5K" path = osp.join("data", dataset) if not osp.exists(path): @@ -345,7 +345,7 @@ def __init__(self, args=None): @register_dataset("reddit-multi-12k") class RedditMulti12K(TUDataset): - def __init__(self, args=None): + def __init__(self): dataset = "REDDIT-MULTI-12K" path = osp.join("data", dataset) if not osp.exists(path): @@ -355,7 +355,7 @@ def __init__(self, args=None): @register_dataset("ptc-mr") class PTCMRDataset(TUDataset): - def __init__(self, args=None): + def __init__(self): dataset = "PTC_MR" path = osp.join("data", dataset) if not osp.exists(path): @@ -365,7 +365,7 @@ def __init__(self, args=None): @register_dataset("nci1") class NCT1Dataset(TUDataset): - def __init__(self, args=None): + def __init__(self): dataset = "NCI1" path = osp.join("data", dataset) if not osp.exists(path): @@ -375,7 +375,7 @@ def __init__(self, args=None): @register_dataset("nci109") class NCT109Dataset(TUDataset): - def __init__(self, args=None): + def __init__(self): dataset = "NCI109" path = osp.join("data", dataset) if not osp.exists(path): @@ -385,7 +385,7 @@ def __init__(self, args=None): @register_dataset("enzymes") class ENZYMES(TUDataset): - def __init__(self, args=None): + def __init__(self): dataset = "ENZYMES" path = osp.join("data", dataset) if not osp.exists(path): diff --git a/cogdl/experiments.py b/cogdl/experiments.py index ba80dadc..d370fbda 100644 --- a/cogdl/experiments.py +++ b/cogdl/experiments.py @@ -12,6 +12,7 @@ from cogdl.tasks import build_task from cogdl.utils import set_random_seed, tabulate_results from cogdl.configs import BEST_CONFIGS +from cogdl.datasets import SUPPORTED_DATASETS from cogdl.models import SUPPORTED_MODELS @@ -116,7 +117,11 @@ def check_task_dataset_model_match(task, variants): clean_variants = [] for item in variants: - if item.model in SUPPORTED_MODELS and (item.model, item.dataset) not in pairs: + if ( + (item.dataset in SUPPORTED_DATASETS) + and (item.model in SUPPORTED_MODELS) + and (item.model, item.dataset) not in pairs + ): print(f"({item.model}, {item.dataset}) is not implemented in task '{task}''.") continue clean_variants.append(item) diff --git a/cogdl/layers/strategies_layers.py b/cogdl/layers/strategies_layers.py index 07705bbc..7d5a0179 100644 --- a/cogdl/layers/strategies_layers.py +++ b/cogdl/layers/strategies_layers.py @@ -7,7 +7,7 @@ import torch.nn as nn import torch.nn.functional as F from cogdl.datasets import build_dataset_from_name -from cogdl.datasets.pyg_strategies_data import ( +from cogdl.datasets.strategies_data import ( BioDataset, ChemExtractSubstructureContextPair, DataLoaderFinetune, @@ -22,7 +22,9 @@ ) from cogdl.utils import add_self_loops, batch_mean_pooling, batch_sum_pooling, cycle_index from sklearn.metrics import roc_auc_score -from torch_geometric.data import DataLoader + +# from torch_geometric.data import DataLoader +from cogdl.data import DataLoader class GINConv(nn.Module): diff --git a/cogdl/models/__init__.py b/cogdl/models/__init__.py index 1531ae2a..f088fcb9 100644 --- a/cogdl/models/__init__.py +++ b/cogdl/models/__init__.py @@ -111,7 +111,7 @@ def build_model(args): "fastgcn": "cogdl.models.nn.fastgcn", "mlp": "cogdl.models.nn.mlp", "sgc": "cogdl.models.nn.sgc", - "stpgnn": "cogdl.models.nn.pyg_stpgnn", + "stpgnn": "cogdl.models.nn.stpgnn", "sortpool": "cogdl.models.nn.sortpool", "srgcn": "cogdl.models.nn.pyg_srgcn", "asgcn": "cogdl.models.nn.asgcn", diff --git a/cogdl/models/emb/dgk.py b/cogdl/models/emb/dgk.py index dd496ef5..f797de53 100644 --- a/cogdl/models/emb/dgk.py +++ b/cogdl/models/emb/dgk.py @@ -2,7 +2,6 @@ from joblib import Parallel, delayed import networkx as nx import numpy as np -from gensim.models.doc2vec import Doc2Vec, TaggedDocument from gensim.models.word2vec import Word2Vec from .. import BaseModel, register_model diff --git a/cogdl/models/nn/__init__.py b/cogdl/models/nn/__init__.py index e69de29b..9e2c8b2c 100644 --- a/cogdl/models/nn/__init__.py +++ b/cogdl/models/nn/__init__.py @@ -0,0 +1,31 @@ +from .compgcn import CompGCN, CompGCNLayer +from .dgi import DGIModel +from .disengcn import DisenGCN, DisenGCNLayer +from .gat import GATLayer +from .gcn import GraphConvolution, TKipfGCN +from .gcnii import GCNIILayer, GCNII +from .gdc_gcn import GDC_GCN +from .grace import GRACE, GraceEncoder +from .graphsage import Graphsage, GraphSAGELayer +from .mvgrl import MVGRL +from .patchy_san import PatchySAN +from .ppnp import PPNP +from .rgcn import RGCNLayer, LinkPredictRGCN, RGCN +from .sgc import SimpleGraphConvolution, sgc + +__all__ = [ + "CompGCN", + "DGIModel", + "DisenGCN", + "GATLayer", + "TKipfGCN", + "GCNII", + "GDC_GCN", + "GRACE", + "Graphsage", + "MVGRL", + "PatchySAN", + "PPNP", + "RGCN", + "sgc", +] diff --git a/cogdl/models/nn/pyg_stpgnn.py b/cogdl/models/nn/stpgnn.py similarity index 100% rename from cogdl/models/nn/pyg_stpgnn.py rename to cogdl/models/nn/stpgnn.py diff --git a/cogdl/utils/utils.py b/cogdl/utils/utils.py index ddecdc50..bb29e8a8 100644 --- a/cogdl/utils/utils.py +++ b/cogdl/utils/utils.py @@ -335,10 +335,10 @@ def coalesce(row, col, value=None): if mask.all(): return row, col.value row = row[mask] - col = col[mask] if value is not None: _value = torch.zeros(row.shape[0], dtype=torch.float).to(row.device) value = _value.scatter_add_(dim=0, src=value, index=col) + col = col[mask] return row, col, value diff --git a/examples/custom_dataset.py b/examples/custom_dataset.py index b78c7f79..a8ce74b6 100644 --- a/examples/custom_dataset.py +++ b/examples/custom_dataset.py @@ -1,65 +1,40 @@ -from cogdl.data.data import Data import torch -from cogdl.tasks import build_task -from cogdl.models import build_model -from cogdl.options import get_task_model_args +from cogdl import experiment +from cogdl.data import Data +from cogdl.datasets import BaseDataset, register_dataset -"""Define your data""" - -class MyData(Data): +@register_dataset("mydataset") +class MyNodeClassificationDataset(BaseDataset): def __init__(self): - super(MyData, self).__init__() + super(MyNodeClassificationDataset, self).__init__() + self.data = self.process() + + def process(self): num_nodes = 100 num_edges = 300 feat_dim = 30 - # load or generate data - self.edge_index = torch.randint(0, num_nodes, (2, num_edges)) - self.x = torch.randn(num_nodes, feat_dim) - self.y = torch.randint(0, 2, (num_nodes,)) - - # set train/val/test mask in node_classification task - self.train_mask = torch.zeros(num_nodes).bool() - self.train_mask[0 : int(0.3 * num_nodes)] = True - self.val_mask = torch.zeros(num_nodes).bool() - self.val_mask[int(0.3 * num_nodes) : int(0.7 * num_nodes)] = True - self.test_mask = torch.zeros(num_nodes).bool() - self.test_mask[int(0.7 * num_nodes) :] = True - - -"""Define your dataset""" - -class MyNodeClassificationDataset(object): - def __init__(self): - self.data = MyData() - self.num_classes = self.data.num_classes - self.num_features = self.data.num_features - - def __getitem__(self, index): - assert index == 0 - return self.data + # load or generate your dataset + edge_index = torch.randint(0, num_nodes, (2, num_edges)) + x = torch.randn(num_nodes, feat_dim) + y = torch.randint(0, 2, (num_nodes,)) -def set_args(args): - """Change default setttings""" - cuda_available = torch.cuda.is_available() - args.cpu = not cuda_available - return args - - -def main_dataset(): - args = get_task_model_args(task="node_classification", model="gcn") - # use customized dataset - dataset = MyNodeClassificationDataset() - args.num_features = dataset.num_features - args.num_classes = dataset.num_classes - # use model in cogdl - model = build_model(args) - task = build_task(args, dataset, model) - result = task.train() - print(result) + # set train/val/test mask in node_classification task + train_mask = torch.zeros(num_nodes).bool() + train_mask[0 : int(0.3 * num_nodes)] = True + val_mask = torch.zeros(num_nodes).bool() + val_mask[int(0.3 * num_nodes) : int(0.7 * num_nodes)] = True + test_mask = torch.zeros(num_nodes).bool() + test_mask[int(0.7 * num_nodes) :] = True + data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) + torch.save(data, "mydata.pt") + return data if __name__ == "__main__": - main_dataset() + # Run with self-loaded dataset + experiment(task="node_classification", dataset="mydataset", model="gcn") + # Run with given datapaath + experiment(task="node_classification", dataset="./mydata.pt", model="gcn") diff --git a/tests/datasets/test_customized_data.py b/tests/datasets/test_customized_data.py new file mode 100644 index 00000000..0ac2e91b --- /dev/null +++ b/tests/datasets/test_customized_data.py @@ -0,0 +1,48 @@ +import torch +from cogdl.data import Data +from cogdl.datasets import BaseDataset, register_dataset, build_dataset, build_dataset_from_name +from cogdl.utils import build_args_from_dict + + +@register_dataset("mydataset") +class MyNodeClassificationDataset(BaseDataset): + def __init__(self): + super(MyNodeClassificationDataset, self).__init__() + self.data = self.process() + + def process(self): + num_nodes = 100 + num_edges = 300 + feat_dim = 30 + + # load or generate your dataset + edge_index = torch.randint(0, num_nodes, (2, num_edges)) + x = torch.randn(num_nodes, feat_dim) + y = torch.randint(0, 2, (num_nodes,)) + + # set train/val/test mask in node_classification task + train_mask = torch.zeros(num_nodes).bool() + train_mask[0 : int(0.3 * num_nodes)] = True + val_mask = torch.zeros(num_nodes).bool() + val_mask[int(0.3 * num_nodes) : int(0.7 * num_nodes)] = True + test_mask = torch.zeros(num_nodes).bool() + test_mask[int(0.7 * num_nodes) :] = True + data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) + torch.save(data, "mydata.pt") + return data + + +def test_customized_dataset(): + dataset = build_dataset_from_name("mydataset") + assert isinstance(dataset[0], Data) + assert dataset[0].x.shape[0] == 100 + + +def test_build_dataset_from_path(): + args = build_args_from_dict({"dataset": "mydata.pt", "task": "node_classification"}) + dataset = build_dataset(args) + assert dataset[0].x.shape[0] == 100 + + +if __name__ == "__main__": + test_customized_dataset() diff --git a/tests/tasks/test_pretrain.py b/tests/tasks/test_pretrain.py index 42c5b903..ca85dd6a 100644 --- a/tests/tasks/test_pretrain.py +++ b/tests/tasks/test_pretrain.py @@ -153,14 +153,14 @@ def test_bace(): if __name__ == "__main__": - test_stpgnn_infomax() - test_stpgnn_contextpred() - test_stpgnn_mask() - test_stpgnn_supervised() - test_stpgnn_finetune() - test_chem_contextpred() - test_chem_infomax() - test_chem_mask() + # test_stpgnn_infomax() + # test_stpgnn_contextpred() + # test_stpgnn_mask() + # test_stpgnn_supervised() + # test_stpgnn_finetune() + # test_chem_contextpred() + # test_chem_infomax() + # test_chem_mask() test_chem_supervised() test_bace() test_bbbp()