From 282ac1ff6b1dd0d95e50300e41b07dd762cba730 Mon Sep 17 00:00:00 2001
From: THINK2TRY <48679723+THINK2TRY@users.noreply.github.com>
Date: Sun, 24 Jan 2021 18:20:50 +0800
Subject: [PATCH] [Feature] Add custom dataset and remove dependency on PyG
 (#174)

* Reformat code

* Add custom dataset and remove dependency on PyG of stpgnn

* Fix a bug in experiment
---
 cogdl/data/__init__.py                        |   9 +-
 cogdl/data/data.py                            |   3 +
 cogdl/data/dataloader.py                      |  21 +++-
 cogdl/data/dataset.py                         |  95 ++++++++++++++++
 cogdl/datasets/__init__.py                    |  24 ++++-
 cogdl/datasets/customizezd_data.py            | 102 ++++++++++++++++++
 cogdl/datasets/gatne.py                       |   6 +-
 cogdl/datasets/gcc_data.py                    |   8 +-
 cogdl/datasets/gtn_data.py                    |  11 +-
 cogdl/datasets/han_data.py                    |  11 +-
 cogdl/datasets/kg_data.py                     |  12 +--
 cogdl/datasets/matlab_matrix.py               |   8 +-
 cogdl/datasets/planetoid_data.py              |   6 +-
 cogdl/datasets/pyg_ogb.py                     |  20 ++--
 cogdl/datasets/saint_data.py                  |   8 +-
 ..._strategies_data.py => strategies_data.py} |  83 +++++---------
 cogdl/datasets/test_data.py                   |  37 +++++--
 cogdl/datasets/tu_data.py                     |  24 ++---
 cogdl/experiments.py                          |   7 +-
 cogdl/layers/strategies_layers.py             |   6 +-
 cogdl/models/__init__.py                      |   2 +-
 cogdl/models/emb/dgk.py                       |   1 -
 cogdl/models/nn/__init__.py                   |  31 ++++++
 cogdl/models/nn/{pyg_stpgnn.py => stpgnn.py}  |   0
 cogdl/utils/utils.py                          |   2 +-
 examples/custom_dataset.py                    |  79 +++++---------
 tests/datasets/test_customized_data.py        |  48 +++++++++
 tests/tasks/test_pretrain.py                  |  16 +--
 28 files changed, 481 insertions(+), 199 deletions(-)
 create mode 100644 cogdl/datasets/customizezd_data.py
 rename cogdl/datasets/{pyg_strategies_data.py => strategies_data.py} (94%)
 rename cogdl/models/nn/{pyg_stpgnn.py => stpgnn.py} (100%)
 create mode 100644 tests/datasets/test_customized_data.py

diff --git a/cogdl/data/__init__.py b/cogdl/data/__init__.py
index 2419f173..5f2bdcae 100644
--- a/cogdl/data/__init__.py
+++ b/cogdl/data/__init__.py
@@ -1,11 +1,6 @@
 from .data import Data
 from .batch import Batch
-from .dataset import Dataset
+from .dataset import Dataset, MultiGraphDataset
 from .dataloader import DataLoader
 
-__all__ = [
-    "Data",
-    "Batch",
-    "Dataset",
-    "DataLoader",
-]
+__all__ = ["Data", "Batch", "Dataset", "DataLoader", "MultiGraphDataset"]
diff --git a/cogdl/data/data.py b/cogdl/data/data.py
index 9ef3e9fb..e489064e 100644
--- a/cogdl/data/data.py
+++ b/cogdl/data/data.py
@@ -113,6 +113,9 @@ def __inc__(self, key, value):
         # creating batches.
         return self.num_nodes if bool(re.search("(index|face)", key)) else 0
 
+    def __cat_dim__(self, key, value):
+        return self.cat_dim(key, value)
+
     @property
     def num_edges(self):
         r"""Returns the number of edges in the graph."""
diff --git a/cogdl/data/dataloader.py b/cogdl/data/dataloader.py
index 8e9a0451..795260ca 100644
--- a/cogdl/data/dataloader.py
+++ b/cogdl/data/dataloader.py
@@ -1,7 +1,7 @@
 import torch.utils.data
 from torch.utils.data.dataloader import default_collate
 
-from cogdl.data import Batch
+from cogdl.data import Batch, Data
 
 
 class DataLoader(torch.utils.data.DataLoader):
@@ -18,5 +18,22 @@ class DataLoader(torch.utils.data.DataLoader):
 
     def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
         super(DataLoader, self).__init__(
-            dataset, batch_size, shuffle, collate_fn=lambda data_list: Batch.from_data_list(data_list), **kwargs
+            # dataset, batch_size, shuffle, collate_fn=lambda data_list: Batch.from_data_list(data_list), **kwargs
+            dataset,
+            batch_size,
+            shuffle,
+            collate_fn=self.collate_fn,
+            **kwargs,
         )
+
+    @staticmethod
+    def collate_fn(batch):
+        item = batch[0]
+        if isinstance(item, Data):
+            return Batch.from_data_list(batch)
+        elif isinstance(item, torch.Tensor):
+            return default_collate(batch)
+        elif isinstance(item, float):
+            return torch.tensor(batch, dtype=torch.float)
+
+        raise TypeError("DataLoader found invalid type: {}".format(type(item)))
diff --git a/cogdl/data/dataset.py b/cogdl/data/dataset.py
index b2f51497..7b329649 100644
--- a/cogdl/data/dataset.py
+++ b/cogdl/data/dataset.py
@@ -1,5 +1,7 @@
 import collections
+import copy
 import os.path as osp
+from itertools import repeat, product
 
 import torch.utils.data
 
@@ -130,5 +132,98 @@ def __getitem__(self, idx):  # pragma: no cover
         data = data if self.transform is None else self.transform(data)
         return data
 
+    @property
+    def num_classes(self):
+        r"""The number of classes in the dataset."""
+        y = self.data.y
+        return y.max().item() + 1 if y.dim() == 1 else y.size(1)
+
     def __repr__(self):  # pragma: no cover
         return "{}({})".format(self.__class__.__name__, len(self))
