Skip to content

Commit

Permalink
use a mock.patch instead of a raw socket in test_web_preload(_worker)
Browse files Browse the repository at this point in the history
  • Loading branch information
graingert committed Apr 7, 2022
1 parent f7c465d commit 70ca320
Showing 1 changed file with 30 additions and 48 deletions.
78 changes: 30 additions & 48 deletions distributed/tests/test_preload.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import asyncio
import os
import re
import shutil
import socket
import sys
import tempfile
from textwrap import dedent
from unittest import mock

import pytest

import dask

from distributed import Client, Nanny, Scheduler, Worker
from distributed.compatibility import to_thread
from distributed.utils_test import captured_logger, cluster, gen_cluster, gen_test

PRELOAD_TEXT = """
Expand Down Expand Up @@ -157,38 +155,30 @@ async def test_preload_import_time(cleanup):

@gen_test()
async def test_web_preload():
def preload():
with server_sock.accept()[0] as client_sock, client_sock.makefile("rwb") as f:
assert f.readline() == b"GET /preload HTTP/1.1\r\n"
f.write(
b"HTTP/1.1 200 OK"
b"\r\nContent-Length: 53"
b"\r\n\r\n"
b"def dask_setup(dask_server):"
b"\n dask_server.foo = 1"
b"\n"
)

async def test():
port = server_sock.getsockname()[1]
with captured_logger("distributed.preloading") as log:
async with Scheduler(
host="localhost",
preload=[f"http://127.0.0.1:{port}/preload"],
) as s:
assert s.foo == 1
with mock.patch(
"urllib3.PoolManager.request",
**{
"return_value.data": b"def dask_setup(dask_server):"
b"\n dask_server.foo = 1"
b"\n"
},
) as request, captured_logger("distributed.preloading") as log:
async with Scheduler(
host="localhost", preload=["http://example.com/preload"]
) as s:
assert s.foo == 1
assert (
re.match(
r"(?s).*Downloading preload at http://127.0.0.1:\d+/preload\n"
r".*Run preload setup function: http://127.0.0.1:\d+/preload\n"
r"(?s).*Downloading preload at http://example.com/preload\n"
r".*Run preload setup function: http://example.com/preload\n"
r".*",
log.getvalue(),
)
is not None
)

with socket.create_server(("127.0.0.1", 0)) as server_sock:
await asyncio.gather(to_thread(preload), test())
assert request.mock_calls == [
mock.call(method="GET", url="http://example.com/preload", retries=mock.ANY)
]


@gen_cluster(nthreads=[])
Expand All @@ -213,28 +203,20 @@ async def test_scheduler_startup_nanny(s):

@gen_test()
async def test_web_preload_worker():
def preload():
with server_sock.accept()[0] as client_sock, client_sock.makefile("rwb") as f:
assert f.readline() == b"GET /preload HTTP/1.1\r\n"
f.write(
b"HTTP/1.1 200 OK"
b"\r\nContent-Length: 70"
b"\r\n\r\n"
b"import dask"
b'\ndask.config.set(scheduler_address="tcp://127.0.0.1:8786")'
b"\n"
)

async def test():
port = server_sock.getsockname()[1]
with mock.patch(
"urllib3.PoolManager.request",
**{
"return_value.data": b"import dask"
b'\ndask.config.set(scheduler_address="tcp://127.0.0.1:8786")'
b"\n"
},
) as request:
async with Scheduler(port=8786, host="localhost") as s:
async with Nanny(
preload_nanny=[f"http://127.0.0.1:{port}/preload"]
) as nanny:
async with Nanny(preload_nanny=["http://example.com/preload"]) as nanny:
assert nanny.scheduler_addr == s.address

with socket.create_server(("127.0.0.1", 0)) as server_sock:
await asyncio.gather(to_thread(preload), test())
assert request.mock_calls == [
mock.call(method="GET", url="http://example.com/preload", retries=mock.ANY)
]


# This test is blocked on https://github.com/dask/distributed/issues/5819
Expand Down

0 comments on commit 70ca320

Please sign in to comment.