Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AsyncResolver to match ThreadedResolver behavior #8270

Merged
merged 46 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
a6b09c3
Fix AsyncResolver to match ThreadedResolver behavior
bdraco Mar 30, 2024
d512342
fixes
bdraco Mar 30, 2024
13ff9cf
fix tests
bdraco Mar 30, 2024
28d3651
fix tests
bdraco Mar 30, 2024
b17c087
use async api
bdraco Mar 30, 2024
67e99e7
avoid looped runtime construction of enums
bdraco Mar 30, 2024
763f26e
typing
bdraco Mar 30, 2024
1c22202
match actual signature
bdraco Mar 30, 2024
8fbc80d
match actual signature
bdraco Mar 30, 2024
692b599
match actual signature
bdraco Mar 30, 2024
88bd82c
match actual signature
bdraco Mar 30, 2024
8aec247
match actual signature
bdraco Mar 30, 2024
f90cbab
typing
bdraco Mar 30, 2024
19d9019
link local
bdraco Mar 30, 2024
23824d9
link local threaded
bdraco Mar 30, 2024
7c4d0b1
preen
bdraco Mar 30, 2024
2d2d33d
Merge remote-tracking branch 'upstream/master' into aiodns_fixes
bdraco Mar 30, 2024
c05fea7
typing
bdraco Mar 30, 2024
4c858a6
typing
bdraco Mar 30, 2024
71d96b3
>=3.9.0 required for scope_id
bdraco Mar 31, 2024
ea7cb59
remove unreachable code
bdraco Mar 31, 2024
68cddb1
remove unreachable code
bdraco Mar 31, 2024
97a3ea9
remove unreachable code
bdraco Mar 31, 2024
10ee0e1
Bump aiodns to 3.2.0+
bdraco Mar 31, 2024
c8527a6
changes
bdraco Mar 31, 2024
df90484
changes
bdraco Mar 31, 2024
420e170
changes
bdraco Mar 31, 2024
8ed7095
Update CHANGES/8270.bugfix.rst
bdraco Mar 31, 2024
d2db4b0
adjust
bdraco Mar 31, 2024
a909646
Merge remote-tracking branch 'upstream/aiodns_fixes' into aiodns_fixes
bdraco Mar 31, 2024
c61211c
fix typo
bdraco Mar 31, 2024
dd7bd22
missed some
bdraco Mar 31, 2024
4f49265
tweak changes
bdraco Mar 31, 2024
3d5175f
resolvers are currently not documented
bdraco Mar 31, 2024
aa1f2e5
Update docs/conf.py
bdraco Mar 31, 2024
4347ba2
fixes from manual testing - fix tuple construction
bdraco Mar 31, 2024
e2974d5
Merge remote-tracking branch 'upstream/aiodns_fixes' into aiodns_fixes
bdraco Mar 31, 2024
9ca6a96
add Abstract Resolver
bdraco Mar 31, 2024
9ea1eb4
no autoclass
bdraco Mar 31, 2024
f9120e7
no autoclass
bdraco Mar 31, 2024
a1e592d
missing .
bdraco Mar 31, 2024
4606536
spelling
bdraco Mar 31, 2024
ce0f58a
Merge branch 'master' into aiodns_fixes
bdraco Mar 31, 2024
0516042
Update conf.py
bdraco Mar 31, 2024
009e3c6
Merge branch 'master' into aiodns_fixes
bdraco Apr 2, 2024
c3e3f40
Merge branch 'master' into aiodns_fixes
bdraco Apr 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion aiohttp/abc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import socket
from abc import ABC, abstractmethod
from collections.abc import Sized
from http.cookies import BaseCookie, Morsel
Expand All @@ -13,6 +14,7 @@
List,
Optional,
Tuple,
TypedDict,
)

from multidict import CIMultiDict
Expand Down Expand Up @@ -117,11 +119,23 @@ def __await__(self) -> Generator[Any, None, StreamResponse]:
"""Execute the view handler."""


