From d2f10e8b303b2853ab3dbce0d615c82631ea5536 Mon Sep 17 00:00:00 2001
From: yyhslyz <164326235@qq.com>
Date: Mon, 1 Feb 2021 20:06:15 +0800
Subject: [PATCH 1/2] [Feature] remove ogb's dependency on pyg
---
cogdl/configs.py | 30 +++-
cogdl/datasets/README.md | 6 +-
cogdl/datasets/__init__.py | 20 +--
cogdl/datasets/ogb.py | 195 ++++++++++++++++++++++++
cogdl/datasets/pyg_ogb.py | 153 -------------------
cogdl/match.yml | 2 +
cogdl/models/nn/gat.py | 96 ++++++++++--
cogdl/models/nn/gcn.py | 7 +-
cogdl/tasks/graph_classification.py | 1 -
cogdl/tasks/node_classification.py | 1 +
tests/datasets/test_ogb.py | 9 +-
tests/tasks/test_node_classification.py | 11 +-
12 files changed, 336 insertions(+), 195 deletions(-)
create mode 100644 cogdl/datasets/ogb.py
delete mode 100644 cogdl/datasets/pyg_ogb.py
diff --git a/cogdl/configs.py b/cogdl/configs.py
index 37cb7c51..49b5bd73 100644
--- a/cogdl/configs.py
+++ b/cogdl/configs.py
@@ -226,6 +226,11 @@
"num_layers": 5,
"dropout": 0.0,
},
+ "nci1": {
+ "num_layers": 5,
+ "dropout": 0.3,
+ "hidden_size": 64,
+ },
},
"infograph": {
"general": {
@@ -240,6 +245,14 @@
"imdb-b": {"degree_feature": True},
"imdb-m": {"degree_feature": True},
"collab": {"degree_feature": True},
+ "nci1": {"num_layers": 3},
+ },
+ "sortpool": {
+ "nci1": {
+ "dropout": 0.3,
+ "hidden_size": 64,
+ "num_layers": 5,
+ },
},
"patchy_san": {
"general": {
@@ -253,7 +266,22 @@
"collab": {"degree_feature": True},
},
},
- "unsupervised_graph_classification": {},
+ "unsupervised_graph_classification": {
+ "graph2vec": {
+ "general": {},
+ "nci1": {
+ "lr": 0.001,
+ "window_size": 8,
+ "epoch": 10,
+ "iteration": 4,
+ },
+ "reddit-b": {
+ "lr": 0.01,
+ "degree_feature": True,
+ "hidden_size": 128,
+ },
+ }
+ },
"link_prediction": {},
"multiplex_link_prediction": {
"gatne": {
diff --git a/cogdl/datasets/README.md b/cogdl/datasets/README.md
index 623669e9..66fb6d83 100644
--- a/cogdl/datasets/README.md
+++ b/cogdl/datasets/README.md
@@ -27,7 +27,7 @@ CogDL now supports the following datasets for different tasks:
Transductive |
- Cora |
+ Cora |
2,708 |
5,429 |
1,433 |
@@ -123,7 +123,8 @@ CogDL now supports the following datasets for different tasks:
-Network Embedding(Unsupervsed Node classification)
+Network Embedding(Unsupervised Node classification)
+
Dataset |
@@ -184,7 +185,6 @@ CogDL now supports the following datasets for different tasks:
-
Heterogenous Graph
diff --git a/cogdl/datasets/__init__.py b/cogdl/datasets/__init__.py
index 716eb1d7..6347040b 100644
--- a/cogdl/datasets/__init__.py
+++ b/cogdl/datasets/__init__.py
@@ -81,16 +81,16 @@ def build_dataset_from_path(data_path, task):
"sigmod_icde": "cogdl.datasets.gcc_data",
"usa-airport": "cogdl.datasets.gcc_data",
"test_small": "cogdl.datasets.test_data",
- "ogbn-arxiv": "cogdl.datasets.pyg_ogb",
- "ogbn-products": "cogdl.datasets.pyg_ogb",
- "ogbn-proteins": "cogdl.datasets.pyg_ogb",
- "ogbn-mag": "cogdl.datasets.pyg_ogb",
- "ogbn-papers100M": "cogdl.datasets.pyg_ogb",
- "ogbg-molbace": "cogdl.datasets.pyg_ogb",
- "ogbg-molhiv": "cogdl.datasets.pyg_ogb",
- "ogbg-molpcba": "cogdl.datasets.pyg_ogb",
- "ogbg-ppa": "cogdl.datasets.pyg_ogb",
- "ogbg-code": "cogdl.datasets.pyg_ogb",
+ "ogbn-arxiv": "cogdl.datasets.ogb",
+ "ogbn-products": "cogdl.datasets.ogb",
+ "ogbn-proteins": "cogdl.datasets.ogb",
+ "ogbn-mag": "cogdl.datasets.ogb",
+ "ogbn-papers100M": "cogdl.datasets.ogb",
+ "ogbg-molbace": "cogdl.datasets.ogb",
+ "ogbg-molhiv": "cogdl.datasets.ogb",
+ "ogbg-molpcba": "cogdl.datasets.ogb",
+ "ogbg-ppa": "cogdl.datasets.ogb",
+ "ogbg-code": "cogdl.datasets.ogb",
"amazon": "cogdl.datasets.gatne",
"twitter": "cogdl.datasets.gatne",
"youtube": "cogdl.datasets.gatne",
diff --git a/cogdl/datasets/ogb.py b/cogdl/datasets/ogb.py
new file mode 100644
index 00000000..08e0c284
--- /dev/null
+++ b/cogdl/datasets/ogb.py
@@ -0,0 +1,195 @@
+import os.path as osp
+
+import torch
+
+from ogb.nodeproppred import NodePropPredDataset
+from ogb.graphproppred import GraphPropPredDataset
+
+from . import register_dataset
+from cogdl.data import Dataset, Data, DataLoader
+from cogdl.utils import cross_entropy_loss, accuracy, remove_self_loops
+
+
+def coalesce(row, col, edge_attr=None):
+ row = torch.tensor(row)
+ col = torch.tensor(col)
+ if edge_attr is not None:
+ edge_attr = torch.tensor(edge_attr)
+ num = col.shape[0] + 1
+ idx = torch.full((num,), -1, dtype=torch.float)
+ idx[1:] = row * num + col
+ mask = idx[1:] > idx[:-1]
+
+ if mask.all():
+ return row, col, edge_attr
+ row = row[mask]
+ col = col[mask]
+ if edge_attr is not None:
+ edge_attr = edge_attr[mask]
+ return row, col, edge_attr
+
+
+class OGBNDataset(Dataset):
+ def __init__(self, root, name):
+ dataset = NodePropPredDataset(name, root)
+ graph, y = dataset[0]
+ x = torch.tensor(graph["node_feat"])
+ y = torch.tensor(y.squeeze())
+ row, col, edge_attr = coalesce(graph["edge_index"][0], graph["edge_index"][1], graph["edge_feat"])
+ edge_index = torch.stack([row, col], dim=0)
+ edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
+ row = torch.cat([edge_index[0], edge_index[1]])
+ col = torch.cat([edge_index[1], edge_index[0]])
+ edge_index = torch.stack([row, col], dim=0)
+ if edge_attr is not None:
+ edge_attr = torch.cat([edge_attr, edge_attr], dim=0)
+
+ self.data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
+ self.data.num_nodes = graph["num_nodes"]
+ assert self.data.num_nodes == self.data.x.shape[0]
+
+ # split
+ split_index = dataset.get_idx_split()
+ self.data.train_mask = torch.zeros(self.data.num_nodes, dtype=torch.bool)
+ self.data.test_mask = torch.zeros(self.data.num_nodes, dtype=torch.bool)
+ self.data.val_mask = torch.zeros(self.data.num_nodes, dtype=torch.bool)
+ self.data.train_mask[split_index["train"]] = True
+ self.data.test_mask[split_index["test"]] = True
+ self.data.val_mask[split_index["valid"]] = True
+
+ self.transform = None
+
+ def get(self, idx):
+ assert idx == 0
+ return self.data
+
+ def get_loss_fn(self):
+ return cross_entropy_loss
+
+ def get_evaluator(self):
+ return accuracy
+
+
+@register_dataset("ogbn-arxiv")
+class OGBArxivDataset(OGBNDataset):
+ def __init__(self):
+ dataset = "ogbn-arxiv"
+ path = "data"
+ super(OGBArxivDataset, self).__init__(path, dataset)
+
+
+@register_dataset("ogbn-products")
+class OGBProductsDataset(OGBNDataset):
+ def __init__(self):
+ dataset = "ogbn-products"
+ path = "data"
+ super(OGBProductsDataset, self).__init__(path, dataset)
+
+
+@register_dataset("ogbn-proteins")
+class OGBProteinsDataset(OGBNDataset):
+ def __init__(self):
+ dataset = "ogbn-proteins"
+ path = "data"
+ super(OGBProteinsDataset, self).__init__(path, dataset)
+
+
+@register_dataset("ogbn-mag")
+class OGBMAGDataset(OGBNDataset):
+ def __init__(self):
+ dataset = "ogbn-mag"
+ path = "data"
+ super(OGBMAGDataset, self).__init__(path, dataset)
+
+
+@register_dataset("ogbn-papers100M")
+class OGBPapers100MDataset(OGBNDataset):
+ def __init__(self):
+ dataset = "ogbn-papers100M"
+ path = "data"
+ super(OGBPapers100MDataset, self).__init__(path, dataset)
+
+
+class OGBGDataset(Dataset):
+ def __init__(self, root, name):
+ self.name = name
+ self.dataset = GraphPropPredDataset(self.name, root)
+
+ self.graphs = []
+ self.all_nodes = 0
+ self.all_edges = 0
+ for i in range(len(self.dataset.graphs)):
+ graph, label = self.dataset[i]
+ data = Data(
+ x=torch.tensor(graph["node_feat"], dtype=torch.float),
+ edge_index=torch.tensor(graph["edge_index"]),
+ edge_attr=None if "edge_feat" not in graph else torch.tensor(graph["edge_feat"], dtype=torch.float),
+ y=torch.tensor(label),
+ )
+ data.num_nodes = graph["num_nodes"]
+ self.graphs.append(data)
+
+ self.all_nodes += graph["num_nodes"]
+ self.all_edges += graph["edge_index"].shape[1]
+
+ self.transform = None
+
+ def get_loader(self, args):
+ split_index = self.dataset.get_idx_split()
+ train_loader = DataLoader(self.get_subset(split_index["train"]), batch_size=args.batch_size, shuffle=True)
+ valid_loader = DataLoader(self.get_subset(split_index["valid"]), batch_size=args.batch_size, shuffle=False)
+ test_loader = DataLoader(self.get_subset(split_index["test"]), batch_size=args.batch_size, shuffle=False)
+ return train_loader, valid_loader, test_loader
+
+ def get_subset(self, subset):
+ datalist = []
+ for idx in subset:
+ datalist.append(self.graphs[idx])
+ return datalist
+
+ def get(self, idx):
+ return self.graphs[idx]
+
+ @property
+ def num_classes(self):
+ return int(self.dataset.num_classes)
+
+
+@register_dataset("ogbg-molbace")
+class OGBMolbaceDataset(OGBGDataset):
+ def __init__(self):
+ dataset = "ogbg-molbace"
+ path = "data"
+ super(OGBMolbaceDataset, self).__init__(path, dataset)
+
+
+@register_dataset("ogbg-molhiv")
+class OGBMolhivDataset(OGBGDataset):
+ def __init__(self):
+ dataset = "ogbg-molhiv"
+ path = "data"
+ super(OGBMolhivDataset, self).__init__(path, dataset)
+
+
+@register_dataset("ogbg-molpcba")
+class OGBMolpcbaDataset(OGBGDataset):
+ def __init__(self):
+ dataset = "ogbg-molpcba"
+ path = "data"
+ super(OGBMolpcbaDataset, self).__init__(path, dataset)
+
+
+@register_dataset("ogbg-ppa")
+class OGBPpaDataset(OGBGDataset):
+ def __init__(self):
+ dataset = "ogbg-ppa"
+ path = "data"
+ super(OGBPpaDataset, self).__init__(path, dataset)
+
+
+@register_dataset("ogbg-code")
+class OGBCodeDataset(OGBGDataset):
+ def __init__(self):
+ dataset = "ogbg-code"
+ path = "data"
+ super(OGBCodeDataset, self).__init__(path, dataset)
diff --git a/cogdl/datasets/pyg_ogb.py b/cogdl/datasets/pyg_ogb.py
deleted file mode 100644
index cd7fedc5..00000000
--- a/cogdl/datasets/pyg_ogb.py
+++ /dev/null
@@ -1,153 +0,0 @@
-import os.path as osp
-
-import torch
-from torch_geometric.data import DataLoader
-from torch_sparse import coalesce
-
-from ogb.nodeproppred import PygNodePropPredDataset
-from ogb.graphproppred import PygGraphPropPredDataset
-
-from . import register_dataset
-
-
-class OGBNDataset(PygNodePropPredDataset):
- def __init__(self, root, name):
- super(OGBNDataset, self).__init__(name, root)
-
- self.data.num_nodes = self.data.num_nodes[0]
- # split
- split_index = self.get_idx_split()
- self.data["train_mask"] = torch.zeros(self.data.num_nodes, dtype=torch.bool)
- self.data["test_mask"] = torch.zeros(self.data.num_nodes, dtype=torch.bool)
- self.data["val_mask"] = torch.zeros(self.data.num_nodes, dtype=torch.bool)
- self.data["train_mask"][split_index["train"]] = True
- self.data["test_mask"][split_index["test"]] = True
- self.data["val_mask"][split_index["valid"]] = True
-
- self.data.y = self.data.y.squeeze()
-
- def get(self, idx):
- assert idx == 0
- return self.data
-
-
-@register_dataset("ogbn-arxiv")
-class OGBArxivDataset(OGBNDataset):
- def __init__(self):
- dataset = "ogbn-arxiv"
- path = osp.join("data", dataset)
- if not osp.exists(path):
- PygNodePropPredDataset(dataset, path)
- super(OGBArxivDataset, self).__init__(path, dataset)
-
- # to_symmetric
- rev_edge_index = self.data.edge_index[[1, 0]]
- edge_index = torch.cat([self.data.edge_index, rev_edge_index], dim=1).to(dtype=torch.int64)
- self.data.edge_index, self.data.edge_attr = coalesce(edge_index, None, self.data.num_nodes, self.data.num_nodes)
-
-
-@register_dataset("ogbn-products")
-class OGBProductsDataset(OGBNDataset):
- def __init__(self):
- dataset = "ogbn-products"
- path = osp.join("data", dataset)
- if not osp.exists(path):
- PygNodePropPredDataset(dataset, path)
- super(OGBProductsDataset, self).__init__(path, dataset)
-
-
-@register_dataset("ogbn-proteins")
-class OGBProteinsDataset(OGBNDataset):
- def __init__(self):
- dataset = "ogbn-proteins"
- path = osp.join("data", dataset)
- if not osp.exists(path):
- PygNodePropPredDataset(dataset, path)
- super(OGBProteinsDataset, self).__init__(path, dataset)
-
-
-@register_dataset("ogbn-mag")
-class OGBMAGDataset(OGBNDataset):
- def __init__(self):
- dataset = "ogbn-mag"
- path = osp.join("data", dataset)
- if not osp.exists(path):
- PygNodePropPredDataset(dataset, path)
- super(OGBMAGDataset, self).__init__(path, dataset)
-
-
-@register_dataset("ogbn-papers100M")
-class OGBPapers100MDataset(OGBNDataset):
- def __init__(self):
- dataset = "ogbn-papers100M"
- path = osp.join("data", dataset)
- if not osp.exists(path):
- PygNodePropPredDataset(dataset, path)
- super(OGBPapers100MDataset, self).__init__(path, dataset)
-
-
-class OGBGDataset(PygGraphPropPredDataset):
- def __init__(self, root, name):
- super(OGBGDataset, self).__init__(name, root)
- self.name = name
-
- def get_loader(self, args):
- split_index = self.get_idx_split()
- dataset = PygGraphPropPredDataset(self.name, osp.join("data", self.name))
- train_loader = DataLoader(dataset[split_index["train"]], batch_size=args.batch_size, shuffle=True)
- valid_loader = DataLoader(dataset[split_index["valid"]], batch_size=args.batch_size, shuffle=False)
- test_loader = DataLoader(dataset[split_index["test"]], batch_size=args.batch_size, shuffle=False)
- return train_loader, valid_loader, test_loader
-
- def get(self, idx):
- return self.data
-
-
-@register_dataset("ogbg-molbace")
-class OGBMolbaceDataset(OGBGDataset):
- def __init__(self):
- dataset = "ogbg-molbace"
- path = osp.join("data", dataset)
- if not osp.exists(path):
- PygGraphPropPredDataset(dataset, path)
- super(OGBMolbaceDataset, self).__init__(path, dataset)
-
-
-@register_dataset("ogbg-molhiv")
-class OGBMolhivDataset(OGBGDataset):
- def __init__(self):
- dataset = "ogbg-molhiv"
- path = osp.join("data", dataset)
- if not osp.exists(path):
- PygGraphPropPredDataset(dataset, path)
- super(OGBMolhivDataset, self).__init__(path, dataset)
-
-
-@register_dataset("ogbg-molpcba")
-class OGBMolpcbaDataset(OGBGDataset):
- def __init__(self):
- dataset = "ogbg-molpcba"
- path = osp.join("data", dataset)
- if not osp.exists(path):
- PygGraphPropPredDataset(dataset, path)
- super(OGBMolpcbaDataset, self).__init__(path, dataset)
-
-
-@register_dataset("ogbg-ppa")
-class OGBPpaDataset(OGBGDataset):
- def __init__(self):
- dataset = "ogbg-ppa"
- path = osp.join("data", dataset)
- if not osp.exists(path):
- PygGraphPropPredDataset(dataset, path)
- super(OGBPpaDataset, self).__init__(path, dataset)
-
-
-@register_dataset("ogbg-code")
-class OGBCodeDataset(OGBGDataset):
- def __init__(self):
- dataset = "ogbg-code"
- path = osp.join("data", dataset)
- if not osp.exists(path):
- PygGraphPropPredDataset(dataset, path)
- super(OGBCodeDataset, self).__init__(path, dataset)
diff --git a/cogdl/match.yml b/cogdl/match.yml
index eb86f508..8d62c9bd 100644
--- a/cogdl/match.yml
+++ b/cogdl/match.yml
@@ -96,6 +96,8 @@ graph_classification:
- reddit-b
- reddit-multi-5k
- reddit-multi-12k
+ - ogbg-molbace
+ - ogbg-molhiv
unsupervised_graph_classification:
- model:
- infograph
diff --git a/cogdl/models/nn/gat.py b/cogdl/models/nn/gat.py
index 15267493..5b16fc86 100644
--- a/cogdl/models/nn/gat.py
+++ b/cogdl/models/nn/gat.py
@@ -12,7 +12,9 @@ class GATLayer(nn.Module):
Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
"""
- def __init__(self, in_features, out_features, nhead=1, alpha=0.2, dropout=0.6, concat=True, fast_mode=False):
+ def __init__(
+ self, in_features, out_features, nhead=1, alpha=0.2, dropout=0.6, concat=True, residual=False, fast_mode=False
+ ):
super(GATLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
@@ -28,6 +30,12 @@ def __init__(self, in_features, out_features, nhead=1, alpha=0.2, dropout=0.6, c
self.dropout = nn.Dropout(dropout)
self.leakyrelu = nn.LeakyReLU(self.alpha)
+
+ if residual:
+ out_features = out_features * nhead if concat else out_features
+ self.residual = nn.Linear(in_features, out_features)
+ else:
+ self.register_buffer("residual", None)
self.reset_parameters()
def reset_parameters(self):
@@ -84,12 +92,17 @@ def forward(self, x, edge):
assert not torch.isnan(hidden).any()
h_prime.append(spmm(edge, edge_weight, hidden))
+ if self.residual:
+ res = self.residual(x)
+ else:
+ res = 0
+
if self.concat:
# if this layer is not last layer,
- out = torch.cat(h_prime, dim=1)
+ out = torch.cat(h_prime, dim=1) + res
else:
# if this layer is last layer,
- out = sum(h_prime) / self.nhead
+ out = sum(h_prime) / self.nhead + res
return out
def __repr__(self):
@@ -115,11 +128,14 @@ def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument("--num-features", type=int)
+ parser.add_argument("--num-layers", type=int, default=2)
+ parser.add_argument("--residual", action="store_true")
parser.add_argument("--num-classes", type=int)
parser.add_argument("--hidden-size", type=int, default=8)
parser.add_argument("--dropout", type=float, default=0.6)
parser.add_argument("--alpha", type=float, default=0.2)
- parser.add_argument("--nheads", type=int, default=8)
+ parser.add_argument("--nhead", type=int, default=8)
+ parser.add_argument("--last-nhead", type=int, default=1)
parser.add_argument("--fast-mode", action="store_true", default=False)
# fmt: on
@@ -129,31 +145,81 @@ def build_model_from_args(cls, args):
args.num_features,
args.hidden_size,
args.num_classes,
+ args.num_layers,
args.dropout,
args.alpha,
- args.nheads,
+ args.nhead,
+ args.residual,
+ args.last_nhead,
args.fast_mode,
)
- def __init__(self, in_feats, hidden_size, out_features, dropout, alpha, nheads, fast_mode=False):
+ def __init__(
+ self,
+ in_feats,
+ hidden_size,
+ out_features,
+ num_layers,
+ dropout,
+ alpha,
+ nhead,
+ residual,
+ last_nhead,
+ fast_mode=False,
+ ):
"""Sparse version of GAT."""
super(GAT, self).__init__()
self.dropout = dropout
-
- self.attention = GATLayer(
- in_feats, hidden_size, dropout=dropout, alpha=alpha, nhead=nheads, concat=True, fast_mode=fast_mode
+ self.attentions = nn.ModuleList()
+ self.attentions.append(
+ GATLayer(
+ in_feats,
+ hidden_size,
+ nhead=nhead,
+ dropout=dropout,
+ alpha=alpha,
+ concat=True,
+ residual=residual,
+ fast_mode=fast_mode,
+ )
)
- self.out_att = GATLayer(
- hidden_size * nheads, out_features, dropout=dropout, alpha=alpha, nhead=1, concat=False, fast_mode=False
+ for i in range(num_layers - 2):
+ self.attentions.append(
+ GATLayer(
+ hidden_size * nhead,
+ hidden_size,
+ nhead=nhead,
+ dropout=dropout,
+ alpha=alpha,
+ concat=True,
+ residual=residual,
+ fast_mode=fast_mode,
+ )
+ )
+ self.attentions.append(
+ GATLayer(
+ hidden_size * nhead,
+ out_features,
+ dropout=dropout,
+ alpha=alpha,
+ concat=False,
+ nhead=last_nhead,
+ residual=False,
+ fast_mode=fast_mode,
+ )
)
+ self.num_layers = num_layers
+ self.last_nhead = last_nhead
+ self.residual = residual
def forward(self, x, edge_index):
edge_index, _ = add_remaining_self_loops(edge_index)
- x = F.dropout(x, p=self.dropout, training=self.training)
- x = F.elu(self.attention(x, edge_index))
- x = F.dropout(x, p=self.dropout, training=self.training)
- x = F.elu(self.out_att(x, edge_index))
+ for i, layer in enumerate(self.attentions):
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = layer(x, edge_index)
+ if i != self.num_layers - 1:
+ x = F.elu(x)
return x
def predict(self, data):
diff --git a/cogdl/models/nn/gcn.py b/cogdl/models/nn/gcn.py
index 1ca4409a..00b3f0b2 100644
--- a/cogdl/models/nn/gcn.py
+++ b/cogdl/models/nn/gcn.py
@@ -34,13 +34,8 @@ def reset_parameters(self):
def forward(self, input, edge_index, edge_attr=None):
if edge_attr is None:
edge_attr = torch.ones(edge_index.shape[1]).float().to(input.device)
- adj = torch.sparse_coo_tensor(
- edge_index,
- edge_attr,
- (input.shape[0], input.shape[0]),
- ).to(input.device)
support = torch.mm(input, self.weight)
- output = torch.spmm(adj, support)
+ output = spmm(edge_index, edge_attr, support)
if self.bias is not None:
return output + self.bias
else:
diff --git a/cogdl/tasks/graph_classification.py b/cogdl/tasks/graph_classification.py
index 68826a1f..e73f0a73 100644
--- a/cogdl/tasks/graph_classification.py
+++ b/cogdl/tasks/graph_classification.py
@@ -79,7 +79,6 @@ def __init__(self, args, dataset=None, model=None):
self.device = "cpu" if not torch.cuda.is_available() or args.cpu else args.device_id[0]
if args.dataset.startswith("ogbg"):
- self.data = dataset.data
self.train_loader, self.val_loader, self.test_loader = dataset.get_loader(args)
model = build_model(args) if model is None else model
else:
diff --git a/cogdl/tasks/node_classification.py b/cogdl/tasks/node_classification.py
index db27aff1..d2ffefeb 100644
--- a/cogdl/tasks/node_classification.py
+++ b/cogdl/tasks/node_classification.py
@@ -23,6 +23,7 @@ class NodeClassification(BaseTask):
def add_args(parser: argparse.ArgumentParser):
"""Add task-specific arguments to the parser."""
# fmt: off
+ parser.add_argument("--missing-rate", type=int, default=0, help="missing rate, from 0 to 100")
# fmt: on
def __init__(
diff --git a/tests/datasets/test_ogb.py b/tests/datasets/test_ogb.py
index 6886e4e8..51a3997c 100644
--- a/tests/datasets/test_ogb.py
+++ b/tests/datasets/test_ogb.py
@@ -8,17 +8,16 @@ def test_ogbn_arxiv():
dataset = build_dataset(args)
data = dataset.data
assert data.num_nodes == 169343
- assert data.num_edges == 2315598
+ assert data.num_edges == 1136420
def test_ogbg_molhiv():
args = build_args_from_dict({"dataset": "ogbg-molhiv"})
assert args.dataset == "ogbg-molhiv"
dataset = build_dataset(args)
- data = dataset.data
- assert data.edge_index.shape[1] == 2259376
- assert data.x.shape[0] == 1049163
- assert data.y.shape[0] == 41127
+ assert dataset.all_edges == 2259376
+ assert dataset.all_nodes == 1049163
+ assert len(dataset.graphs) == 41127
if __name__ == "__main__":
diff --git a/tests/tasks/test_node_classification.py b/tests/tasks/test_node_classification.py
index 36412086..2881a8fb 100644
--- a/tests/tasks/test_node_classification.py
+++ b/tests/tasks/test_node_classification.py
@@ -66,13 +66,22 @@ def test_gat_cora():
args.dataset = "cora"
args.model = "gat"
args.alpha = 0.2
- args.nheads = 8
+ args.nhead = 8
+ args.residual = False
+ args.last_nhead = 2
+ args.num_layers = 2
for i in [True, False]:
args.fast_mode = i
task = build_task(args)
ret = task.train()
assert 0 <= ret["Acc"] <= 1
+ args.num_layers = 3
+ args.residual = True
+ task = build_task(args)
+ ret = task.train()
+ assert 0 <= ret["Acc"] <= 1
+
def test_mlp_pubmed():
args = get_default_args()
From fc443537c2f3be8e7a29a3f28733535e120dbf87 Mon Sep 17 00:00:00 2001
From: yyhslyz <164326235@qq.com>
Date: Mon, 1 Feb 2021 20:14:22 +0800
Subject: [PATCH 2/2] fix typo
---
cogdl/tasks/node_classification.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/cogdl/tasks/node_classification.py b/cogdl/tasks/node_classification.py
index d2ffefeb..db27aff1 100644
--- a/cogdl/tasks/node_classification.py
+++ b/cogdl/tasks/node_classification.py
@@ -23,7 +23,6 @@ class NodeClassification(BaseTask):
def add_args(parser: argparse.ArgumentParser):
"""Add task-specific arguments to the parser."""
# fmt: off
- parser.add_argument("--missing-rate", type=int, default=0, help="missing rate, from 0 to 100")
# fmt: on
def __init__(