From d2f10e8b303b2853ab3dbce0d615c82631ea5536 Mon Sep 17 00:00:00 2001 From: yyhslyz <164326235@qq.com> Date: Mon, 1 Feb 2021 20:06:15 +0800 Subject: [PATCH 1/2] [Feature] remove ogb's dependency on pyg --- cogdl/configs.py | 30 +++- cogdl/datasets/README.md | 6 +- cogdl/datasets/__init__.py | 20 +-- cogdl/datasets/ogb.py | 195 ++++++++++++++++++++++++ cogdl/datasets/pyg_ogb.py | 153 ------------------- cogdl/match.yml | 2 + cogdl/models/nn/gat.py | 96 ++++++++++-- cogdl/models/nn/gcn.py | 7 +- cogdl/tasks/graph_classification.py | 1 - cogdl/tasks/node_classification.py | 1 + tests/datasets/test_ogb.py | 9 +- tests/tasks/test_node_classification.py | 11 +- 12 files changed, 336 insertions(+), 195 deletions(-) create mode 100644 cogdl/datasets/ogb.py delete mode 100644 cogdl/datasets/pyg_ogb.py diff --git a/cogdl/configs.py b/cogdl/configs.py index 37cb7c51..49b5bd73 100644 --- a/cogdl/configs.py +++ b/cogdl/configs.py @@ -226,6 +226,11 @@ "num_layers": 5, "dropout": 0.0, }, + "nci1": { + "num_layers": 5, + "dropout": 0.3, + "hidden_size": 64, + }, }, "infograph": { "general": { @@ -240,6 +245,14 @@ "imdb-b": {"degree_feature": True}, "imdb-m": {"degree_feature": True}, "collab": {"degree_feature": True}, + "nci1": {"num_layers": 3}, + }, + "sortpool": { + "nci1": { + "dropout": 0.3, + "hidden_size": 64, + "num_layers": 5, + }, }, "patchy_san": { "general": { @@ -253,7 +266,22 @@ "collab": {"degree_feature": True}, }, }, - "unsupervised_graph_classification": {}, + "unsupervised_graph_classification": { + "graph2vec": { + "general": {}, + "nci1": { + "lr": 0.001, + "window_size": 8, + "epoch": 10, + "iteration": 4, + }, + "reddit-b": { + "lr": 0.01, + "degree_feature": True, + "hidden_size": 128, + }, + } + }, "link_prediction": {}, "multiplex_link_prediction": { "gatne": { diff --git a/cogdl/datasets/README.md b/cogdl/datasets/README.md index 623669e9..66fb6d83 100644 --- a/cogdl/datasets/README.md +++ b/cogdl/datasets/README.md @@ -27,7 +27,7 @@ CogDL now supports the following datasets for different tasks: Transductive - Cora + Cora 2,708 5,429 1,433 @@ -123,7 +123,8 @@ CogDL now supports the following datasets for different tasks: -

Network Embedding(Unsupervsed Node classification)

+

Network Embedding(Unsupervised Node classification)

+ @@ -184,7 +185,6 @@ CogDL now supports the following datasets for different tasks:
Dataset
-

Heterogenous Graph

