diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 15cb4685e18a1..0bd5d20f7877f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -133,8 +133,8 @@ def test_group_by_key(self): def gen_data(N, step): for i in range(1, N + 1, step): - for j in range(i * 10): - yield (i, j) + for j in range(i): + yield (i, [j]) def gen_gs(N, step=1): return shuffle.GroupByKey(gen_data(N, step)) @@ -143,20 +143,17 @@ def gen_gs(N, step=1): self.assertEqual(2, len(list(gen_gs(2)))) self.assertEqual(100, len(list(gen_gs(100)))) self.assertEqual(range(1, 101), [k for k, _ in gen_gs(100)]) - self.assertTrue(all(k * 10 == len(list(vs)) for k, vs in gen_gs(100))) + self.assertTrue(all(range(k) == list(vs) for k, vs in gen_gs(100))) - for k, vs in gen_gs(5002, 100): - if k % 1000 == 1: - self.assertEqual(range(k), list(itertools.islice(vs, k))) - self.assertEqual(k * 10, sum(1 for _ in vs)) - self.assertEqual(range(k * 9, k * 10), list(itertools.islice(vs, k * 9, k * 10))) - self.assertEqual(k * 10, sum(1 for _ in vs)) + for k, vs in gen_gs(50002, 10000): + self.assertEqual(k, len(vs)) + self.assertEqual(range(k), list(vs)) ser = PickleSerializer() - l = ser.loads(ser.dumps(list(gen_gs(5002, 1000)))) + l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) for k, vs in l: - self.assertEqual(k * 10, len(vs)) - self.assertEqual(range(k * 10), list(vs)) + self.assertEqual(k, len(vs)) + self.assertEqual(range(k), list(vs)) class SorterTests(unittest.TestCase):