Skip to content

Commit

Permalink
Use ensure_bytes from dask.utils
Browse files Browse the repository at this point in the history
Also drop `ensure_bytes` implementation & tests from Distributed.
  • Loading branch information
jakirkham committed May 6, 2022
1 parent 9f7646a commit cadb0bf
Show file tree
Hide file tree
Showing 7 changed files with 8 additions and 71 deletions.
3 changes: 2 additions & 1 deletion distributed/comm/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tornado.websocket import WebSocketClosedError, WebSocketHandler, websocket_connect

import dask
from dask.utils import ensure_bytes

from distributed.comm.addressing import parse_host_port, unparse_host_port
from distributed.comm.core import (
Expand All @@ -36,7 +37,7 @@
get_tcp_server_address,
to_frames,
)
from distributed.utils import ensure_bytes, nbytes
from distributed.utils import nbytes

logger = logging.getLogger(__name__)

Expand Down
3 changes: 1 addition & 2 deletions distributed/protocol/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from tlz import identity

import dask

from distributed.utils import ensure_bytes
from dask.utils import ensure_bytes

compressions: dict[
str | None | Literal[False],
Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import dask
from dask.base import normalize_token
from dask.utils import typename
from dask.utils import ensure_bytes, typename

from distributed.protocol import pickle
from distributed.protocol.compression import decompress, maybe_compress
Expand All @@ -22,7 +22,7 @@
pack_frames_prelude,
unpack_frames,
)
from distributed.utils import ensure_bytes, has_keyword
from distributed.utils import has_keyword

dask_serialize = dask.utils.Dispatch("dask_serialize")
dask_deserialize = dask.utils.Dispatch("dask_deserialize")
Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

np = pytest.importorskip("numpy")

from dask.utils import tmpfile
from dask.utils import ensure_bytes, tmpfile

from distributed.protocol import (
decompress,
Expand All @@ -20,7 +20,7 @@
from distributed.protocol.pickle import HIGHEST_PROTOCOL
from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE
from distributed.system import MEMORY_LIMIT
from distributed.utils import ensure_bytes, nbytes
from distributed.utils import nbytes
from distributed.utils_test import gen_cluster


Expand Down
2 changes: 1 addition & 1 deletion distributed/protocol/tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
np = pytest.importorskip("numpy")

from dask.dataframe.utils import assert_eq
from dask.utils import ensure_bytes

from distributed.protocol import (
decompress,
Expand All @@ -13,7 +14,6 @@
serialize,
to_serialize,
)
from distributed.utils import ensure_bytes

dfs = [
pd.DataFrame({}),
Expand Down
22 changes: 0 additions & 22 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
LoopRunner,
TimeoutError,
_maybe_complex,
ensure_bytes,
ensure_ip,
format_dashboard_link,
get_ip_interface,
Expand Down Expand Up @@ -248,27 +247,6 @@ def test_seek_delimiter_endline():
assert f.tell() == 7


def test_ensure_bytes():
data = [b"1", "1", memoryview(b"1"), bytearray(b"1"), array.array("b", [49])]
for d in data:
result = ensure_bytes(d)
assert isinstance(result, bytes)
assert result == b"1"


def test_ensure_bytes_ndarray():
np = pytest.importorskip("numpy")
result = ensure_bytes(np.arange(12))
assert isinstance(result, bytes)


def test_ensure_bytes_pyarrow_buffer():
pa = pytest.importorskip("pyarrow")
buf = pa.py_buffer(b"123")
result = ensure_bytes(buf)
assert isinstance(result, bytes)


def test_nbytes():
np = pytest.importorskip("numpy")

Expand Down
41 changes: 0 additions & 41 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,47 +972,6 @@ def read_block(f, offset, length, delimiter=None):
return bytes


def ensure_bytes(s):
"""Attempt to turn `s` into bytes.
Parameters
----------
s : Any
The object to be converted. Will correctly handled
* str
* bytes
* objects implementing the buffer protocol (memoryview, ndarray, etc.)
Returns
-------
b : bytes
Raises
------
TypeError
When `s` cannot be converted
Examples
--------
>>> ensure_bytes('123')
b'123'
>>> ensure_bytes(b'123')
b'123'
"""
if isinstance(s, bytes):
return s
elif hasattr(s, "encode"):
return s.encode()
else:
try:
return bytes(s)
except Exception as e:
raise TypeError(
"Object %s is neither a bytes object nor has an encode method" % s
) from e


def open_port(host=""):
"""Return a probably-open port
Expand Down

0 comments on commit cadb0bf

Please sign in to comment.