-
Notifications
You must be signed in to change notification settings - Fork 627
/
Copy pathenmf.py
124 lines (85 loc) · 4.25 KB
/
enmf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# -*- coding: utf-8 -*-
# @Time : 2020/12/31
# @Author : Zihan Lin
# @Email : zhlin@ruc.edu.cn
r"""
ENMF
################################################
Reference:
Chong Chen et al. "Efficient Neural Matrix Factorization without Sampling for Recommendation." in TOIS 2020.
Reference code:
https://github.com/chenchongthu/ENMF
"""
import torch
import torch.nn as nn
from recbole.model.init import xavier_normal_initialization
from recbole.utils import InputType
from recbole.model.abstract_recommender import GeneralRecommender
class ENMF(GeneralRecommender):
r"""ENMF is an efficient non-sampling model for general recommendation.
In order to run non-sampling model, please set the neg_sampling parameter as None .
"""
input_type = InputType.POINTWISE
def __init__(self, config, dataset):
super(ENMF, self).__init__(config, dataset)
self.embedding_size = config['embedding_size']
self.dropout_prob = config['dropout_prob']
self.reg_weight = config['reg_weight']
self.negative_weight = config['negative_weight']
# get all users' history interaction information.
# matrix is padding by the maximum number of a user's interactions
self.history_item_matrix, _, self.history_lens = dataset.history_item_matrix()
self.history_item_matrix = self.history_item_matrix.to(self.device)
self.user_embedding = nn.Embedding(self.n_users, self.embedding_size, padding_idx=0)
self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0)
self.H_i = nn.Linear(self.embedding_size, 1, bias=False)
self.dropout = nn.Dropout(self.dropout_prob)
self.apply(xavier_normal_initialization)
def reg_loss(self):
"""calculate the reg loss for embedding layers and mlp layers
Returns:
torch.Tensor: reg loss
"""
l2_reg = self.user_embedding.weight.norm(2) + self.item_embedding.weight.norm(2)
loss_l2 = self.reg_weight * l2_reg
return loss_l2
def forward(self, user):
user_embedding = self.user_embedding(user) # shape:[B, embedding_size]
user_embedding = self.dropout(user_embedding) # shape:[B, embedding_size]
user_inter = self.history_item_matrix[user] # shape :[B, max_len]
item_embedding = self.item_embedding(user_inter) # shape: [B, max_len, embedding_size]
score = torch.mul(user_embedding.unsqueeze(1), item_embedding) # shape: [B, max_len, embedding_size]
score = self.H_i(score) # shape: [B,max_len,1]
score = score.squeeze() # shape:[B,max_len]
return score
def calculate_loss(self, interaction):
user = interaction[self.USER_ID]
pos_score = self.forward(user)
# shape: [embedding_size, embedding_size]
item_sum = torch.bmm(self.item_embedding.weight.unsqueeze(2),
self.item_embedding.weight.unsqueeze(1)).sum(dim=0)
# shape: [embedding_size, embedding_size]
user_sum = torch.bmm(self.user_embedding.weight.unsqueeze(2),
self.user_embedding.weight.unsqueeze(1)).sum(dim=0)
# shape: [embedding_size, embedding_size]
H_sum = torch.matmul(self.H_i.weight.t(), self.H_i.weight)
t = torch.sum(item_sum * user_sum * H_sum)
loss = self.negative_weight * t
loss = loss + torch.sum((1 - self.negative_weight) * torch.square(pos_score) - 2 * pos_score)
loss = loss + self.reg_loss()
return loss
def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]
u_e = self.user_embedding(user)
i_e = self.item_embedding(item)
score = torch.mul(u_e, i_e) # shape: [B,embedding_dim]
score = self.H_i(score) # shape: [B,1]
return score.squeeze(1)
def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
u_e = self.user_embedding(user) # shape: [B,embedding_dim]
all_i_e = self.item_embedding.weight # shape: [n_item,embedding_dim]
score = torch.mul(u_e.unsqueeze(1), all_i_e.unsqueeze(0)) # shape: [B, n_item, embedding_dim]
score = self.H_i(score).squeeze(2) # shape: [B, n_item]
return score.view(-1)