Skip to content

Commit

Permalink
Changed patching mpire...time.time to simply time.time
Browse files Browse the repository at this point in the history
  • Loading branch information
sybrenjansen committed Feb 14, 2024
1 parent d871676 commit 97229dc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 21 deletions.
22 changes: 9 additions & 13 deletions tests/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_init_comms(self):
self.assertIsNone(comms._progress_bar_complete)

with self.subTest('without initial values', ctx=ctx, n_jobs=n_jobs, order_tasks=order_tasks), \
patch('mpire.comms.time.time', return_value=0.0):
patch('time.time', return_value=0.0):
comms.init_comms()
self._check_comms_are_initialized(comms, n_jobs)

Expand Down Expand Up @@ -106,7 +106,7 @@ def test_init_comms(self):
comms.reset()

with self.subTest('with initial values', ctx=ctx, n_jobs=n_jobs, order_tasks=order_tasks), \
patch('mpire.comms.time.time', return_value=0.0):
patch('time.time', return_value=0.0):
comms.init_comms()
self._check_comms_are_initialized(comms, n_jobs)

Expand Down Expand Up @@ -185,15 +185,15 @@ def test_progress_bar(self):
# 3 task done, but not enough time has passed to send the update
last_updated = 0.0
n_tasks_completed = 0
with patch('mpire.comms.time.time', return_value=0.0):
with patch('time.time', return_value=0.0):
for n in range(1, 4):
last_updated, n_tasks_completed = comms.task_completed_progress_bar(0, last_updated, n_tasks_completed,
force_update=False)
self.assertEqual(n_tasks_completed, n)
self.assertEqual(sum(comms._tasks_completed_array), 0)

# Not enough time has passed, but we'll force the update. Number of tasks done should still be 3
with patch('mpire.comms.time.time', return_value=0.0):
with patch('time.time', return_value=0.0):
last_updated, n_tasks_completed = comms.task_completed_progress_bar(0, last_updated, n_tasks_completed,
force_update=True)
self.assertEqual(comms.get_tasks_completed_progress_bar(), 3)
Expand All @@ -204,7 +204,7 @@ def test_progress_bar(self):
# second. In total we have 3 (from above) + 4 + 4 = 11 tasks done
last_updated = 0.0
n_tasks_completed = 4
with patch('mpire.comms.time.time', side_effect=[1.0, 1.0, 3.0, 4.0]):
with patch('time.time', side_effect=[1.0, 1.0, 3.0, 4.0]):
for expected_last_updated in [1.0, 1.0, 3.0, 4.0]:
last_updated, n_tasks_completed = comms.task_completed_progress_bar(0, last_updated, n_tasks_completed,
force_update=False)
Expand Down Expand Up @@ -611,8 +611,7 @@ def test_timeouts(self):

# Signal workers started
for worker_id in range(5):
with self.subTest(worker_id=worker_id), \
patch('mpire.comms.time.time', side_effect=[1000.0, 2000.0, 3000.0]):
with self.subTest(worker_id=worker_id), patch('time.time', side_effect=[1000.0, 2000.0, 3000.0]):
self.assertListEqual(
comms._workers_time_task_started[worker_id * 3 : worker_id * 3 + 3], [0.0, 0.0, 0.0]
)
Expand All @@ -627,20 +626,17 @@ def test_timeouts(self):
for worker_id in range(5):
# worker_init, times out at > 10
for timeout, has_timed_out in [(8, True), (9, True), (10, True), (11, False)]:
with self.subTest(timeout=timeout, worker_id=worker_id), \
patch('mpire.comms.time.time', return_value=1010.0):
with self.subTest(timeout=timeout, worker_id=worker_id), patch('time.time', return_value=1010.0):
self.assertEqual(comms.has_worker_init_timed_out(worker_id, timeout), has_timed_out)

# task, times out at > 9
for timeout, has_timed_out in [(8, True), (9, True), (10, False), (11, False)]:
with self.subTest(timeout=timeout, worker_id=worker_id), \
patch('mpire.comms.time.time', return_value=2009.0):
with self.subTest(timeout=timeout, worker_id=worker_id), patch('time.time', return_value=2009.0):
self.assertEqual(comms.has_worker_task_timed_out(worker_id, timeout), has_timed_out)

