Skip to content

Commit

Permalink
Merge pull request #514 from gazpachoking/fix_min_max
Browse files Browse the repository at this point in the history
Fix min/max functions with generators, and 'None' default
  • Loading branch information
jmadler authored Oct 24, 2019
2 parents 615ee1e + f4926e5 commit 228a297
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
19 changes: 13 additions & 6 deletions src/future/builtins/new_min_max.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import itertools

from future import utils
if utils.PY2:
from __builtin__ import max as _builtin_max, min as _builtin_min
else:
from builtins import max as _builtin_max, min as _builtin_min

_SENTINEL = object()


def newmin(*args, **kwargs):
return new_min_max(_builtin_min, *args, **kwargs)
Expand All @@ -29,21 +33,24 @@ def new_min_max(_builtin_func, *args, **kwargs):
if len(args) == 0:
raise TypeError

if len(args) != 1 and kwargs.get('default') is not None:
if len(args) != 1 and kwargs.get('default', _SENTINEL) is not _SENTINEL:
raise TypeError

if len(args) == 1:
iterator = iter(args[0])
try:
next(iter(args[0]))
first = next(iterator)
except StopIteration:
if kwargs.get('default') is not None:
if kwargs.get('default', _SENTINEL) is not _SENTINEL:
return kwargs.get('default')
else:
raise ValueError('iterable is an empty sequence')
raise ValueError('{}() arg is an empty sequence'.format(_builtin_func.__name__))
else:
iterator = itertools.chain([first], iterator)
if kwargs.get('key') is not None:
return _builtin_func(args[0], key=kwargs.get('key'))
return _builtin_func(iterator, key=kwargs.get('key'))
else:
return _builtin_func(args[0])
return _builtin_func(iterator)

if len(args) > 1:
if kwargs.get('key') is not None:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_future/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,7 @@ def test_max(self):
with self.assertRaises(TypeError):
max(1, 2, default=0)
self.assertEqual(max([], default=0), 0)
self.assertIs(max([], default=None), None)

def test_min(self):
self.assertEqual(min('123123'), '1')
Expand All @@ -1123,6 +1124,7 @@ class BadSeq:
def __getitem__(self, index):
raise ValueError
self.assertRaises(ValueError, min, BadSeq())
self.assertEqual(max(x for x in [5, 4, 3]), 5)

for stmt in (
"min(key=int)", # no args
Expand All @@ -1149,11 +1151,15 @@ def __getitem__(self, index):
sorted(data, key=f)[0])
self.assertEqual(min([], default=5), 5)
self.assertEqual(min([], default=0), 0)
self.assertIs(min([], default=None), None)
with self.assertRaises(TypeError):
max(None, default=5)
with self.assertRaises(TypeError):
max(1, 2, default=0)

# Test iterables that can only be looped once #510
self.assertEqual(min(x for x in [5]), 5)

def test_next(self):
it = iter(range(2))
self.assertEqual(next(it), 0)
Expand Down

0 comments on commit 228a297

Please sign in to comment.