Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Update tests #7

Merged
merged 27 commits into from
Jun 3, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ jobs:
- "3.7"
- "3.8"
- "3.9"
- "3.10"
# temporarily remove because Ray doesn't support yet
# https://github.com/ray-project/ray/issues/19116
# - "3.10"
fail-fast: false
steps:
- uses: actions/checkout@v3
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ repos:
- id: isort
language_version: python3
- repo: https://github.com/psf/black
rev: 22.1.0
rev: 22.3.0
hooks:
- id: black
language_version: python3
Expand All @@ -18,4 +18,4 @@ repos:
hooks:
- id: interrogate
args: [-vv]
pass_filenames: false
pass_filenames: false
26 changes: 16 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,25 @@ pip install prefect-ray
### Write and run a flow

```python
from prefect import flow
from prefect_ray.tasks import (
goodbye_prefect_ray,
hello_prefect_ray,
)
from prefect import flow, task
from prefect.task_runners import RayTaskRunner

@task
def say_hello(name):
print(f"hello {name}")

@flow
def example_flow():
hello_prefect_ray
goodbye_prefect_ray
@task
def say_goodbye(name):
print(f"goodbye {name}")

example_flow()
@flow(task_runner=RayTaskRunner())
def greetings(names):
for name in names:
say_hello(name)
say_goodbye(name)

if __name__ == "__main__":
greetings(["arthur", "trillian", "ford", "marvin"])
```

## Resources
Expand Down
3 changes: 1 addition & 2 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,5 @@ plugins:

nav:
- Home: index.md
- Tasks: tasks.md
- Flows: flows.md
- Task Runners: task_runners.md

2 changes: 1 addition & 1 deletion prefect_ray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
__version__ = _version.get_versions()["version"]


from .task_runners import RayTaskRunner
from .task_runners import RayTaskRunner # noqa
99 changes: 54 additions & 45 deletions prefect_ray/task_runners.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,43 @@
"""
Interface and implementations of the Ray Task Runner.
[Task Runners](/concepts/task-runners/) in Prefect are
responsible for managing the execution of Prefect task runs.
Generally speaking, users are not expected to interact with
task runners outside of configuring and initializing them for a flow.

Example:
>>> from prefect import flow, task
>>> from prefect.task_runners import SequentialTaskRunner
>>> from typing import List
>>>
>>> @task
>>> def say_hello(name):
... print(f"hello {name}")
>>>
>>> @task
>>> def say_goodbye(name):
... print(f"goodbye {name}")
>>>
>>> @flow(task_runner=SequentialTaskRunner())
>>> def greetings(names: List[str]):
... for name in names:
... say_hello(name)
... say_goodbye(name)

Switching to a `RayTaskRunner`:
>>> from prefect.task_runners import RayTaskRunner
>>> flow.task_runner = RayTaskRunner()
>>> greetings(["arthur", "trillian", "ford", "marvin"])
hello arthur
goodbye arthur
hello trillian
hello ford
goodbye marvin
hello marvin
goodbye ford
goodbye trillian
"""

from contextlib import AsyncExitStack
from typing import Any, Awaitable, Callable, Dict, Optional
from uuid import UUID
Expand All @@ -8,31 +48,25 @@
from prefect.orion.schemas.core import TaskRun
from prefect.orion.schemas.states import State
from prefect.states import exception_to_crashed_state
from prefect.task_runners import BaseTaskRunner, R
from prefect.task_runners import BaseTaskRunner, R, TaskConcurrencyType
from prefect.utilities.asyncio import A, sync_compatible