+
+
+class MultiGraphDataset(Dataset):
+    def __init__(self, root=None, transform=None, pre_transform=None, pre_filter=None):
+        super(MultiGraphDataset, self).__init__(root, transform, pre_transform, pre_filter)
+        self.data, self.slices = None, None
+
+    @property
+    def num_classes(self):
+        r"""The number of classes in the dataset."""
+        y = self.data.y
+        return y.max().item() + 1 if y.dim() == 1 else y.size(1)
+
+    def len(self):
+        for item in self.slices.values():
+            return len(item) - 1
+        return 0
+
+    def _get(self, idx):
+        data = self.data.__class__()
+
+        if hasattr(self.data, "__num_nodes__"):
+            data.num_nodes = self.data.__num_nodes__[idx]
+
+        for key in self.data.keys:
+            item, slices = self.data[key], self.slices[key]
+            start, end = slices[idx].item(), slices[idx + 1].item()
+            if torch.is_tensor(item):
+                s = list(repeat(slice(None), item.dim()))
+                s[self.data.__cat_dim__(key, item)] = slice(start, end)
+            elif start + 1 == end:
+                s = slices[start]
+            else:
+                s = slice(start, end)
+            data[key] = item[s]
+        return data
+
+    def get(self, idx):
+        if isinstance(idx, int) or (len(idx) == 0):
+            return self._get(idx)
+        elif len(idx) > 1:
+            data_list = [self._get(i) for i in idx]
+            data, slices = self.from_data_list(data_list)
+            dataset = copy.copy(self)
+            dataset.data = data
+            dataset.slices = slices
+            return dataset
+
+    @staticmethod
+    def from_data_list(data_list):
+        r""" Borrowed from PyG"""
+
+        keys = data_list[0].keys
+        data = data_list[0].__class__()
+
+        for key in keys:
+            data[key] = []
+        slices = {key: [0] for key in keys}
+
+        for item, key in product(data_list, keys):
+            data[key].append(item[key])
+            if torch.is_tensor(item[key]):
+                s = slices[key][-1] + item[key].size(item.__cat_dim__(key, item[key]))
+            else:
+                s = slices[key][-1] + 1
+            slices[key].append(s)
+
+        if hasattr(data_list[0], "__num_nodes__"):
+            data.__num_nodes__ = []
+            for item in data_list:
+                data.__num_nodes__.append(item.num_nodes)
+
+        for key in keys:
+            item = data_list[0][key]
+            if torch.is_tensor(item):
+                data[key] = torch.cat(data[key], dim=data.__cat_dim__(key, item))
+            elif isinstance(item, int) or isinstance(item, float):
+                data[key] = torch.tensor(data[key])
+
+            slices[key] = torch.tensor(slices[key], dtype=torch.long)
+
+        return data, slices
+
+    def __len__(self):
+        for item in self.slices.values():
+            return len(item) - 1
+        return 0
diff --git a/cogdl/datasets/__init__.py b/cogdl/datasets/__init__.py
index 5e43a6d9..b060bfb2 100644
--- a/cogdl/datasets/__init__.py
+++ b/cogdl/datasets/__init__.py
@@ -1,6 +1,7 @@
 import importlib
 
 from cogdl.data.dataset import Dataset
+from .customizezd_data import CustomizedGraphClassificationDataset, CustomizedNodeClassificationDataset, BaseDataset
 
 try:
     import torch_geometric
@@ -51,8 +52,12 @@ def try_import_dataset(dataset):
 
 def build_dataset(args):
     if not try_import_dataset(args.dataset):
+        assert hasattr(args, "task")
+        dataset = build_dataset_from_path(args.dataset, args.task)
+        if dataset is not None:
+            return dataset
         exit(1)
-    return DATASET_REGISTRY[args.dataset](args=args)
+    return DATASET_REGISTRY[args.dataset]()
 
 
 def build_dataset_from_name(dataset):
@@ -61,6 +66,15 @@ def build_dataset_from_name(dataset):
     return DATASET_REGISTRY[dataset]()
 
 
+def build_dataset_from_path(data_path, task):
+    if "node_classification" in task:
+        return CustomizedNodeClassificationDataset(data_path)
+    elif "graph_classification" in task:
+        return CustomizedGraphClassificationDataset(data_path)
+    else:
+        return None
+
+
 SUPPORTED_DATASETS = {
     "kdd_icdm": "cogdl.datasets.gcc_data",
     "sigir_cikm": "cogdl.datasets.gcc_data",
@@ -117,8 +131,8 @@ def build_dataset_from_name(dataset):
     "reddit": "cogdl.datasets.saint_data",
     "test_bio": "cogdl.datasets.pyg_strategies_data",
     "test_chem": "cogdl.datasets.pyg_strategies_data",
-    "bio": "cogdl.datasets.pyg_strategies_data",
-    "chem": "cogdl.datasets.pyg_strategies_data",
-    "bace": "cogdl.datasets.pyg_strategies_data",
-    "bbbp": "cogdl.datasets.pyg_strategies_data",
+    "bio": "cogdl.datasets.strategies_data",
+    "chem": "cogdl.datasets.strategies_data",
+    "bace": "cogdl.datasets.strategies_data",
+    "bbbp": "cogdl.datasets.strategies_data",
 }
