Skip to content

Commit

Permalink
Fix AsyncResolver to match ThreadedResolver behavior (#8270)
Browse files Browse the repository at this point in the history
Co-authored-by: Sviatoslav Sydorenko (Святослав Сидоренко) <sviat@redhat.com>
  • Loading branch information
bdraco and webknjaz authored Apr 5, 2024
1 parent 28f1fd8 commit 012f986
Show file tree
Hide file tree
Showing 10 changed files with 334 additions and 76 deletions.
9 changes: 9 additions & 0 deletions CHANGES/8270.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Fix ``AsyncResolver`` to match ``ThreadedResolver`` behavior
-- by :user:`bdraco`.

On system with IPv6 support, the :py:class:`~aiohttp.resolver.AsyncResolver` would not fallback
to providing A records when AAAA records were not available.
Additionally, unlike the :py:class:`~aiohttp.resolver.ThreadedResolver`, the :py:class:`~aiohttp.resolver.AsyncResolver`
did not handle link-local addresses correctly.

This change makes the behavior consistent with the :py:class:`~aiohttp.resolver.ThreadedResolver`.
28 changes: 27 additions & 1 deletion aiohttp/abc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import socket
from abc import ABC, abstractmethod
from collections.abc import Sized
from http.cookies import BaseCookie, Morsel
Expand All @@ -13,6 +14,7 @@
List,
Optional,
Tuple,
TypedDict,
)

from multidict import CIMultiDict
Expand Down Expand Up @@ -117,11 +119,35 @@ def __await__(self) -> Generator[Any, None, StreamResponse]:
"""Execute the view handler."""


class ResolveResult(TypedDict):
"""Resolve result.
This is the result returned from an AbstractResolver's
resolve method.
:param hostname: The hostname that was provided.
:param host: The IP address that was resolved.
:param port: The port that was resolved.
:param family: The address family that was resolved.
:param proto: The protocol that was resolved.
:param flags: The flags that were resolved.
"""

hostname: str
host: str
port: int
family: int
proto: int
flags: int


class AbstractResolver(ABC):
"""Abstract DNS resolver."""

@abstractmethod
async def resolve(self, host: str, port: int, family: int) -> List[Dict[str, Any]]:
async def resolve(
self, host: str, port: int = 0, family: int = socket.AF_INET
) -> List[ResolveResult]:
"""Return IP address for given hostname"""

@abstractmethod
Expand Down
16 changes: 8 additions & 8 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import aiohappyeyeballs

from . import hdrs, helpers
from .abc import AbstractResolver
from .abc import AbstractResolver, ResolveResult
from .client_exceptions import (
ClientConnectionError,
ClientConnectorCertificateError,
Expand Down Expand Up @@ -674,14 +674,14 @@ async def _create_connection(

class _DNSCacheTable:
def __init__(self, ttl: Optional[float] = None) -> None:
self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[Dict[str, Any]], int]] = {}
self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {}
self._timestamps: Dict[Tuple[str, int], float] = {}
self._ttl = ttl

def __contains__(self, host: object) -> bool:
return host in self._addrs_rr

def add(self, key: Tuple[str, int], addrs: List[Dict[str, Any]]) -> None:
def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None:
self._addrs_rr[key] = (cycle(addrs), len(addrs))

if self._ttl is not None:
Expand All @@ -697,7 +697,7 @@ def clear(self) -> None:
self._addrs_rr.clear()
self._timestamps.clear()

