-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
64 lines (57 loc) · 3.09 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import torch.nn as nn
import torch_geometric
from torch_geometric.datasets import Planetoid, Amazon, Coauthor, CoraFull
class LoadData(nn.Module):
def __init__(self, root, name, pre_transform, transform):
super(LoadData, self).__init__()
self.root = root
self.name = name.lower()
self.pre_transform = pre_transform
self.transform = transform
def load(self):
if self.name == "cora":
self.dataset = Planetoid(root=self.root, name="Cora",
pre_transform=self.pre_transform, transform=self.transform)
elif self.name == "citeseer":
self.dataset = Planetoid(root=self.root, name="CiteSeer",
pre_transform=self.pre_transform, transform=self.transform)
elif self.name == "pubmed":
self.dataset = Planetoid(root=self.root, name="PubMed",
pre_transform=self.pre_transform, transform=self.transform)
elif self.name == "corafull":
self.dataset = CoraFull(root=self.root,
pre_transform=self.pre_transform, transform=self.transform)
elif self.name == "computers":
self.dataset = Amazon(root=self.root, name="Computers",
pre_transform=self.pre_transform, transform=self.transform)
elif self.name == "photo":
self.dataset = Amazon(root=self.root, name="Photo",
pre_transform=self.pre_transform, transform=self.transform)
elif self.name == "cs":
self.dataset = Coauthor(root=self.root, name="CS",
pre_transform=self.pre_transform, transform=self.transform)
elif self.name == "physics":
self.dataset = Coauthor(root=self.root, name="Physics",
pre_transform=self.pre_transform, transform=self.transform)
else:
raise ValueError("{} dataset is not included".format(self.name))
def split(self, split_type="random", num_train_per_class=20, num_val=500, num_test=1000):
data = self.dataset.get(0)
if split_type=="public" and hasattr(data, "train_mask"):
train_mask = data.train_mask
val_mask = data.val_mask
test_mask = data.test_mask
else:
train_mask = torch.zeros_like(data.y, dtype=torch.bool)
val_mask = torch.zeros_like(data.y, dtype=torch.bool)
test_mask = torch.zeros_like(data.y, dtype=torch.bool)
for c in range(self.dataset.num_classes):
idx = (data.y == c).nonzero(as_tuple=False).view(-1)
idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]]
train_mask[idx] = True
remaining = (~train_mask).nonzero(as_tuple=False).view(-1)
remaining = remaining[torch.randperm(remaining.size(0))]
val_mask[remaining[:num_val]] = True
test_mask[remaining[num_val:num_val + num_test]] = True
return (train_mask, val_mask, test_mask)