Skip to content

Commit

Permalink
feat: speed up parsing incoming records (#1458)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Jan 6, 2025
1 parent 9f6af54 commit 783c1b3
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 26 deletions.
20 changes: 18 additions & 2 deletions src/zeroconf/_dns.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,25 @@ cdef class DNSEntry:
cdef public cython.uint class_
cdef public bint unique

cdef _set_class(self, cython.uint class_)
cdef _fast_init_entry(self, str name, cython.uint type_, cython.uint class_)

cdef bint _dns_entry_matches(self, DNSEntry other)

cdef class DNSQuestion(DNSEntry):

cdef public cython.int _hash

cdef _fast_init(self, str name, cython.uint type_, cython.uint class_)

cpdef bint answered_by(self, DNSRecord rec)

cdef class DNSRecord(DNSEntry):

cdef public cython.float ttl
cdef public double created

cdef _fast_init_record(self, str name, cython.uint type_, cython.uint class_, cython.float ttl, double created)

cdef bint _suppressed_by_answer(self, DNSRecord answer)

@cython.locals(
Expand All @@ -69,9 +73,11 @@ cdef class DNSRecord(DNSEntry):
cdef class DNSAddress(DNSRecord):

cdef public cython.int _hash
cdef public object address
cdef public bytes address
cdef public object scope_id

cdef _fast_init(self, str name, cython.uint type_, cython.uint class_, cython.float ttl, bytes address, object scope_id, double created)

cdef bint _eq(self, DNSAddress other)

cpdef write(self, DNSOutgoing out)
Expand All @@ -83,6 +89,8 @@ cdef class DNSHinfo(DNSRecord):
cdef public str cpu
cdef public str os

cdef _fast_init(self, str name, cython.uint type_, cython.uint class_, cython.float ttl, str cpu, str os, double created)

cdef bint _eq(self, DNSHinfo other)

cpdef write(self, DNSOutgoing out)
Expand All @@ -93,6 +101,8 @@ cdef class DNSPointer(DNSRecord):
cdef public str alias
cdef public str alias_key

cdef _fast_init(self, str name, cython.uint type_, cython.uint class_, cython.float ttl, str alias, double created)

cdef bint _eq(self, DNSPointer other)

cpdef write(self, DNSOutgoing out)
Expand All @@ -102,6 +112,8 @@ cdef class DNSText(DNSRecord):
cdef public cython.int _hash
cdef public bytes text

cdef _fast_init(self, str name, cython.uint type_, cython.uint class_, cython.float ttl, bytes text, double created)

cdef bint _eq(self, DNSText other)

cpdef write(self, DNSOutgoing out)
Expand All @@ -115,6 +127,8 @@ cdef class DNSService(DNSRecord):
cdef public str server
cdef public str server_key

cdef _fast_init(self, str name, cython.uint type_, cython.uint class_, cython.float ttl, cython.uint priority, cython.uint weight, cython.uint port, str server, double created)

cdef bint _eq(self, DNSService other)

cpdef write(self, DNSOutgoing out)
Expand All @@ -125,6 +139,8 @@ cdef class DNSNsec(DNSRecord):
cdef public object next_name
cdef public cython.list rdtypes

cdef _fast_init(self, str name, cython.uint type_, cython.uint class_, cython.float ttl, str next_name, cython.list rdtypes, double created)

cdef bint _eq(self, DNSNsec other)

cpdef write(self, DNSOutgoing out)
Expand Down
92 changes: 79 additions & 13 deletions src/zeroconf/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ class DNSEntry:
__slots__ = ("class_", "key", "name", "type", "unique")

def __init__(self, name: str, type_: int, class_: int) -> None:
self._fast_init_entry(name, type_, class_)

def _fast_init_entry(self, name: str, type_: _int, class_: _int) -> None:
"""Fast init for reuse."""
self.name = name
self.key = name.lower()
self.type = type_
self._set_class(class_)

def _set_class(self, class_: _int) -> None:
self.class_ = class_ & _CLASS_MASK
self.unique = (class_ & _CLASS_UNIQUE) != 0

Expand Down Expand Up @@ -111,7 +112,11 @@ class DNSQuestion(DNSEntry):
__slots__ = ("_hash",)

def __init__(self, name: str, type_: int, class_: int) -> None:
super().__init__(name, type_, class_)
self._fast_init(name, type_, class_)

def _fast_init(self, name: str, type_: _int, class_: _int) -> None:
"""Fast init for reuse."""
self._fast_init_entry(name, type_, class_)
self._hash = hash((self.key, type_, self.class_))

def answered_by(self, rec: "DNSRecord") -> bool:
Expand Down Expand Up @@ -168,9 +173,13 @@ def __init__(
ttl: Union[float, int],
created: Optional[float] = None,
) -> None:
super().__init__(name, type_, class_)
self._fast_init_record(name, type_, class_, ttl, created or current_time_millis())

def _fast_init_record(self, name: str, type_: _int, class_: _int, ttl: _float, created: _float) -> None:
"""Fast init for reuse."""
self._fast_init_entry(name, type_, class_)
self.ttl = ttl
self.created = created or current_time_millis()
self.created = created

def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use
"""Abstract method"""
Expand Down Expand Up @@ -248,7 +257,20 @@ def __init__(
scope_id: Optional[int] = None,
created: Optional[float] = None,
) -> None:
super().__init__(name, type_, class_, ttl, created)
self._fast_init(name, type_, class_, ttl, address, scope_id, created or current_time_millis())

def _fast_init(
self,
name: str,
type_: _int,
class_: _int,
ttl: _float,
address: bytes,
scope_id: Optional[_int],
created: _float,
) -> None:
"""Fast init for reuse."""
self._fast_init_record(name, type_, class_, ttl, created)
self.address = address
self.scope_id = scope_id
self._hash = hash((self.key, type_, self.class_, address, scope_id))
Expand Down Expand Up @@ -300,7 +322,13 @@ def __init__(
os: str,
created: Optional[float] = None,
) -> None:
super().__init__(name, type_, class_, ttl, created)
self._fast_init(name, type_, class_, ttl, cpu, os, created or current_time_millis())

def _fast_init(
self, name: str, type_: _int, class_: _int, ttl: _float, cpu: str, os: str, created: _float
) -> None:
"""Fast init for reuse."""
self._fast_init_record(name, type_, class_, ttl, created)
self.cpu = cpu
self.os = os
self._hash = hash((self.key, type_, self.class_, cpu, os))
Expand Down Expand Up @@ -341,7 +369,12 @@ def __init__(
alias: str,
created: Optional[float] = None,
) -> None:
super().__init__(name, type_, class_, ttl, created)
self._fast_init(name, type_, class_, ttl, alias, created or current_time_millis())

def _fast_init(
self, name: str, type_: _int, class_: _int, ttl: _float, alias: str, created: _float
) -> None:
self._fast_init_record(name, type_, class_, ttl, created)
self.alias = alias
self.alias_key = alias.lower()
self._hash = hash((self.key, type_, self.class_, self.alias_key))
Expand Down Expand Up @@ -391,7 +424,12 @@ def __init__(
text: bytes,
created: Optional[float] = None,
) -> None:
super().__init__(name, type_, class_, ttl, created)
self._fast_init(name, type_, class_, ttl, text, created or current_time_millis())

def _fast_init(
self, name: str, type_: _int, class_: _int, ttl: _float, text: bytes, created: _float
) -> None:
self._fast_init_record(name, type_, class_, ttl, created)
self.text = text
self._hash = hash((self.key, type_, self.class_, text))

Expand Down Expand Up @@ -435,7 +473,23 @@ def __init__(
server: str,
created: Optional[float] = None,
) -> None:
super().__init__(name, type_, class_, ttl, created)
self._fast_init(
name, type_, class_, ttl, priority, weight, port, server, created or current_time_millis()
)

def _fast_init(
self,
name: str,
type_: _int,
class_: _int,
ttl: _float,
priority: _int,
weight: _int,
port: _int,
server: str,
created: _float,
) -> None:
self._fast_init_record(name, type_, class_, ttl, created)
self.priority = priority
self.weight = weight
self.port = port
Expand Down Expand Up @@ -483,12 +537,24 @@ def __init__(
name: str,
type_: int,
class_: int,
ttl: int,
ttl: Union[int, float],
next_name: str,
rdtypes: List[int],
created: Optional[float] = None,
) -> None:
super().__init__(name, type_, class_, ttl, created)
self._fast_init(name, type_, class_, ttl, next_name, rdtypes, created or current_time_millis())

def _fast_init(
self,
name: str,
type_: _int,
class_: _int,
ttl: _float,
next_name: str,
rdtypes: List[_int],
created: _float,
) -> None:
self._fast_init_record(name, type_, class_, ttl, created)
self.next_name = next_name
self.rdtypes = sorted(rdtypes)
self._hash = hash((self.key, type_, self.class_, next_name, *self.rdtypes))
Expand Down
12 changes: 9 additions & 3 deletions src/zeroconf/_protocol/incoming.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ cdef class DNSIncoming:
)
cdef void _read_others(self)

@cython.locals(offset="unsigned int")
@cython.locals(offset="unsigned int", question=DNSQuestion)
cdef _read_questions(self)

@cython.locals(
Expand All @@ -109,9 +109,15 @@ cdef class DNSIncoming:

@cython.locals(
name_start="unsigned int",
offset="unsigned int"
offset="unsigned int",
address_rec=DNSAddress,
pointer_rec=DNSPointer,
text_rec=DNSText,
srv_rec=DNSService,
hinfo_rec=DNSHinfo,
nsec_rec=DNSNsec,
)
cdef _read_record(self, object domain, unsigned int type_, unsigned int class_, unsigned int ttl, unsigned int length)
cdef _read_record(self, str domain, unsigned int type_, unsigned int class_, unsigned int ttl, unsigned int length)

@cython.locals(
offset="unsigned int",
Expand Down
31 changes: 23 additions & 8 deletions src/zeroconf/_protocol/incoming.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def _read_questions(self) -> None:
# The question has 2 unsigned shorts in network order
type_ = view[offset] << 8 | view[offset + 1]
class_ = view[offset + 2] << 8 | view[offset + 3]
question = DNSQuestion(name, type_, class_)
question = DNSQuestion.__new__(DNSQuestion)
question._fast_init(name, type_, class_)
if question.unique: # QU questions use the same bit as unique
self._has_qu_question = True
questions.append(question)
Expand Down Expand Up @@ -306,11 +307,17 @@ def _read_record(
) -> Optional[DNSRecord]:
"""Read known records types and skip unknown ones."""
if type_ == _TYPE_A:
return DNSAddress(domain, type_, class_, ttl, self._read_string(4), None, self.now)
address_rec = DNSAddress.__new__(DNSAddress)
address_rec._fast_init(domain, type_, class_, ttl, self._read_string(4), None, self.now)
return address_rec
if type_ in (_TYPE_CNAME, _TYPE_PTR):
return DNSPointer(domain, type_, class_, ttl, self._read_name(), self.now)
pointer_rec = DNSPointer.__new__(DNSPointer)
pointer_rec._fast_init(domain, type_, class_, ttl, self._read_name(), self.now)
return pointer_rec
if type_ == _TYPE_TXT:
return DNSText(domain, type_, class_, ttl, self._read_string(length), self.now)
text_rec = DNSText.__new__(DNSText)
text_rec._fast_init(domain, type_, class_, ttl, self._read_string(length), self.now)
return text_rec
if type_ == _TYPE_SRV:
view = self.view
offset = self.offset
Expand All @@ -319,7 +326,8 @@ def _read_record(
priority = view[offset] << 8 | view[offset + 1]
weight = view[offset + 2] << 8 | view[offset + 3]
port = view[offset + 4] << 8 | view[offset + 5]
return DNSService(
srv_rec = DNSService.__new__(DNSService)
srv_rec._fast_init(
domain,
type_,
class_,
Expand All @@ -330,8 +338,10 @@ def _read_record(
self._read_name(),
self.now,
)
return srv_rec
if type_ == _TYPE_HINFO:
return DNSHinfo(
hinfo_rec = DNSHinfo.__new__(DNSHinfo)
hinfo_rec._fast_init(
domain,
type_,
class_,
Expand All @@ -340,8 +350,10 @@ def _read_record(
self._read_character_string(),
self.now,
)
return hinfo_rec
if type_ == _TYPE_AAAA:
return DNSAddress(
address_rec = DNSAddress.__new__(DNSAddress)
address_rec._fast_init(
domain,
type_,
class_,
Expand All @@ -350,9 +362,11 @@ def _read_record(
self.scope_id,
self.now,
)
return address_rec
if type_ == _TYPE_NSEC:
name_start = self.offset
return DNSNsec(
nsec_rec = DNSNsec.__new__(DNSNsec)
nsec_rec._fast_init(
domain,
type_,
class_,
Expand All @@ -361,6 +375,7 @@ def _read_record(
self._read_bitmap(name_start + length),
self.now,
)
return nsec_rec
# Try to ignore types we don't know about
# Skip the payload for the resource record so the next
# records can be parsed correctly
Expand Down

0 comments on commit 783c1b3

Please sign in to comment.