Skip to content

Commit

Permalink
Uniform TLS verify argument support. (#1027)
Browse files Browse the repository at this point in the history
* Uniform TLS verify argument support.

* async TLS should get verify too
  • Loading branch information
rthalley authored Dec 28, 2023
1 parent 1e7389f commit 609d6b2
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 13 deletions.
10 changes: 4 additions & 6 deletions dns/asyncquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
UDPMode,
_compute_times,
_have_http2,
_make_dot_ssl_context,
_matches_destination,
_remaining,
have_doh,
Expand Down Expand Up @@ -297,7 +298,7 @@ async def send_tcp(
# copying the wire into tcpmsg is inefficient, but lets us
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = len(what).to_bytes(2, 'big') + what
tcpmsg = len(what).to_bytes(2, "big") + what
sent_time = time.time()
await sock.sendall(tcpmsg, _timeout(expiration, sent_time))
return (len(tcpmsg), sent_time)
Expand Down Expand Up @@ -416,6 +417,7 @@ async def tls(
backend: Optional[dns.asyncbackend.Backend] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None,
verify: Union[bool, str] = True,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS.
Expand All @@ -437,11 +439,7 @@ async def tls(
cm: contextlib.AbstractAsyncContextManager = NullContext(sock)
else:
if ssl_context is None:
# See the comment about ssl.create_default_context() in query.py
ssl_context = ssl.create_default_context() # lgtm[py/insecure-protocol]
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None:
ssl_context.check_hostname = False
ssl_context = _make_dot_ssl_context(server_hostname, verify)
af = dns.inet.af_for_address(where)
stuple = _source_tuple(af, source, source_port)
dtuple = (where, port)
Expand Down
21 changes: 19 additions & 2 deletions dns/nameserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,16 @@ async def async_query(


class DoHNameserver(Nameserver):
def __init__(self, url: str, bootstrap_address: Optional[str] = None):
def __init__(
self,
url: str,
bootstrap_address: Optional[str] = None,
verify: Union[bool, str] = True,
):
super().__init__()
self.url = url
self.bootstrap_address = bootstrap_address
self.verify = verify

def kind(self):
return "DoH"
Expand Down Expand Up @@ -198,6 +204,7 @@ def query(
bootstrap_address=self.bootstrap_address,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
)

async def async_query(
Expand All @@ -218,13 +225,21 @@ async def async_query(
bootstrap_address=self.bootstrap_address,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
verify=self.verify,
)


class DoTNameserver(AddressAndPortNameserver):
def __init__(self, address: str, port: int = 853, hostname: Optional[str] = None):
def __init__(
self,
address: str,
port: int = 853,
hostname: Optional[str] = None,
verify: Union[bool, str] = True,
):
super().__init__(address, port)
self.hostname = hostname
self.verify = verify

def kind(self):
return "DoT"
Expand All @@ -247,6 +262,7 @@ def query(
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
server_hostname=self.hostname,
verify=self.verify,
)

async def async_query(
Expand All @@ -268,6 +284,7 @@ async def async_query(
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
server_hostname=self.hostname,
verify=self.verify,
)


Expand Down
37 changes: 32 additions & 5 deletions dns/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import enum
import errno
import os
import os.path
import selectors
import socket
import struct
Expand Down Expand Up @@ -161,6 +162,8 @@ def connect_tcp(self, host, port, timeout, local_address):
except ImportError: # pragma: no cover

class ssl: # type: ignore
CERT_NONE = 0

class WantReadException(Exception):
pass

Expand Down Expand Up @@ -1012,6 +1015,28 @@ def _tls_handshake(s, expiration):
_wait_for_writable(s, expiration)


def _make_dot_ssl_context(
server_hostname: Optional[str], verify: Union[bool, str]
) -> ssl.SSLContext:
cafile: Optional[str] = None
capath: Optional[str] = None
if isinstance(verify, str):
if os.path.isfile(verify):
cafile = verify
elif os.path.isdir(verify):
capath = verify
else:
raise ValueError("invalid verify string")
ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None:
ssl_context.check_hostname = False
ssl_context.set_alpn_protocols(["dot"])
if verify is False:
ssl_context.verify_mode = ssl.CERT_NONE
return ssl_context


def tls(
q: dns.message.Message,
where: str,
Expand All @@ -1024,6 +1049,7 @@ def tls(
sock: Optional[ssl.SSLSocket] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None,
verify: Union[bool, str] = True,
) -> dns.message.Message:
"""Return the response obtained after sending a query via TLS.
Expand Down Expand Up @@ -1063,6 +1089,11 @@ def tls(
default is ``None``, which means that no hostname is known, and if an
SSL context is created, hostname checking will be disabled.
*verify*, a ``bool`` or ``str``. If a ``True``, then TLS certificate verification
of the server is done using the default CA bundle; if ``False``, then no
verification is done; if a `str` then it specifies the path to a certificate file or
directory which will be used for verification.
Returns a ``dns.message.Message``.
"""
Expand All @@ -1089,11 +1120,7 @@ def tls(
where, port, source, source_port
)
if ssl_context is None and not sock:
ssl_context = ssl.create_default_context()
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
if server_hostname is None:
ssl_context.check_hostname = False
ssl_context.set_alpn_protocols(["dot"])
ssl_context = _make_dot_ssl_context(server_hostname, verify)

with _make_socket(
af,
Expand Down

0 comments on commit 609d6b2

Please sign in to comment.