# worker_exit, times out at > 8
for timeout, has_timed_out in [(8, True), (9, False), (10, False), (11, False)]:
with self.subTest(timeout=timeout, worker_id=worker_id), \
patch('mpire.comms.time.time', return_value=3008.0):
with self.subTest(timeout=timeout, worker_id=worker_id), patch('time.time', return_value=3008.0):
self.assertEqual(comms.has_worker_exit_timed_out(worker_id, timeout), has_timed_out)

# Reset
Expand Down
13 changes: 5 additions & 8 deletions tests/test_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,14 @@ def test_update_start_up_time(self):
insights = WorkerInsights(mp.get_context(DEFAULT_START_METHOD), n_jobs=5)

# Shouldn't do anything when insights haven't been enabled
with self.subTest(insights_enabled=False), \
patch('mpire.insights.time.time', side_effect=[1.0, 2.0, 3.0, 7.0, 8.0]):
with self.subTest(insights_enabled=False), patch('time.time', side_effect=[1.0, 2.0, 3.0, 7.0, 8.0]):
for worker_id in range(5):
insights.update_start_up_time(worker_id, 1.0)
self.assertIsNone(insights.worker_start_up_time)

insights.reset_insights(enable_insights=True)

with self.subTest(insights_enabled=True), \
patch('mpire.insights.time.time', side_effect=[1.0, 2.0, 3.0, 7.0, 8.0]):
with self.subTest(insights_enabled=True), patch('time.time', side_effect=[1.0, 2.0, 3.0, 7.0, 8.0]):
for worker_id in range(5):
insights.update_start_up_time(worker_id, 1.0)
self.assertListEqual(list(insights.worker_start_up_time), [0.0, 1.0, 2.0, 6.0, 7.0])
Expand Down Expand Up @@ -267,8 +265,7 @@ def test_update_task_insights_not_forced(self):
max_task_duration_last_updated = 1.0

# The first three worker IDs won't send an update because the two seconds hasn't passed yet.
with self.subTest(insights_enabled=True), \
patch('mpire.insights.time.time', side_effect=[0.0, 2.0, 3.0, 7.0, 8.0]):
with self.subTest(insights_enabled=True), patch('time.time', side_effect=[0.0, 2.0, 3.0, 7.0, 8.0]):
last_updated_times = []
for worker_id, max_task_duration_list in [
(0, [(0.1, '0'), (0.2, '1'), (0.3, '2'), (0.4, '3'), (0.5, '4')]),
Expand Down Expand Up @@ -300,7 +297,7 @@ def test_update_task_insights_forced(self):
max_task_duration_last_updated = 0.0

# Shouldn't do anything when insights haven't been enabled
with self.subTest(insights_enabled=False), patch('mpire.insights.time.time', side_effect=[1.0, 2.0]):
with self.subTest(insights_enabled=False), patch('time.time', side_effect=[1.0, 2.0]):
for worker_id in range(2):
max_task_duration_last_updated = insights.update_task_insights(
worker_id, max_task_duration_last_updated, [(0.1, '1'), (0.2, '2')], force_update=True
Expand All @@ -312,7 +309,7 @@ def test_update_task_insights_forced(self):
insights.reset_insights(enable_insights=True)
max_task_duration_last_updated = 0.0

with self.subTest(insights_enabled=True), patch('mpire.insights.time.time', side_effect=[1, 2]):
with self.subTest(insights_enabled=True), patch('time.time', side_effect=[1, 2]):
for worker_id, max_task_duration_list in [
(0, [(5.0, '5'), (6.0, '6'), (7.0, '7'), (8.0, '8'), (9.0, '9')]),
(1, [(0.0, '0'), (1.0, '1'), (2.0, '2'), (3.0, '3'), (4.0, '4')])
Expand Down

0 comments on commit 97229dc

Please sign in to comment.