Skip to content

Commit

Permalink
correction for TruncatedEpochScheme (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
Melanie Ducoffe committed Jun 24, 2015
1 parent 88ccff7 commit ee50df1
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
19 changes: 18 additions & 1 deletion fuel/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import Iterable

import numpy
from picklable_itertools import chain, repeat, imap, iter_
from picklable_itertools import chain, repeat, imap, iter_, islice
from picklable_itertools.extras import partition_all
from six import add_metaclass
from six.moves import xrange
Expand Down Expand Up @@ -280,3 +280,20 @@ def cross_validation(scheme_class, num_examples, num_folds, strict=True,
yield (train, valid)
else:
yield (train, valid, end - begin)


@add_metaclass(ABCMeta)
class TruncatedEpochScheme(IterationScheme):
"""limiting the number of excursion in an iterator
Returns elements from an iterator with an early stopping
"""
def __init__(self, iteration_scheme, times):
self.iteration_scheme = iteration_scheme
if times < 1:
raise ValueError("times is a positive number")
self.times = times

def get_request_iterator(self):
it = self.iteration_scheme.get_request_iterator()
return islice(it, start=0, stop=self.times)
7 changes: 6 additions & 1 deletion tests/test_schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fuel.schemes import (ConstantScheme, SequentialExampleScheme,
SequentialScheme, ShuffledExampleScheme,
ShuffledScheme, ConcatenatedScheme,
cross_validation)
cross_validation, TruncatedEpochScheme)


def iterator_requester(scheme):
Expand Down Expand Up @@ -145,3 +145,8 @@ def test_cross_validation():
assert list(valid.get_request_iterator()) == [[4, 5], [6, 7]]

assert_raises(StopIteration, next, cross)


def test_truncated_epoch_scheme():
limited_scheme = TruncatedEpochScheme(SequentialScheme(20, 2), times=5)
assert sum(list(limited_scheme.get_request_iterator()), []) == list(range(10))

0 comments on commit ee50df1

Please sign in to comment.