def next_addrs(self, key: Tuple[str, int]) -> List[Dict[str, Any]]:
def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]:
loop, length = self._addrs_rr[key]
addrs = list(islice(loop, length))
# Consume one more element to shift internal state of `cycle`
Expand Down Expand Up @@ -813,7 +813,7 @@ def clear_dns_cache(

async def _resolve_host(
self, host: str, port: int, traces: Optional[List["Trace"]] = None
) -> List[Dict[str, Any]]:
) -> List[ResolveResult]:
"""Resolve host and return list of addresses."""
if is_ip_address(host):
return [
Expand Down Expand Up @@ -868,7 +868,7 @@ async def _resolve_host(
return await asyncio.shield(resolved_host_task)
except asyncio.CancelledError:

def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None:
with suppress(Exception, asyncio.CancelledError):
fut.result()

Expand All @@ -881,7 +881,7 @@ async def _resolve_host_with_throttle(
host: str,
port: int,
traces: Optional[List["Trace"]],
) -> List[Dict[str, Any]]:
) -> List[ResolveResult]:
"""Resolve host with a dns events throttle."""
if key in self._throttle_dns_events:
# get event early, before any await (#4014)
Expand Down Expand Up @@ -1129,7 +1129,7 @@ async def _start_tls_connection(
return tls_transport, tls_proto

def _convert_hosts_to_addr_infos(
self, hosts: List[Dict[str, Any]]
self, hosts: List[ResolveResult]
) -> List[aiohappyeyeballs.AddrInfoType]:
"""Converts the list of hosts to a list of addr_infos.
Expand Down
94 changes: 62 additions & 32 deletions aiohttp/resolver.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
import asyncio
import socket
from typing import Any, Dict, List, Type, Union
import sys
from typing import Any, List, Tuple, Type, Union

from .abc import AbstractResolver
from .abc import AbstractResolver, ResolveResult

__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")

try:
import aiodns

# aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname')
# aiodns_default = hasattr(aiodns.DNSResolver, 'getaddrinfo')
except ImportError: # pragma: no cover
aiodns = None


aiodns_default = False

_NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
_SUPPORTS_SCOPE_ID = sys.version_info >= (3, 9, 0)


class ThreadedResolver(AbstractResolver):
"""Threaded resolver.
Expand All @@ -27,45 +32,45 @@ def __init__(self) -> None:
self._loop = asyncio.get_running_loop()

async def resolve(
self, hostname: str, port: int = 0, family: int = socket.AF_INET
) -> List[Dict[str, Any]]:
self, host: str, port: int = 0, family: int = socket.AF_INET
) -> List[ResolveResult]:
infos = await self._loop.getaddrinfo(
hostname,
host,
port,
type=socket.SOCK_STREAM,
family=family,
flags=socket.AI_ADDRCONFIG,
)

hosts = []
hosts: List[ResolveResult] = []
for family, _, proto, _, address in infos:
if family == socket.AF_INET6:
if len(address) < 3:
# IPv6 is not supported by Python build,
# or IPv6 is not enabled in the host
continue
if address[3]:
if address[3] and _SUPPORTS_SCOPE_ID:
# This is essential for link-local IPv6 addresses.
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
# getnameinfo() unconditionally, but performance makes sense.
host, _port = socket.getnameinfo(
address, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
resolved_host, _port = await self._loop.getnameinfo(
address, _NUMERIC_SOCKET_FLAGS
)
port = int(_port)
else:
host, port = address[:2]
resolved_host, port = address[:2]
else: # IPv4
assert family == socket.AF_INET
host, port = address # type: ignore[misc]
resolved_host, port = address # type: ignore[misc]
hosts.append(
{
"hostname": hostname,
"host": host,
"port": port,
"family": family,
"proto": proto,
"flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
}
ResolveResult(
hostname=host,
host=resolved_host,
port=port,
family=family,
proto=proto,
flags=_NUMERIC_SOCKET_FLAGS,
)
)

return hosts
Expand All @@ -86,23 +91,48 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:

async def resolve(
self, host: str, port: int = 0, family: int = socket.AF_INET
) -> List[Dict[str, Any]]:
) -> List[ResolveResult]:
try:
resp = await self._resolver.gethostbyname(host, family)
resp = await self._resolver.getaddrinfo(
host,
port=port,
type=socket.SOCK_STREAM,
family=family,
flags=socket.AI_ADDRCONFIG,
)
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(msg) from exc
hosts = []
for address in resp.addresses:
hosts: List[ResolveResult] = []
for node in resp.nodes:
address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
family = node.family
if family == socket.AF_INET6:
if len(address) > 3 and address[3] and _SUPPORTS_SCOPE_ID:
# This is essential for link-local IPv6 addresses.
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
# getnameinfo() unconditionally, but performance makes sense.
result = await self._resolver.getnameinfo(
(address[0].decode("ascii"), *address[1:]),
_NUMERIC_SOCKET_FLAGS,
)
resolved_host = result.node
else:
resolved_host = address[0].decode("ascii")
port = address[1]
else: # IPv4
assert family == socket.AF_INET
resolved_host = address[0].decode("ascii")
port = address[1]
hosts.append(
{
"hostname": host,
"host": address,
"port": port,
"family": family,
"proto": 0,
"flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
}
ResolveResult(
hostname=host,
host=resolved_host,
port=port,
family=family,
proto=0,
flags=_NUMERIC_SOCKET_FLAGS,
)
)

if not hosts:
Expand Down
54 changes: 54 additions & 0 deletions docs/abc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,57 @@ Abstract Access Logger
:param response: :class:`aiohttp.web.Response` object.

:param float time: Time taken to serve the request.


Abstract Resolver
-------------------------------

.. class:: AbstractResolver

An abstract class, base for all resolver implementations.

Method ``resolve`` should be overridden.

.. method:: resolve(host, port, family)

Resolve host name to IP address.

:param str host: host name to resolve.

:param int port: port number.

:param int family: socket family.

:return: list of :class:`aiohttp.abc.ResolveResult` instances.

.. method:: close()

Release resolver.

.. class:: ResolveResult

Result of host name resolution.

.. attribute:: hostname

The host name that was provided.

.. attribute:: host

The IP address that was resolved.

.. attribute:: port

The port that was resolved.

.. attribute:: family

The address family that was resolved.

.. attribute:: proto

The protocol that was resolved.

.. attribute:: flags

The flags that were resolved.
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@
("py:class", "aiohttp.protocol.HttpVersion"), # undocumented
("py:class", "aiohttp.ClientRequest"), # undocumented
("py:class", "aiohttp.payload.Payload"), # undocumented
("py:class", "aiohttp.abc.AbstractResolver"), # undocumented
("py:class", "aiohttp.resolver.AsyncResolver"), # undocumented
("py:class", "aiohttp.resolver.ThreadedResolver"), # undocumented
("py:func", "aiohttp.ws_connect"), # undocumented
("py:meth", "start"), # undocumented
("py:exc", "aiohttp.ClientHttpProxyError"), # undocumented
Expand Down
6 changes: 3 additions & 3 deletions examples/fake_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import pathlib
import socket
import ssl
from typing import Any, Dict, List, Union
from typing import Dict, List, Union

from aiohttp import ClientSession, TCPConnector, resolver, test_utils, web
from aiohttp.abc import AbstractResolver
from aiohttp.abc import AbstractResolver, ResolveResult


class FakeResolver(AbstractResolver):
Expand All @@ -22,7 +22,7 @@ async def resolve(
host: str,
port: int = 0,
family: Union[socket.AddressFamily, int] = socket.AF_INET,
) -> List[Dict[str, Any]]:
) -> List[ResolveResult]:
fake_port = self._fakes.get(host)
if fake_port is not None:
return [
Expand Down
2 changes: 1 addition & 1 deletion requirements/runtime-deps.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Extracted from `setup.cfg` via `make sync-direct-runtime-deps`

aiodns >= 1.1; sys_platform=="linux" or sys_platform=="darwin"
aiodns >= 3.2.0; sys_platform=="linux" or sys_platform=="darwin"
aiohappyeyeballs >= 2.3.0
aiosignal >= 1.1.2
async-timeout >= 4.0, < 5.0 ; python_version < "3.11"
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ install_requires =
[options.extras_require]
speedups =
# required c-ares (aiodns' backend) will not build on windows
aiodns >= 1.1; sys_platform=="linux" or sys_platform=="darwin"
aiodns >= 3.2.0; sys_platform=="linux" or sys_platform=="darwin"
Brotli; platform_python_implementation == 'CPython'
brotlicffi; platform_python_implementation != 'CPython'

Expand Down
Loading

0 comments on commit 012f986

Please sign in to comment.