Skip to content

Commit

Permalink
Use ensure_bytes from dask.utils (#6295)
Browse files Browse the repository at this point in the history
Keep `ensure_bytes` around and have it call the `dask.utils`
implementation. Though have it raise a `DeprecationWarning` so users
know this will be going away in the future and should update their code.
  • Loading branch information
jakirkham authored May 9, 2022
1 parent 61aa93e commit decfe7f
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 37 deletions.
3 changes: 2 additions & 1 deletion distributed/comm/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tornado.websocket import WebSocketClosedError, WebSocketHandler, websocket_connect

import dask
from dask.utils import ensure_bytes

from distributed.comm.addressing import parse_host_port, unparse_host_port
from distributed.comm.core import (
Expand All @@ -36,7 +37,7 @@
get_tcp_server_address,
to_frames,
)
from distributed.utils import ensure_bytes, nbytes
from distributed.utils import nbytes

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

np = pytest.importorskip("numpy")

from dask.utils import tmpfile
from dask.utils import ensure_bytes, tmpfile

from distributed.protocol import (
decompress,
Expand All @@ -20,7 +20,7 @@
from distributed.protocol.pickle import HIGHEST_PROTOCOL
from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE
from distributed.system import MEMORY_LIMIT
from distributed.utils import ensure_bytes, nbytes
from distributed.utils import nbytes
from distributed.utils_test import gen_cluster


Expand Down
2 changes: 1 addition & 1 deletion distributed/protocol/tests/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
np = pytest.importorskip("numpy")

from dask.dataframe.utils import assert_eq
from dask.utils import ensure_bytes

from distributed.protocol import (
decompress,
Expand All @@ -13,7 +14,6 @@
serialize,
to_serialize,
)
from distributed.utils import ensure_bytes

dfs = [
pd.DataFrame({}),
Expand Down
22 changes: 0 additions & 22 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
LoopRunner,
TimeoutError,
_maybe_complex,
ensure_bytes,
ensure_ip,
ensure_memoryview,
format_dashboard_link,
Expand Down Expand Up @@ -249,27 +248,6 @@ def test_seek_delimiter_endline():
assert f.tell() == 7


def test_ensure_bytes():
data = [b"1", "1", memoryview(b"1"), bytearray(b"1"), array.array("b", [49])]
for d in data:
result = ensure_bytes(d)
assert isinstance(result, bytes)
assert result == b"1"


def test_ensure_bytes_ndarray():
np = pytest.importorskip("numpy")
result = ensure_bytes(np.arange(12))
assert isinstance(result, bytes)


def test_ensure_bytes_pyarrow_buffer():
pa = pytest.importorskip("pyarrow")
buf = pa.py_buffer(b"123")
result = ensure_bytes(buf)
assert isinstance(result, bytes)


def test_ensure_memoryview_empty():
result = ensure_memoryview(b"")
assert isinstance(result, memoryview)
Expand Down
20 changes: 9 additions & 11 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

import dask
from dask import istask
from dask.utils import ensure_bytes as _ensure_bytes
from dask.utils import parse_timedelta as _parse_timedelta
from dask.widgets import get_template

Expand Down Expand Up @@ -1000,17 +1001,14 @@ def ensure_bytes(s):
>>> ensure_bytes(b'123')
b'123'
"""
if isinstance(s, bytes):
return s
elif hasattr(s, "encode"):
return s.encode()
else:
try:
return bytes(s)
except Exception as e:
raise TypeError(
"Object %s is neither a bytes object nor has an encode method" % s
) from e
warnings.warn(
"`distributed.utils.ensure_bytes` is deprecated. "
"Please switch to `dask.utils.ensure_bytes`. "
"This will be removed in `2022.6.0`.",
DeprecationWarning,
stacklevel=2,
)
return _ensure_bytes(s)


def ensure_memoryview(obj):
Expand Down

0 comments on commit decfe7f

Please sign in to comment.