Skip to content

Commit

Permalink
Add back checks for v1.12.0
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Jun 28, 2024
1 parent 02bc2a5 commit 02b148f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
7 changes: 7 additions & 0 deletions docs/source/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@ UCX-Py redefines some of the UCX defaults for a variety of reasons, including be

Below is a list of the UCX-Py redefined default values, and what conditions are required for them to apply.

Apply to all UCX versions:

::

UCX_RNDV_THRESH=8192
UCX_RNDV_SCHEME=get_zcopy

Apply to UCX >= 1.12.0, older UCX versions rely on UCX defaults:

::

UCX_CUDA_COPY_MAX_REG_RATIO=1.0
UCX_MAX_RNDV_RAILS=1
UCX_PROTO_ENABLE=n
Expand Down
12 changes: 8 additions & 4 deletions ucp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@
logger.info("Setting UCX_RNDV_FRAG_MEM_TYPE=cuda")
os.environ["UCX_RNDV_FRAG_MEM_TYPE"] = "cuda"

if pynvml is not None and "UCX_CUDA_COPY_MAX_REG_RATIO" not in os.environ:
if (
pynvml is not None
and "UCX_CUDA_COPY_MAX_REG_RATIO" not in os.environ
and get_ucx_version() >= (1, 12, 0)
):
try:
pynvml.nvmlInit()
device_count = pynvml.nvmlDeviceGetCount()
Expand Down Expand Up @@ -94,19 +98,19 @@ def _is_mig_device(handle):
):
pass

if "UCX_MAX_RNDV_RAILS" not in os.environ:
if "UCX_MAX_RNDV_RAILS" not in os.environ and get_ucx_version() >= (1, 12, 0):
logger.info("Setting UCX_MAX_RNDV_RAILS=1")
os.environ["UCX_MAX_RNDV_RAILS"] = "1"

if "UCX_PROTO_ENABLE" not in os.environ:
if "UCX_PROTO_ENABLE" not in os.environ and get_ucx_version() >= (1, 12, 0):
# UCX protov2 still doesn't support CUDA async/managed memory
logger.info("Setting UCX_PROTO_ENABLE=n")
os.environ["UCX_PROTO_ENABLE"] = "n"


__ucx_version__ = "%d.%d.%d" % get_ucx_version()

if get_ucx_version() < (1, 11, 1):
if get_ucx_version() < (1, 15, 0):
raise ImportError(
f"Support for UCX {__ucx_version__} has ended. Please upgrade to "
"1.11.1 or newer."
Expand Down

0 comments on commit 02b148f

Please sign in to comment.