class ResolveResult(TypedDict):
bdraco marked this conversation as resolved.
Show resolved Hide resolved

hostname: str
host: str
port: int
family: int
proto: int
flags: int


class AbstractResolver(ABC):
"""Abstract DNS resolver."""

@abstractmethod
async def resolve(self, host: str, port: int, family: int) -> List[Dict[str, Any]]:
async def resolve(
self, host: str, port: int = 0, family: int = socket.AF_INET
) -> List[ResolveResult]:
"""Return IP address for given hostname"""

@abstractmethod
Expand Down
16 changes: 8 additions & 8 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import aiohappyeyeballs

from . import hdrs, helpers
from .abc import AbstractResolver
from .abc import AbstractResolver, ResolveResult
from .client_exceptions import (
ClientConnectionError,
ClientConnectorCertificateError,
Expand Down Expand Up @@ -674,14 +674,14 @@ async def _create_connection(

class _DNSCacheTable:
def __init__(self, ttl: Optional[float] = None) -> None:
self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[Dict[str, Any]], int]] = {}
self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[ResolveResult], int]] = {}
self._timestamps: Dict[Tuple[str, int], float] = {}
self._ttl = ttl

def __contains__(self, host: object) -> bool:
return host in self._addrs_rr

def add(self, key: Tuple[str, int], addrs: List[Dict[str, Any]]) -> None:
def add(self, key: Tuple[str, int], addrs: List[ResolveResult]) -> None:
self._addrs_rr[key] = (cycle(addrs), len(addrs))

if self._ttl is not None:
Expand All @@ -697,7 +697,7 @@ def clear(self) -> None:
self._addrs_rr.clear()
self._timestamps.clear()

