Skip to content

Commit

Permalink
[Dataset] Add ogbn-mag dataset (#226)
Browse files Browse the repository at this point in the history
* Add ogbn-mag and fix bugs

* fix bugs

* fix bugs in __getitem__
  • Loading branch information
THINK2TRY authored Apr 23, 2021
1 parent 5c14b9a commit 1dec0dd
Showing 2 changed files with 214 additions and 15 deletions.
46 changes: 39 additions & 7 deletions cogdl/data/data.py
Original file line number Diff line number Diff line change
@@ -130,13 +130,14 @@ def cuda(self, *keys):


class Adjacency(BaseGraph):
def __init__(self, row=None, col=None, row_ptr=None, weight=None, attr=None, num_nodes=None, **kwargs):
def __init__(self, row=None, col=None, row_ptr=None, weight=None, attr=None, num_nodes=None, types=None, **kwargs):
super(Adjacency, self).__init__()
self.row = row
self.col = col
self.row_ptr = row_ptr
self.weight = weight
self.attr = attr
self.types = types
self.__num_nodes__ = num_nodes
self.__normed__ = None
self.__in_norm__ = self.__out_norm__ = None
@@ -294,12 +295,26 @@ def __out_repr__(self):
]
return info

# def __getitem__(self, item):
# assert type(item) == str, f"{item} must be str"
# if item[0] == "_" and item[1] != "_":
# # item = re.search("[_]*(.*)", item).group(1)
# item = item[1:]
# if item.startswith("edge_") and item != "edge_index":
# item = item[5:]
# return getattr(self, item)

def __getitem__(self, item):
assert type(item) == str, f"{item} must be str"
if item[0] == "_" and item[1] != "_":
# item = re.search("[_]*(.*)", item).group(1)
item = item[1:]
return getattr(self, item)
if item.startswith("edge_") and item != "edge_index":
item = item[5:]
if item in self.__dict__:
return self.__dict__[item]
else:
raise KeyError(f"{item} not in Adjacency")

def __copy__(self):
result = self.__class__()
@@ -337,10 +352,7 @@ def from_dict(dictionary):
return data


KEY_MAP = {
"edge_weight": "weight",
"edge_attr": "attr",
}
KEY_MAP = {"edge_weight": "weight", "edge_attr": "attr", "edge_types": "types"}
EDGE_INDEX = "edge_index"
EDGE_WEIGHT = "edge_weight"
EDGE_ATTR = "edge_attr"
@@ -353,7 +365,7 @@ def is_adj_key_train(key):


def is_adj_key(key):
return key in ["row", "col", "row_ptr", "attr", "weight"]
return key in ["row", "col", "row_ptr", "attr", "weight", "types"] or key.startswith("edge_")


def is_read_adj_key(key):
@@ -465,6 +477,10 @@ def edge_weight(self):
def edge_attr(self):
return self._adj.attr

@property
def edge_types(self):
return self._adj.types

@edge_index.setter
def edge_index(self, edge_index):
row, col = edge_index
@@ -482,6 +498,10 @@ def edge_weight(self, edge_weight):
def edge_attr(self, edge_attr):
self._adj.attr = edge_attr

@edge_types.setter
def edge_types(self, edge_types):
self._adj.types = edge_types

@property
def row_indptr(self):
if self._adj.row_ptr is None:
@@ -531,12 +551,20 @@ def __old_keys__(self):
def __getitem__(self, key):
r"""Gets the data of the attribute :obj:`key`."""
if is_adj_key(key):
if key[0] == "_" and key[1] != "_":
key = key[1:]
if key.startswith("edge_") and key != "edge_index":
key = key[5:]
return getattr(self._adj, key)
else:
return getattr(self, key)

