diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index e5e955d..272d3d1 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ concurrency: jobs: python-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-25.02 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -40,7 +40,7 @@ jobs: upload-conda: needs: python-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-24.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-upload-packages.yaml@branch-25.02 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -48,7 +48,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-25.02 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -61,7 +61,7 @@ jobs: wheel-publish: needs: wheel-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-24.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-publish.yaml@branch-25.02 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index bbe54fc..0d92ae8 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -15,10 +15,10 @@ jobs: - conda-python-build - wheel-build secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-24.12 + uses: rapidsai/shared-workflows/.github/workflows/pr-builder.yaml@branch-25.02 conda-python-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-24.12 + uses: rapidsai/shared-workflows/.github/workflows/conda-python-build.yaml@branch-25.02 with: build_type: pull-request # Package is pure Python and only ever requires one build. @@ -26,7 +26,7 @@ jobs: matrix_filter: map(select(.ARCH == "amd64")) | max_by([(.PY_VER|split(".")|map(tonumber)), (.CUDA_VER|split(".")|map(tonumber))]) | [.] wheel-build: secrets: inherit - uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-24.12 + uses: rapidsai/shared-workflows/.github/workflows/wheels-build.yaml@branch-25.02 with: build_type: pull-request # Package is pure Python and only ever requires one build. diff --git a/.github/workflows/trigger-breaking-change-alert.yaml b/.github/workflows/trigger-breaking-change-alert.yaml new file mode 100644 index 0000000..01dd243 --- /dev/null +++ b/.github/workflows/trigger-breaking-change-alert.yaml @@ -0,0 +1,26 @@ +name: Trigger Breaking Change Notifications + +on: + pull_request_target: + types: + - closed + - reopened + - labeled + - unlabeled + +jobs: + trigger-notifier: + if: contains(github.event.pull_request.labels.*.name, 'breaking') + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/breaking-change-alert.yaml@branch-25.02 + with: + sender_login: ${{ github.event.sender.login }} + sender_avatar: ${{ github.event.sender.avatar_url }} + repo: ${{ github.repository }} + pr_number: ${{ github.event.pull_request.number }} + pr_title: "${{ github.event.pull_request.title }}" + pr_body: "${{ github.event.pull_request.body || '_Empty PR description_' }}" + pr_base_ref: ${{ github.event.pull_request.base.ref }} + pr_author: ${{ github.event.pull_request.user.login }} + event_action: ${{ github.event.action }} + pr_merged: ${{ github.event.pull_request.merged }} diff --git a/conda/recipes/rapids-dask-dependency/meta.yaml b/conda/recipes/rapids-dask-dependency/meta.yaml index 8c1c505..9ea27a9 100644 --- a/conda/recipes/rapids-dask-dependency/meta.yaml +++ b/conda/recipes/rapids-dask-dependency/meta.yaml @@ -28,11 +28,10 @@ requirements: - setuptools - conda-verify run: - - dask ==2024.11.2 - - dask-core ==2024.11.2 - - distributed ==2024.11.2 - - dask-expr ==1.1.19 - - pynvml >=11.0.0,<11.5.0a0 + - dask ==2024.12.1 + - dask-core ==2024.12.1 + - distributed ==2024.12.1 + - dask-expr ==1.1.21 about: home: https://rapids.ai/ diff --git a/pyproject.toml b/pyproject.toml index b0527d3..548be13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,13 +9,12 @@ requires = [ [project] name = "rapids-dask-dependency" -version = "24.12.00a0" +version = "25.02.00a0" description = "Dask and Distributed version pinning for RAPIDS" dependencies = [ - "dask==2024.11.2", - "distributed==2024.11.2", - "dask-expr==1.1.19", - "pynvml>=11.0.0,<11.5.0a0", + "dask==2024.12.1", + "distributed==2024.12.1", + "dask-expr==1.1.21", ] license = { text = "Apache 2.0" } readme = { file = "README.md", content-type = "text/markdown" } diff --git a/rapids_dask_dependency/patches/distributed/comm/__rdd_patch_ucx.py b/rapids_dask_dependency/patches/distributed/comm/__rdd_patch_ucx.py new file mode 100644 index 0000000..974e256 --- /dev/null +++ b/rapids_dask_dependency/patches/distributed/comm/__rdd_patch_ucx.py @@ -0,0 +1,667 @@ +""" +:ref:`UCX`_ based communications for distributed. + +See :ref:`communications` for more. + +.. _UCX: https://github.com/openucx/ucx +""" + +from __future__ import annotations + +import functools +import logging +import os +import struct +import weakref +from collections.abc import Awaitable, Callable, Collection +from typing import TYPE_CHECKING, Any +from unittest.mock import patch + +import dask +from dask.utils import parse_bytes + +from distributed.comm.addressing import parse_host_port, unparse_host_port +from distributed.comm.core import BaseListener, Comm, CommClosedError, Connector +from distributed.comm.registry import Backend, backends +from distributed.comm.utils import ensure_concrete_host, from_frames, to_frames +from distributed.diagnostics.nvml import ( + CudaDeviceInfo, + get_device_index_and_uuid, + has_cuda_context, +) +from distributed.protocol.utils import host_array +from distributed.utils import ensure_ip, get_ip, get_ipv6, log_errors, nbytes + +logger = logging.getLogger(__name__) + +# In order to avoid double init when forking/spawning new processes (multiprocess), +# we make sure only to import and initialize UCX once at first use. This is also +# required to ensure Dask configuration gets propagated to UCX, which needs +# variables to be set before being imported. +if TYPE_CHECKING: + try: + import ucp + except ImportError: + pass +else: + ucp = None + +device_array = None +pre_existing_cuda_context = False +cuda_context_created = False + + +_warning_suffix = ( + "This is often the result of a CUDA-enabled library calling a CUDA runtime function before " + "Dask-CUDA can spawn worker processes. Please make sure any such function calls don't happen " + "at import time or in the global scope of a program." +) + + +def _get_device_and_uuid_str(device_info: CudaDeviceInfo) -> str: + return f"{device_info.device_index} ({str(device_info.uuid)})" + + +def _warn_existing_cuda_context(device_info: CudaDeviceInfo, pid: int) -> None: + device_uuid_str = _get_device_and_uuid_str(device_info) + logger.warning( + f"A CUDA context for device {device_uuid_str} already exists " + f"on process ID {pid}. {_warning_suffix}" + ) + + +def _warn_cuda_context_wrong_device( + device_info_expected: CudaDeviceInfo, device_info_actual: CudaDeviceInfo, pid: int +) -> None: + expected_device_uuid_str = _get_device_and_uuid_str(device_info_expected) + actual_device_uuid_str = _get_device_and_uuid_str(device_info_actual) + logger.warning( + f"Worker with process ID {pid} should have a CUDA context assigned to device " + f"{expected_device_uuid_str}, but instead the CUDA context is on device " + f"{actual_device_uuid_str}. {_warning_suffix}" + ) + + +def synchronize_stream(stream=0): + import numba.cuda + + ctx = numba.cuda.current_context() + cu_stream = numba.cuda.driver.drvapi.cu_stream(stream) + stream = numba.cuda.driver.Stream(ctx, cu_stream, None) + stream.synchronize() + + +def init_once(): + global ucp, device_array + global ucx_create_endpoint, ucx_create_listener + global pre_existing_cuda_context, cuda_context_created + + if ucp is not None: + return + + # remove/process dask.ucx flags for valid ucx options + ucx_config, ucx_environment = _prepare_ucx_config() + + # We ensure the CUDA context is created before initializing UCX. This can't + # be safely handled externally because communications in Dask start before + # preload scripts run. + # Precedence: + # 1. external environment + # 2. ucx_config (high level settings passed to ucp.init) + # 3. ucx_environment (low level settings equivalent to environment variables) + ucx_tls = os.environ.get( + "UCX_TLS", + ucx_config.get("TLS", ucx_environment.get("UCX_TLS", "")), + ) + if ( + dask.config.get("distributed.comm.ucx.create-cuda-context") is True + # This is not foolproof, if UCX_TLS=all we might require CUDA + # depending on configuration of UCX, but this is better than + # nothing + or ("cuda" in ucx_tls and "^cuda" not in ucx_tls) + ): + try: + import numba.cuda + except ImportError: + raise ImportError( + "CUDA support with UCX requires Numba for context management" + ) + + cuda_visible_device = get_device_index_and_uuid( + os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0] + ) + pre_existing_cuda_context = has_cuda_context() + if pre_existing_cuda_context.has_context: + _warn_existing_cuda_context( + pre_existing_cuda_context.device_info, os.getpid() + ) + + numba.cuda.current_context() + + cuda_context_created = has_cuda_context() + if ( + cuda_context_created.has_context + and cuda_context_created.device_info.uuid != cuda_visible_device.uuid + ): + _warn_cuda_context_wrong_device( + cuda_visible_device, cuda_context_created.device_info, os.getpid() + ) + + connect_timeout = dask.config.get("distributed.comm.ucx.connect-timeout", None) + + import ucp as _ucp + + ucp = _ucp + + with patch.dict(os.environ, ucx_environment): + # We carefully ensure that ucx_environment only contains things + # that don't override ucx_config or existing slots in the + # environment, so the user's external environment can safely + # override things here. + ucp.init(options=ucx_config, env_takes_precedence=True, connect_timeout=connect_timeout) + + pool_size_str = dask.config.get("distributed.rmm.pool-size") + + # Find the function, `cuda_array()`, to use when allocating new CUDA arrays + try: + import rmm + + def device_array(n): + return rmm.DeviceBuffer(size=n) + + if pool_size_str is not None: + pool_size = parse_bytes(pool_size_str) + rmm.reinitialize( + pool_allocator=True, managed_memory=False, initial_pool_size=pool_size + ) + except ImportError: + try: + import numba.cuda + + def numba_device_array(n): + a = numba.cuda.device_array((n,), dtype="u1") + weakref.finalize(a, numba.cuda.current_context) + return a + + device_array = numba_device_array + + except ImportError: + + def device_array(n): + raise RuntimeError( + "In order to send/recv CUDA arrays, Numba or RMM is required" + ) + + if pool_size_str is not None: + logger.warning( + "Initial RMM pool size defined, but RMM is not available. " + "Please consider installing RMM or removing the pool size option." + ) + + +def _close_comm(ref): + """Callback to close Dask Comm when UCX Endpoint closes or errors + + Parameters + ---------- + ref: weak reference to a Dask UCX comm + """ + comm = ref() + if comm is not None: + comm._closed = True + + +class UCX(Comm): + """Comm object using UCP. + + Parameters + ---------- + ep : ucp.Endpoint + The UCP endpoint. + address : str + The address, prefixed with `ucx://` to use. + deserialize : bool, default True + Whether to deserialize data in :meth:`distributed.protocol.loads` + + Notes + ----- + The read-write cycle uses the following pattern: + + Each msg is serialized into a number of "data" frames. We prepend these + real frames with two additional frames + + 1. is_gpus: Boolean indicator for whether the frame should be + received into GPU memory. Packed in '?' format. Unpack with + ``?`` format. + 2. frame_size : Unsigned int describing the size of frame (in bytes) + to receive. Packed in 'Q' format, so a length-0 frame is equivalent + to an unsized frame. Unpacked with ``Q``. + + The expected read cycle is + + 1. Read the frame describing if connection is closing and number of frames + 2. Read the frame describing whether each data frame is gpu-bound + 3. Read the frame describing whether each data frame is sized + 4. Read all the data frames. + """ + + def __init__( # type: ignore[no-untyped-def] + self, ep, local_addr: str, peer_addr: str, deserialize: bool = True + ): + super().__init__(deserialize=deserialize) + self._ep = ep + if local_addr: + assert local_addr.startswith("ucx") + assert peer_addr.startswith("ucx") + self._local_addr = local_addr + self._peer_addr = peer_addr + self.comm_flag = None + + # When the UCX endpoint closes or errors the registered callback + # is called. + if hasattr(self._ep, "set_close_callback"): + ref = weakref.ref(self) + self._ep.set_close_callback(functools.partial(_close_comm, ref)) + self._closed = False + self._has_close_callback = True + else: + self._has_close_callback = False + + logger.debug("UCX.__init__ %s", self) + + @property + def local_address(self) -> str: + return self._local_addr + + @property + def peer_address(self) -> str: + return self._peer_addr + + @property + def same_host(self) -> bool: + """Unlike in TCP, local_address can be blank""" + return super().same_host if self._local_addr else False + + @log_errors + async def write( + self, + msg: dict, + serializers: Collection[str] | None = None, + on_error: str = "message", + ) -> int: + if self.closed(): + raise CommClosedError("Endpoint is closed -- unable to send message") + + if serializers is None: + serializers = ("cuda", "dask", "pickle", "error") + # msg can also be a list of dicts when sending batched messages + frames = await to_frames( + msg, + serializers=serializers, + on_error=on_error, + allow_offload=self.allow_offload, + ) + nframes = len(frames) + cuda_frames = tuple(hasattr(f, "__cuda_array_interface__") for f in frames) + sizes = tuple(nbytes(f) for f in frames) + cuda_send_frames, send_frames = zip( + *( + (is_cuda, each_frame) + for is_cuda, each_frame in zip(cuda_frames, frames) + if nbytes(each_frame) > 0 + ) + ) + + try: + # Send meta data + + # Send close flag and number of frames (_Bool, int64) + await self.ep.send(struct.pack("?Q", False, nframes)) + # Send which frames are CUDA (bool) and + # how large each frame is (uint64) + await self.ep.send( + struct.pack(nframes * "?" + nframes * "Q", *cuda_frames, *sizes) + ) + + # Send frames + + # It is necessary to first synchronize the default stream before start + # sending We synchronize the default stream because UCX is not + # stream-ordered and syncing the default stream will wait for other + # non-blocking CUDA streams. Note this is only sufficient if the memory + # being sent is not currently in use on non-blocking CUDA streams. + if any(cuda_send_frames): + synchronize_stream(0) + + for each_frame in send_frames: + await self.ep.send(each_frame) + return sum(sizes) + except ucp.exceptions.UCXBaseException: + self.abort() + raise CommClosedError("While writing, the connection was closed") + + @log_errors + async def read(self, deserializers=("cuda", "dask", "pickle", "error")): + if deserializers is None: + deserializers = ("cuda", "dask", "pickle", "error") + + try: + # Recv meta data + + # Recv close flag and number of frames (_Bool, int64) + msg = host_array(struct.calcsize("?Q")) + await self.ep.recv(msg) + (shutdown, nframes) = struct.unpack("?Q", msg) + + if shutdown: # The writer is closing the connection + raise CommClosedError("Connection closed by writer") + + # Recv which frames are CUDA (bool) and + # how large each frame is (uint64) + header_fmt = nframes * "?" + nframes * "Q" + header = host_array(struct.calcsize(header_fmt)) + await self.ep.recv(header) + header = struct.unpack(header_fmt, header) + cuda_frames, sizes = header[:nframes], header[nframes:] + except BaseException as e: # noqa: B036 + # In addition to UCX exceptions, may be CancelledError or another + # "low-level" exception. The only safe thing to do is to abort. + # (See also https://github.com/dask/distributed/pull/6574). + self.abort() + raise CommClosedError( + f"Connection closed by writer.\nInner exception: {e!r}" + ) + else: + # Recv frames + frames = [ + device_array(each_size) if is_cuda else host_array(each_size) + for is_cuda, each_size in zip(cuda_frames, sizes) + ] + cuda_recv_frames, recv_frames = zip( + *( + (is_cuda, each_frame) + for is_cuda, each_frame in zip(cuda_frames, frames) + if nbytes(each_frame) > 0 + ) + ) + + # It is necessary to first populate `frames` with CUDA arrays and synchronize + # the default stream before starting receiving to ensure buffers have been allocated + if any(cuda_recv_frames): + synchronize_stream(0) + + try: + for each_frame in recv_frames: + await self.ep.recv(each_frame) + except BaseException as e: # noqa: B036 + # In addition to UCX exceptions, may be CancelledError or another + # "low-level" exception. The only safe thing to do is to abort. + # (See also https://github.com/dask/distributed/pull/6574). + self.abort() + raise CommClosedError( + f"Connection closed by writer.\nInner exception: {e!r}" + ) + + try: + msg = await from_frames( + frames, + deserialize=self.deserialize, + deserializers=deserializers, + allow_offload=self.allow_offload, + ) + except EOFError: + # Frames possibly garbled or truncated by communication error + self.abort() + raise CommClosedError("Aborted stream on truncated data") + return msg + + async def close(self): + self._closed = True + if self._ep is not None: + try: + await self.ep.send(struct.pack("?Q", True, 0)) + except ( # noqa: B030 + ucp.exceptions.UCXError, + ucp.exceptions.UCXCloseError, + ucp.exceptions.UCXCanceled, + ) + (getattr(ucp.exceptions, "UCXConnectionReset", ()),): + # If the other end is in the process of closing, + # UCX will sometimes raise a `Input/output` error, + # which we can ignore. + pass + self.abort() + self._ep = None + + def abort(self): + self._closed = True + if self._ep is not None: + self._ep.abort() + self._ep = None + + @property + def ep(self): + if self._ep is not None: + return self._ep + else: + raise CommClosedError("UCX Endpoint is closed") + + def closed(self): + if self._has_close_callback is True: + # The self._closed flag is separate from the endpoint's lifetime, even when + # the endpoint has closed or errored, there may be messages on its buffer + # still to be received, even though sending is not possible anymore. + return self._closed + else: + return self._ep is None + + +class UCXConnector(Connector): + prefix = "ucx://" + comm_class = UCX + encrypted = False + + async def connect( + self, address: str, deserialize: bool = True, **connection_args: Any + ) -> UCX: + logger.debug("UCXConnector.connect: %s", address) + ip, port = parse_host_port(address) + init_once() + try: + ep = await ucp.create_endpoint(ip, port) + except ucp.exceptions.UCXBaseException: + raise CommClosedError("Connection closed before handshake completed") + return self.comm_class( + ep, + local_addr="", + peer_addr=self.prefix + address, + deserialize=deserialize, + ) + + +class UCXListener(BaseListener): + prefix = UCXConnector.prefix + comm_class = UCXConnector.comm_class + encrypted = UCXConnector.encrypted + + def __init__( + self, + address: str, + comm_handler: Callable[[UCX], Awaitable[None]] | None = None, + deserialize: bool = False, + allow_offload: bool = True, + **connection_args: Any, + ): + super().__init__() + if not address.startswith("ucx"): + address = "ucx://" + address + self.ip, self._input_port = parse_host_port(address, default_port=0) + self.comm_handler = comm_handler + self.deserialize = deserialize + self.allow_offload = allow_offload + self._ep = None # type: ucp.Endpoint + self.ucp_server = None + self.connection_args = connection_args + + @property + def port(self): + return self.ucp_server.port + + @property + def address(self): + return "ucx://" + self.ip + ":" + str(self.port) + + async def start(self): + async def serve_forever(client_ep): + ucx = UCX( + client_ep, + local_addr=self.address, + peer_addr=self.address, + deserialize=self.deserialize, + ) + ucx.allow_offload = self.allow_offload + try: + await self.on_connection(ucx) + except CommClosedError: + logger.debug("Connection closed before handshake completed") + return + if self.comm_handler: + await self.comm_handler(ucx) + + init_once() + self.ucp_server = ucp.create_listener(serve_forever, port=self._input_port) + + def stop(self): + self.ucp_server = None + + def get_host_port(self): + # TODO: TCP raises if this hasn't started yet. + return self.ip, self.port + + @property + def listen_address(self): + return self.prefix + unparse_host_port(*self.get_host_port()) + + @property + def contact_address(self): + host, port = self.get_host_port() + host = ensure_concrete_host(host) # TODO: ensure_concrete_host + return self.prefix + unparse_host_port(host, port) + + @property + def bound_address(self): + # TODO: Does this become part of the base API? Kinda hazy, since + # we exclude in for inproc. + return self.get_host_port() + + +class UCXBackend(Backend): + # I / O + + def get_connector(self): + return UCXConnector() + + def get_listener(self, loc, handle_comm, deserialize, **connection_args): + return UCXListener(loc, handle_comm, deserialize, **connection_args) + + # Address handling + # This duplicates BaseTCPBackend + + def get_address_host(self, loc): + return parse_host_port(loc)[0] + + def get_address_host_port(self, loc): + return parse_host_port(loc) + + def resolve_address(self, loc): + host, port = parse_host_port(loc) + return unparse_host_port(ensure_ip(host), port) + + def get_local_address_for(self, loc): + host, port = parse_host_port(loc) + host = ensure_ip(host) + if ":" in host: + local_host = get_ipv6(host) + else: + local_host = get_ip(host) + return unparse_host_port(local_host, None) + + +backends["ucx"] = UCXBackend() + + +def _prepare_ucx_config(): + """Translate dask config options to appropriate UCX config options + + Returns + ------- + tuple + Options suitable for passing to ``ucp.init`` and additional + UCX options that will be inserted directly into the environment + while calling ``ucp.init``. + """ + + # configuration of UCX can happen in two ways: + # 1) high level on/off flags which correspond to UCX configuration + # 2) explicitly defined UCX configuration flags in distributed.comm.ucx.environment + # High-level settings in (1) are preferred to settings in (2) + # Settings in the external environment override both + + high_level_options = {} + + # if any of the high level flags are set, as long as they are not Null/None, + # we assume we should configure basic TLS settings for UCX, otherwise we + # leave UCX to its default configuration + if any( + [ + dask.config.get("distributed.comm.ucx.tcp"), + dask.config.get("distributed.comm.ucx.nvlink"), + dask.config.get("distributed.comm.ucx.infiniband"), + ] + ): + if dask.config.get("distributed.comm.ucx.rdmacm"): + tls = "tcp" + tls_priority = "rdmacm" + else: + tls = "tcp" + tls_priority = "tcp" + + # CUDA COPY can optionally be used with ucx -- we rely on the user + # to define when messages will include CUDA objects. Note: + # defining only the Infiniband flag will not enable cuda_copy + if any( + [ + dask.config.get("distributed.comm.ucx.nvlink"), + dask.config.get("distributed.comm.ucx.cuda-copy"), + ] + ): + tls = tls + ",cuda_copy" + + if dask.config.get("distributed.comm.ucx.infiniband"): + tls = "rc," + tls + if dask.config.get("distributed.comm.ucx.nvlink"): + tls = tls + ",cuda_ipc" + + high_level_options = {"TLS": tls, "SOCKADDR_TLS_PRIORITY": tls_priority} + + # Pick up any other ucx environment settings + environment_options = {} + for k, v in dask.config.get("distributed.comm.ucx.environment", {}).items(): + # {"some-name": value} is translated to {"UCX_SOME_NAME": value} + key = "_".join(map(str.upper, ("UCX", *k.split("-")))) + if (hl_key := key[4:]) in high_level_options: + logger.warning( + f"Ignoring {k}={v} ({key=}) in ucx.environment, " + f"preferring {hl_key}={high_level_options[hl_key]} " + "from high level options" + ) + elif key in os.environ: + # This is only info because setting UCX configuration via + # environment variables is a reasonably common approach + logger.info( + f"Ignoring {k}={v} ({key=}) in ucx.environment, " + f"preferring {key}={os.environ[key]} from external environment" + ) + else: + environment_options[key] = v + + return high_level_options, environment_options diff --git a/rapids_dask_dependency/patches/distributed/comm/ucx.py b/rapids_dask_dependency/patches/distributed/comm/ucx.py new file mode 100644 index 0000000..adbcf24 --- /dev/null +++ b/rapids_dask_dependency/patches/distributed/comm/ucx.py @@ -0,0 +1,6 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +import sys + +from rapids_dask_dependency.loaders import make_vendored_loader + +load_module = make_vendored_loader(__name__)