From ee50df1a2d1f98eb121c76da1ac90e5c00871ee0 Mon Sep 17 00:00:00 2001 From: Melanie Ducoffe Date: Wed, 24 Jun 2015 12:27:27 -0400 Subject: [PATCH] correction for TruncatedEpochScheme (#177) --- fuel/schemes.py | 19 ++++++++++++++++++- tests/test_schemes.py | 7 ++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/fuel/schemes.py b/fuel/schemes.py index ba9ddea88..0b2f7df3b 100644 --- a/fuel/schemes.py +++ b/fuel/schemes.py @@ -2,7 +2,7 @@ from collections import Iterable import numpy -from picklable_itertools import chain, repeat, imap, iter_ +from picklable_itertools import chain, repeat, imap, iter_, islice from picklable_itertools.extras import partition_all from six import add_metaclass from six.moves import xrange @@ -280,3 +280,20 @@ def cross_validation(scheme_class, num_examples, num_folds, strict=True, yield (train, valid) else: yield (train, valid, end - begin) + + +@add_metaclass(ABCMeta) +class TruncatedEpochScheme(IterationScheme): + """limiting the number of excursion in an iterator + + Returns elements from an iterator with an early stopping + """ + def __init__(self, iteration_scheme, times): + self.iteration_scheme = iteration_scheme + if times < 1: + raise ValueError("times is a positive number") + self.times = times + + def get_request_iterator(self): + it = self.iteration_scheme.get_request_iterator() + return islice(it, start=0, stop=self.times) diff --git a/tests/test_schemes.py b/tests/test_schemes.py index addc64a79..a06dd8269 100644 --- a/tests/test_schemes.py +++ b/tests/test_schemes.py @@ -4,7 +4,7 @@ from fuel.schemes import (ConstantScheme, SequentialExampleScheme, SequentialScheme, ShuffledExampleScheme, ShuffledScheme, ConcatenatedScheme, - cross_validation) + cross_validation, TruncatedEpochScheme) def iterator_requester(scheme): @@ -145,3 +145,8 @@ def test_cross_validation(): assert list(valid.get_request_iterator()) == [[4, 5], [6, 7]] assert_raises(StopIteration, next, cross) + + +def test_truncated_epoch_scheme(): + limited_scheme = TruncatedEpochScheme(SequentialScheme(20, 2), times=5) + assert sum(list(limited_scheme.get_request_iterator()), []) == list(range(10))