def next_addrs(self, key: Tuple[str, int]) -> List[Dict[str, Any]]:
def next_addrs(self, key: Tuple[str, int]) -> List[ResolveResult]:
loop, length = self._addrs_rr[key]
addrs = list(islice(loop, length))
# Consume one more element to shift internal state of `cycle`
Expand Down Expand Up @@ -813,7 +813,7 @@ def clear_dns_cache(

async def _resolve_host(
self, host: str, port: int, traces: Optional[List["Trace"]] = None
) -> List[Dict[str, Any]]:
) -> List[ResolveResult]:
"""Resolve host and return list of addresses."""
if is_ip_address(host):
return [
Expand Down Expand Up @@ -868,7 +868,7 @@ async def _resolve_host(
return await asyncio.shield(resolved_host_task)
except asyncio.CancelledError:

def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
def drop_exception(fut: "asyncio.Future[List[ResolveResult]]") -> None:
with suppress(Exception, asyncio.CancelledError):
fut.result()

Expand All @@ -881,7 +881,7 @@ async def _resolve_host_with_throttle(
host: str,
port: int,
traces: Optional[List["Trace"]],
) -> List[Dict[str, Any]]:
) -> List[ResolveResult]:
"""Resolve host with a dns events throttle."""
if key in self._throttle_dns_events:
# get event early, before any await (#4014)
Expand Down Expand Up @@ -1129,7 +1129,7 @@ async def _start_tls_connection(
return tls_transport, tls_proto

def _convert_hosts_to_addr_infos(
self, hosts: List[Dict[str, Any]]
self, hosts: List[ResolveResult]
) -> List[aiohappyeyeballs.AddrInfoType]:
"""Converts the list of hosts to a list of addr_infos.

Expand Down
91 changes: 61 additions & 30 deletions aiohttp/resolver.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import socket
from typing import Any, Dict, List, Type, Union
from typing import Any, List, Tuple, Type, Union

from .abc import AbstractResolver
from .abc import AbstractResolver, ResolveResult

__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")

Expand All @@ -13,8 +13,11 @@
except ImportError: # pragma: no cover
aiodns = None


aiodns_default = False

_NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV


class ThreadedResolver(AbstractResolver):
"""Threaded resolver.
Expand All @@ -27,17 +30,17 @@ def __init__(self) -> None:
self._loop = asyncio.get_running_loop()

async def resolve(
self, hostname: str, port: int = 0, family: int = socket.AF_INET
) -> List[Dict[str, Any]]:
self, host: str, port: int = 0, family: int = socket.AF_INET
) -> List[ResolveResult]:
infos = await self._loop.getaddrinfo(
hostname,
host,
port,
type=socket.SOCK_STREAM,
family=family,
flags=socket.AI_ADDRCONFIG,
)

hosts = []
hosts: List[ResolveResult] = []
for family, _, proto, _, address in infos:
if family == socket.AF_INET6:
if len(address) < 3:
Expand All @@ -48,24 +51,24 @@ async def resolve(
# This is essential for link-local IPv6 addresses.
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
# getnameinfo() unconditionally, but performance makes sense.
host, _port = socket.getnameinfo(
address, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
resolved_host, _port = await self._loop.getnameinfo(
bdraco marked this conversation as resolved.
Show resolved Hide resolved
address, _NUMERIC_SOCKET_FLAGS
)
port = int(_port)
else:
host, port = address[:2]
resolved_host, port = address[:2]
else: # IPv4
assert family == socket.AF_INET
host, port = address # type: ignore[misc]
resolved_host, port = address # type: ignore[misc]
hosts.append(
{
"hostname": hostname,
"host": host,
"port": port,
"family": family,
"proto": proto,
"flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
}
ResolveResult(
hostname=host,
host=resolved_host,
port=port,
family=family,
proto=proto,
flags=_NUMERIC_SOCKET_FLAGS,
)
)

return hosts
Expand All @@ -86,23 +89,51 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:

async def resolve(
self, host: str, port: int = 0, family: int = socket.AF_INET
) -> List[Dict[str, Any]]:
) -> List[ResolveResult]:
try:
resp = await self._resolver.gethostbyname(host, family)
resp = await self._resolver.getaddrinfo(
host,
port=port,
type=socket.SOCK_STREAM,
family=family,
flags=socket.AI_ADDRCONFIG,
)
except aiodns.error.DNSError as exc:
msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
raise OSError(msg) from exc
hosts = []
for address in resp.addresses:
hosts: List[ResolveResult] = []
for node in resp.nodes:
address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
family = node.family
if family == socket.AF_INET6:
if len(address) < 3:
# IPv6 is not supported by Python build,
# or IPv6 is not enabled in the host
continue
if address[3]:
# This is essential for link-local IPv6 addresses.
# LL IPv6 is a VERY rare case. Strictly speaking, we should use
# getnameinfo() unconditionally, but performance makes sense.
result = await self._resolver.getnameinfo(
address[0].decode("ascii"), *address[1:], _NUMERIC_SOCKET_FLAGS
)
resolved_host = result.node
else:
resolved_host = address[0].decode("ascii")
port = address[1]
else: # IPv4
assert family == socket.AF_INET
resolved_host = address[0].decode("ascii")
port = address[1]
hosts.append(
{
"hostname": host,
"host": address,
"port": port,
"family": family,
"proto": 0,
"flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV,
}
ResolveResult(
hostname=host,
host=resolved_host,
port=port,
family=family,
proto=0,
flags=_NUMERIC_SOCKET_FLAGS,
)
)

if not hosts:
Expand Down
6 changes: 3 additions & 3 deletions examples/fake_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import pathlib
import socket
import ssl
from typing import Any, Dict, List, Union
from typing import Dict, List, Union

from aiohttp import ClientSession, TCPConnector, resolver, test_utils, web
from aiohttp.abc import AbstractResolver
from aiohttp.abc import AbstractResolver, ResolveResult


class FakeResolver(AbstractResolver):
Expand All @@ -22,7 +22,7 @@ async def resolve(
host: str,
port: int = 0,
family: Union[socket.AddressFamily, int] = socket.AF_INET,
) -> List[Dict[str, Any]]:
) -> List[ResolveResult]:
fake_port = self._fakes.get(host)
if fake_port is not None:
return [
Expand Down
Loading
Loading