Skip to content

Commit

Permalink
DNS signals implementation aio-libs#2313
Browse files Browse the repository at this point in the history
  • Loading branch information
pfreixes committed Oct 25, 2017
1 parent 7f4471e commit 06e987d
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 12 deletions.
20 changes: 20 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,26 @@ def on_request_createconn_start(self):
def on_request_createconn_end(self):
return self._connector.on_createconn_end

@property
def on_request_reuseconn(self):
return self._connector.on_reuseconn

@property
def on_request_resolvehost_start(self):
return self._connector.on_resolvehost_start

@property
def on_request_resolvehost_end(self):
return self._connector.on_resolvehost_end

@property
def on_request_dnscache_hit(self):
return self._connector.on_dnscache_hit

@property
def on_request_dnscache_miss(self):
return self._connector.on_dnscache_miss

# req resp signals

@property
Expand Down
71 changes: 60 additions & 11 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def __init__(self, *, keepalive_timeout=sentinel,
self._on_queued_end = FuncSignal()
self._on_createconn_start = FuncSignal()
self._on_createconn_end = FuncSignal()
self._on_reuseconn = FuncSignal()

def __del__(self, _warnings=warnings):
if self._closed:
Expand Down Expand Up @@ -407,7 +408,10 @@ def connect(self, req, trace_context=None):
self.on_createconn_start.send(trace_context)

try:
proto = yield from self._create_connection(req)
proto = yield from self._create_connection(
req,
trace_context=trace_context
)
if self._closed:
proto.close()
raise ClientConnectionError("Connector is closed.")
Expand All @@ -424,6 +428,8 @@ def connect(self, req, trace_context=None):
self._acquired_per_host[key].remove(placeholder)

self.on_createconn_end.send(trace_context)
else:
self.on_reuseconn.send(trace_context)

self._acquired.add(proto)
self._acquired_per_host[key].add(proto)
Expand Down Expand Up @@ -518,7 +524,7 @@ def _release(self, key, protocol, *, should_close=False):
self, '_cleanup', self._keepalive_timeout, self._loop)

@asyncio.coroutine
def _create_connection(self, req):
def _create_connection(self, req, trace_context=None):
raise NotImplementedError()

@property
Expand All @@ -537,6 +543,10 @@ def on_createconn_start(self):
def on_createconn_end(self):
return self._on_createconn_end

@property
def on_reuseconn(self):
return self._on_reuseconn


_SSL_OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)

