Skip to content

Commit

Permalink
Merge pull request #29 from tqtg/master
Browse files Browse the repository at this point in the history
Add tests for base_strategy and ratio
  • Loading branch information
saghiles authored Dec 20, 2018
2 parents d361d42 + 023bb81 commit 28e81b2
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 11 deletions.
22 changes: 12 additions & 10 deletions cornac/eval_strategies/ratio_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


class RatioSplit(BaseStrategy):

"""Train-Test Split Evaluation Strategy.
Parameters
Expand Down Expand Up @@ -48,15 +47,15 @@ class RatioSplit(BaseStrategy):
Output running log
"""

def __init__(self, data, data_format='UIR', val_size=0.0, test_size=0.2, rating_threshold=1., shuffle=True, random_state=None,
exclude_unknowns=False, verbose=False):
BaseStrategy.__init__(self, data = data, data_format='UIR', rating_threshold=rating_threshold, exclude_unknowns=exclude_unknowns, verbose=verbose)
def __init__(self, data, data_format='UIR', val_size=0.0, test_size=0.2, rating_threshold=1., shuffle=True,
random_state=None, exclude_unknowns=False, verbose=False):
BaseStrategy.__init__(self, data=data, data_format=data_format, rating_threshold=rating_threshold,
exclude_unknowns=exclude_unknowns, verbose=verbose)

self._shuffle = shuffle
self._random_state = random_state
self._train_size, self._val_size, self._test_size = self._validate_sizes(val_size, test_size, len(self._data))
self._split_run = False

self._split_ran = False


@staticmethod
Expand Down Expand Up @@ -93,6 +92,11 @@ def _validate_sizes(val_size, test_size, num_ratings):


def split(self):
if self._split_ran:
if self.verbose:
print('Data is already split!')
return

if self.verbose:
print("Splitting the data")

Expand All @@ -114,15 +118,13 @@ def split(self):
if self._data_format == 'UIR':
self.build_from_uir_format(train_data, val_data, test_data)

self._split_run = True
self._split_ran = True

if self.verbose:
print('Total users = {}'.format(self.total_users))
print('Total items = {}'.format(self.total_items))


def evaluate(self, model, metrics, user_based):
if not self._split_run:
self.split()

self.split()
return BaseStrategy.evaluate(self, model, metrics, user_based)
Empty file.
34 changes: 34 additions & 0 deletions cornac/eval_strategies/tests/test_base_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-

"""
@author: Quoc-Tuan Truong <tuantq.vnu@gmail.com>
"""

from ..base_strategy import BaseStrategy


def test_init():
bs = BaseStrategy(None, verbose=True)

assert not bs.exclude_unknowns
assert 1. == bs.rating_threshold


def test_trainset_none():
bs = BaseStrategy(None, verbose=True)

try:
bs.evaluate(None, {}, False)
except ValueError:
assert True


def test_testset_none():
bs = BaseStrategy(None, train_set=[], verbose=True)

try:
bs.evaluate(None, {}, False)
except ValueError:
assert True


49 changes: 49 additions & 0 deletions cornac/eval_strategies/tests/test_ratio_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-

"""
@author: Quoc-Tuan Truong <tuantq.vnu@gmail.com>
"""

from ..ratio_split import RatioSplit


def test_validate_size():
train_size, val_size, test_size = RatioSplit._validate_sizes(0.1, 0.2, 10)
assert 7 == train_size
assert 1 == val_size
assert 2 == test_size

train_size, val_size, test_size = RatioSplit._validate_sizes(None, 0.5, 10)
assert 5 == train_size
assert 0 == val_size
assert 5 == test_size

train_size, val_size, test_size = RatioSplit._validate_sizes(None, None, 10)
assert 10 == train_size
assert 0 == val_size
assert 0 == test_size

train_size, val_size, test_size = RatioSplit._validate_sizes(2, 2, 10)
assert 6 == train_size
assert 2 == val_size
assert 2 == test_size

try:
RatioSplit._validate_sizes(-1, 0.2, 10)
except ValueError:
assert True

try:
RatioSplit._validate_sizes(11, 0.2, 10)
except ValueError:
assert True

try:
RatioSplit._validate_sizes(0, 11, 10)
except ValueError:
assert True

try:
RatioSplit._validate_sizes(3, 8, 10)
except ValueError:
assert True
2 changes: 1 addition & 1 deletion cornac/experiment/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_with_ratio_split():
from ..experiment import Experiment

data = reader.txt_to_uir_triplets('./cornac/data/tests/data.txt')
exp = Experiment(eval_strategy=RatioSplit(data),
exp = Experiment(eval_strategy=RatioSplit(data, verbose=True),
models=[PMF(1, 0)],
metrics=[MAE(), RMSE(), Recall(1), FMeasure(1)],
verbose=True)
Expand Down
6 changes: 6 additions & 0 deletions cornac/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ def safe_indexing(X, indices):


def validate_data_format(data_format):
"""Check the input data format is supported or not
- UIR: (user, item, rating) triplet data
- UIRT: (user, item , rating, timestamp) quadruplet data
:raise ValueError if not supported
"""
data_format = str(data_format).upper()
if not data_format in ['UIR', 'UIRT']:
raise ValueError('{} data format is not supported!'.format(data_format))
Expand Down

0 comments on commit 28e81b2

Please sign in to comment.