Skip to content

Commit

Permalink
Enable Rust Extension for Faster PackStream (#979)
Browse files Browse the repository at this point in the history
* Making the driver pick up optional rust extension

* Enable rust extension for packing as well

* Fix not using rust packer

* Minor clean-ups in TestKit glue

* TestKit backend: make error classification more robust

* Optimization: check only once for rust availability
  • Loading branch information
robsdedude authored Nov 3, 2023
1 parent d7dec04 commit 9a2a20a
Show file tree
Hide file tree
Showing 16 changed files with 223 additions and 104 deletions.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,9 @@ asyncio_mode = "auto"
[tool.mypy]

[[tool.mypy.overrides]]
module = "pandas.*"
module = [
"pandas.*",
"neo4j._codec.packstream._rust",
"neo4j._codec.packstream._rust.*",
]
ignore_missing_imports = true
30 changes: 4 additions & 26 deletions src/neo4j/_codec/packstream/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,7 @@
# limitations under the License.


class Structure:

def __init__(self, tag, *fields):
self.tag = tag
self.fields = list(fields)

def __repr__(self):
return "Structure[0x%02X](%s)" % (ord(self.tag), ", ".join(map(repr, self.fields)))

def __eq__(self, other):
try:
return self.tag == other.tag and self.fields == other.fields
except AttributeError:
return False

def __ne__(self, other):
return not self.__eq__(other)

def __len__(self):
return len(self.fields)

def __getitem__(self, key):
return self.fields[key]

def __setitem__(self, key, value):
self.fields[key] = value
try:
from ._rust import Structure
except ImportError:
from ._python import Structure
24 changes: 24 additions & 0 deletions src/neo4j/_codec/packstream/_python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# This file is part of Neo4j.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from ._common import Structure


__all__ = [
"Structure",
]
46 changes: 46 additions & 0 deletions src/neo4j/_codec/packstream/_python/_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# This file is part of Neo4j.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class Structure:

def __init__(self, tag, *fields):
self.tag = tag
self.fields = list(fields)

def __repr__(self):
return "Structure[0x%02X](%s)" % (
ord(self.tag), ", ".join(map(repr, self.fields))
)

def __eq__(self, other):
try:
return self.tag == other.tag and self.fields == other.fields
except AttributeError:
return False

def __ne__(self, other):
return not self.__eq__(other)

def __len__(self):
return len(self.fields)

def __getitem__(self, key):
return self.fields[key]

def __setitem__(self, key, value):
self.fields[key] = value
118 changes: 56 additions & 62 deletions src/neo4j/_codec/packstream/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,26 @@
# limitations under the License.


import typing as t
from codecs import decode
from contextlib import contextmanager
from struct import (
pack as struct_pack,
unpack as struct_unpack,
)

from ...._optional_deps import (
np,
pd,
)
from ...hydration import DehydrationHooks
from .._common import Structure
from .types import *


NONE_VALUES: t.Tuple = (None,)
TRUE_VALUES: t.Tuple = (True,)
FALSE_VALUES: t.Tuple = (False,)
INT_TYPES: t.Tuple[t.Type, ...] = (int,)
FLOAT_TYPES: t.Tuple[t.Type, ...] = (float,)
# we can't put tuple here because spatial types subclass tuple,
# and we don't want to treat them as sequences
SEQUENCE_TYPES: t.Tuple[t.Type, ...] = (list,)
MAPPING_TYPES: t.Tuple[t.Type, ...] = (dict,)
BYTES_TYPES: t.Tuple[t.Type, ...] = (bytes, bytearray)


if np is not None:
TRUE_VALUES = (*TRUE_VALUES, np.bool_(True))
FALSE_VALUES = (*FALSE_VALUES, np.bool_(False))
INT_TYPES = (*INT_TYPES, np.integer)
FLOAT_TYPES = (*FLOAT_TYPES, np.floating)
SEQUENCE_TYPES = (*SEQUENCE_TYPES, np.ndarray)

if pd is not None:
NONE_VALUES = (*NONE_VALUES, pd.NA)
SEQUENCE_TYPES = (*SEQUENCE_TYPES, pd.Series, pd.Categorical,
pd.core.arrays.ExtensionArray)
MAPPING_TYPES = (*MAPPING_TYPES, pd.DataFrame)
try:
from .._rust.v1 import (
pack as _rust_pack,
unpack as _rust_unpack,
)
except ImportError:
_rust_pack = None
_rust_unpack = None


PACKED_UINT_8 = [struct_pack(">B", value) for value in range(0x100)]
Expand All @@ -74,12 +54,17 @@ def __init__(self, stream):
self.stream = stream
self._write = self.stream.write

def _pack_raw(self, data):
self._write(data)

def pack(self, data, dehydration_hooks=None):
self._pack(data,
dehydration_hooks=self._inject_hooks(dehydration_hooks))
dehydration_hooks = self._inject_hooks(dehydration_hooks)
self._pack(data, dehydration_hooks=dehydration_hooks)

if _rust_pack:
def _pack(self, data, dehydration_hooks=None):
data = _rust_pack(data, dehydration_hooks)
self._write(data)
else:
def _pack(self, data, dehydration_hooks=None):
self._py_pack(data, dehydration_hooks)

@classmethod
def _inject_hooks(cls, dehydration_hooks=None):
Expand All @@ -93,8 +78,7 @@ def _inject_hooks(cls, dehydration_hooks=None):
subtypes={}
)


def _pack(self, value, dehydration_hooks=None):
def _py_pack(self, value, dehydration_hooks=None):
write = self._write

