-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathlearn.py
143 lines (119 loc) · 3.76 KB
/
learn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
from typing import Dict
import os
import torch
from torch import optim
from datasets import Dataset
from models import CP, ComplEx
from regularizers import F2, N3
from optimizers import KBCOptimizer
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
big_datasets = ['FB15K', 'WN', 'WN18RR', 'FB237', 'YAGO3-10','FB15K_1','FB15K_3','FB15K_5','FB15K_7','FB15K_15','FB15K_10','FB15K_20','FB15K_40','FB15K_60','FB15K_80','FB15K_100']
datasets = big_datasets
parser = argparse.ArgumentParser(
description="Relational learning contraption"
)
parser.add_argument(
'--dataset', choices=datasets,
help="Dataset in {}".format(datasets)
)
models = ['CP', 'ComplEx']
parser.add_argument(
'--model', choices=models,
help="Model in {}".format(models)
)
regularizers = ['N3', 'F2']
parser.add_argument(
'--regularizer', choices=regularizers, default='N3',
help="Regularizer in {}".format(regularizers)
)
optimizers = ['Adagrad', 'Adam', 'SGD']
parser.add_argument(
'--optimizer', choices=optimizers, default='Adagrad',
help="Optimizer in {}".format(optimizers)
)
parser.add_argument(
'--max_epochs', default=50, type=int,
help="Number of epochs."
)
parser.add_argument(
'--valid', default=3, type=float,
help="Number of epochs before valid."
)
parser.add_argument(
'--rank', default=1000, type=int,
help="Factorization rank."
)
parser.add_argument(
'--batch_size', default=1000, type=int,
help="Factorization rank."
)
parser.add_argument(
'--reg', default=0, type=float,
help="Regularization weight"
)
parser.add_argument(
'--init', default=1e-3, type=float,
help="Initial scale"
)
parser.add_argument(
'--learning_rate', default=1e-1, type=float,
help="Learning rate"
)
parser.add_argument(
'--decay1', default=0.9, type=float,
help="decay rate for the first moment estimate in Adam"
)
parser.add_argument(
'--decay2', default=0.999, type=float,
help="decay rate for second moment estimate in Adam"
)
args = parser.parse_args()
dataset = Dataset(args.dataset)
examples = torch.from_numpy(dataset.get_train().astype('int64'))
print(dataset.get_shape())
model = {
'CP': lambda: CP(dataset.get_shape(), args.rank, args.init),
'ComplEx': lambda: ComplEx(dataset.get_shape(), args.rank, args.init),
}[args.model]()
regularizer = {
'F2': F2(args.reg),
'N3': N3(args.reg),
}[args.regularizer]
device = 'cuda'
model.to(device)
optim_method = {
'Adagrad': lambda: optim.Adagrad(model.parameters(), lr=args.learning_rate),
'Adam': lambda: optim.Adam(model.parameters(), lr=args.learning_rate, betas=(args.decay1, args.decay2)),
'SGD': lambda: optim.SGD(model.parameters(), lr=args.learning_rate)
}[args.optimizer]()
optimizer = KBCOptimizer(model, regularizer, optim_method, args.batch_size)
def avg_both(mrrs: Dict[str, float], hits: Dict[str, torch.FloatTensor]):
"""
aggregate metrics for missing lhs and rhs
:param mrrs: d
:param hits:
:return:
"""
m = (mrrs['lhs'] + mrrs['rhs']) / 2.
h = (hits['lhs'] + hits['rhs']) / 2.
return {'MRR': m, 'hits@[1,3,10]': h}
cur_loss = 0
curve = {'train': [], 'valid': [], 'test': []}
for e in range(args.max_epochs):
cur_loss = optimizer.epoch(examples)
if (e + 1) % args.valid == 0:
valid, test, train = [
avg_both(*dataset.eval(model, split, -1 if split != 'train' else 50000))
for split in ['valid', 'test', 'train']
]
curve['valid'].append(valid)
curve['test'].append(test)
curve['train'].append(train)
print("\t TRAIN: ", train)
print("\t TEST : ", test)
print("\t VALID : ", valid)
results = dataset.eval(model, 'test', -1)
print("\n\nTEST : ", results)