Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Async implementation of driver/adapter #171

Merged
merged 5 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,27 @@ jobs:
command: |
. venv/bin/activate
python -m pytest tests/integrations
asyncio-py39:
docker:
- image: circleci/python:3.9
steps:
- checkout
- run:
name: install hamilton dependencies + testing dependencies
command: |
python -m venv venv || virtualenv venv
. venv/bin/activate
python --version
pip --version
pip install -e .
pip install -r graph_adapter_tests/h_async/requirements-test.txt

# run tests!
- run:
name: run tests
command: |
. venv/bin/activate
python -m pytest graph_adapter_tests/h_async
workflows:
version: 2
unit-test-workflow:
Expand All @@ -299,3 +320,4 @@ workflows:
- integrations-py37
- integrations-py38
- integrations-py39
- asyncio-py39
53 changes: 53 additions & 0 deletions examples/async/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Hamilton + Async

This is currently an experimental feature, allowing one to run a hamilton DAG composed (entirely or partially) of async functions.

## How to use

See the [example](fastapi.py) for the example. The difference from a normal driver is two-fold:

1. You call it using the `AsyncDriver` rather than the standard driver
2. `raw_execute`, and `execute` are both coroutines, meaning they should be called with `await`.
3. It allows for tasks as inputs -- they'll get properly awaited

To run the example, make sure to install `requirements.txt`.

Then run `uvicorn fastapi_example:app` in one terminal.

Then curl with:

```bash
curl -X 'POST' \
'http://localhost:8000/execute' \
-H 'accept: application/json' \
-d '{}'
```

You should get the following result:

```json
{"pipeline":{"computation1":false,"computation2":true}}
```


## How it works

Behind the scenes, we create a [GraphAdapter](../../hamilton/experimental/h_async.py)
that turns every function into a coroutine. The function graph then executes, creating tasks for each node,
that are awaited at the end. Thus no computation is complete until a final node
is awaited.

Any node inputs are awaited on prior to node computation if they are awaitable, so you can pass
in external tasks as inputs if you want.

## Caveats

1. This will break in certain cases when decorating an async function (E.G. with `extract_outputs`).
This is because the output of that function is never awaited during delegation. We are looking into ways to fix this,
but for now be careful. We will at least be adding validation so the errors are clearer.
2. Performance *should* be close to optimal but we have not benchmarked. We welcome contributions.

We want feedback! We can determine how to make this part of the core API once we get userse who are happy,
so have some fun!

Fixing the caveats will be the first order of business, and adding validations when things won't work.
36 changes: 36 additions & 0 deletions examples/async/async_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import asyncio
import aiohttp
import fastapi


async def request_raw(request: fastapi.Request) -> dict:
return await request.json()


def foo(request_raw: dict) -> str:
return request_raw.get('foo', 'far')


def bar(request_raw: dict) -> str:
return request_raw.get('bar', 'baz')


async def computation1(foo: str, some_data: dict) -> bool:
await asyncio.sleep(1)
return False


async def some_data() -> dict:
async with aiohttp.ClientSession() as session:
async with session.get('http://httpbin.org/get') as resp:
return await resp.json()


async def computation2(bar: str) -> bool:
await asyncio.sleep(1)
return True


async def pipeline(computation1: bool, computation2: bool) -> dict:
await asyncio.sleep(1)
return {'computation1': computation1, 'computation2': computation2}
16 changes: 16 additions & 0 deletions examples/async/fastapi_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import fastapi

from hamilton.experimental import h_async
from . import async_module

app = fastapi.FastAPI()


@app.post('/execute')
async def call(
request: fastapi.Request
) -> dict:
"""Handler for pipeline call"""
dr = h_async.AsyncDriver({}, async_module)
input_data = {'request': request}
return await dr.raw_execute(['pipeline'], inputs=input_data)
3 changes: 3 additions & 0 deletions examples/async/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
aiohttp
fastapi
uvicorn
Empty file.
1 change: 1 addition & 0 deletions graph_adapter_tests/h_async/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest-asyncio
Empty file.
28 changes: 28 additions & 0 deletions graph_adapter_tests/h_async/resources/simple_async_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import asyncio
from typing import Dict

import hamilton.function_modifiers


async def simple_async_func(external_input: int) -> int:
await asyncio.sleep(.01)
return external_input + 1


async def async_func_with_param(simple_async_func: int, external_input: int) -> int:
await asyncio.sleep(.01)
return simple_async_func + external_input + 1


def simple_non_async_func(simple_async_func: int, async_func_with_param: int) -> int:
return simple_async_func + async_func_with_param + 1


async def another_async_func(simple_non_async_func: int) -> int:
await asyncio.sleep(.01)
return simple_non_async_func + 1


@hamilton.function_modifiers.extract_fields(dict(result_1=int, result_2=int))
def non_async_func_with_decorator(async_func_with_param: int, another_async_func: int) -> Dict[str, int]:
return {'result_1': another_async_func + 1, 'result_2': async_func_with_param + 1}
57 changes: 57 additions & 0 deletions graph_adapter_tests/h_async/test_h_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import asyncio
import pdb

import pytest

from hamilton.experimental import h_async
from .resources import simple_async_module


async def async_identity(n: int) -> int:
await asyncio.sleep(.01)
return n


@pytest.mark.asyncio
async def test_await_dict_of_coroutines():
tasks = {n: async_identity(n) for n in range(0, 10)}
results = await h_async.await_dict_of_tasks(tasks)
assert results == {n: await async_identity(n) for n in range(0, 10)}


@pytest.mark.asyncio
async def test_await_dict_of_tasks():
tasks = {n: asyncio.create_task(async_identity(n)) for n in range(0, 10)}
results = await h_async.await_dict_of_tasks(tasks)
assert results == {n: await async_identity(n) for n in range(0, 10)}


# The following are not parameterized as we need to use the event loop -- fixtures will complicate this
@pytest.mark.asyncio
async def test_process_value_raw():
assert await h_async.process_value(1) == 1


@pytest.mark.asyncio
async def test_process_value_coroutine():
assert await h_async.process_value(async_identity(1)) == 1


@pytest.mark.asyncio
async def test_process_value_task():
assert await h_async.process_value(asyncio.create_task(async_identity(1))) == 1


@pytest.mark.asyncio
async def test_driver_end_to_end():
dr = h_async.AsyncDriver({}, simple_async_module)
all_vars = [var.name for var in dr.list_available_variables()]
result = await dr.raw_execute(final_vars=all_vars, inputs={'external_input': 1})
assert result == {'another_async_func': 8,
'async_func_with_param': 4,
'external_input': 1,
'non_async_func_with_decorator': {'result_1': 9, 'result_2': 5},
'result_1': 9,
'result_2': 5,
'simple_async_func': 2,
'simple_non_async_func': 7}
13 changes: 4 additions & 9 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import logging
from datetime import datetime
from typing import Dict, Collection, List, Any
from types import ModuleType

import pandas as pd

# required if we want to run this code stand alone.
import typing

from dataclasses import dataclass, field
from datetime import datetime
from types import ModuleType
from typing import Dict, Collection, List, Any

from hamilton import node
import pandas as pd

SLACK_ERROR_MESSAGE = (
'-------------------------------------------------------------------\n'
Expand Down Expand Up @@ -266,7 +262,6 @@ def what_is_upstream_of(self, *node_names: str) -> List[Variable]:
upstream_nodes, _ = self.graph.get_upstream_nodes(list(node_names))
return [Variable(node.name, node.type, node.tags) for node in upstream_nodes]


if __name__ == '__main__':
"""some example test code"""
import sys
Expand Down
Loading