def __setitem__(self, key, value):
if is_adj_key(key):
if key[0] == "_" and key[1] != "_":
key = key[1:]
if key.startswith("edge_") and key != "edge_index":
key = key[5:]
self._adj[key] = value
else:
setattr(self, key, value)
@@ -654,6 +682,10 @@ def csr_subgraph(self, node_idx):
data = Graph(row_ptr=indptr, col=indices, weight=edge_weight)
for key in self.__keys__():
data[key] = self[key][nodes_idx]
for key in self._adj.keys:
if "row" in key or "col" in key:
continue
data._adj[key] = self._adj[key][edges]
return data

def subgraph(self, node_idx):
183 changes: 175 additions & 8 deletions cogdl/datasets/ogb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import numpy as np
import torch

from ogb.nodeproppred import NodePropPredDataset
from ogb.nodeproppred import Evaluator as NodeEvaluator
from ogb.graphproppred import GraphPropPredDataset

from . import register_dataset
@@ -82,6 +85,15 @@ def __init__(self):
path = "data"
super(OGBArxivDataset, self).__init__(path, dataset)

def get_evaluator(self):
evaluator = NodeEvaluator(name="ogbn-arxiv")

def wrap(y_pred, y_true):
input_dict = {"y_true": y_true, "y_pred": y_pred}
return evaluator.eval(input_dict)

return wrap


@register_dataset("ogbn-products")
class OGBProductsDataset(OGBNDataset):
@@ -99,14 +111,6 @@ def __init__(self):
super(OGBProteinsDataset, self).__init__(path, dataset)


@register_dataset("ogbn-mag")
class OGBMAGDataset(OGBNDataset):
def __init__(self):
dataset = "ogbn-mag"
path = "data"
super(OGBMAGDataset, self).__init__(path, dataset)


@register_dataset("ogbn-papers100M")
class OGBPapers100MDataset(OGBNDataset):
def __init__(self):
@@ -115,6 +119,169 @@ def __init__(self):
super(OGBPapers100MDataset, self).__init__(path, dataset)


@register_dataset("ogbn-mag")
class MAGDataset(Dataset):
def __init__(self):
self.name = "ogbn-mag"
name = "_".join(self.name.split("-"))
self.root = "./data/" + name
super(MAGDataset, self).__init__(self.root)
data = torch.load(self.processed_paths[0])
(self.data, self.node_type_dict, self.edge_type_dict, self.num_nodes_dict) = data
self.paper_feat = torch.as_tensor(np.load(self.processed_paths[1]))
self.other_feat = torch.as_tensor(np.load(self.processed_paths[2]))
if self.other_feat.shape[0] == self.data.num_nodes:
self.other_feat = self.other_feat[self.paper_feat.shape[0] :]

def __len__(self):
return 1

def get(self, idx):
assert idx == 0
return self.data

def _download(self):
pass

def process(self):
dataset = NodePropPredDataset(name=self.name, root="./data")
node_type_dict = {"paper": 0, "author": 1, "field_of_study": 2, "institution": 3}
edge_type_dict = {"cites": 0, "affiliated_with": 1, "writes": 2, "has_topic": 3}
num_nodes_dict = dataset[0][0]["num_nodes_dict"]
num_nodes = torch.as_tensor(
[0]
+ [
num_nodes_dict["paper"],
num_nodes_dict["author"],
num_nodes_dict["field_of_study"],
num_nodes_dict["institution"],
]
)
cum_num_nodes = torch.cumsum(num_nodes, dim=-1)
node_types = torch.repeat_interleave(torch.arange(0, 4), num_nodes[1:])

edge_index_dict = dataset[0][0]["edge_index_dict"]

edge_index = [None] * len(edge_type_dict)
edge_attr = [None] * len(edge_type_dict)

i = 0
for k, v in edge_index_dict.items():
head, edge_type, tail = k
head_offset = cum_num_nodes[node_type_dict[head]].item()
tail_offset = cum_num_nodes[node_type_dict[tail]].item()
src = v[0] + head_offset
tgt = v[1] + tail_offset
edge_tps = np.full(src.shape, edge_type_dict[edge_type])

