Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HRDR model #518

Merged
merged 3 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ The recommender models supported by Cornac are listed below. Why don't you join
| | [Causal Inference for Visual Debiasing in Visually-Aware Recommendation (CausalRec)](cornac/models/causalrec), [paper](https://arxiv.org/abs/2107.02390) | [requirements.txt](cornac/models/causalrec/requirements.txt) | [causalrec_clothing.py](examples/causalrec_clothing.py)
| | [Explainable Recommendation with Comparative Constraints on Product Aspects (ComparER)](cornac/models/comparer), [paper](https://dl.acm.org/doi/pdf/10.1145/3437963.3441754) | N/A | [PreferredAI/ComparER](https://github.com/PreferredAI/ComparER)
| 2020 | [Adversarial Training Towards Robust Multimedia Recommender System (AMR)](cornac/models/amr), [paper](https://ieeexplore.ieee.org/document/8618394) | [requirements.txt](cornac/models/amr/requirements.txt) | [amr_clothing.py](examples/amr_clothing.py)
| | [Hybrid neural recommendation with joint deep representation learning of ratings and reviews (HRDR)](cornac/models/hrdr), [paper](https://www.sciencedirect.com/science/article/abs/pii/S0925231219313207) | [requirements.txt](cornac/models/hrdr/requirements.txt) | [hrdr_example.py](examples/hrdr_example.py)
| 2019 | [Embarrassingly Shallow Autoencoders for Sparse Data (EASEᴿ)](cornac/models/ease), [paper](https://arxiv.org/pdf/1905.03375.pdf) | N/A | [ease_movielens.py](examples/ease_movielens.py)
| 2018 | [Collaborative Context Poisson Factorization (C2PF)](cornac/models/c2pf), [paper](https://www.ijcai.org/proceedings/2018/0370.pdf) | N/A | [c2pf_exp.py](examples/c2pf_example.py)
| | [Multi-Task Explainable Recommendation (MTER)](cornac/models/mter), [paper](https://arxiv.org/pdf/1806.03568.pdf) | N/A | [mter_exp.py](examples/mter_example.py)
Expand Down
1 change: 1 addition & 0 deletions cornac/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .global_avg import GlobalAvg
from .hft import HFT
from .hpf import HPF
from .hrdr import HRDR
from .ibpr import IBPR
from .knn import ItemKNN
from .knn import UserKNN
Expand Down
1 change: 1 addition & 0 deletions cornac/models/hrdr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .recom_hrdr import HRDR
197 changes: 197 additions & 0 deletions cornac/models/hrdr/hrdr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, initializers, Input
from tensorflow.python.keras.preprocessing.sequence import pad_sequences

from ...utils import get_rng
from ...utils.init_utils import uniform
from ..narre.narre import TextProcessor, AddGlobalBias

def get_data(batch_ids, train_set, max_text_length, by="user", max_num_review=None):
batch_reviews, batch_num_reviews = [], []
review_group = (
train_set.review_text.user_review
if by == "user"
else train_set.review_text.item_review
)
for idx in batch_ids:
review_ids = []
for inc, (jdx, review_idx) in enumerate(review_group[idx].items()):
if max_num_review is not None and inc == max_num_review:
break
review_ids.append(review_idx)
reviews = train_set.review_text.batch_seq(
review_ids, max_length=max_text_length
)
batch_reviews.append(reviews)
batch_num_reviews.append(len(reviews))
batch_reviews = pad_sequences(batch_reviews, padding="post")
batch_num_reviews = np.array(batch_num_reviews).astype(np.int32)
batch_ratings = (
np.zeros((len(batch_ids), train_set.num_items), dtype=np.float32)
if by == "user"
else np.zeros((len(batch_ids), train_set.num_users), dtype=np.float32)
)
rating_group = train_set.user_data if by == "user" else train_set.item_data
for batch_inc, idx in enumerate(batch_ids):
jds, ratings = rating_group[idx]
for jdx, rating in zip(jds, ratings):
batch_ratings[batch_inc, jdx] = rating
return batch_reviews, batch_num_reviews, batch_ratings

class Model:
def __init__(self, n_users, n_items, vocab, global_mean,
n_factors=32, embedding_size=100, id_embedding_size=32,
attention_size=16, kernel_sizes=[3], n_filters=64,
n_user_mlp_factors=128, n_item_mlp_factors=128,
dropout_rate=0.5, max_text_length=50,
pretrained_word_embeddings=None, verbose=False, seed=None):
self.n_users = n_users
self.n_items = n_items
self.n_vocab = vocab.size
self.global_mean = global_mean
self.n_factors = n_factors
self.embedding_size = embedding_size
self.id_embedding_size = id_embedding_size
self.attention_size = attention_size
self.kernel_sizes = kernel_sizes
self.n_filters = n_filters
self.n_user_mlp_factors = n_user_mlp_factors
self.n_item_mlp_factors = n_item_mlp_factors
self.dropout_rate = dropout_rate
self.max_text_length = max_text_length
self.verbose = verbose
if seed is not None:
self.rng = get_rng(seed)
tf.random.set_seed(seed)

embedding_matrix = uniform(shape=(self.n_vocab, self.embedding_size), low=-0.5, high=0.5, random_state=self.rng)
embedding_matrix[:4, :] = np.zeros((4, self.embedding_size))
if pretrained_word_embeddings is not None:
oov_count = 0
for word, idx in vocab.tok2idx.items():
embedding_vector = pretrained_word_embeddings.get(word)
if embedding_vector is not None:
embedding_matrix[idx] = embedding_vector
else:
oov_count += 1
if self.verbose:
print("Number of OOV words: %d" % oov_count)

embedding_matrix = initializers.Constant(embedding_matrix)
i_user_id = Input(shape=(1,), dtype="int32", name="input_user_id")
i_item_id = Input(shape=(1,), dtype="int32", name="input_item_id")
i_user_rating = Input(shape=(self.n_items), dtype="float32", name="input_user_rating")
i_item_rating = Input(shape=(self.n_users), dtype="float32", name="input_item_rating")
i_user_review = Input(shape=(None, self.max_text_length), dtype="int32", name="input_user_review")
i_item_review = Input(shape=(None, self.max_text_length), dtype="int32", name="input_item_review")
i_user_num_reviews = Input(shape=(1,), dtype="int32", name="input_user_number_of_review")
i_item_num_reviews = Input(shape=(1,), dtype="int32", name="input_item_number_of_review")

l_user_review_embedding = layers.Embedding(self.n_vocab, self.embedding_size, embeddings_initializer=embedding_matrix, mask_zero=True, name="layer_user_review_embedding")
l_item_review_embedding = layers.Embedding(self.n_vocab, self.embedding_size, embeddings_initializer=embedding_matrix, mask_zero=True, name="layer_item_review_embedding")
l_user_embedding = layers.Embedding(self.n_users, self.id_embedding_size, embeddings_initializer="uniform", name="user_embedding")
l_item_embedding = layers.Embedding(self.n_items, self.id_embedding_size, embeddings_initializer="uniform", name="item_embedding")

user_bias = layers.Embedding(self.n_users, 1, embeddings_initializer=tf.initializers.Constant(0.1), name="user_bias")
item_bias = layers.Embedding(self.n_items, 1, embeddings_initializer=tf.initializers.Constant(0.1), name="item_bias")

user_text_processor = TextProcessor(self.max_text_length, filters=self.n_filters, kernel_sizes=self.kernel_sizes, dropout_rate=self.dropout_rate, name='user_text_processor')
item_text_processor = TextProcessor(self.max_text_length, filters=self.n_filters, kernel_sizes=self.kernel_sizes, dropout_rate=self.dropout_rate, name='item_text_processor')

user_review_h = user_text_processor(l_user_review_embedding(i_user_review))
item_review_h = item_text_processor(l_item_review_embedding(i_item_review))

l_user_mlp = keras.models.Sequential([
layers.Dense(self.n_user_mlp_factors, input_dim=self.n_items, activation="relu"),
layers.Dense(self.n_user_mlp_factors // 2, activation="relu"),
layers.Dense(self.n_filters, activation="relu"),
layers.BatchNormalization(),
])
l_item_mlp = keras.models.Sequential([
layers.Dense(self.n_item_mlp_factors, input_dim=self.n_users, activation="relu"),
layers.Dense(self.n_item_mlp_factors // 2, activation="relu"),
layers.Dense(self.n_filters, activation="relu"),
layers.BatchNormalization(),
])
user_rating_h = l_user_mlp(i_user_rating)
item_rating_h = l_item_mlp(i_item_rating)
# mlp
a_user = layers.Dense(1, activation=None, use_bias=True)(
layers.Dense(self.attention_size, activation="relu", use_bias=True)(
tf.multiply(
user_review_h,
tf.expand_dims(user_rating_h, 1)
)
)
)
a_user_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_user_num_reviews, [-1]), maxlen=i_user_review.shape[1]), -1)
user_attention = layers.Softmax(axis=1, name="user_attention")(a_user, a_user_masking)

a_item = layers.Dense(1, activation=None, use_bias=True)(
layers.Dense(self.attention_size, activation="relu", use_bias=True)(
tf.multiply(
item_review_h,
tf.expand_dims(item_rating_h, 1)
)
)
)
a_item_masking = tf.expand_dims(tf.sequence_mask(tf.reshape(i_item_num_reviews, [-1]), maxlen=i_item_review.shape[1]), -1)
item_attention = layers.Softmax(axis=1, name="item_attention")(a_item, a_item_masking)

ou = layers.Dense(self.n_factors, use_bias=True, name="ou")(
layers.Dropout(rate=self.dropout_rate)(
tf.reduce_sum(layers.Multiply()([user_attention, user_review_h]), 1)
)
)
oi = layers.Dense(self.n_factors, use_bias=True, name="oi")(
layers.Dropout(rate=self.dropout_rate)(
tf.reduce_sum(layers.Multiply()([item_attention, item_review_h]), 1)
)
)

pu = layers.Concatenate(axis=-1, name="pu")([
tf.expand_dims(user_rating_h, 1),
tf.expand_dims(ou, axis=1),
l_user_embedding(i_user_id)
])

qi = layers.Concatenate(axis=-1, name="qi")([
tf.expand_dims(item_rating_h, 1),
tf.expand_dims(oi, axis=1),
l_item_embedding(i_item_id)
])

W1 = layers.Dense(1, activation=None, use_bias=False, name="W1")
add_global_bias = AddGlobalBias(init_value=self.global_mean, name="global_bias")
r = layers.Add(name="prediction")([
W1(tf.multiply(pu, qi)),
user_bias(i_user_id),
item_bias(i_item_id)
])
r = add_global_bias(r)
self.graph = keras.Model(inputs=[i_user_id, i_item_id, i_user_rating, i_user_review, i_user_num_reviews, i_item_rating, i_item_review, i_item_num_reviews], outputs=r)
if self.verbose:
self.graph.summary()

def get_weights(self, train_set, batch_size=64, max_num_review=32):
user_attention_review_pooling = keras.Model(inputs=[self.graph.get_layer('input_user_id').input, self.graph.get_layer('input_user_rating').input, self.graph.get_layer('input_user_review').input, self.graph.get_layer('input_user_number_of_review').input], outputs=self.graph.get_layer('pu').output)
item_attention_pooling = keras.Model(inputs=[self.graph.get_layer('input_item_id').input, self.graph.get_layer('input_item_rating').input, self.graph.get_layer('input_item_review').input, self.graph.get_layer('input_item_number_of_review').input], outputs=[self.graph.get_layer('qi').output, self.graph.get_layer('item_attention').output])
P = np.zeros((self.n_users, self.n_filters + self.n_factors + self.id_embedding_size), dtype=np.float32)
Q = np.zeros((self.n_items, self.n_filters + self.n_factors + self.id_embedding_size), dtype=np.float32)
A = np.zeros((self.n_items, max_num_review), dtype=np.float32)
for batch_users in train_set.user_iter(batch_size):
user_reviews, user_num_reviews, user_ratings = get_data(batch_users, train_set, self.max_text_length, by='user', max_num_review=max_num_review)
pu = user_attention_review_pooling([batch_users, user_ratings, user_reviews, user_num_reviews], training=False)
P[batch_users] = pu.numpy().reshape(len(batch_users), self.n_filters + self.n_factors + self.id_embedding_size)
for batch_items in train_set.item_iter(batch_size):
item_reviews, item_num_reviews, item_ratings = get_data(batch_items, train_set, self.max_text_length, by='item', max_num_review=max_num_review)
qi, item_attention = item_attention_pooling([batch_items, item_ratings, item_reviews, item_num_reviews], training=False)
Q[batch_items] = qi.numpy().reshape(len(batch_items), self.n_filters + self.n_factors + self.id_embedding_size)
A[batch_items, :item_attention.shape[1]] = item_attention.numpy().reshape(item_attention.shape[:2])
W1 = self.graph.get_layer('W1').get_weights()[0]
bu = self.graph.get_layer('user_bias').get_weights()[0]
bi = self.graph.get_layer('item_bias').get_weights()[0]
mu = self.graph.get_layer('global_bias').get_weights()[0][0]
return P, Q, W1, bu, bi, mu, A
Loading