Expand Down Expand Up @@ -659,6 +669,11 @@ def __init__(self, *, verify_ssl=True, fingerprint=None,
self._family = family
self._local_addr = local_addr

self._on_resolvehost_start = FuncSignal()
self._on_resolvehost_end = FuncSignal()
self._on_dnscache_hit = FuncSignal()
self._on_dnscache_miss = FuncSignal()

def close(self):
"""Close all ongoing DNS calls."""
for ev in self._throttle_dns_events.values():
Expand Down Expand Up @@ -723,32 +738,40 @@ def clear_dns_cache(self, host=None, port=None):
self._cached_hosts.clear()

@asyncio.coroutine
def _resolve_host(self, host, port):
def _resolve_host(self, host, port, trace_context=None):
if is_ip_address(host):
return [{'hostname': host, 'host': host, 'port': port,
'family': self._family, 'proto': 0, 'flags': 0}]

if not self._use_dns_cache:
return (yield from self._resolver.resolve(
self.on_resolvehost_start.send(trace_context)
res = (yield from self._resolver.resolve(
host, port, family=self._family))
self.on_resolvehost_end.send(trace_context)
return res

key = (host, port)

if (key in self._cached_hosts) and\
(not self._cached_hosts.expired(key)):
self.on_dnscache_hit.send(trace_context)
return self._cached_hosts.next_addrs(key)

if key in self._throttle_dns_events:
self.on_dnscache_hit.send(trace_context)
yield from self._throttle_dns_events[key].wait()
else:
self.on_dnscache_miss.send(trace_context)
self._throttle_dns_events[key] = \
EventResultOrError(self._loop)
try:
self.on_resolvehost_start.send(trace_context)
addrs = yield from \
asyncio.shield(self._resolver.resolve(host,
port,
family=self._family),
loop=self._loop)
self.on_resolvehost_end.send(trace_context)
self._cached_hosts.add(key, addrs)
self._throttle_dns_events[key].set()
except Exception as e:
Expand All @@ -762,15 +785,21 @@ def _resolve_host(self, host, port):
return self._cached_hosts.next_addrs(key)

@asyncio.coroutine
def _create_connection(self, req):
def _create_connection(self, req, trace_context=None):
"""Create connection.
Has same keyword arguments as BaseEventLoop.create_connection.
"""
if req.proxy:
_, proto = yield from self._create_proxy_connection(req)
_, proto = yield from self._create_proxy_connection(
req,
trace_context=None
)
else:
_, proto = yield from self._create_direct_connection(req)
_, proto = yield from self._create_direct_connection(
req,
trace_context=None
)

return proto

Expand Down Expand Up @@ -814,11 +843,15 @@ def _get_fingerprint_and_hashfunc(self, req):
return (None, None)

@asyncio.coroutine
def _create_direct_connection(self, req):
def _create_direct_connection(self, req, trace_context=None):
sslcontext = self._get_ssl_context(req)
fingerprint, hashfunc = self._get_fingerprint_and_hashfunc(req)

hosts = yield from self._resolve_host(req.url.raw_host, req.port)
hosts = yield from self._resolve_host(
req.url.raw_host,
req.port,
trace_context=trace_context
)

for hinfo in hosts:
try:
Expand Down Expand Up @@ -859,7 +892,7 @@ def _create_direct_connection(self, req):
raise ClientConnectorError(req.connection_key, exc) from exc

@asyncio.coroutine
def _create_proxy_connection(self, req):
def _create_proxy_connection(self, req, trace_context=None):
headers = {}
if req.proxy_headers is not None:
headers = req.proxy_headers
Expand Down Expand Up @@ -937,6 +970,22 @@ def _create_proxy_connection(self, req):

return transport, proto

@property
def on_resolvehost_start(self):
return self._on_resolvehost_start

@property
def on_resolvehost_end(self):
return self._on_resolvehost_end

@property
def on_dnscache_hit(self):
return self._on_dnscache_hit

@property
def on_dnscache_miss(self):
return self._on_dnscache_miss


class UnixConnector(BaseConnector):
"""Unix socket connector.
Expand Down Expand Up @@ -970,7 +1019,7 @@ def path(self):
return self._path

@asyncio.coroutine
def _create_connection(self, req):
def _create_connection(self, req, trace_context=None):
_, proto = yield from self._loop.create_unix_connection(
self._factory, self._path)
return proto
11 changes: 10 additions & 1 deletion tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def test_reraise_os_error(create_session):
session = create_session(request_class=req_factory)

@asyncio.coroutine
def create_connection(req):
def create_connection(req, trace_context=None):
# return self.transport, self.protocol
return mock.Mock()
session._connector._create_connection = create_connection
Expand Down Expand Up @@ -585,6 +585,15 @@ def test_request_tracing_proxies_connector_signals(loop):
id(connector.on_createconn_start)
assert id(session.on_request_createconn_end) ==\
id(connector.on_createconn_end)
assert id(session.on_request_reuseconn) == id(connector.on_reuseconn)
assert id(session.on_request_resolvehost_start) ==\
id(connector.on_resolvehost_start)
assert id(session.on_request_resolvehost_end) ==\
id(connector.on_resolvehost_end)
assert id(session.on_request_dnscache_hit) ==\
id(connector.on_dnscache_hit)
assert id(session.on_request_dnscache_miss) ==\
id(connector.on_dnscache_miss)


@asyncio.coroutine
Expand Down
59 changes: 59 additions & 0 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,45 @@ def test_tcp_connector_dns_throttle_requests_cancelled_when_close(
yield from f


@asyncio.coroutine
def test_tcp_connector_dns_tracing(loop, dns_response):
trace_context = mock.Mock()
on_resolvehost_start = mock.Mock()
on_resolvehost_end = mock.Mock()
on_dnscache_hit = mock.Mock()
on_dnscache_miss = mock.Mock()

with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver:
conn = aiohttp.TCPConnector(
loop=loop,
use_dns_cache=True,
ttl_dns_cache=10
)
conn.on_resolvehost_start.append(on_resolvehost_start)
conn.on_resolvehost_end.append(on_resolvehost_end)
conn.on_dnscache_hit.append(on_dnscache_hit)
conn.on_dnscache_miss.append(on_dnscache_miss)

m_resolver().resolve.return_value = dns_response()

yield from conn._resolve_host(
'localhost',
8080,
trace_context=trace_context
)
on_resolvehost_start.assert_called_once_with(trace_context)
on_resolvehost_end.assert_called_once_with(trace_context)
on_dnscache_miss.assert_called_once_with(trace_context)
assert not on_dnscache_hit.called

yield from conn._resolve_host(
'localhost',
8080,
trace_context=trace_context
)
on_dnscache_hit.assert_called_once_with(trace_context)


def test_get_pop_empty_conns(loop):
# see issue #473
conn = aiohttp.BaseConnector(loop=loop)
Expand Down Expand Up @@ -946,6 +985,26 @@ def f():
conn.close()


@asyncio.coroutine
def test_connect_reuseconn_tracing(loop, key):
proto = mock.Mock()
proto.is_connected.return_value = True
trace_context = mock.Mock()
on_reuseconn = mock.Mock()

req = ClientRequest('GET', URL('http://localhost1:80'),
loop=loop,
response_class=mock.Mock())

conn = aiohttp.BaseConnector(loop=loop, limit=1)
conn.on_reuseconn.append(on_reuseconn)
conn._conns[key] = [(proto, loop.time())]
yield from conn.connect(req, trace_context=trace_context)

on_reuseconn.assert_called_with(trace_context)
conn.close()


@asyncio.coroutine
def test_connect_with_limit_and_limit_per_host(loop, key):
proto = mock.Mock()
Expand Down

0 comments on commit 06e987d

Please sign in to comment.