_src = np.concatenate([src, tgt])
_tgt = np.concatenate([tgt, src])
if edge_type == "cites":
re_tps = np.full(src.shape, edge_type_dict[edge_type])
else:
re_tps = np.full(src.shape, len(edge_type_dict))
edge_type_dict[edge_type + "_re"] = len(edge_type_dict)
edge_index[i] = np.vstack([_src, _tgt])
edge_tps = np.concatenate([edge_tps, re_tps])
edge_attr[i] = edge_tps
i += 1
edge_index = np.concatenate(edge_index, axis=-1)
edge_index = torch.from_numpy(edge_index)
edge_attr = torch.from_numpy(np.concatenate(edge_attr))

assert edge_index.shape[1] == edge_attr.shape[0]

split_index = dataset.get_idx_split()
train_index = torch.from_numpy(split_index["train"]["paper"])
val_index = torch.from_numpy(split_index["valid"]["paper"])
test_index = torch.from_numpy(split_index["test"]["paper"])
y = torch.as_tensor(dataset[0][1]["paper"]).view(-1)

paper_feat = dataset[0][0]["node_feat_dict"]["paper"]
data = Graph(
y=y,
edge_index=edge_index,
edge_types=edge_attr,
train_mask=train_index,
val_mask=val_index,
test_mask=test_index,
node_types=node_types,
)
# self.save_edges(data)
torch.save((data, node_type_dict, edge_type_dict, num_nodes_dict), self.processed_paths[0])
np.save(self.processed_paths[1], paper_feat)

def get_evaluator(self):
evaluator = NodeEvaluator(name="ogbn-mag")

def wrap(y_pred, y_true):
y_pred = y_pred.argmax(dim=-1, keepdim=True)
y_true = y_true.view(-1, 1)
input_dict = {"y_true": y_true, "y_pred": y_pred}
return evaluator.eval(input_dict)["acc"]

return wrap

@property
def processed_file_names(self):
return ["data.pt", "paper_feat.npy", "other_feat.npy"]

@property
def num_node_types(self):
return len(self.node_type_dict)

@property
def num_edge_types(self):
return len(self.edge_type_dict)

@property
def num_papers(self):
return self.num_nodes_dict["paper"]

@property
def num_authors(self):
return self.num_nodes_dict["author"]

@property
def num_institutions(self):
return self.num_nodes_dict["institution"]

@property
def num_field_of_study(self):
return self.num_nodes_dict["field_of_study"]

def save_edges(self, data):
edge_index = data.edge_index.numpy().transpose()
edge_types = data.edge_types.numpy()
os.makedirs("./ogbn_mag_kg", exist_ok=True)
with open("./ogbn_mag_kg/train.txt", "w") as f:
for i in range(edge_index.shape[0]):
edge = edge_index[i]
tp = edge_types[i]
f.write(f"{edge[0]}\t{edge[1]}\t{tp}\n")

with open("./ogbn_mag_kg/valid.txt", "w") as f:
val_num = np.random.randint(0, edge_index.shape[0], (10000,))
for i in val_num:
edge = edge_index[i]
tp = edge_types[i]
f.write(f"{edge[0]}\t{edge[1]}\t{tp}\n")

with open("./ogbn_mag_kg/test.txt", "w") as f:
val_num = np.random.randint(0, edge_index.shape[0], (20000,))
for i in val_num:
edge = edge_index[i]
tp = edge_types[i]
f.write(f"{edge[0]}\t{edge[1]}\t{tp}\n")

with open("./ogbn_mag_kg/entities.dict", "w") as f:
for i in range(np.max(edge_index)):
f.write(f"{i}\t{i}\n")
with open("./ogbn_mag_kg/relations.dict", "w") as f:
for i in range(np.max(edge_types)):
f.write(f"{i}\t{i}\n")


class OGBGDataset(Dataset):
def __init__(self, root, name):
super(OGBGDataset, self).__init__(root)

0 comments on commit 1dec0dd

Please sign in to comment.