This repository has been archived by the owner on Apr 26, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest_task_runners.py
354 lines (294 loc) · 11 KB
/
test_task_runners.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
import asyncio
import logging
import subprocess
import sys
import time
import warnings
from functools import partial
from uuid import uuid4
import prefect
import prefect.engine
import pytest
import ray
import ray.cluster_utils
from prefect import flow, get_run_logger, task
from prefect.states import State, StateType
from prefect.tasks import TaskRun
from prefect.testing.fixtures import ( # noqa: F401
hosted_api_server,
use_hosted_api_server,
)
from prefect.testing.standard_test_suites import TaskRunnerStandardTestSuite
from prefect.testing.utilities import exceptions_equal
from ray.exceptions import TaskCancelledError
import tests
from prefect_ray import RayTaskRunner
from prefect_ray.context import remote_options
@pytest.fixture(scope="session")
def event_loop(request):
"""
Redefine the event loop to support session/module-scoped fixtures;
see https://github.com/pytest-dev/pytest-asyncio/issues/68
When running on Windows we need to use a non-default loop for subprocess support.
"""
if sys.platform == "win32" and sys.version_info >= (3, 8):
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
policy = asyncio.get_event_loop_policy()
if sys.version_info < (3, 8) and sys.platform != "win32":
from prefect.utilities.compat import ThreadedChildWatcher
# Python < 3.8 does not use a `ThreadedChildWatcher` by default which can
# lead to errors in tests as the previous default `SafeChildWatcher` is not
# compatible with threaded event loops.
policy.set_child_watcher(ThreadedChildWatcher())
loop = policy.new_event_loop()
# configure asyncio logging to capture long running tasks
asyncio_logger = logging.getLogger("asyncio")
asyncio_logger.setLevel("WARNING")
asyncio_logger.addHandler(logging.StreamHandler())
loop.set_debug(True)
loop.slow_callback_duration = 0.25
try:
yield loop
finally:
loop.close()
# Workaround for failures in pytest_asyncio 0.17;
# see https://github.com/pytest-dev/pytest-asyncio/issues/257
policy.set_event_loop(loop)
@pytest.fixture(scope="module")
def machine_ray_instance():
"""
Starts a ray instance for the current machine
"""
subprocess.check_call(
["ray", "start", "--head", "--include-dashboard", "False"],
cwd=str(prefect.__development_base_path__),
)
try:
yield "ray://127.0.0.1:10001"
finally:
subprocess.run(["ray", "stop"])
@pytest.fixture
def default_ray_task_runner():
with warnings.catch_warnings():
# Ray does not properly close resources and we do not want their warnings to
# bubble into our test suite
# https://github.com/ray-project/ray/pull/22419
warnings.simplefilter("ignore", ResourceWarning)
yield RayTaskRunner()
@pytest.fixture
def ray_task_runner_with_existing_cluster(
machine_ray_instance, use_hosted_api_server, hosted_api_server # noqa: F811
):
"""
Generate a ray task runner that's connected to a ray instance running in a separate
process.
This tests connection via `ray://` which is a client-based connection.
"""
yield RayTaskRunner(
address=machine_ray_instance,
init_kwargs={
"runtime_env": {
# Ship the 'tests' module to the workers or they will not be able to
# deserialize test tasks / flows
"py_modules": [tests]
}
},
)
@pytest.fixture(scope="module")
def inprocess_ray_cluster():
"""
Starts a ray cluster in-process
"""
cluster = ray.cluster_utils.Cluster(initialize_head=True)
try:
cluster.add_node() # We need to add a second node for parallelism
yield cluster
finally:
cluster.shutdown()
@pytest.fixture
def ray_task_runner_with_inprocess_cluster(
inprocess_ray_cluster, use_hosted_api_server, hosted_api_server # noqa: F811
):
"""
Generate a ray task runner that's connected to an in-process cluster.
This tests connection via 'localhost' which is not a client-based connection.
"""
yield RayTaskRunner(
address=inprocess_ray_cluster.address,
init_kwargs={
"runtime_env": {
# Ship the 'tests' module to the workers or they will not be able to
# deserialize test tasks / flows
"py_modules": [tests]
}
},
)
@pytest.fixture
def ray_task_runner_with_temporary_cluster(
use_hosted_api_server, hosted_api_server # noqa: F811
):
"""
Generate a ray task runner that creates a temporary cluster.
This tests connection via 'localhost' which is not a client-based connection.
"""
yield RayTaskRunner(
init_kwargs={
"runtime_env": {
# Ship the 'tests' module to the workers or they will not be able to
# deserialize test tasks / flows
"py_modules": [tests]
}
},
)
class TestRayTaskRunner(TaskRunnerStandardTestSuite):
@pytest.fixture(
params=[
default_ray_task_runner,
ray_task_runner_with_existing_cluster,
ray_task_runner_with_inprocess_cluster,
ray_task_runner_with_temporary_cluster,
]
)
def task_runner(self, request):
yield request.getfixturevalue(
request.param._pytestfixturefunction.name or request.param.__name__
)
def get_sleep_time(self) -> float:
"""
Return an amount of time to sleep for concurrency tests.
The RayTaskRunner is prone to flaking on concurrency tests.
"""
return 5.0
@pytest.mark.parametrize("exception", [KeyboardInterrupt(), ValueError("test")])
async def test_wait_captures_exceptions_as_crashed_state(
self, task_runner, exception
):
"""
Ray wraps the exception, interrupts will result in "Cancelled" tasks
or "Killed" workers while normal errors will result in a "RayTaskError".
We care more about the crash detection and
lack of re-raise here than the equality of the exception.
"""
async def fake_orchestrate_task_run(task_run):
raise exception
task_run = TaskRun(
flow_run_id=uuid4(), task_key=str(uuid4()), dynamic_key="bar"
)
async with task_runner.start():
await task_runner.submit(
call=partial(fake_orchestrate_task_run, task_run=task_run),
key=task_run.id,
)
state = await task_runner.wait(task_run.id, 5)
assert state is not None, "wait timed out"
assert isinstance(state, State), "wait should return a state"
assert state.name == "Crashed"
@pytest.mark.parametrize(
"exceptions",
[
(KeyboardInterrupt(), TaskCancelledError),
(ValueError("test"), ValueError),
],
)
async def test_exception_to_crashed_state_in_flow_run(
self, exceptions, task_runner, monkeypatch
):
(raised_exception, state_exception_type) = exceptions
async def throws_exception_before_task_begins(
task, task_run, parameters, wait_for, result_factory, settings, **kwds
):
"""
Simulates an exception occurring while a remote task runner is attempting
to unpickle and run a Prefect task.
"""
raise raised_exception
monkeypatch.setattr(
prefect.engine, "begin_task_run", throws_exception_before_task_begins
)
@task()
def test_task():
logger = get_run_logger()
logger.info("Ray should raise an exception before this task runs.")
@flow(task_runner=task_runner)
def test_flow():
future = test_task.submit()
future.wait(10)
# ensure that the type of exception raised by the flow matches the type of
# exception we expected the task runner to receive.
with pytest.raises(state_exception_type) as exc:
test_flow()
# If Ray passes the same exception type back, it should pass
# the equality check
if type(raised_exception) == state_exception_type:
assert exceptions_equal(raised_exception, exc)
def test_flow_and_subflow_both_with_task_runner(self, task_runner, tmp_file):
@task
def some_task(text):
tmp_file.write_text(text)
@flow(task_runner=RayTaskRunner())
def subflow():
some_task.submit("a")
some_task.submit("b")
some_task.submit("c")
@flow(task_runner=task_runner)
def base_flow():
subflow()
time.sleep(self.get_sleep_time())
some_task.submit("d")
base_flow()
assert tmp_file.read_text() == "d"
def test_ray_options(self):
@task
def process(x):
return x + 1
@flow(task_runner=RayTaskRunner())
def my_flow():
# equivalent to setting @ray.remote(max_calls=1)
with remote_options(max_calls=1):
process.submit(42)
my_flow()
def test_dependencies(self):
@task
def a():
time.sleep(self.get_sleep_time())
b = c = d = e = a
@flow(task_runner=RayTaskRunner())
def flow_with_dependent_tasks():
for _ in range(3):
a_future = a.submit(wait_for=[])
b_future = b.submit(wait_for=[a_future])
c.submit(wait_for=[b_future])
d.submit(wait_for=[b_future])
e.submit(wait_for=[b_future])
flow_with_dependent_tasks()
def test_sync_task_timeout(self, task_runner):
"""
This test is inherited from the prefect testing module and it may not
appropriately skip on Windows. Here we skip it explicitly.
"""
if sys.platform.startswith("win"):
pytest.skip("cancellation due to timeouts is not supported on Windows")
super().test_async_task_timeout(task_runner)
async def test_submit_and_wait(self, task_runner):
"""
This test is inherited from the prefect testing module. The key difference
here is that task_runner is waiting longer than 5 seconds.
"""
MAX_WAIT_TIME = 60
task_run = TaskRun(flow_run_id=uuid4(), task_key="foo", dynamic_key="bar")
async def fake_orchestrate_task_run(example_kwarg, task_run):
return State(
type=StateType.COMPLETED,
data=example_kwarg,
)
async with task_runner.start():
await task_runner.submit(
key=task_run.id,
call=partial(
fake_orchestrate_task_run, task_run=task_run, example_kwarg=1
),
)
state = await task_runner.wait(task_run.id, MAX_WAIT_TIME)
assert state is not None, "wait timed out"
assert isinstance(state, State), "wait should return a state"
assert await state.result() == 1