class RayTaskRunner(BaseTaskRunner):
"""
A parallel task_runner that submits tasks to `ray`.

By default, a temporary Ray cluster is created for the duration of the flow run.

Alternatively, if you already have a `ray` instance running, you can provide
the connection URL via the `address` kwarg.

Args:
address (string, optional): Address of a currently running `ray` instance; if
one is not provided, a temporary instance will be created.
init_kwargs (dict, optional): Additional kwargs to use when calling `ray.init`.

Examples:

Using a temporary local ray cluster:
>>> from prefect import flow
>>> from prefect.task_runners import RayTaskRunner
>>> @flow(task_runner=RayTaskRunner)

Connecting to an existing ray instance:
>>> RayTaskRunner(address="ray://192.0.2.255:8786")
"""
Expand All @@ -54,6 +88,10 @@ def __init__(

super().__init__()

@property
def concurrency_type(self) -> TaskConcurrencyType:
return TaskConcurrencyType.PARALLEL

async def submit(
self,
task_run: TaskRun,
Expand Down Expand Up @@ -94,24 +132,6 @@ async def wait(

return result

@property
def _ray(self) -> "ray":
"""
Delayed import of `ray` allowing configuration of the task runner
without the extra installed and improves `prefect` import times.
"""
global ray

if ray is None:
try:
import ray
except ImportError as exc:
raise RuntimeError(
"Using the `RayTaskRunner` requires `ray` to be installed."
) from exc

return ray

async def _start(self, exit_stack: AsyncExitStack):
"""
Start the task runner and prep for context exit.
Expand All @@ -129,37 +149,26 @@ async def _start(self, exit_stack: AsyncExitStack):
self.logger.info("Creating a local Ray instance")
init_args = ()

# When connecting to an out-of-process cluster (e.g. ray://ip) this returns a
# `ClientContext` otherwise it returns a `dict`.
context_or_metadata = self._ray.init(*init_args, **self.init_kwargs)
if isinstance(context_or_metadata, dict):
metadata = context_or_metadata
context = None
else:
metadata = None # TODO: Some of this may be retrievable from the client ctx
context = context_or_metadata

# Shutdown differs depending on the connection type
if context:
# Just disconnect the client
exit_stack.push(context)
else:
# Shutdown ray
exit_stack.push_async_callback(self._shutdown_ray)
context = ray.init(*init_args, **self.init_kwargs)
dashboard_url = getattr(context, "dashboard_url", None)
exit_stack.push(context)

# Display some information about the cluster
nodes = ray.nodes()
living_nodes = [node for node in nodes if node.get("alive")]
self.logger.info(f"Using Ray cluster with {len(living_nodes)} nodes.")

if metadata and metadata.get("webui_url"):
if dashboard_url:
self.logger.info(
f"The Ray UI is available at {metadata['webui_url']}",
f"The Ray UI is available at {dashboard_url}",
)

async def _shutdown_ray(self):
"""
Shuts down the cluster.
"""
self.logger.debug("Shutting down Ray cluster...")
self._ray.shutdown()
ray.shutdown()

def _get_ray_ref(self, prefect_future: PrefectFuture) -> "ray.ObjectRef":
"""
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
prefect>=2.0a13
ray[default] >= 1.9
prefect>=2.0a5
ray[default] >= 1.12.0; python_version >= '3.7' and python_version < '3.10' and platform_machine == "x86_64"

3 changes: 2 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pytest
black
flake8
flaky
mypy
mkdocs
mkdocs-material
Expand All @@ -11,4 +12,4 @@ pytest-asyncio
mock; python_version < '3.8'
mkdocs-gen-files
interrogate
coverage
coverage
6 changes: 6 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ parentdir_prefix =
[tool:interrogate]
ignore-init-module = True
exclude = prefect_ray/_version.py, tests, setup.py, versioneer.py, docs, site
ignore_init_method = True
ignore_regex = submit,wait,concurrency_type
fail-under = 95
omit-covered-files = True

Expand All @@ -35,3 +37,7 @@ fail_under = 80

[tool:pytest]
asyncio_mode = auto

markers =
service(arg): a service integration test. For example 'docker'
enable_orion_handler: by default, sending logs to the API is disabled. Tests marked with this use the handler.
Empty file added tests/__init__.py
Empty file.
Loading