-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add KaNCD model * delete useless codes in KaNCD.py * modify default learning rate in KaNCD.py * add example files for KaNCD * add doc for KaNCD * update README.md * add test files * delete blank line at the end of file
- Loading branch information
1 parent
f7e4916
commit 79f300c
Showing
13 changed files
with
736 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
# coding: utf-8 | ||
# 2023/7/3 @ WangFei | ||
|
||
import logging | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torch.nn.functional as F | ||
import numpy as np | ||
from tqdm import tqdm | ||
from sklearn.metrics import roc_auc_score, accuracy_score | ||
from EduCDM import CDM | ||
|
||
|
||
class PosLinear(nn.Linear): | ||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
weight = 2 * F.relu(1 * torch.neg(self.weight)) + self.weight | ||
return F.linear(input, weight, self.bias) | ||
|
||
|
||
class Net(nn.Module): | ||
|
||
def __init__(self, exer_n, student_n, knowledge_n, mf_type, dim): | ||
self.knowledge_n = knowledge_n | ||
self.exer_n = exer_n | ||
self.student_n = student_n | ||
self.emb_dim = dim | ||
self.mf_type = mf_type | ||
self.prednet_input_len = self.knowledge_n | ||
self.prednet_len1, self.prednet_len2 = 256, 128 # changeable | ||
|
||
super(Net, self).__init__() | ||
|
||
# prediction sub-net | ||
self.student_emb = nn.Embedding(self.student_n, self.emb_dim) | ||
self.exercise_emb = nn.Embedding(self.exer_n, self.emb_dim) | ||
self.knowledge_emb = nn.Parameter(torch.zeros(self.knowledge_n, self.emb_dim)) | ||
self.e_discrimination = nn.Embedding(self.exer_n, 1) | ||
self.prednet_full1 = PosLinear(self.prednet_input_len, self.prednet_len1) | ||
self.drop_1 = nn.Dropout(p=0.5) | ||
self.prednet_full2 = PosLinear(self.prednet_len1, self.prednet_len2) | ||
self.drop_2 = nn.Dropout(p=0.5) | ||
self.prednet_full3 = PosLinear(self.prednet_len2, 1) | ||
|
||
if mf_type == 'gmf': | ||
self.k_diff_full = nn.Linear(self.emb_dim, 1) | ||
self.stat_full = nn.Linear(self.emb_dim, 1) | ||
elif mf_type == 'ncf1': | ||
self.k_diff_full = nn.Linear(2 * self.emb_dim, 1) | ||
self.stat_full = nn.Linear(2 * self.emb_dim, 1) | ||
elif mf_type == 'ncf2': | ||
self.k_diff_full1 = nn.Linear(2 * self.emb_dim, self.emb_dim) | ||
self.k_diff_full2 = nn.Linear(self.emb_dim, 1) | ||
self.stat_full1 = nn.Linear(2 * self.emb_dim, self.emb_dim) | ||
self.stat_full2 = nn.Linear(self.emb_dim, 1) | ||
|
||
# initialize | ||
for name, param in self.named_parameters(): | ||
if 'weight' in name: | ||
nn.init.xavier_normal_(param) | ||
nn.init.xavier_normal_(self.knowledge_emb) | ||
|
||
def forward(self, stu_id, input_exercise, input_knowledge_point): | ||
# before prednet | ||
stu_emb = self.student_emb(stu_id) | ||
exer_emb = self.exercise_emb(input_exercise) | ||
# get knowledge proficiency | ||
batch, dim = stu_emb.size() | ||
stu_emb = stu_emb.view(batch, 1, dim).repeat(1, self.knowledge_n, 1) | ||
knowledge_emb = self.knowledge_emb.repeat(batch, 1).view(batch, self.knowledge_n, -1) | ||
if self.mf_type == 'mf': # simply inner product | ||
stat_emb = torch.sigmoid((stu_emb * knowledge_emb).sum(dim=-1, keepdim=False)) # batch, knowledge_n | ||
elif self.mf_type == 'gmf': | ||
stat_emb = torch.sigmoid(self.stat_full(stu_emb * knowledge_emb)).view(batch, -1) | ||
elif self.mf_type == 'ncf1': | ||
stat_emb = torch.sigmoid(self.stat_full(torch.cat((stu_emb, knowledge_emb), dim=-1))).view(batch, -1) | ||
elif self.mf_type == 'ncf2': | ||
stat_emb = torch.sigmoid(self.stat_full1(torch.cat((stu_emb, knowledge_emb), dim=-1))) | ||
stat_emb = torch.sigmoid(self.stat_full2(stat_emb)).view(batch, -1) | ||
batch, dim = exer_emb.size() | ||
exer_emb = exer_emb.view(batch, 1, dim).repeat(1, self.knowledge_n, 1) | ||
if self.mf_type == 'mf': | ||
k_difficulty = torch.sigmoid((exer_emb * knowledge_emb).sum(dim=-1, keepdim=False)) # batch, knowledge_n | ||
elif self.mf_type == 'gmf': | ||
k_difficulty = torch.sigmoid(self.k_diff_full(exer_emb * knowledge_emb)).view(batch, -1) | ||
elif self.mf_type == 'ncf1': | ||
k_difficulty = torch.sigmoid(self.k_diff_full(torch.cat((exer_emb, knowledge_emb), dim=-1))).view(batch, -1) | ||
elif self.mf_type == 'ncf2': | ||
k_difficulty = torch.sigmoid(self.k_diff_full1(torch.cat((exer_emb, knowledge_emb), dim=-1))) | ||
k_difficulty = torch.sigmoid(self.k_diff_full2(k_difficulty)).view(batch, -1) | ||
# get exercise discrimination | ||
e_discrimination = torch.sigmoid(self.e_discrimination(input_exercise)) | ||
|
||
# prednet | ||
input_x = e_discrimination * (stat_emb - k_difficulty) * input_knowledge_point | ||
# f = input_x[input_knowledge_point == 1] | ||
input_x = self.drop_1(torch.tanh(self.prednet_full1(input_x))) | ||
input_x = self.drop_2(torch.tanh(self.prednet_full2(input_x))) | ||
output_1 = torch.sigmoid(self.prednet_full3(input_x)) | ||
|
||
return output_1.view(-1) | ||
|
||
|
||
class KaNCD(CDM): | ||
def __init__(self, **kwargs): | ||
super(KaNCD, self).__init__() | ||
mf_type = kwargs['mf_type'] if 'mf_type' in kwargs else 'gmf' | ||
self.net = Net(kwargs['exer_n'], kwargs['student_n'], kwargs['knowledge_n'], mf_type, kwargs['dim']) | ||
|
||
def train(self, train_set, valid_set, lr=0.002, device='cpu', epoch_n=15): | ||
logging.info("traing... (lr={})".format(lr)) | ||
self.net = self.net.to(device) | ||
loss_function = nn.BCELoss() | ||
optimizer = optim.Adam(self.net.parameters(), lr=lr) | ||
for epoch_i in range(epoch_n): | ||
self.net.train() | ||
epoch_losses = [] | ||
batch_count = 0 | ||
for batch_data in tqdm(train_set, "Epoch %s" % epoch_i): | ||
batch_count += 1 | ||
user_info, item_info, knowledge_emb, y = batch_data | ||
user_info: torch.Tensor = user_info.to(device) | ||
item_info: torch.Tensor = item_info.to(device) | ||
knowledge_emb: torch.Tensor = knowledge_emb.to(device) | ||
y: torch.Tensor = y.to(device) | ||
pred = self.net(user_info, item_info, knowledge_emb) | ||
loss = loss_function(pred, y) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
epoch_losses.append(loss.mean().item()) | ||
|
||
print("[Epoch %d] average loss: %.6f" % (epoch_i, float(np.mean(epoch_losses)))) | ||
logging.info("[Epoch %d] average loss: %.6f" % (epoch_i, float(np.mean(epoch_losses)))) | ||
auc, acc = self.eval(valid_set, device) | ||
print("[Epoch %d] auc: %.6f, acc: %.6f" % (epoch_i, auc, acc)) | ||
logging.info("[Epoch %d] auc: %.6f, acc: %.6f" % (epoch_i, auc, acc)) | ||
|
||
return auc, acc | ||
|
||
def eval(self, test_data, device="cpu"): | ||
logging.info('eval ... ') | ||
self.net = self.net.to(device) | ||
self.net.eval() | ||
y_true, y_pred = [], [] | ||
for batch_data in tqdm(test_data, "Evaluating"): | ||
user_id, item_id, knowledge_emb, y = batch_data | ||
user_id: torch.Tensor = user_id.to(device) | ||
item_id: torch.Tensor = item_id.to(device) | ||
knowledge_emb: torch.Tensor = knowledge_emb.to(device) | ||
pred = self.net(user_id, item_id, knowledge_emb) | ||
y_pred.extend(pred.detach().cpu().tolist()) | ||
y_true.extend(y.tolist()) | ||
|
||
return roc_auc_score(y_true, y_pred), accuracy_score(y_true, np.array(y_pred) >= 0.5) | ||
|
||
def save(self, filepath): | ||
torch.save(self.net.state_dict(), filepath) | ||
logging.info("save parameters to %s" % filepath) | ||
|
||
def load(self, filepath): | ||
self.net.load_state_dict(torch.load(filepath, map_location=lambda s, loc: s)) | ||
logging.info("load parameters from %s" % filepath) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# coding: utf-8 | ||
# 2021/4/1 @ WangFei | ||
|
||
from .KaNCD import KaNCD |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,3 +9,4 @@ | |
from .NCDM import NCDM | ||
from .IRT import EMIRT, GDIRT | ||
from .MIRT import MIRT | ||
from .KaNCD import KaNCD |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# KaNCD | ||
|
||
The implementation of the KaNCD model in paper: [NeuralCD: A General Framework for Cognitive Diagnosis](https://ieeexplore.ieee.org/abstract/document/9865139) | ||
|
||
KaNCD is an **K**nowledge-**a**ssociation based extension of the **N**eural**CD**M (alias NCDM in this package) model. In KaNCD, higher-order low dimensional latent traits of students, exercises and knowledge concepts are used respectively. | ||
|
||
The knowledge difficulty vector of an exercise is calculated from the latent trait of the exercise and the latent trait of each knowledge concept. | ||
|
||
![KDM_MF](F:\git_project\EduCDM\EduCDM\docs\_static\KDM_MF.png) | ||
|
||
Similarly, the knowledge proficiency vector of a student is calculated from the latent trait of the student and the latent trait of each knowledge concept. | ||
|
||
![KPM_MF](F:\git_project\EduCDM\EduCDM\docs\_static\KPM_MF.png) | ||
|
||
Please refer to the paper for more details. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Oops, something went wrong.