diff --git a/cornac/eval_methods/base_method.py b/cornac/eval_methods/base_method.py index 8c3a5424..7904a9e4 100644 --- a/cornac/eval_methods/base_method.py +++ b/cornac/eval_methods/base_method.py @@ -272,7 +272,6 @@ def __init__( self.val_set = None self.rating_threshold = rating_threshold self.exclude_unknowns = exclude_unknowns - self.mode = kwargs.get("mode", None) self.verbose = verbose self.seed = seed self.rng = get_rng(seed) @@ -664,7 +663,6 @@ def eval( rating_metrics, ranking_metrics, verbose, - **kwargs, ): """Running evaluation for rating and ranking metrics respectively.""" metric_avg_results = OrderedDict() @@ -756,7 +754,6 @@ def evaluate(self, model, metrics, user_based, show_validation=True): rating_metrics=rating_metrics, ranking_metrics=ranking_metrics, user_based=user_based, - mode=self.mode, verbose=self.verbose, ) test_time = time.time() - start @@ -777,7 +774,6 @@ def evaluate(self, model, metrics, user_based, show_validation=True): rating_metrics=rating_metrics, ranking_metrics=ranking_metrics, user_based=user_based, - mode=self.mode, verbose=self.verbose, ) val_time = time.time() - start diff --git a/cornac/eval_methods/next_item_evaluation.py b/cornac/eval_methods/next_item_evaluation.py index d27c4342..05de988f 100644 --- a/cornac/eval_methods/next_item_evaluation.py +++ b/cornac/eval_methods/next_item_evaluation.py @@ -13,6 +13,8 @@ # limitations under the License. # ============================================================================ +import time +import warnings from collections import OrderedDict, defaultdict import numpy as np @@ -20,13 +22,13 @@ from ..data import SequentialDataset from ..experiment.result import Result +from ..models import NextItemRecommender from . import BaseMethod - -EVALUATION_MODES = [ +EVALUATION_MODES = set([ "last", "next", -] +]) def ranking_eval( model, @@ -213,7 +215,7 @@ def __init__( mode=mode, **kwargs, ) - assert mode in EVALUATION_MODES + assert mode in EVALUATION_MODES, "Evaluation mode is in %s, but '%s' is provided" % (EVALUATION_MODES, mode) self.mode = mode self.global_sid_map = kwargs.get("global_sid_map", OrderedDict()) @@ -308,6 +310,95 @@ def eval( return Result(model.name, metric_avg_results, metric_user_results) + def evaluate(self, model, metrics, user_based, show_validation=True): + """Evaluate given models according to given metrics. Supposed to be called by Experiment. + + Parameters + ---------- + model: :obj:`cornac.models.NextItemRecommender` + NextItemRecommender model to be evaluated. + + metrics: :obj:`iterable` + List of metrics. + + user_based: bool, required + Evaluation strategy for the rating metrics. Whether results + are averaging based on number of users or number of ratings. + + show_validation: bool, optional, default: True + Whether to show the results on validation set (if exists). + + Returns + ------- + res: :obj:`cornac.experiment.Result` + """ + if not isinstance(model, NextItemRecommender): + raise ValueError("model must be a NextItemRecommender but '%s' is provided" % type(model)) + + if self.train_set is None: + raise ValueError("train_set is required but None!") + if self.test_set is None: + raise ValueError("test_set is required but None!") + + self._reset() + + ########### + # FITTING # + ########### + if self.verbose: + print("\n[{}] Training started!".format(model.name)) + + start = time.time() + model.fit(self.train_set, self.val_set) + train_time = time.time() - start + + ############## + # EVALUATION # + ############## + if self.verbose: + print("\n[{}] Evaluation started!".format(model.name)) + + rating_metrics, ranking_metrics = self.organize_metrics(metrics) + if len(rating_metrics) > 0: + warnings.warn("NextItemEvaluation only supports ranking metrics. The given rating metrics {} will be ignored!".format([mt.name for mt in rating_metrics])) + + start = time.time() + model.transform(self.test_set) + test_result = self.eval( + model=model, + train_set=self.train_set, + test_set=self.test_set, + val_set=self.val_set, + exclude_unknowns=self.exclude_unknowns, + ranking_metrics=ranking_metrics, + user_based=user_based, + mode=self.mode, + verbose=self.verbose, + ) + test_time = time.time() - start + test_result.metric_avg_results["Train (s)"] = train_time + test_result.metric_avg_results["Test (s)"] = test_time + + val_result = None + if show_validation and self.val_set is not None: + start = time.time() + model.transform(self.val_set) + val_result = self.eval( + model=model, + train_set=self.train_set, + test_set=self.val_set, + val_set=None, + exclude_unknowns=self.exclude_unknowns, + ranking_metrics=ranking_metrics, + user_based=user_based, + mode=self.mode, + verbose=self.verbose, + ) + val_time = time.time() - start + val_result.metric_avg_results["Time (s)"] = val_time + + return test_result, val_result + @classmethod def from_splits( cls,