-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmodel.py
80 lines (55 loc) · 2.74 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# coding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class KT_backbone(nn.Module):
def __init__(self, skill_dim, answer_dim, hidden_dim, output_dim):
super(KT_backbone, self).__init__()
self.skill_dim=skill_dim
self.answer_dim=answer_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.rnn = nn.LSTM(self.skill_dim+self.answer_dim, self.hidden_dim, batch_first=True)
self.fc = nn.Linear(self.hidden_dim*2, self.output_dim)
self.sig = nn.Sigmoid()
self.skill_emb = nn.Embedding(self.output_dim+1, self.skill_dim)
self.skill_emb.weight.data[-1]= 0
self.answer_emb = nn.Embedding(2+1, self.answer_dim)
self.answer_emb.weight.data[-1]= 0
self.attention_dim = 80
self.mlp = nn.Linear(self.hidden_dim, self.attention_dim)
self.similarity = nn.Linear(self.attention_dim, 1, bias=False)
def _get_next_pred(self, res, skill):
one_hot = torch.eye(self.output_dim, device=res.device)
one_hot = torch.cat((one_hot, torch.zeros(1, self.output_dim).to(device)), dim=0)
next_skill = skill[:, 1:]
one_hot_skill = F.embedding(next_skill, one_hot)
pred = (res * one_hot_skill).sum(dim=-1)
return pred
def attention_module(self, lstm_output):
att_w = self.mlp(lstm_output)
att_w = torch.tanh(att_w)
att_w = self.similarity(att_w)
alphas=nn.Softmax(dim=1)(att_w)
attn_ouput=alphas*lstm_output
attn_output_cum=torch.cumsum(attn_ouput, dim=1)
attn_output_cum_1=attn_output_cum-attn_ouput
final_output=torch.cat((attn_output_cum_1, lstm_output),2)
return final_output
def forward(self, skill, answer, perturbation=None):
skill_embedding=self.skill_emb(skill)
answer_embedding=self.answer_emb(answer)
skill_answer=torch.cat((skill_embedding,answer_embedding), 2)
answer_skill=torch.cat((answer_embedding,skill_embedding), 2)
answer=answer.unsqueeze(2).expand_as(skill_answer)
skill_answer_embedding=torch.where(answer==1, skill_answer, answer_skill)
skill_answer_embedding1=skill_answer_embedding
if perturbation is not None:
skill_answer_embedding+=perturbation
out,_ = self.rnn(skill_answer_embedding)
out=self.attention_module(out)
res = self.sig(self.fc(out))
res = res[:, :-1, :]
pred_res = self._get_next_pred(res, skill)
return pred_res, skill_answer_embedding1