-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathclfm.py
145 lines (120 loc) · 6.15 KB
/
clfm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# -*- coding: utf-8 -*-
# @Time : 2022/3/28
# @Author : Zihan Lin
# @Email : zhlin@ruc.edu.cn
r"""
CLFM
################################################
Reference:
Sheng Gao et al. "Cross-Domain Recommendation via Cluster-Level Latent Factor Model." in PKDD 2013.
"""
import torch
import torch.nn as nn
from recbole.model.init import xavier_normal_initialization
from recbole.utils import InputType
from recbole.model.loss import EmbLoss
from recbole_cdr.model.crossdomain_recommender import CrossDomainRecommender
class CLFM(CrossDomainRecommender):
r"""CLFM factorize the interaction matrix from both domain
with domain-shared embeddings and domain-specific embeddings.
"""
input_type = InputType.POINTWISE
def __init__(self, config, dataset):
super(CLFM, self).__init__(config, dataset)
self.SOURCE_LABEL = dataset.source_domain_dataset.label_field
self.TARGET_LABEL = dataset.target_domain_dataset.label_field
# load parameters info
self.user_embedding_size = config['user_embedding_size']
self.source_item_embedding_size = config['source_item_embedding_size']
self.target_item_embedding_size = config['source_item_embedding_size']
self.share_embedding_size = config['share_embedding_size']
self.alpha = config['alpha']
self.reg_weight = config['reg_weight']
assert 0 <= self.share_embedding_size <= self.source_item_embedding_size and \
0 <= self.share_embedding_size <= self.target_item_embedding_size
"The number of shared dimension must less than the dimension of both " \
"the source item embedding and target item embedding"
# define layers and loss
self.source_user_embedding = nn.Embedding(self.total_num_users, self.user_embedding_size)
self.target_user_embedding = nn.Embedding(self.total_num_users, self.user_embedding_size)
self.source_item_embedding = nn.Embedding(self.total_num_items, self.source_item_embedding_size)
self.target_item_embedding = nn.Embedding(self.total_num_items, self.target_item_embedding_size)
if self.share_embedding_size > 0:
self.shared_linear = nn.Linear(self.user_embedding_size, self.share_embedding_size, bias=False)
if self.source_item_embedding_size - self.share_embedding_size > 0:
self.source_only_linear = \
nn.Linear(self.user_embedding_size, self.source_item_embedding_size - self.share_embedding_size,
bias=False)
if self.target_item_embedding_size - self.share_embedding_size > 0:
self.target_only_linear = \
nn.Linear(self.user_embedding_size, self.target_item_embedding_size - self.share_embedding_size,
bias=False)
self.sigmoid = nn.Sigmoid()
self.loss = nn.BCELoss()
self.source_reg_loss = EmbLoss()
self.target_reg_loss = EmbLoss()
self.apply(xavier_normal_initialization)
def forward(self, user, item):
pass
def source_forward(self, user, item):
user_embedding = self.source_user_embedding(user)
item_embedding = self.source_item_embedding(item)
factors = []
if self.share_embedding_size > 0:
share_factors = self.shared_linear(user_embedding)
factors.append(share_factors)
if self.source_item_embedding_size - self.share_embedding_size > 0:
only_factors = self.source_only_linear(user_embedding)
factors.append(only_factors)
factors = torch.cat(factors, dim=1)
output = self.sigmoid(torch.mul(factors, item_embedding).sum(dim=1))
return output
def target_forward(self, user, item):
user_embedding = self.target_user_embedding(user)
item_embedding = self.target_item_embedding(item)
factors = []
if self.share_embedding_size > 0:
share_factors = self.shared_linear(user_embedding)
factors.append(share_factors)
if self.target_item_embedding_size - self.share_embedding_size > 0:
only_factors = self.target_only_linear(user_embedding)
factors.append(only_factors)
factors = torch.cat(factors, dim=1)
output = self.sigmoid(torch.mul(factors, item_embedding).sum(dim=1))
return output
def calculate_loss(self, interaction):
source_user = interaction[self.SOURCE_USER_ID]
source_item = interaction[self.SOURCE_ITEM_ID]
source_label = interaction[self.SOURCE_LABEL]
target_user = interaction[self.TARGET_USER_ID]
target_item = interaction[self.TARGET_ITEM_ID]
target_label = interaction[self.TARGET_LABEL]
p_source = self.source_forward(source_user, source_item)
p_target = self.target_forward(target_user, target_item)
loss_s = self.loss(p_source, source_label) + self.reg_weight * self.source_reg_loss(
self.source_user_embedding(source_user),
self.source_item_embedding(source_item))
loss_t = self.loss(p_target, target_label) + self.reg_weight * self.target_reg_loss(
self.target_user_embedding(target_user),
self.target_item_embedding(target_item))
loss = loss_s * self.alpha + loss_t * (1 - self.alpha)
return loss
def predict(self, interaction):
user = interaction[self.TARGET_USER_ID]
item = interaction[self.TARGET_ITEM_ID]
p = self.target_forward(user, item)
return p
def full_sort_predict(self, interaction):
user = interaction[self.TARGET_USER_ID]
user_embedding = self.target_user_embedding(user)
all_item_embedding = self.target_item_embedding.weight[:self.target_num_items]
factors = []
if self.share_embedding_size > 0:
share_factors = self.shared_linear(user_embedding)
factors.append(share_factors)
if self.target_item_embedding_size - self.share_embedding_size > 0:
only_factors = self.target_only_linear(user_embedding)
factors.append(only_factors)
factors = torch.cat(factors, dim=1)
score = torch.matmul(factors, all_item_embedding.transpose(0, 1))
return score.view(-1)