Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang committed Jan 10, 2024
1 parent d0453ec commit a158b40
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 8 deletions.
4 changes: 0 additions & 4 deletions cornac/eval_methods/base_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def __init__(
self.val_set = None
self.rating_threshold = rating_threshold
self.exclude_unknowns = exclude_unknowns
self.mode = kwargs.get("mode", None)
self.verbose = verbose
self.seed = seed
self.rng = get_rng(seed)
Expand Down Expand Up @@ -664,7 +663,6 @@ def eval(
rating_metrics,
ranking_metrics,
verbose,
**kwargs,
):
"""Running evaluation for rating and ranking metrics respectively."""
metric_avg_results = OrderedDict()
Expand Down Expand Up @@ -756,7 +754,6 @@ def evaluate(self, model, metrics, user_based, show_validation=True):
rating_metrics=rating_metrics,
ranking_metrics=ranking_metrics,
user_based=user_based,
mode=self.mode,
verbose=self.verbose,
)
test_time = time.time() - start
Expand All @@ -777,7 +774,6 @@ def evaluate(self, model, metrics, user_based, show_validation=True):
rating_metrics=rating_metrics,
ranking_metrics=ranking_metrics,
user_based=user_based,
mode=self.mode,
verbose=self.verbose,
)
val_time = time.time() - start
Expand Down
99 changes: 95 additions & 4 deletions cornac/eval_methods/next_item_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@
# limitations under the License.
# ============================================================================

import time
import warnings
from collections import OrderedDict, defaultdict

import numpy as np
from tqdm.auto import tqdm

from ..data import SequentialDataset
from ..experiment.result import Result
from ..models import NextItemRecommender
from . import BaseMethod


EVALUATION_MODES = [
EVALUATION_MODES = set([
"last",
"next",
]
])

def ranking_eval(
model,
Expand Down Expand Up @@ -213,7 +215,7 @@ def __init__(
mode=mode,
**kwargs,
)
assert mode in EVALUATION_MODES
assert mode in EVALUATION_MODES, "Evaluation mode is in %s, but '%s' is provided" % (EVALUATION_MODES, mode)
self.mode = mode
self.global_sid_map = kwargs.get("global_sid_map", OrderedDict())

Expand Down Expand Up @@ -308,6 +310,95 @@ def eval(

return Result(model.name, metric_avg_results, metric_user_results)

def evaluate(self, model, metrics, user_based, show_validation=True):
"""Evaluate given models according to given metrics. Supposed to be called by Experiment.
Parameters
----------
model: :obj:`cornac.models.NextItemRecommender`
NextItemRecommender model to be evaluated.
metrics: :obj:`iterable`
List of metrics.
user_based: bool, required
Evaluation strategy for the rating metrics. Whether results
are averaging based on number of users or number of ratings.
show_validation: bool, optional, default: True
Whether to show the results on validation set (if exists).
Returns
-------
res: :obj:`cornac.experiment.Result`
"""
if not isinstance(model, NextItemRecommender):
raise ValueError("model must be a NextItemRecommender but '%s' is provided" % type(model))

if self.train_set is None:
raise ValueError("train_set is required but None!")
if self.test_set is None:
raise ValueError("test_set is required but None!")

self._reset()

###########
# FITTING #
###########
if self.verbose:
print("\n[{}] Training started!".format(model.name))

start = time.time()
model.fit(self.train_set, self.val_set)
train_time = time.time() - start

##############
# EVALUATION #
##############
if self.verbose:
print("\n[{}] Evaluation started!".format(model.name))

rating_metrics, ranking_metrics = self.organize_metrics(metrics)
if len(rating_metrics) > 0:
warnings.warn("NextItemEvaluation only supports ranking metrics. The given rating metrics {} will be ignored!".format([mt.name for mt in rating_metrics]))

start = time.time()
model.transform(self.test_set)
test_result = self.eval(
model=model,
train_set=self.train_set,
test_set=self.test_set,
val_set=self.val_set,
exclude_unknowns=self.exclude_unknowns,
ranking_metrics=ranking_metrics,
user_based=user_based,
mode=self.mode,
verbose=self.verbose,
)
test_time = time.time() - start
test_result.metric_avg_results["Train (s)"] = train_time
test_result.metric_avg_results["Test (s)"] = test_time

val_result = None
if show_validation and self.val_set is not None:
start = time.time()
model.transform(self.val_set)
val_result = self.eval(
model=model,
train_set=self.train_set,
test_set=self.val_set,
val_set=None,
exclude_unknowns=self.exclude_unknowns,
ranking_metrics=ranking_metrics,
user_based=user_based,
mode=self.mode,
verbose=self.verbose,
)
val_time = time.time() - start
val_result.metric_avg_results["Time (s)"] = val_time

return test_result, val_result

@classmethod
def from_splits(
cls,
Expand Down

0 comments on commit a158b40

Please sign in to comment.