diff --git a/cogdl/datasets/__init__.py b/cogdl/datasets/__init__.py index 716eb1d7..6347040b 100644 --- a/cogdl/datasets/__init__.py +++ b/cogdl/datasets/__init__.py @@ -81,16 +81,16 @@ def build_dataset_from_path(data_path, task): "sigmod_icde": "cogdl.datasets.gcc_data", "usa-airport": "cogdl.datasets.gcc_data", "test_small": "cogdl.datasets.test_data", - "ogbn-arxiv": "cogdl.datasets.pyg_ogb", - "ogbn-products": "cogdl.datasets.pyg_ogb", - "ogbn-proteins": "cogdl.datasets.pyg_ogb", - "ogbn-mag": "cogdl.datasets.pyg_ogb", - "ogbn-papers100M": "cogdl.datasets.pyg_ogb", - "ogbg-molbace": "cogdl.datasets.pyg_ogb", - "ogbg-molhiv": "cogdl.datasets.pyg_ogb", - "ogbg-molpcba": "cogdl.datasets.pyg_ogb", - "ogbg-ppa": "cogdl.datasets.pyg_ogb", - "ogbg-code": "cogdl.datasets.pyg_ogb", + "ogbn-arxiv": "cogdl.datasets.ogb", + "ogbn-products": "cogdl.datasets.ogb", + "ogbn-proteins": "cogdl.datasets.ogb", + "ogbn-mag": "cogdl.datasets.ogb", + "ogbn-papers100M": "cogdl.datasets.ogb", + "ogbg-molbace": "cogdl.datasets.ogb", + "ogbg-molhiv": "cogdl.datasets.ogb", + "ogbg-molpcba": "cogdl.datasets.ogb", + "ogbg-ppa": "cogdl.datasets.ogb", + "ogbg-code": "cogdl.datasets.ogb", "amazon": "cogdl.datasets.gatne", "twitter": "cogdl.datasets.gatne", "youtube": "cogdl.datasets.gatne", diff --git a/cogdl/datasets/ogb.py b/cogdl/datasets/ogb.py new file mode 100644 index 00000000..08e0c284 --- /dev/null +++ b/cogdl/datasets/ogb.py @@ -0,0 +1,195 @@ +import os.path as osp + +import torch + +from ogb.nodeproppred import NodePropPredDataset +from ogb.graphproppred import GraphPropPredDataset + +from . import register_dataset +from cogdl.data import Dataset, Data, DataLoader +from cogdl.utils import cross_entropy_loss, accuracy, remove_self_loops + + +def coalesce(row, col, edge_attr=None): + row = torch.tensor(row) + col = torch.tensor(col) + if edge_attr is not None: + edge_attr = torch.tensor(edge_attr) + num = col.shape[0] + 1 + idx = torch.full((num,), -1, dtype=torch.float) + idx[1:] = row * num + col + mask = idx[1:] > idx[:-1] + + if mask.all(): + return row, col, edge_attr + row = row[mask] + col = col[mask] + if edge_attr is not None: + edge_attr = edge_attr[mask] + return row, col, edge_attr + + +class OGBNDataset(Dataset): + def __init__(self, root, name): + dataset = NodePropPredDataset(name, root) + graph, y = dataset[0] + x = torch.tensor(graph["node_feat"]) + y = torch.tensor(y.squeeze()) + row, col, edge_attr = coalesce(graph["edge_index"][0], graph["edge_index"][1], graph["edge_feat"]) + edge_index = torch.stack([row, col], dim=0) + edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) + row = torch.cat([edge_index[0], edge_index[1]]) + col = torch.cat([edge_index[1], edge_index[0]]) + edge_index = torch.stack([row, col], dim=0) + if edge_attr is not None: + edge_attr = torch.cat([edge_attr, edge_attr], dim=0) + + self.data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) + self.data.num_nodes = graph["num_nodes"] + assert self.data.num_nodes == self.data.x.shape[0] + + # split + split_index = dataset.get_idx_split() + self.data.train_mask = torch.zeros(self.data.num_nodes, dtype=torch.bool) + self.data.test_mask = torch.zeros(self.data.num_nodes, dtype=torch.bool) + self.data.val_mask = torch.zeros(self.data.num_nodes, dtype=torch.bool) + self.data.train_mask[split_index["train"]] = True + self.data.test_mask[split_index["test"]] = True + self.data.val_mask[split_index["valid"]] = True + + self.transform = None + + def get(self, idx): + assert idx == 0 + return self.data + + def get_loss_fn(self): + return cross_entropy_loss + + def get_evaluator(self): + return accuracy + + +@register_dataset("ogbn-arxiv") +class OGBArxivDataset(OGBNDataset): + def __init__(self): + dataset = "ogbn-arxiv" + path = "data" + super(OGBArxivDataset, self).__init__(path, dataset) + + +@register_dataset("ogbn-products") +class OGBProductsDataset(OGBNDataset): + def __init__(self): + dataset = "ogbn-products" + path = "data" + super(OGBProductsDataset, self).__init__(path, dataset) + + +@register_dataset("ogbn-proteins") +class OGBProteinsDataset(OGBNDataset): + def __init__(self): + dataset = "ogbn-proteins" + path = "data" + super(OGBProteinsDataset, self).__init__(path, dataset) + + +@register_dataset("ogbn-mag") +class OGBMAGDataset(OGBNDataset): + def __init__(self): + dataset = "ogbn-mag" + path = "data" + super(OGBMAGDataset, self).__init__(path, dataset) + + +@register_dataset("ogbn-papers100M") +class OGBPapers100MDataset(OGBNDataset): + def __init__(self): + dataset = "ogbn-papers100M" + path = "data" + super(OGBPapers100MDataset, self).__init__(path, dataset) + + +class OGBGDataset(Dataset): + def __init__(self, root, name): + self.name = name + self.dataset = GraphPropPredDataset(self.name, root) + + self.graphs = [] + self.all_nodes = 0 + self.all_edges = 0 + for i in range(len(self.dataset.graphs)): + graph, label = self.dataset[i] + data = Data( + x=torch.tensor(graph["node_feat"], dtype=torch.float), + edge_index=torch.tensor(graph["edge_index"]), + edge_attr=None if "edge_feat" not in graph else torch.tensor(graph["edge_feat"], dtype=torch.float), + y=torch.tensor(label), + ) + data.num_nodes = graph["num_nodes"] + self.graphs.append(data) + + self.all_nodes += graph["num_nodes"] + self.all_edges += graph["edge_index"].shape[1] + + self.transform = None + + def get_loader(self, args): + split_index = self.dataset.get_idx_split() + train_loader = DataLoader(self.get_subset(split_index["train"]), batch_size=args.batch_size, shuffle=True) + valid_loader = DataLoader(self.get_subset(split_index["valid"]), batch_size=args.batch_size, shuffle=False) + test_loader = DataLoader(self.get_subset(split_index["test"]), batch_size=args.batch_size, shuffle=False) + return train_loader, valid_loader, test_loader + + def get_subset(self, subset): + datalist = [] + for idx in subset: + datalist.append(self.graphs[idx]) + return datalist + + def get(self, idx): + return self.graphs[idx] + + @property + def num_classes(self): + return int(self.dataset.num_classes) + + +@register_dataset("ogbg-molbace") +class OGBMolbaceDataset(OGBGDataset): + def __init__(self): + dataset = "ogbg-molbace" + path = "data" + super(OGBMolbaceDataset, self).__init__(path, dataset) + + +@register_dataset("ogbg-molhiv") +class OGBMolhivDataset(OGBGDataset): + def __init__(self): + dataset = "ogbg-molhiv" + path = "data" + super(OGBMolhivDataset, self).__init__(path, dataset) + + +@register_dataset("ogbg-molpcba") +class OGBMolpcbaDataset(OGBGDataset): + def __init__(self): + dataset = "ogbg-molpcba" + path = "data" + super(OGBMolpcbaDataset, self).__init__(path, dataset) + + +@register_dataset("ogbg-ppa") +class OGBPpaDataset(OGBGDataset): + def __init__(self): + dataset = "ogbg-ppa" + path = "data" + super(OGBPpaDataset, self).__init__(path, dataset) + + +@register_dataset("ogbg-code") +class OGBCodeDataset(OGBGDataset): + def __init__(self): + dataset = "ogbg-code" + path = "data" + super(OGBCodeDataset, self).__init__(path, dataset) diff --git a/cogdl/datasets/pyg_ogb.py b/cogdl/datasets/pyg_ogb.py deleted file mode 100644 index cd7fedc5..00000000 --- a/cogdl/datasets/pyg_ogb.py +++ /dev/null @@ -1,153 +0,0 @@ -import os.path as osp - -import torch -from torch_geometric.data import DataLoader -from torch_sparse import coalesce - -from ogb.nodeproppred import PygNodePropPredDataset -from ogb.graphproppred import PygGraphPropPredDataset - -from . import register_dataset - - -class OGBNDataset(PygNodePropPredDataset): - def __init__(self, root, name): - super(OGBNDataset, self).__init__(name, root) - - self.data.num_nodes = self.data.num_nodes[0] - # split - split_index = self.get_idx_split() - self.data["train_mask"] = torch.zeros(self.data.num_nodes, dtype=torch.bool) - self.data["test_mask"] = torch.zeros(self.data.num_nodes, dtype=torch.bool) - self.data["val_mask"] = torch.zeros(self.data.num_nodes, dtype=torch.bool) - self.data["train_mask"][split_index["train"]] = True - self.data["test_mask"][split_index["test"]] = True - self.data["val_mask"][split_index["valid"]] = True - - self.data.y = self.data.y.squeeze() - - def get(self, idx): - assert idx == 0 - return self.data - - -@register_dataset("ogbn-arxiv") -class OGBArxivDataset(OGBNDataset): - def __init__(self): - dataset = "ogbn-arxiv" - path = osp.join("data", dataset) - if not osp.exists(path): - PygNodePropPredDataset(dataset, path) - super(OGBArxivDataset, self).__init__(path, dataset) - - # to_symmetric - rev_edge_index = self.data.edge_index[[1, 0]] - edge_index = torch.cat([self.data.edge_index, rev_edge_index], dim=1).to(dtype=torch.int64) - self.data.edge_index, self.data.edge_attr = coalesce(edge_index, None, self.data.num_nodes, self.data.num_nodes) - - -@register_dataset("ogbn-products") -class OGBProductsDataset(OGBNDataset): - def __init__(self): - dataset = "ogbn-products" - path = osp.join("data", dataset) - if not osp.exists(path): - PygNodePropPredDataset(dataset, path) - super(OGBProductsDataset, self).__init__(path, dataset) - - -@register_dataset("ogbn-proteins") -class OGBProteinsDataset(OGBNDataset): - def __init__(self): - dataset = "ogbn-proteins" - path = osp.join("data", dataset) - if not osp.exists(path): - PygNodePropPredDataset(dataset, path) - super(OGBProteinsDataset, self).__init__(path, dataset) - - -@register_dataset("ogbn-mag") -class OGBMAGDataset(OGBNDataset): - def __init__(self): - dataset = "ogbn-mag" - path = osp.join("data", dataset) - if not osp.exists(path): - PygNodePropPredDataset(dataset, path) - super(OGBMAGDataset, self).__init__(path, dataset) - - -@register_dataset("ogbn-papers100M") -class OGBPapers100MDataset(OGBNDataset): - def __init__(self): - dataset = "ogbn-papers100M" - path = osp.join("data", dataset) - if not osp.exists(path): - PygNodePropPredDataset(dataset, path) - super(OGBPapers100MDataset, self).__init__(path, dataset) - - -class OGBGDataset(PygGraphPropPredDataset): - def __init__(self, root, name): - super(OGBGDataset, self).__init__(name, root) - self.name = name - - def get_loader(self, args): - split_index = self.get_idx_split() - dataset = PygGraphPropPredDataset(self.name, osp.join("data", self.name)) - train_loader = DataLoader(dataset[split_index["train"]], batch_size=args.batch_size, shuffle=True) - valid_loader = DataLoader(dataset[split_index["valid"]], batch_size=args.batch_size, shuffle=False) - test_loader = DataLoader(dataset[split_index["test"]], batch_size=args.batch_size, shuffle=False) - return train_loader, valid_loader, test_loader - - def get(self, idx): - return self.data - - -@register_dataset("ogbg-molbace") -class OGBMolbaceDataset(OGBGDataset): - def __init__(self): - dataset = "ogbg-molbace" - path = osp.join("data", dataset) - if not osp.exists(path): - PygGraphPropPredDataset(dataset, path) - super(OGBMolbaceDataset, self).__init__(path, dataset) - - -@register_dataset("ogbg-molhiv") -class OGBMolhivDataset(OGBGDataset): - def __init__(self): - dataset = "ogbg-molhiv" - path = osp.join("data", dataset) - if not osp.exists(path): - PygGraphPropPredDataset(dataset, path) - super(OGBMolhivDataset, self).__init__(path, dataset) - - -@register_dataset("ogbg-molpcba") -class OGBMolpcbaDataset(OGBGDataset): - def __init__(self): - dataset = "ogbg-molpcba" - path = osp.join("data", dataset) - if not osp.exists(path): - PygGraphPropPredDataset(dataset, path) - super(OGBMolpcbaDataset, self).__init__(path, dataset) - - -@register_dataset("ogbg-ppa") -class OGBPpaDataset(OGBGDataset): - def __init__(self): - dataset = "ogbg-ppa" - path = osp.join("data", dataset) - if not osp.exists(path): - PygGraphPropPredDataset(dataset, path) - super(OGBPpaDataset, self).__init__(path, dataset) - - -@register_dataset("ogbg-code") -class OGBCodeDataset(OGBGDataset): - def __init__(self): - dataset = "ogbg-code" - path = osp.join("data", dataset) - if not osp.exists(path): - PygGraphPropPredDataset(dataset, path) - super(OGBCodeDataset, self).__init__(path, dataset) diff --git a/cogdl/match.yml b/cogdl/match.yml index eb86f508..8d62c9bd 100644 --- a/cogdl/match.yml +++ b/cogdl/match.yml @@ -96,6 +96,8 @@ graph_classification: - reddit-b - reddit-multi-5k - reddit-multi-12k + - ogbg-molbace + - ogbg-molhiv unsupervised_graph_classification: - model: - infograph diff --git a/cogdl/models/nn/gat.py b/cogdl/models/nn/gat.py index 15267493..5b16fc86 100644 --- a/cogdl/models/nn/gat.py +++ b/cogdl/models/nn/gat.py @@ -12,7 +12,9 @@ class GATLayer(nn.Module): Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 """ - def __init__(self, in_features, out_features, nhead=1, alpha=0.2, dropout=0.6, concat=True, fast_mode=False): + def __init__( + self, in_features, out_features, nhead=1, alpha=0.2, dropout=0.6, concat=True, residual=False, fast_mode=False + ): super(GATLayer, self).__init__() self.in_features = in_features self.out_features = out_features @@ -28,6 +30,12 @@ def __init__(self, in_features, out_features, nhead=1, alpha=0.2, dropout=0.6, c self.dropout = nn.Dropout(dropout) self.leakyrelu = nn.LeakyReLU(self.alpha) + + if residual: + out_features = out_features * nhead if concat else out_features + self.residual = nn.Linear(in_features, out_features) + else: + self.register_buffer("residual", None) self.reset_parameters() def reset_parameters(self): @@ -84,12 +92,17 @@ def forward(self, x, edge): assert not torch.isnan(hidden).any() h_prime.append(spmm(edge, edge_weight, hidden)) + if self.residual: + res = self.residual(x) + else: + res = 0 + if self.concat: # if this layer is not last layer, - out = torch.cat(h_prime, dim=1) + out = torch.cat(h_prime, dim=1) + res else: # if this layer is last layer, - out = sum(h_prime) / self.nhead + out = sum(h_prime) / self.nhead + res return out def __repr__(self): @@ -115,11 +128,14 @@ def add_args(parser): """Add model-specific arguments to the parser.""" # fmt: off parser.add_argument("--num-features", type=int) + parser.add_argument("--num-layers", type=int, default=2) + parser.add_argument("--residual", action="store_true") parser.add_argument("--num-classes", type=int) parser.add_argument("--hidden-size", type=int, default=8) parser.add_argument("--dropout", type=float, default=0.6) parser.add_argument("--alpha", type=float, default=0.2) - parser.add_argument("--nheads", type=int, default=8) + parser.add_argument("--nhead", type=int, default=8) + parser.add_argument("--last-nhead", type=int, default=1) parser.add_argument("--fast-mode", action="store_true", default=False) # fmt: on @@ -129,31 +145,81 @@ def build_model_from_args(cls, args): args.num_features, args.hidden_size, args.num_classes, + args.num_layers, args.dropout, args.alpha, - args.nheads, + args.nhead, + args.residual, + args.last_nhead, args.fast_mode, ) - def __init__(self, in_feats, hidden_size, out_features, dropout, alpha, nheads, fast_mode=False): + def __init__( + self, + in_feats, + hidden_size, + out_features, + num_layers, + dropout, + alpha, + nhead, + residual, + last_nhead, + fast_mode=False, + ): """Sparse version of GAT.""" super(GAT, self).__init__() self.dropout = dropout - - self.attention = GATLayer( - in_feats, hidden_size, dropout=dropout, alpha=alpha, nhead=nheads, concat=True, fast_mode=fast_mode + self.attentions = nn.ModuleList() + self.attentions.append( + GATLayer( + in_feats, + hidden_size, + nhead=nhead, + dropout=dropout, + alpha=alpha, + concat=True, + residual=residual, + fast_mode=fast_mode, + ) ) - self.out_att = GATLayer( - hidden_size * nheads, out_features, dropout=dropout, alpha=alpha, nhead=1, concat=False, fast_mode=False + for i in range(num_layers - 2): + self.attentions.append( + GATLayer( + hidden_size * nhead, + hidden_size, + nhead=nhead, + dropout=dropout, + alpha=alpha, + concat=True, + residual=residual, + fast_mode=fast_mode, + ) + ) + self.attentions.append( + GATLayer( + hidden_size * nhead, + out_features, + dropout=dropout, + alpha=alpha, + concat=False, + nhead=last_nhead, + residual=False, + fast_mode=fast_mode, + ) ) + self.num_layers = num_layers + self.last_nhead = last_nhead + self.residual = residual def forward(self, x, edge_index): edge_index, _ = add_remaining_self_loops(edge_index) - x = F.dropout(x, p=self.dropout, training=self.training) - x = F.elu(self.attention(x, edge_index)) - x = F.dropout(x, p=self.dropout, training=self.training) - x = F.elu(self.out_att(x, edge_index)) + for i, layer in enumerate(self.attentions): + x = F.dropout(x, p=self.dropout, training=self.training) + x = layer(x, edge_index) + if i != self.num_layers - 1: + x = F.elu(x) return x def predict(self, data): diff --git a/cogdl/models/nn/gcn.py b/cogdl/models/nn/gcn.py index 1ca4409a..00b3f0b2 100644 --- a/cogdl/models/nn/gcn.py +++ b/cogdl/models/nn/gcn.py @@ -34,13 +34,8 @@ def reset_parameters(self): def forward(self, input, edge_index, edge_attr=None): if edge_attr is None: edge_attr = torch.ones(edge_index.shape[1]).float().to(input.device) - adj = torch.sparse_coo_tensor( - edge_index, - edge_attr, - (input.shape[0], input.shape[0]), - ).to(input.device) support = torch.mm(input, self.weight) - output = torch.spmm(adj, support) + output = spmm(edge_index, edge_attr, support) if self.bias is not None: return output + self.bias else: diff --git a/cogdl/tasks/graph_classification.py b/cogdl/tasks/graph_classification.py index 68826a1f..e73f0a73 100644 --- a/cogdl/tasks/graph_classification.py +++ b/cogdl/tasks/graph_classification.py @@ -79,7 +79,6 @@ def __init__(self, args, dataset=None, model=None): self.device = "cpu" if not torch.cuda.is_available() or args.cpu else args.device_id[0] if args.dataset.startswith("ogbg"): - self.data = dataset.data self.train_loader, self.val_loader, self.test_loader = dataset.get_loader(args) model = build_model(args) if model is None else model else: diff --git a/cogdl/tasks/node_classification.py b/cogdl/tasks/node_classification.py index db27aff1..d2ffefeb 100644 --- a/cogdl/tasks/node_classification.py +++ b/cogdl/tasks/node_classification.py @@ -23,6 +23,7 @@ class NodeClassification(BaseTask): def add_args(parser: argparse.ArgumentParser): """Add task-specific arguments to the parser.""" # fmt: off + parser.add_argument("--missing-rate", type=int, default=0, help="missing rate, from 0 to 100") # fmt: on def __init__( diff --git a/tests/datasets/test_ogb.py b/tests/datasets/test_ogb.py index 6886e4e8..51a3997c 100644 --- a/tests/datasets/test_ogb.py +++ b/tests/datasets/test_ogb.py @@ -8,17 +8,16 @@ def test_ogbn_arxiv(): dataset = build_dataset(args) data = dataset.data assert data.num_nodes == 169343 - assert data.num_edges == 2315598 + assert data.num_edges == 1136420 def test_ogbg_molhiv(): args = build_args_from_dict({"dataset": "ogbg-molhiv"}) assert args.dataset == "ogbg-molhiv" dataset = build_dataset(args) - data = dataset.data - assert data.edge_index.shape[1] == 2259376 - assert data.x.shape[0] == 1049163 - assert data.y.shape[0] == 41127 + assert dataset.all_edges == 2259376 + assert dataset.all_nodes == 1049163 + assert len(dataset.graphs) == 41127 if __name__ == "__main__": diff --git a/tests/tasks/test_node_classification.py b/tests/tasks/test_node_classification.py index 36412086..2881a8fb 100644 --- a/tests/tasks/test_node_classification.py +++ b/tests/tasks/test_node_classification.py @@ -66,13 +66,22 @@ def test_gat_cora(): args.dataset = "cora" args.model = "gat" args.alpha = 0.2 - args.nheads = 8 + args.nhead = 8 + args.residual = False + args.last_nhead = 2 + args.num_layers = 2 for i in [True, False]: args.fast_mode = i task = build_task(args) ret = task.train() assert 0 <= ret["Acc"] <= 1 + args.num_layers = 3 + args.residual = True + task = build_task(args) + ret = task.train() + assert 0 <= ret["Acc"] <= 1 + def test_mlp_pubmed(): args = get_default_args() From fc443537c2f3be8e7a29a3f28733535e120dbf87 Mon Sep 17 00:00:00 2001 From: yyhslyz <164326235@qq.com> Date: Mon, 1 Feb 2021 20:14:22 +0800 Subject: [PATCH 2/2] fix typo --- cogdl/tasks/node_classification.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cogdl/tasks/node_classification.py b/cogdl/tasks/node_classification.py index d2ffefeb..db27aff1 100644 --- a/cogdl/tasks/node_classification.py +++ b/cogdl/tasks/node_classification.py @@ -23,7 +23,6 @@ class NodeClassification(BaseTask): def add_args(parser: argparse.ArgumentParser): """Add task-specific arguments to the parser.""" # fmt: off - parser.add_argument("--missing-rate", type=int, default=0, help="missing rate, from 0 to 100") # fmt: on def __init__(