-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_teacher.py
85 lines (73 loc) · 3.21 KB
/
main_teacher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import time
import torch
import argparse
import numpy as np
from sas import *
from utils import *
import pickle
def trans_to_cuda(variable):
if torch.cuda.is_available():
return variable.cuda()
else:
return variable
def trans_to_cpu(variable):
if torch.cuda.is_available():
return variable.cpu()
else:
return variable
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='Tmall')
parser.add_argument('--act', default='gelu')
parser.add_argument('--batch_size', default=100, type=int)
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--num_layer', default=1, type=int)
parser.add_argument('--maxlen', default=50, type=int)
parser.add_argument('--hidden_units', default=128, type=int)
parser.add_argument('--inner_units', default=128, type=int)
parser.add_argument('--num_blocks', default=3, type=int)
parser.add_argument('--num_epochs', default=201, type=int)
parser.add_argument('--num_heads', default=1, type=int)
parser.add_argument('--dropout_rate', default=0.5, type=float)
parser.add_argument('--l2_emb', default=0.0, type=float)
parser.add_argument('--epoch', type=int, default=30, help='number of epochs to train for')
opt = parser.parse_args()
print(opt)
def main():
train_data = pickle.load(open('../datasets/' + opt.dataset + '/train.txt', 'rb'))
test_data = pickle.load(open('../datasets/' + opt.dataset + '/test.txt', 'rb'))
if opt.dataset == 'Tmall':
n_node = 40727 + 2
elif opt.dataset == 'retailrocket':
n_node = 36968 + 4
else:
n_node = 309
train_data = Data(train_data, train_data, shuffle=True, n_node=n_node, train=1)
test_data = Data(test_data, train_data, shuffle=False, n_node=n_node, train=0)
model = trans_to_cuda(SAS(n_node, opt))
top_K = [5, 10, 20]
best_results = {}
for K in top_K:
best_results['epoch%d' % K] = [0, 0]
best_results['metric%d' % K] = [0, 0]
for epoch in range(opt.num_epochs):
print('-------------------------------------------------------')
print('epoch: ', epoch)
metrics, total_loss = train_test(model, train_data, test_data, epoch, opt)
# torch.save(model.state_dict(), '../sas_teacher_rr_nn.pkl')
for K in top_K:
metrics['hit%d' % K] = np.mean(metrics['hit%d' % K]) * 100
metrics['mrr%d' % K] = np.mean(metrics['mrr%d' % K]) * 100
if best_results['metric%d' % K][0] < metrics['hit%d' % K]:
best_results['metric%d' % K][0] = metrics['hit%d' % K]
best_results['epoch%d' % K][0] = epoch
if best_results['metric%d' % K][1] < metrics['mrr%d' % K]:
best_results['metric%d' % K][1] = metrics['mrr%d' % K]
best_results['epoch%d' % K][1] = epoch
print(metrics)
for K in top_K:
print('train_loss:\t%.4f\tRecall@%d: %.4f\tMRR%d: %.4f\tEpoch: %d, %d' %
(total_loss, K, best_results['metric%d' % K][0], K, best_results['metric%d' % K][1],
best_results['epoch%d' % K][0], best_results['epoch%d' % K][1]))
if __name__ == '__main__':
main()