diff --git a/var/ramble/repos/builtin/modifiers/nccl-tcpxo/modifier.py b/var/ramble/repos/builtin/modifiers/nccl-tcpxo/modifier.py new file mode 100644 index 000000000..e23a0f877 --- /dev/null +++ b/var/ramble/repos/builtin/modifiers/nccl-tcpxo/modifier.py @@ -0,0 +1,77 @@ +# Copyright 2022-2024 The Ramble Authors +# +# Licensed under the Apache License, Version 2.0 or the MIT license +# , at your +# option. This file may not be copied, modified, or distributed +# except according to those terms. + +import os +from ramble.modkit import * + + +class NcclTcpxo(BasicModifier): + """Modifier to ensure NCCL TCPXO is loaded into the execution environment.""" + + name = "nccl-tcpxo" + + tags("gpu") + + maintainers("douglasjacobsen") + + mode( + "auto", description="Auto detected shell based on config:shell setting" + ) + default_mode("auto") + + def _cos_paths(self): + return "/usr/local/nvidia/lib64/nccl-env-profile.sh" + + def _debian_paths(self): + return "/var/lib/tcpxo/lib64/nccl-env-profile.sh" + + register_builtin("source_tcpxo", injection_method="prepend") + + def source_tcpxo(self): + import ramble.util.shell_utils + import ramble.config + + shell = ramble.config.get("config:shell") + source_str = ramble.util.shell_utils.source_str(shell) + cmds = [] + + path_funcs = [self._cos_paths, self._debian_paths] + + for path_func in path_funcs: + script = path_func() + script_dir = os.path.dirname(script) + if shell in ["bash", "sh"]: + cmds.extend( + [ + f'if [ -d "{script_dir}" ]; then', + f" NCCL_LIB_DIR={script_dir} {source_str} {script}", + "fi", + ] + ) + elif shell == "csh": + cmds.extend( + [ + f'if ( -d "{script_dir}" ) then', + f" NCCL_LIB_DIR={script_dir} {source_str} {script}", + "endif", + ] + ) + elif shell == "fish": + cmds.extend( + [ + f'if test -d "{script_dir}"', + f" NCCL_LIB_DIR={script_dir} {source_str} {script}", + "end", + ] + ) + elif shell == "bat": + logger.die( + "The nccl-tcpxo modifier is not currently supported for batch shell." + ) + + return cmds