# None
Expand Down Expand Up @@ -136,18 +120,18 @@ def _pack(self, value, dehydration_hooks=None):
elif isinstance(value, str):
encoded = value.encode("utf-8")
self._pack_string_header(len(encoded))
self._pack_raw(encoded)
self._write(encoded)

# Bytes
elif isinstance(value, BYTES_TYPES):
self._pack_bytes_header(len(value))
self._pack_raw(value)
self._write(value)

# List
elif isinstance(value, SEQUENCE_TYPES):
self._pack_list_header(len(value))
for item in value:
self._pack(item, dehydration_hooks)
self._py_pack(item, dehydration_hooks)

# Map
elif isinstance(value, MAPPING_TYPES):
Expand All @@ -157,8 +141,8 @@ def _pack(self, value, dehydration_hooks=None):
raise TypeError(
"Map keys must be strings, not {}".format(type(key))
)
self._pack(key, dehydration_hooks)
self._pack(item, dehydration_hooks)
self._py_pack(key, dehydration_hooks)
self._py_pack(item, dehydration_hooks)

# Structure
elif isinstance(value, Structure):
Expand All @@ -169,7 +153,7 @@ def _pack(self, value, dehydration_hooks=None):
if dehydration_hooks:
transformer = dehydration_hooks.get_transformer(value)
if transformer is not None:
self._pack(transformer(value), dehydration_hooks)
self._py_pack(transformer(value), dehydration_hooks)
return

raise ValueError("Values of type %s are not supported" % type(value))
Expand Down Expand Up @@ -298,11 +282,16 @@ def read(self, n=1):
def read_u8(self):
return self.unpackable.read_u8()

def unpack(self, hydration_hooks=None):
value = self._unpack(hydration_hooks=hydration_hooks)
if hydration_hooks and type(value) in hydration_hooks:
return hydration_hooks[type(value)](value)
return value
if _rust_unpack:
def unpack(self, hydration_hooks=None):
value, i = _rust_unpack(
self.unpackable.data, self.unpackable.p, hydration_hooks
)
self.unpackable.p = i
return value
else:
def unpack(self, hydration_hooks=None):
return self._unpack(hydration_hooks=hydration_hooks)

def _unpack(self, hydration_hooks=None):
marker = self.read_u8()
Expand Down Expand Up @@ -384,8 +373,13 @@ def _unpack(self, hydration_hooks=None):
size, tag = self._unpack_structure_header(marker)
value = Structure(tag, *([None] * size))
for i in range(len(value)):
value[i] = self.unpack(hydration_hooks=hydration_hooks)
return value
value[i] = self._unpack(hydration_hooks=hydration_hooks)
if not hydration_hooks:
return value
hydration_hook = hydration_hooks.get(type(value))
if not hydration_hook:
return value
return hydration_hook(value)

else:
raise ValueError("Unknown PackStream marker %02X" % marker)
Expand All @@ -397,22 +391,22 @@ def _unpack_list_items(self, marker, hydration_hooks=None):
if size == 0:
return
elif size == 1:
yield self.unpack(hydration_hooks=hydration_hooks)
yield self._unpack(hydration_hooks=hydration_hooks)
else:
for _ in range(size):
yield self.unpack(hydration_hooks=hydration_hooks)
yield self._unpack(hydration_hooks=hydration_hooks)
elif marker == 0xD4: # LIST_8:
size, = struct_unpack(">B", self.read(1))
for _ in range(size):
yield self.unpack(hydration_hooks=hydration_hooks)
yield self._unpack(hydration_hooks=hydration_hooks)
elif marker == 0xD5: # LIST_16:
size, = struct_unpack(">H", self.read(2))
for _ in range(size):
yield self.unpack(hydration_hooks=hydration_hooks)
yield self._unpack(hydration_hooks=hydration_hooks)
elif marker == 0xD6: # LIST_32:
size, = struct_unpack(">I", self.read(4))
for _ in range(size):
yield self.unpack(hydration_hooks=hydration_hooks)
yield self._unpack(hydration_hooks=hydration_hooks)
else:
return

Expand All @@ -426,29 +420,29 @@ def _unpack_map(self, marker, hydration_hooks=None):
size = marker & 0x0F
value = {}
for _ in range(size):
key = self.unpack(hydration_hooks=hydration_hooks)
value[key] = self.unpack(hydration_hooks=hydration_hooks)
key = self._unpack(hydration_hooks=hydration_hooks)
value[key] = self._unpack(hydration_hooks=hydration_hooks)
return value
elif marker == 0xD8: # MAP_8:
size, = struct_unpack(">B", self.read(1))
value = {}
for _ in range(size):
key = self.unpack(hydration_hooks=hydration_hooks)
value[key] = self.unpack(hydration_hooks=hydration_hooks)
key = self._unpack(hydration_hooks=hydration_hooks)
value[key] = self._unpack(hydration_hooks=hydration_hooks)
return value
elif marker == 0xD9: # MAP_16:
size, = struct_unpack(">H", self.read(2))
value = {}
for _ in range(size):
key = self.unpack(hydration_hooks=hydration_hooks)
value[key] = self.unpack(hydration_hooks=hydration_hooks)
key = self._unpack(hydration_hooks=hydration_hooks)
value[key] = self._unpack(hydration_hooks=hydration_hooks)
return value
elif marker == 0xDA: # MAP_32:
size, = struct_unpack(">I", self.read(4))
value = {}
for _ in range(size):
key = self.unpack(hydration_hooks=hydration_hooks)
value[key] = self.unpack(hydration_hooks=hydration_hooks)
key = self._unpack(hydration_hooks=hydration_hooks)
value[key] = self._unpack(hydration_hooks=hydration_hooks)
return value
else:
return None
Expand Down
Loading

0 comments on commit 9a2a20a

Please sign in to comment.