diff --git a/cogdl/datasets/customizezd_data.py b/cogdl/datasets/customizezd_data.py
new file mode 100644
index 00000000..89507e83
--- /dev/null
+++ b/cogdl/datasets/customizezd_data.py
@@ -0,0 +1,102 @@
+import torch
+
+from cogdl.data import Dataset, MultiGraphDataset, Batch
+from cogdl.utils import accuracy_evaluator, download_url, multiclass_evaluator, multilabel_evaluator
+
+
+def _get_evaluator(metric):
+    if metric == "accuracy":
+        return accuracy_evaluator()
+    elif metric == "multilabel_f1":
+        return multilabel_evaluator()
+    elif metric == "multiclass_f1":
+        return multiclass_evaluator()
+    else:
+        raise NotImplementedError
+
+
+class BaseDataset(Dataset):
+    def __init__(self):
+        super(BaseDataset, self).__init__("custom")
+
+    def process(self):
+        pass
+
+    def _download(self):
+        pass
+
+    def _process(self):
+        pass
+
+    def get(self, idx):
+        return self.data
+
+
+class CustomizedNodeClassificationDataset(Dataset):
+    """
+    data_path : path to load dataset. The dataset must be processed to specific format
+    metric: Accuracy, multi-label f1 or multi-class f1. Default: `accuracy`
+    """
+
+    def __init__(self, data_path, metric="accuracy"):
+        super(CustomizedNodeClassificationDataset, self).__init__(root=data_path)
+        try:
+            self.data = torch.load(data_path)
+        except Exception as e:
+            print(e)
+            exit(1)
+        self.metric = metric
+
+    def download(self):
+        for name in self.raw_file_names:
+            download_url("{}{}&dl=1".format(self.url, name), self.raw_dir, name=name)
+
+    def process(self):
+        pass
+
+    def get(self, idx):
+        assert idx == 0
+        return self.data
+
+    def get_evaluator(self):
+        return _get_evaluator(self.metric)
+
+    def __repr__(self):
+        return "{}()".format(self.name)
+
+    def _download(self):
+        pass
+
+    def _process(self):
+        pass
+
+
+class CustomizedGraphClassificationDataset(MultiGraphDataset):
+    def __init__(self, data_path):
+        super(CustomizedGraphClassificationDataset, self).__init__(root=data_path)
+        try:
+            data = torch.load(data_path)
+            if isinstance(data, list):
+                batch = Batch.from_data_list(data)
+                self.data = batch
+                self.slices = batch.__slices__
+                del self.data.batch
+            else:
+                assert len(data) == 0
+                self.data = data[0]
+                self.slices = data[1]
+        except Exception as e:
+            print(e)
+            exit(1)
+
+    def get_evaluator(self):
+        return _get_evaluator(self.metric)
+
+    def __repr__(self):
+        return "{}()".format(self.name)
+
+    def _download(self):
+        pass
+
+    def _process(self):
+        pass
diff --git a/cogdl/datasets/gatne.py b/cogdl/datasets/gatne.py
index 81b52f57..6acf734a 100644
--- a/cogdl/datasets/gatne.py
+++ b/cogdl/datasets/gatne.py
@@ -86,7 +86,7 @@ def __repr__(self):
 
 @register_dataset("amazon")
 class AmazonDataset(GatneDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "amazon"
         path = osp.join("data", dataset)
         super(AmazonDataset, self).__init__(path, dataset)
@@ -94,7 +94,7 @@ def __init__(self, args=None):
 
 @register_dataset("twitter")
 class TwitterDataset(GatneDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "twitter"
         path = osp.join("data", dataset)
         super(TwitterDataset, self).__init__(path, dataset)
@@ -102,7 +102,7 @@ def __init__(self, args=None):
 
 @register_dataset("youtube")
 class YouTubeDataset(GatneDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "youtube"
         path = osp.join("data", dataset)
         super(YouTubeDataset, self).__init__(path, dataset)
diff --git a/cogdl/datasets/gcc_data.py b/cogdl/datasets/gcc_data.py
index 6fa835c5..37a15d49 100644
--- a/cogdl/datasets/gcc_data.py
+++ b/cogdl/datasets/gcc_data.py
@@ -161,7 +161,7 @@ def process(self):
 
 @register_dataset("kdd_icdm")
 class KDD_ICDM_GCCDataset(GCCDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "kdd_icdm"
         path = osp.join("data", dataset)
         super(KDD_ICDM_GCCDataset, self).__init__(path, dataset)
@@ -169,7 +169,7 @@ def __init__(self, args=None):
 
 @register_dataset("sigir_cikm")
 class SIGIR_CIKM_GCCDataset(GCCDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "sigir_cikm"
         path = osp.join("data", dataset)
         super(SIGIR_CIKM_GCCDataset, self).__init__(path, dataset)
@@ -177,7 +177,7 @@ def __init__(self, args=None):
 
 @register_dataset("sigmod_icde")
 class SIGMOD_ICDE_GCCDataset(GCCDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "sigmod_icde"
         path = osp.join("data", dataset)
         super(SIGMOD_ICDE_GCCDataset, self).__init__(path, dataset)
@@ -185,7 +185,7 @@ def __init__(self, args=None):
 
 @register_dataset("usa-airport")
 class USAAirportDataset(Edgelist):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "usa-airport"
         path = osp.join("data", dataset)
         super(USAAirportDataset, self).__init__(path, dataset)
diff --git a/cogdl/datasets/gtn_data.py b/cogdl/datasets/gtn_data.py
index ea2a146f..0fd96c04 100644
--- a/cogdl/datasets/gtn_data.py
+++ b/cogdl/datasets/gtn_data.py
@@ -25,7 +25,6 @@ def __init__(self, root, name):
         self.url = f"https://github.com/cenyk1230/gtn-data/blob/master/{name}.zip?raw=true"
         super(GTNDataset, self).__init__(root)
         self.data = torch.load(self.processed_paths[0])
-        self.num_classes = torch.max(self.data.train_target).item() + 1
         self.num_edge = len(self.data.adj)
         self.num_nodes = self.data.x.shape[0]
 
@@ -38,6 +37,10 @@ def raw_file_names(self):
     def processed_file_names(self):
         return ["data.pt"]
 
+    @property
+    def num_classes(self):
+        return torch.max(self.data.train_target).item() + 1
+
     def read_gtn_data(self, folder):
         edges = pickle.load(open(osp.join(folder, "edges.pkl"), "rb"))
         labels = pickle.load(open(osp.join(folder, "labels.pkl"), "rb"))
@@ -128,7 +131,7 @@ def __repr__(self):
 
 @register_dataset("gtn-acm")
 class ACM_GTNDataset(GTNDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "gtn-acm"
         path = osp.join("data", dataset)
         super(ACM_GTNDataset, self).__init__(path, dataset)
@@ -136,7 +139,7 @@ def __init__(self, args=None):
 
 @register_dataset("gtn-dblp")
 class DBLP_GTNDataset(GTNDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "gtn-dblp"
         path = osp.join("data", dataset)
         super(DBLP_GTNDataset, self).__init__(path, dataset)
@@ -144,7 +147,7 @@ def __init__(self, args=None):
 
 @register_dataset("gtn-imdb")
 class IMDB_GTNDataset(GTNDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "gtn-imdb"
         path = osp.join("data", dataset)
         super(IMDB_GTNDataset, self).__init__(path, dataset)
diff --git a/cogdl/datasets/han_data.py b/cogdl/datasets/han_data.py
index 3dea403b..5aecc0d3 100644
--- a/cogdl/datasets/han_data.py
+++ b/cogdl/datasets/han_data.py
@@ -32,7 +32,6 @@ def __init__(self, root, name):
         self.url = f"https://github.com/cenyk1230/han-data/blob/master/{name}.zip?raw=true"
         super(HANDataset, self).__init__(root)
         self.data = torch.load(self.processed_paths[0])
-        self.num_classes = torch.max(self.data.train_target).item() + 1
         self.num_edge = len(self.data.adj)
         self.num_nodes = self.data.x.shape[0]
 
@@ -45,6 +44,10 @@ def raw_file_names(self):
     def processed_file_names(self):
         return ["data.pt"]
 
+    @property
+    def num_classes(self):
+        return torch.max(self.data.train_target).item() + 1
+
     def read_gtn_data(self, folder):
         data = sio.loadmat(osp.join(folder, "data.mat"))
         if self.name == "han-acm" or self.name == "han-imdb":
@@ -136,7 +139,7 @@ def __repr__(self):
 
 @register_dataset("han-acm")
 class ACM_HANDataset(HANDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "han-acm"
         path = osp.join("data", dataset)
         super(ACM_HANDataset, self).__init__(path, dataset)
@@ -144,7 +147,7 @@ def __init__(self, args=None):
 
 @register_dataset("han-dblp")
 class DBLP_HANDataset(HANDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "han-dblp"
         path = osp.join("data", dataset)
         super(DBLP_HANDataset, self).__init__(path, dataset)
@@ -152,7 +155,7 @@ def __init__(self, args=None):
 
 @register_dataset("han-imdb")
 class IMDB_HANDataset(HANDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "han-imdb"
         path = osp.join("data", dataset)
         super(IMDB_HANDataset, self).__init__(path, dataset)
diff --git a/cogdl/datasets/kg_data.py b/cogdl/datasets/kg_data.py
index c400424c..7dc7845e 100644
--- a/cogdl/datasets/kg_data.py
+++ b/cogdl/datasets/kg_data.py
@@ -308,7 +308,7 @@ def process(self):
 
 @register_dataset("fb13")
 class FB13Datset(KnowledgeGraphDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "FB13"
         path = osp.join("data", dataset)
         super(FB13Datset, self).__init__(path, dataset)
@@ -316,7 +316,7 @@ def __init__(self, args=None):
 
 @register_dataset("fb15k")
 class FB15kDatset(KnowledgeGraphDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "FB15K"
         path = osp.join("data", dataset)
         super(FB15kDatset, self).__init__(path, dataset)
@@ -324,7 +324,7 @@ def __init__(self, args=None):
 
 @register_dataset("fb15k237")
 class FB15k237Datset(KnowledgeGraphDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "FB15K237"
         path = osp.join("data", dataset)
         super(FB15k237Datset, self).__init__(path, dataset)
@@ -332,7 +332,7 @@ def __init__(self, args=None):
 
 @register_dataset("wn18")
 class WN18Datset(KnowledgeGraphDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "WN18"
         path = osp.join("data", dataset)
         super(WN18Datset, self).__init__(path, dataset)
@@ -340,7 +340,7 @@ def __init__(self, args=None):
 
 @register_dataset("wn18rr")
 class WN18RRDataset(KnowledgeGraphDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "WN18RR"
         path = osp.join("data", dataset)
         super(WN18RRDataset, self).__init__(path, dataset)
@@ -350,7 +350,7 @@ def __init__(self, args=None):
 class FB13SDatset(KnowledgeGraphDataset):
     url = "https://raw.githubusercontent.com/cenyk1230/test-data/main"
 
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "FB13-S"
         path = osp.join("data", dataset)
         super(FB13SDatset, self).__init__(path, dataset)
diff --git a/cogdl/datasets/matlab_matrix.py b/cogdl/datasets/matlab_matrix.py
index 13350a45..4bcae52e 100644
--- a/cogdl/datasets/matlab_matrix.py
+++ b/cogdl/datasets/matlab_matrix.py
@@ -68,7 +68,7 @@ def process(self):
 
 @register_dataset("blogcatalog")
 class BlogcatalogDataset(MatlabMatrix):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset, filename = "blogcatalog", "blogcatalog"
         url = "http://leitang.net/code/social-dimension/data/"
         path = osp.join("data", dataset)
@@ -77,7 +77,7 @@ def __init__(self, args=None):
 
 @register_dataset("flickr-ne")
 class FlickrDataset(MatlabMatrix):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset, filename = "flickr", "flickr"
         url = "http://leitang.net/code/social-dimension/data/"
         path = osp.join("data", dataset)
@@ -86,7 +86,7 @@ def __init__(self, args=None):
 
 @register_dataset("wikipedia")
 class WikipediaDataset(MatlabMatrix):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset, filename = "wikipedia", "POS"
         url = "http://snap.stanford.edu/node2vec/"
         path = osp.join("data", dataset)
@@ -95,7 +95,7 @@ def __init__(self, args=None):
 
 @register_dataset("ppi")
 class PPIDataset(MatlabMatrix):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset, filename = "ppi", "Homo_sapiens"
         url = "http://snap.stanford.edu/node2vec/"
         path = osp.join("data", dataset)
diff --git a/cogdl/datasets/planetoid_data.py b/cogdl/datasets/planetoid_data.py
index 7f91b0bf..3e0a260a 100644
--- a/cogdl/datasets/planetoid_data.py
+++ b/cogdl/datasets/planetoid_data.py
@@ -191,7 +191,7 @@ def normalize_feature(data):
 
 @register_dataset("cora")
 class CoraDataset(Planetoid):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "Cora"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -202,7 +202,7 @@ def __init__(self, args=None):
 
 @register_dataset("citeseer")
 class CiteSeerDataset(Planetoid):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "CiteSeer"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -213,7 +213,7 @@ def __init__(self, args=None):
 
 @register_dataset("pubmed")
 class PubMedDataset(Planetoid):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "PubMed"
         path = osp.join("data", dataset)
         if not osp.exists(path):
diff --git a/cogdl/datasets/pyg_ogb.py b/cogdl/datasets/pyg_ogb.py
index 8aa9cd08..771d4d3c 100644
--- a/cogdl/datasets/pyg_ogb.py
+++ b/cogdl/datasets/pyg_ogb.py
@@ -34,7 +34,7 @@ def get(self, idx):
 
 @register_dataset("ogbn-arxiv")
 class OGBArxivDataset(OGBNDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "ogbn-arxiv"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -49,7 +49,7 @@ def __init__(self, args=None):
 
 @register_dataset("ogbn-products")
 class OGBProductsDataset(OGBNDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "ogbn-products"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -59,7 +59,7 @@ def __init__(self, args=None):
 
 @register_dataset("ogbn-proteins")
 class OGBProteinsDataset(OGBNDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "ogbn-proteins"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -69,7 +69,7 @@ def __init__(self, args=None):
 
 @register_dataset("ogbn-mag")
 class OGBMAGDataset(OGBNDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "ogbn-mag"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -79,7 +79,7 @@ def __init__(self, args=None):
 
 @register_dataset("ogbn-papers100M")
 class OGBPapers100MDataset(OGBNDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "ogbn-papers100M"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -109,7 +109,7 @@ def get(self, idx):
 
 @register_dataset("ogbg-molbace")
 class OGBMolbaceDataset(OGBGDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "ogbg-molbace"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -119,7 +119,7 @@ def __init__(self, args=None):
 
 @register_dataset("ogbg-molhiv")
 class OGBMolhivDataset(OGBGDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "ogbg-molhiv"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -129,7 +129,7 @@ def __init__(self, args=None):
 
 @register_dataset("ogbg-molpcba")
 class OGBMolpcbaDataset(OGBGDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "ogbg-molpcba"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -139,7 +139,7 @@ def __init__(self, args=None):
 
 @register_dataset("ogbg-ppa")
 class OGBPpaDataset(OGBGDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "ogbg-ppa"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -149,7 +149,7 @@ def __init__(self, args=None):
 
 @register_dataset("ogbg-code")
 class OGBCodeDataset(OGBGDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "ogbg-code"
         path = osp.join("data", dataset)
         if not osp.exists(path):
diff --git a/cogdl/datasets/saint_data.py b/cogdl/datasets/saint_data.py
index 213332f6..1b92769a 100644
--- a/cogdl/datasets/saint_data.py
+++ b/cogdl/datasets/saint_data.py
@@ -117,7 +117,7 @@ def scale_feats(data):
 
 @register_dataset("yelp")
 class YelpDataset(SAINTDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "Yelp"
         url = "https://cloud.tsinghua.edu.cn/d/7218cc013c9a40159306/files/?p=%2F{}&dl=1"
         path = osp.join("data", dataset)
@@ -129,7 +129,7 @@ def __init__(self, args=None):
 
 @register_dataset("amazon-s")
 class AmazonDataset(SAINTDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "AmazonSaint"
         url = "https://cloud.tsinghua.edu.cn/d/ae4b2c4f59bd41be9b0b/files/?p=%2F{}&dl=1"
         path = osp.join("data", dataset)
@@ -141,7 +141,7 @@ def __init__(self, args=None):
 
 @register_dataset("flickr")
 class FlickrDatset(SAINTDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "Flickr"
         url = "https://cloud.tsinghua.edu.cn/d/d3ebcb5fa2da463b8213/files/?p=%2F{}&dl=1"
         path = osp.join("data", dataset)
@@ -156,7 +156,7 @@ def get_evaluator(self):
 
 @register_dataset("reddit")
 class RedditDataset(SAINTDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "Reddit"
         url = "https://cloud.tsinghua.edu.cn/d/d087e7e766e747ce8073/files/?p=%2F{}&dl=1"
         path = osp.join("data", dataset)
diff --git a/cogdl/datasets/pyg_strategies_data.py b/cogdl/datasets/strategies_data.py
similarity index 94%
rename from cogdl/datasets/pyg_strategies_data.py
rename to cogdl/datasets/strategies_data.py
index e7de8504..07757414 100644
--- a/cogdl/datasets/pyg_strategies_data.py
+++ b/cogdl/datasets/strategies_data.py
@@ -8,11 +8,12 @@
 import numpy as np
 
 import torch
-from torch_geometric.data import InMemoryDataset, Data
 from cogdl.utils import download_url
 import os.path as osp
 from itertools import repeat
 
+from cogdl.data import Data, MultiGraphDataset
+
 # ================
 # Dataset utils
 # ================
@@ -557,7 +558,7 @@ def __repr__(self):
 
 class BatchFinetune(Data):
     def __init__(self, batch=None, **kwargs):
-        super(BatchMasking, self).__init__(**kwargs)
+        super(BatchFinetune, self).__init__(**kwargs)
         self.batch = batch
 
     @staticmethod
@@ -841,10 +842,8 @@ def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs):
 
 
 @register_dataset("test_bio")
-class TestBioDataset(InMemoryDataset):
-    def __init__(
-        self, data_type="unsupervised", root=None, transform=None, pre_transform=None, pre_filter=None, args=None
-    ):
+class TestBioDataset(MultiGraphDataset):
+    def __init__(self, data_type="unsupervised", root="testbio", transform=None, pre_transform=None, pre_filter=None):
         super(TestBioDataset, self).__init__(root, transform, pre_transform, pre_filter)
         num_nodes = 20
         num_edges = 20
@@ -891,12 +890,16 @@ def cycle_index(num, shift):
             self.slices["go_target_pretrain"] = torch.arange(0, (num_graphs + 1) * pretrain_tasks)
             self.slices["go_target_downstream"] = torch.arange(0, (num_graphs + 1) * downstream_tasks)
 
+    def _download(self):
+        pass
+
+    def _process(self):
+        pass
+
 
 @register_dataset("test_chem")
-class TestChemDataset(InMemoryDataset):
-    def __init__(
-        self, data_type="unsupervised", root=None, transform=None, pre_transform=None, pre_filter=None, args=None
-    ):
+class TestChemDataset(MultiGraphDataset):
+    def __init__(self, data_type="unsupervised", root="testchem", transform=None, pre_transform=None, pre_filter=None):
         super(TestChemDataset, self).__init__(root, transform, pre_transform, pre_filter)
         num_nodes = 10
         num_edges = 10
@@ -943,21 +946,16 @@ def cycle_index(num, shift):
             self.data.y = go_target_pretrain
             self.slices["y"] = torch.arange(0, (num_graphs + 1) * pretrain_tasks, pretrain_tasks)
 
-    def get(self, idx):
-        data = Data()
-        for key in self.data.keys:
-            item, slices = self.data[key], self.slices[key]
-            s = list(repeat(slice(None), item.dim()))
-            s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
-            data[key] = item[s]
-        return data
+    def _download(self):
+        pass
+
+    def _process(self):
+        pass
 
 
 @register_dataset("bio")
-class BioDataset(InMemoryDataset):
-    def __init__(
-        self, data_type="unsupervised", empty=False, transform=None, pre_transform=None, pre_filter=None, args=None
-    ):
+class BioDataset(MultiGraphDataset):
+    def __init__(self, data_type="unsupervised", empty=False, transform=None, pre_transform=None, pre_filter=None):
         self.data_type = data_type
         self.url = "https://cloud.tsinghua.edu.cn/f/c865b1d61348489e86ac/?dl=1"
         self.root = osp.join("data", "BIO")
@@ -987,10 +985,8 @@ def process(self):
 
 
 @register_dataset("chem")
-class MoleculeDataset(InMemoryDataset):
-    def __init__(
-        self, data_type="unsupervised", transform=None, pre_transform=None, pre_filter=None, empty=False, args=None
-    ):
+class MoleculeDataset(MultiGraphDataset):
+    def __init__(self, data_type="unsupervised", transform=None, pre_transform=None, pre_filter=None, empty=False):
         self.data_type = data_type
         self.url = "https://cloud.tsinghua.edu.cn/f/2cac04ee904e4b54b4b2/?dl=1"
         self.root = osp.join("data", "CHEM")
@@ -1004,15 +1000,6 @@ def __init__(
             else:
                 self.data, self.slices = torch.load(self.processed_paths[0])
 
-    def get(self, idx):
-        data = Data()
-        for key in self.data.keys:
-            item, slices = self.data[key], self.slices[key]
-            s = list(repeat(slice(None), item.dim()))
-            s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
-            data[key] = item[s]
-        return data
-
     @property
     def raw_file_names(self):
         return ["processed.zip"]
@@ -1037,8 +1024,8 @@ def process(self):
 
 
 @register_dataset("bace")
-class BACEDataset(InMemoryDataset):
-    def __init__(self, transform=None, pre_transform=None, pre_filter=None, empty=False, args=None):
+class BACEDataset(MultiGraphDataset):
+    def __init__(self, transform=None, pre_transform=None, pre_filter=None, empty=False):
         self.url = "https://cloud.tsinghua.edu.cn/f/253270b278f4465380f1/?dl=1"
         self.root = osp.join("data", "BACE")
 
@@ -1048,15 +1035,6 @@ def __init__(self, transform=None, pre_transform=None, pre_filter=None, empty=Fa
         if not empty:
             self.data, self.slices = torch.load(self.processed_paths[0])
 
-    def get(self, idx):
-        data = Data()
-        for key in self.data.keys:
-            item, slices = self.data[key], self.slices[key]
-            s = list(repeat(slice(None), item.dim()))
-            s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
-            data[key] = item[s]
-        return data
-
     @property
     def raw_file_names(self):
         return ["processed.zip"]
@@ -1076,8 +1054,8 @@ def process(self):
 
 
 @register_dataset("bbbp")
-class BBBPDataset(InMemoryDataset):
-    def __init__(self, transform=None, pre_transform=None, pre_filter=None, empty=False, args=None):
+class BBBPDataset(MultiGraphDataset):
+    def __init__(self, transform=None, pre_transform=None, pre_filter=None, empty=False):
         self.url = "https://cloud.tsinghua.edu.cn/f/ab8ff4d0a68c40a38956/?dl=1"
         self.root = osp.join("data", "BBBP")
 
@@ -1087,15 +1065,6 @@ def __init__(self, transform=None, pre_transform=None, pre_filter=None, empty=Fa
         if not empty:
             self.data, self.slices = torch.load(self.processed_paths[0])
 
-    def get(self, idx):
-        data = Data()
-        for key in self.data.keys:
-            item, slices = self.data[key], self.slices[key]
-            s = list(repeat(slice(None), item.dim()))
-            s[data.__cat_dim__(key, item)] = slice(slices[idx], slices[idx + 1])
-            data[key] = item[s]
-        return data
-
     @property
     def raw_file_names(self):
         return ["processed.zip"]
diff --git a/cogdl/datasets/test_data.py b/cogdl/datasets/test_data.py
index a3e6db21..da9fc801 100644
--- a/cogdl/datasets/test_data.py
+++ b/cogdl/datasets/test_data.py
@@ -3,21 +3,42 @@
 from cogdl.datasets import register_dataset
 from cogdl.data import Dataset, Data
 
+
 @register_dataset("test_small")
 class TestSmallDataset(Dataset):
     r"""small dataset for debug"""
-    def __init__(self, args=None):
-        x = torch.FloatTensor([[-2, -1], [-2, 1], [-1, 0], [0, 0], [0, 1], [1, 0], [2, 1], [3, 0], [2, -1], [4, 0], [4, 1], [5, 0]])
-        edge_index = torch.LongTensor([[0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 7, 7, 7, 8, 8, 9, 9 , 9 , 10, 10, 11, 11],
-                                       [1, 2, 0, 2, 0, 1, 3, 2, 4, 5, 3, 3, 6, 7, 8, 5, 7, 5, 6, 8, 9, 5, 7, 7, 10, 11, 9 , 11, 9 , 10]])
+
+    def __init__(self):
+        super(TestSmallDataset, self).__init__("test")
+        x = torch.FloatTensor(
+            [[-2, -1], [-2, 1], [-1, 0], [0, 0], [0, 1], [1, 0], [2, 1], [3, 0], [2, -1], [4, 0], [4, 1], [5, 0]]
+        )
+        edge_index = torch.LongTensor(
+            [
+                [0, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 5, 5, 5, 6, 6, 7, 7, 7, 7, 8, 8, 9, 9, 9, 10, 10, 11, 11],
+                [1, 2, 0, 2, 0, 1, 3, 2, 4, 5, 3, 3, 6, 7, 8, 5, 7, 5, 6, 8, 9, 5, 7, 7, 10, 11, 9, 11, 9, 10],
+            ]
+        )
         y = torch.LongTensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3])
         self.data = Data(x, edge_index, None, y, None)
-        self.data.train_mask = torch.tensor([True, False, False, True, False, True, False, False, False, True, False, False])
-        self.data.val_mask = torch.tensor([False, True, False, False, False, False, True, False, False, False, False, True])
-        self.data.test_mask = torch.tensor([False, False, True, False, True, False, False, True, True, False, True, False])
+        self.data.train_mask = torch.tensor(
+            [True, False, False, True, False, True, False, False, False, True, False, False]
+        )
+        self.data.val_mask = torch.tensor(
+            [False, True, False, False, False, False, True, False, False, False, False, True]
+        )
+        self.data.test_mask = torch.tensor(
+            [False, False, True, False, True, False, False, True, True, False, True, False]
+        )
         self.num_classes = 4
         self.transform = None
 
     def get(self, idx):
         assert idx == 0
-        return self.data
\ No newline at end of file
+        return self.data
+
+    def _download(self):
+        pass
+
+    def _process(self):
+        pass
diff --git a/cogdl/datasets/tu_data.py b/cogdl/datasets/tu_data.py
index 1ebd2c61..4e7d096c 100644
--- a/cogdl/datasets/tu_data.py
+++ b/cogdl/datasets/tu_data.py
@@ -275,7 +275,7 @@ def get(self, idx):
 
 @register_dataset("mutag")
 class MUTAGDataset(TUDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "MUTAG"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -285,7 +285,7 @@ def __init__(self, args=None):
 
 @register_dataset("imdb-b")
 class ImdbBinaryDataset(TUDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "IMDB-BINARY"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -295,7 +295,7 @@ def __init__(self, args=None):
 
 @register_dataset("imdb-m")
 class ImdbMultiDataset(TUDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "IMDB-MULTI"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -305,7 +305,7 @@ def __init__(self, args=None):
 
 @register_dataset("collab")
 class CollabDataset(TUDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "COLLAB"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -315,7 +315,7 @@ def __init__(self, args=None):
 
 @register_dataset("proteins")
 class ProtainsDataset(TUDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "PROTEINS"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -325,7 +325,7 @@ def __init__(self, args=None):
 
 @register_dataset("reddit-b")
 class RedditBinary(TUDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "REDDIT-BINARY"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -335,7 +335,7 @@ def __init__(self, args=None):
 
 @register_dataset("reddit-multi-5k")
 class RedditMulti5K(TUDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "REDDIT-MULTI-5K"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -345,7 +345,7 @@ def __init__(self, args=None):
 
 @register_dataset("reddit-multi-12k")
 class RedditMulti12K(TUDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "REDDIT-MULTI-12K"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -355,7 +355,7 @@ def __init__(self, args=None):
 
 @register_dataset("ptc-mr")
 class PTCMRDataset(TUDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "PTC_MR"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -365,7 +365,7 @@ def __init__(self, args=None):
 
 @register_dataset("nci1")
 class NCT1Dataset(TUDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "NCI1"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -375,7 +375,7 @@ def __init__(self, args=None):
 
 @register_dataset("nci109")
 class NCT109Dataset(TUDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "NCI109"
         path = osp.join("data", dataset)
         if not osp.exists(path):
@@ -385,7 +385,7 @@ def __init__(self, args=None):
 
 @register_dataset("enzymes")
 class ENZYMES(TUDataset):
-    def __init__(self, args=None):
+    def __init__(self):
         dataset = "ENZYMES"
         path = osp.join("data", dataset)
         if not osp.exists(path):
diff --git a/cogdl/experiments.py b/cogdl/experiments.py
index ba80dadc..d370fbda 100644
--- a/cogdl/experiments.py
+++ b/cogdl/experiments.py
@@ -12,6 +12,7 @@
 from cogdl.tasks import build_task
 from cogdl.utils import set_random_seed, tabulate_results
 from cogdl.configs import BEST_CONFIGS
+from cogdl.datasets import SUPPORTED_DATASETS
 from cogdl.models import SUPPORTED_MODELS
 
 
@@ -116,7 +117,11 @@ def check_task_dataset_model_match(task, variants):
 
     clean_variants = []
     for item in variants:
-        if item.model in SUPPORTED_MODELS and (item.model, item.dataset) not in pairs:
+        if (
+            (item.dataset in SUPPORTED_DATASETS)
+            and (item.model in SUPPORTED_MODELS)
+            and (item.model, item.dataset) not in pairs
+        ):
             print(f"({item.model}, {item.dataset}) is not implemented in task '{task}''.")
             continue
         clean_variants.append(item)
diff --git a/cogdl/layers/strategies_layers.py b/cogdl/layers/strategies_layers.py
index 07705bbc..7d5a0179 100644
--- a/cogdl/layers/strategies_layers.py
+++ b/cogdl/layers/strategies_layers.py
@@ -7,7 +7,7 @@
 import torch.nn as nn
 import torch.nn.functional as F
 from cogdl.datasets import build_dataset_from_name
-from cogdl.datasets.pyg_strategies_data import (
+from cogdl.datasets.strategies_data import (
     BioDataset,
     ChemExtractSubstructureContextPair,
     DataLoaderFinetune,
@@ -22,7 +22,9 @@
 )
 from cogdl.utils import add_self_loops, batch_mean_pooling, batch_sum_pooling, cycle_index
 from sklearn.metrics import roc_auc_score
-from torch_geometric.data import DataLoader
+
+# from torch_geometric.data import DataLoader
+from cogdl.data import DataLoader
 
 
 class GINConv(nn.Module):
diff --git a/cogdl/models/__init__.py b/cogdl/models/__init__.py
index 1531ae2a..f088fcb9 100644
--- a/cogdl/models/__init__.py
+++ b/cogdl/models/__init__.py
@@ -111,7 +111,7 @@ def build_model(args):
     "fastgcn": "cogdl.models.nn.fastgcn",
     "mlp": "cogdl.models.nn.mlp",
     "sgc": "cogdl.models.nn.sgc",
-    "stpgnn": "cogdl.models.nn.pyg_stpgnn",
+    "stpgnn": "cogdl.models.nn.stpgnn",
     "sortpool": "cogdl.models.nn.sortpool",
     "srgcn": "cogdl.models.nn.pyg_srgcn",
     "asgcn": "cogdl.models.nn.asgcn",
diff --git a/cogdl/models/emb/dgk.py b/cogdl/models/emb/dgk.py
index dd496ef5..f797de53 100644
--- a/cogdl/models/emb/dgk.py
+++ b/cogdl/models/emb/dgk.py
@@ -2,7 +2,6 @@
 from joblib import Parallel, delayed
 import networkx as nx
 import numpy as np
-from gensim.models.doc2vec import Doc2Vec, TaggedDocument
 from gensim.models.word2vec import Word2Vec
 from .. import BaseModel, register_model
 
diff --git a/cogdl/models/nn/__init__.py b/cogdl/models/nn/__init__.py
index e69de29b..9e2c8b2c 100644
--- a/cogdl/models/nn/__init__.py
+++ b/cogdl/models/nn/__init__.py
@@ -0,0 +1,31 @@
+from .compgcn import CompGCN, CompGCNLayer
+from .dgi import DGIModel
+from .disengcn import DisenGCN, DisenGCNLayer
+from .gat import GATLayer
+from .gcn import GraphConvolution, TKipfGCN
+from .gcnii import GCNIILayer, GCNII
+from .gdc_gcn import GDC_GCN
+from .grace import GRACE, GraceEncoder
+from .graphsage import Graphsage, GraphSAGELayer
+from .mvgrl import MVGRL
+from .patchy_san import PatchySAN
+from .ppnp import PPNP
+from .rgcn import RGCNLayer, LinkPredictRGCN, RGCN
+from .sgc import SimpleGraphConvolution, sgc
+
+__all__ = [
+    "CompGCN",
+    "DGIModel",
+    "DisenGCN",
+    "GATLayer",
+    "TKipfGCN",
+    "GCNII",
+    "GDC_GCN",
+    "GRACE",
+    "Graphsage",
+    "MVGRL",
+    "PatchySAN",
+    "PPNP",
+    "RGCN",
+    "sgc",
+]
diff --git a/cogdl/models/nn/pyg_stpgnn.py b/cogdl/models/nn/stpgnn.py
similarity index 100%
rename from cogdl/models/nn/pyg_stpgnn.py
rename to cogdl/models/nn/stpgnn.py
diff --git a/cogdl/utils/utils.py b/cogdl/utils/utils.py
index ddecdc50..bb29e8a8 100644
--- a/cogdl/utils/utils.py
+++ b/cogdl/utils/utils.py
@@ -335,10 +335,10 @@ def coalesce(row, col, value=None):
     if mask.all():
         return row, col.value
     row = row[mask]
-    col = col[mask]
     if value is not None:
         _value = torch.zeros(row.shape[0], dtype=torch.float).to(row.device)
         value = _value.scatter_add_(dim=0, src=value, index=col)
+    col = col[mask]
     return row, col, value
 
 
diff --git a/examples/custom_dataset.py b/examples/custom_dataset.py
index b78c7f79..a8ce74b6 100644
--- a/examples/custom_dataset.py
+++ b/examples/custom_dataset.py
@@ -1,65 +1,40 @@
-from cogdl.data.data import Data
 import torch
 
-from cogdl.tasks import build_task
-from cogdl.models import build_model
-from cogdl.options import get_task_model_args
+from cogdl import experiment
+from cogdl.data import Data
+from cogdl.datasets import BaseDataset, register_dataset
 
 
-"""Define your data"""
-
-class MyData(Data):
+@register_dataset("mydataset")
+class MyNodeClassificationDataset(BaseDataset):
     def __init__(self):
-        super(MyData, self).__init__()
+        super(MyNodeClassificationDataset, self).__init__()
+        self.data = self.process()
+
+    def process(self):
         num_nodes = 100
         num_edges = 300
         feat_dim = 30
-        # load or generate data
-        self.edge_index = torch.randint(0, num_nodes, (2, num_edges))
-        self.x = torch.randn(num_nodes, feat_dim)
-        self.y = torch.randint(0, 2, (num_nodes,))
-
-        # set train/val/test mask in node_classification task
-        self.train_mask = torch.zeros(num_nodes).bool()
-        self.train_mask[0 : int(0.3 * num_nodes)] = True
-        self.val_mask = torch.zeros(num_nodes).bool()
-        self.val_mask[int(0.3 * num_nodes) : int(0.7 * num_nodes)] = True
-        self.test_mask = torch.zeros(num_nodes).bool()
-        self.test_mask[int(0.7 * num_nodes) :] = True
-
-
-"""Define your dataset"""
-
-class MyNodeClassificationDataset(object):
-    def __init__(self):
-        self.data = MyData()
-        self.num_classes = self.data.num_classes
-        self.num_features = self.data.num_features
-
-    def __getitem__(self, index):
-        assert index == 0
-        return self.data
 
+        # load or generate your dataset
+        edge_index = torch.randint(0, num_nodes, (2, num_edges))
+        x = torch.randn(num_nodes, feat_dim)
+        y = torch.randint(0, 2, (num_nodes,))
 
-def set_args(args):
-    """Change default setttings"""
-    cuda_available = torch.cuda.is_available()
-    args.cpu = not cuda_available
-    return args
-
-
-def main_dataset():
-    args = get_task_model_args(task="node_classification", model="gcn")
-    # use customized dataset
-    dataset = MyNodeClassificationDataset()
-    args.num_features = dataset.num_features
-    args.num_classes = dataset.num_classes
-    # use model in cogdl
-    model = build_model(args)
-    task = build_task(args, dataset, model)
-    result = task.train()
-    print(result)
+        # set train/val/test mask in node_classification task
+        train_mask = torch.zeros(num_nodes).bool()
+        train_mask[0 : int(0.3 * num_nodes)] = True
+        val_mask = torch.zeros(num_nodes).bool()
+        val_mask[int(0.3 * num_nodes) : int(0.7 * num_nodes)] = True
+        test_mask = torch.zeros(num_nodes).bool()
+        test_mask[int(0.7 * num_nodes) :] = True
+        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
+        torch.save(data, "mydata.pt")
+        return data
 
 
 if __name__ == "__main__":
-    main_dataset()
+    # Run with self-loaded dataset
+    experiment(task="node_classification", dataset="mydataset", model="gcn")
+    # Run with given datapaath
+    experiment(task="node_classification", dataset="./mydata.pt", model="gcn")
diff --git a/tests/datasets/test_customized_data.py b/tests/datasets/test_customized_data.py
new file mode 100644
index 00000000..0ac2e91b
--- /dev/null
+++ b/tests/datasets/test_customized_data.py
@@ -0,0 +1,48 @@
+import torch
+from cogdl.data import Data
+from cogdl.datasets import BaseDataset, register_dataset, build_dataset, build_dataset_from_name
+from cogdl.utils import build_args_from_dict
+
+
+@register_dataset("mydataset")
+class MyNodeClassificationDataset(BaseDataset):
+    def __init__(self):
+        super(MyNodeClassificationDataset, self).__init__()
+        self.data = self.process()
+
+    def process(self):
+        num_nodes = 100
+        num_edges = 300
+        feat_dim = 30
+
+        # load or generate your dataset
+        edge_index = torch.randint(0, num_nodes, (2, num_edges))
+        x = torch.randn(num_nodes, feat_dim)
+        y = torch.randint(0, 2, (num_nodes,))
+
+        # set train/val/test mask in node_classification task
+        train_mask = torch.zeros(num_nodes).bool()
+        train_mask[0 : int(0.3 * num_nodes)] = True
+        val_mask = torch.zeros(num_nodes).bool()
+        val_mask[int(0.3 * num_nodes) : int(0.7 * num_nodes)] = True
+        test_mask = torch.zeros(num_nodes).bool()
+        test_mask[int(0.7 * num_nodes) :] = True
+        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
+        torch.save(data, "mydata.pt")
+        return data
+
+
+def test_customized_dataset():
+    dataset = build_dataset_from_name("mydataset")
+    assert isinstance(dataset[0], Data)
+    assert dataset[0].x.shape[0] == 100
+
+
+def test_build_dataset_from_path():
+    args = build_args_from_dict({"dataset": "mydata.pt", "task": "node_classification"})
+    dataset = build_dataset(args)
+    assert dataset[0].x.shape[0] == 100
+
+
+if __name__ == "__main__":
+    test_customized_dataset()
diff --git a/tests/tasks/test_pretrain.py b/tests/tasks/test_pretrain.py
index 42c5b903..ca85dd6a 100644
--- a/tests/tasks/test_pretrain.py
+++ b/tests/tasks/test_pretrain.py
@@ -153,14 +153,14 @@ def test_bace():
 
 
 if __name__ == "__main__":
-    test_stpgnn_infomax()
-    test_stpgnn_contextpred()
-    test_stpgnn_mask()
-    test_stpgnn_supervised()
-    test_stpgnn_finetune()
-    test_chem_contextpred()
-    test_chem_infomax()
-    test_chem_mask()
+    # test_stpgnn_infomax()
+    # test_stpgnn_contextpred()
+    # test_stpgnn_mask()
+    # test_stpgnn_supervised()
+    # test_stpgnn_finetune()
+    # test_chem_contextpred()
+    # test_chem_infomax()
+    # test_chem_mask()
     test_chem_supervised()
     test_bace()
     test_bbbp()