Skip to content

Commit

Permalink
bugfix/flux-nodes-prior-versions (#487)
Browse files Browse the repository at this point in the history
* add a version check for flux when getting node count

* update CHANGELOG

* add major version check for flux
  • Loading branch information
bgunnar5 authored Jun 11, 2024
1 parent 0f6bebf commit 831bc40
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Link to Merlin banner in readme
- Issue with escape sequences in ascii art (caught by python 3.12)
- Bug where Flux wasn't identifying total number of nodes on an allocation
- Not supporting Flux versions below 0.17.0


## [1.12.1]
Expand Down
12 changes: 9 additions & 3 deletions merlin/study/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
import subprocess
from typing import Dict, Optional, Union

from merlin.utils import convert_timestring, get_flux_alloc, get_yaml_var
from merlin.utils import convert_timestring, get_flux_alloc, get_flux_version, get_yaml_var


LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -126,7 +126,7 @@ def get_batch_type(scheduler_legend, default=None):
return default


def get_node_count(default=1):
def get_node_count(parsed_batch: Dict, default=1):
"""
Determine a default node count based on the environment.
Expand All @@ -135,6 +135,12 @@ def get_node_count(default=1):
:param returns: (int) The number of nodes to use.
"""

# Flux version check
flux_ver = get_flux_version(parsed_batch["flux exe"], no_errors=True)
major, minor, _ = map(int, flux_ver.split("."))
if major < 1 and minor < 17:
raise ValueError("Flux version is too old. Supported versions are 0.17.0+.")

# If flux is the scheduler, we can get the size of the allocation with this
try:
get_size_proc = subprocess.run("flux getattr size", shell=True, capture_output=True, text=True)
Expand Down Expand Up @@ -254,7 +260,7 @@ def batch_worker_launch(

# Get the number of nodes from the environment if unset
if nodes is None or nodes == "all":
nodes = get_node_count(default=1)
nodes = get_node_count(parsed_batch, default=1)
elif not isinstance(nodes, int):
raise TypeError("Nodes was passed into batch_worker_launch with an invalid type (likely a string other than 'all').")

Expand Down

0 comments on commit 831bc40

Please sign in to comment.