diff --git a/distributed/comm/ws.py b/distributed/comm/ws.py index 2f6a9279fd8..dcfdc41e8d0 100644 --- a/distributed/comm/ws.py +++ b/distributed/comm/ws.py @@ -15,6 +15,7 @@ from tornado.websocket import WebSocketClosedError, WebSocketHandler, websocket_connect import dask +from dask.utils import ensure_bytes from distributed.comm.addressing import parse_host_port, unparse_host_port from distributed.comm.core import ( @@ -36,7 +37,7 @@ get_tcp_server_address, to_frames, ) -from distributed.utils import ensure_bytes, nbytes +from distributed.utils import nbytes logger = logging.getLogger(__name__) diff --git a/distributed/protocol/compression.py b/distributed/protocol/compression.py index 5995d1d9d51..d054d88d1ef 100644 --- a/distributed/protocol/compression.py +++ b/distributed/protocol/compression.py @@ -15,8 +15,7 @@ from tlz import identity import dask - -from distributed.utils import ensure_bytes +from dask.utils import ensure_bytes compressions: dict[ str | None | Literal[False], diff --git a/distributed/protocol/serialize.py b/distributed/protocol/serialize.py index 0cd1de3db69..66ee828df57 100644 --- a/distributed/protocol/serialize.py +++ b/distributed/protocol/serialize.py @@ -11,7 +11,7 @@ import dask from dask.base import normalize_token -from dask.utils import typename +from dask.utils import ensure_bytes, typename from distributed.protocol import pickle from distributed.protocol.compression import decompress, maybe_compress @@ -22,7 +22,7 @@ pack_frames_prelude, unpack_frames, ) -from distributed.utils import ensure_bytes, has_keyword +from distributed.utils import has_keyword dask_serialize = dask.utils.Dispatch("dask_serialize") dask_deserialize = dask.utils.Dispatch("dask_deserialize") diff --git a/distributed/protocol/tests/test_numpy.py b/distributed/protocol/tests/test_numpy.py index 93c42aacf1d..895d5ad8896 100644 --- a/distributed/protocol/tests/test_numpy.py +++ b/distributed/protocol/tests/test_numpy.py @@ -4,7 +4,7 @@ np = pytest.importorskip("numpy") -from dask.utils import tmpfile +from dask.utils import ensure_bytes, tmpfile from distributed.protocol import ( decompress, @@ -20,7 +20,7 @@ from distributed.protocol.pickle import HIGHEST_PROTOCOL from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE from distributed.system import MEMORY_LIMIT -from distributed.utils import ensure_bytes, nbytes +from distributed.utils import nbytes from distributed.utils_test import gen_cluster diff --git a/distributed/protocol/tests/test_pandas.py b/distributed/protocol/tests/test_pandas.py index 58bfb90f75e..07fc916064d 100644 --- a/distributed/protocol/tests/test_pandas.py +++ b/distributed/protocol/tests/test_pandas.py @@ -4,6 +4,7 @@ np = pytest.importorskip("numpy") from dask.dataframe.utils import assert_eq +from dask.utils import ensure_bytes from distributed.protocol import ( decompress, @@ -13,7 +14,6 @@ serialize, to_serialize, ) -from distributed.utils import ensure_bytes dfs = [ pd.DataFrame({}), diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 3358dbc1907..296fea8ee55 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -26,7 +26,6 @@ LoopRunner, TimeoutError, _maybe_complex, - ensure_bytes, ensure_ip, format_dashboard_link, get_ip_interface, @@ -248,27 +247,6 @@ def test_seek_delimiter_endline(): assert f.tell() == 7 -def test_ensure_bytes(): - data = [b"1", "1", memoryview(b"1"), bytearray(b"1"), array.array("b", [49])] - for d in data: - result = ensure_bytes(d) - assert isinstance(result, bytes) - assert result == b"1" - - -def test_ensure_bytes_ndarray(): - np = pytest.importorskip("numpy") - result = ensure_bytes(np.arange(12)) - assert isinstance(result, bytes) - - -def test_ensure_bytes_pyarrow_buffer(): - pa = pytest.importorskip("pyarrow") - buf = pa.py_buffer(b"123") - result = ensure_bytes(buf) - assert isinstance(result, bytes) - - def test_nbytes(): np = pytest.importorskip("numpy") diff --git a/distributed/utils.py b/distributed/utils.py index 40116f6b01b..ec155e539cb 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -972,47 +972,6 @@ def read_block(f, offset, length, delimiter=None): return bytes -def ensure_bytes(s): - """Attempt to turn `s` into bytes. - - Parameters - ---------- - s : Any - The object to be converted. Will correctly handled - - * str - * bytes - * objects implementing the buffer protocol (memoryview, ndarray, etc.) - - Returns - ------- - b : bytes - - Raises - ------ - TypeError - When `s` cannot be converted - - Examples - -------- - >>> ensure_bytes('123') - b'123' - >>> ensure_bytes(b'123') - b'123' - """ - if isinstance(s, bytes): - return s - elif hasattr(s, "encode"): - return s.encode() - else: - try: - return bytes(s) - except Exception as e: - raise TypeError( - "Object %s is neither a bytes object nor has an encode method" % s - ) from e - - def open_port(host=""): """Return a probably-open port