Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transform function for Recommender #522

Merged
merged 2 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions cornac/eval_methods/base_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def rating_eval(model, metrics, test_set, user_based=False, verbose=False):

verbose: bool, optional, default: False
Output evaluation progress.

Returns
-------
res: (List, List)
Expand Down Expand Up @@ -79,7 +79,7 @@ def rating_eval(model, metrics, test_set, user_based=False, verbose=False):
miniters=100,
total=len(u_indices),
),
dtype='float',
dtype="float",
)

gt_mat = test_set.csr_matrix
Expand Down Expand Up @@ -177,7 +177,7 @@ def pos_items(csr_row):
if len(test_pos_items) == 0:
continue

u_gt_pos = np.zeros(test_set.num_items, dtype='int')
u_gt_pos = np.zeros(test_set.num_items, dtype="int")
u_gt_pos[test_pos_items] = 1

val_pos_items = [] if val_mat is None else pos_items(val_mat.getrow(user_idx))
Expand All @@ -187,7 +187,7 @@ def pos_items(csr_row):
else pos_items(train_mat.getrow(user_idx))
)

u_gt_neg = np.ones(test_set.num_items, dtype='int')
u_gt_neg = np.ones(test_set.num_items, dtype="int")
u_gt_neg[test_pos_items + val_pos_items + train_pos_items] = 0

item_indices = None if exclude_unknowns else np.arange(test_set.num_items)
Expand Down Expand Up @@ -585,8 +585,8 @@ def _build_modalities(self):

def add_modalities(self, **kwargs):
"""
Add successfully built modalities to all datasets. This is handy for
seperately built modalities that are not invoked in the build method.
Add successfully built modalities to all datasets. This is handy for
seperately built modalities that are not invoked in the build method.
"""
self.user_feature = kwargs.get("user_feature", None)
self.user_text = kwargs.get("user_text", None)
Expand Down Expand Up @@ -671,11 +671,11 @@ def evaluate(self, model, metrics, user_based, show_validation=True):
metrics: :obj:`iterable`
List of metrics.

user_based: bool, required
Evaluation strategy for the rating metrics. Whether results
user_based: bool, required
Evaluation strategy for the rating metrics. Whether results
are averaging based on number of users or number of ratings.

show_validation: bool, optional, default: True
show_validation: bool, optional, default: True
Whether to show the results on validation set (if exists).

Returns
Expand Down Expand Up @@ -707,6 +707,7 @@ def evaluate(self, model, metrics, user_based, show_validation=True):
print("\n[{}] Evaluation started!".format(model.name))

start = time.time()
model.transform(self.test_set)
test_result = self._eval(
model=model,
test_set=self.test_set,
Expand All @@ -720,6 +721,7 @@ def evaluate(self, model, metrics, user_based, show_validation=True):
val_result = None
if show_validation and self.val_set is not None:
start = time.time()
model.transform(self.val_set)
val_result = self._eval(
model=model, test_set=self.val_set, val_set=None, user_based=user_based
)
Expand Down Expand Up @@ -790,4 +792,3 @@ def from_splits(
return method.build(
train_data=train_data, test_data=test_data, val_data=val_data
)

27 changes: 19 additions & 8 deletions cornac/models/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@


class Recommender:
"""Generic class for a recommender model. All recommendation models should inherit from this class
"""Generic class for a recommender model. All recommendation models should inherit from this class

Parameters
----------------
name: str, required
Expand Down Expand Up @@ -138,9 +138,9 @@ def load(model_path, trainable=False):
provided, the latest model will be loaded.

trainable: boolean, optional, default: False
Set it to True if you would like to finetune the model. By default,
Set it to True if you would like to finetune the model. By default,
the model parameters are assumed to be fixed after being loaded.

Returns
-------
self : object
Expand Down Expand Up @@ -176,14 +176,27 @@ def fit(self, train_set, val_set=None):
self.val_set = None if val_set is None else val_set.reset()
return self

def transform(self, test_set):
"""Transform test set into cached results accelerating the score function.
This function is supposed to be called in the `cornac.eval_methods.BaseMethod`
before evaluation step. It is optional for this function to be implemented.

Parameters
----------
test_set: :obj:`cornac.data.Dataset`, required
User-Item preference data as well as additional modalities.

"""
pass

def score(self, user_idx, item_idx=None):
"""Predict the scores/ratings of a user for an item.

Parameters
----------
user_idx: int, required
The index of the user for whom to perform score prediction.

item_idx: int, optional, default: None
The index of the item for which to perform score prediction.
If None, scores for all known items will be returned.
Expand All @@ -197,9 +210,7 @@ def score(self, user_idx, item_idx=None):
raise NotImplementedError("The algorithm is not able to make score prediction!")

def default_score(self):
"""Overwrite this function if your algorithm has special treatment for cold-start problem

"""
"""Overwrite this function if your algorithm has special treatment for cold-start problem"""
return self.train_set.global_mean

def rate(self, user_idx, item_idx, clipping=True):
Expand Down