Skip to content

Commit

Permalink
implement initial support for running shell commands asynchronously u…
Browse files Browse the repository at this point in the history
…sing run_shell_cmd
  • Loading branch information
boegel committed Jan 19, 2024
1 parent 8aaaec2 commit 7cba1dc
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 32 deletions.
40 changes: 26 additions & 14 deletions easybuild/tools/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import subprocess
import sys
import tempfile
import threading
import time
from collections import namedtuple
from datetime import datetime
Expand Down Expand Up @@ -79,7 +80,7 @@


RunShellCmdResult = namedtuple('RunShellCmdResult', ('cmd', 'exit_code', 'output', 'stderr', 'work_dir',
'out_file', 'err_file'))
'out_file', 'err_file', 'thread_id'))


class RunShellCmdError(BaseException):
Expand Down Expand Up @@ -199,7 +200,7 @@ def run_shell_cmd(cmd, fail_on_error=True, split_stderr=False, stdin=None, env=N
:param use_bash: execute command through bash shell (enabled by default)
:param output_file: collect command output in temporary output file
:param stream_output: stream command output to stdout (auto-enabled with --logtostdout if None)
:param asynchronous: run command asynchronously
:param asynchronous: indicate that command is being run asynchronously
:param with_hooks: trigger pre/post run_shell_cmd hooks (if defined)
:param qa_patterns: list of 2-tuples with patterns for questions + corresponding answers
:param qa_wait_patterns: list of 2-tuples with patterns for non-questions
Expand All @@ -223,9 +224,6 @@ def to_cmd_str(cmd):
return cmd_str

# temporarily raise a NotImplementedError until all options are implemented
if asynchronous:
raise NotImplementedError

if qa_patterns or qa_wait_patterns:
raise NotImplementedError

Expand All @@ -235,6 +233,11 @@ def to_cmd_str(cmd):
cmd_str = to_cmd_str(cmd)
cmd_name = os.path.basename(cmd_str.split(' ')[0])

thread_id = None
if asynchronous:
thread_id = threading.get_native_id()
_log.info(f"Initiating running of shell command '{cmd_str}' via thread with ID {thread_id}")

# auto-enable streaming of command output under --logtostdout/-l, unless it was disabled explicitely
if stream_output is None and build_option('logtostdout'):
_log.info(f"Auto-enabling streaming output of '{cmd_str}' command because logging to stdout is enabled")
Expand All @@ -259,16 +262,16 @@ def to_cmd_str(cmd):
if not in_dry_run and build_option('extended_dry_run'):
if not hidden or verbose_dry_run:
silent = build_option('silent')
msg = f" running command \"{cmd_str}\"\n"
msg = f" running shell command \"{cmd_str}\"\n"
msg += f" (in {work_dir})"
dry_run_msg(msg, silent=silent)

return RunShellCmdResult(cmd=cmd_str, exit_code=0, output='', stderr=None, work_dir=work_dir,
out_file=cmd_out_fp, err_file=cmd_err_fp)
out_file=cmd_out_fp, err_file=cmd_err_fp, thread_id=thread_id)

start_time = datetime.now()
if not hidden:
cmd_trace_msg(cmd_str, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp)
_cmd_trace_msg(cmd_str, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp, thread_id)

if stream_output:
print_msg(f"(streaming) output for command '{cmd_str}':")
Expand All @@ -293,7 +296,11 @@ def to_cmd_str(cmd):

stderr = subprocess.PIPE if split_stderr else subprocess.STDOUT

_log.info(f"Running command '{cmd_str}' in {work_dir}")
log_msg = f"Running shell command '{cmd_str}' in {work_dir}"
if thread_id:
log_msg += f" (via thread with ID {thread_id})"
_log.info(log_msg)

proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=stderr, stdin=subprocess.PIPE,
cwd=work_dir, env=env, shell=shell, executable=executable)

Expand Down Expand Up @@ -337,7 +344,7 @@ def to_cmd_str(cmd):
raise EasyBuildError(f"Failed to dump command output to temporary file: {err}")

res = RunShellCmdResult(cmd=cmd_str, exit_code=proc.returncode, output=output, stderr=stderr, work_dir=work_dir,
out_file=cmd_out_fp, err_file=cmd_err_fp)
out_file=cmd_out_fp, err_file=cmd_err_fp, thread_id=thread_id)

# always log command output
cmd_name = cmd_str.split(' ')[0]
Expand Down Expand Up @@ -370,7 +377,7 @@ def to_cmd_str(cmd):
return res


def cmd_trace_msg(cmd, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp):
def _cmd_trace_msg(cmd, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp, thread_id):
"""
Helper function to construct and print trace message for command being run
Expand All @@ -380,11 +387,18 @@ def cmd_trace_msg(cmd, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp):
:param stdin: stdin input value for command
:param cmd_out_fp: path to output file for command
:param cmd_err_fp: path to errors/warnings output file for command
:param thread_id: thread ID (None when not running shell command asynchronously)
"""
start_time = start_time.strftime('%Y-%m-%d %H:%M:%S')

if thread_id:
run_cmd_msg = f"running shell command (asynchronously, thread ID: {thread_id}):"
else:
run_cmd_msg = "running shell command:"

lines = [
"running command:",
run_cmd_msg,
f"\t{cmd}",
f"\t[started at: {start_time}]",
f"\t[working dir: {work_dir}]",
]
Expand All @@ -395,8 +409,6 @@ def cmd_trace_msg(cmd, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp):
if cmd_err_fp:
lines.append(f"\t[errors/warnings saved to {cmd_err_fp}]")

lines.append('\t' + cmd)

trace_msg('\n'.join(lines))


Expand Down
95 changes: 79 additions & 16 deletions test/framework/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import tempfile
import textwrap
import time
from concurrent.futures import ThreadPoolExecutor
from test.framework.utilities import EnhancedTestCase, TestLoaderFiltered, init_config
from unittest import TextTestRunner
from easybuild.base.fancylogger import setLogLevelDebug
Expand Down Expand Up @@ -248,7 +249,7 @@ def test_run_shell_cmd_log(self):
fd, logfile = tempfile.mkstemp(suffix='.log', prefix='eb-test-')
os.close(fd)

regex_start_cmd = re.compile("Running command 'echo hello' in /")
regex_start_cmd = re.compile("Running shell command 'echo hello' in /")
regex_cmd_exit = re.compile(r"Shell command completed successfully \(see output above\): echo hello")

# command output is always logged
Expand Down Expand Up @@ -448,7 +449,7 @@ def test_run_cmd_work_dir(self):

def test_run_shell_cmd_work_dir(self):
"""
Test running command in specific directory with run_shell_cmd function.
Test running shell command in specific directory with run_shell_cmd function.
"""
orig_wd = os.getcwd()
self.assertFalse(os.path.samefile(orig_wd, self.test_prefix))
Expand Down Expand Up @@ -615,11 +616,11 @@ def test_run_shell_cmd_trace(self):
"""Test run_shell_cmd function in trace mode, and with tracing disabled."""

pattern = [
r"^ >> running command:",
r"^ >> running shell command:",
r"\techo hello",
r"\t\[started at: .*\]",
r"\t\[working dir: .*\]",
r"\t\[output saved to .*\]",
r"\techo hello",
r" >> command completed: exit 0, ran in .*",
]

Expand Down Expand Up @@ -675,11 +676,11 @@ def test_run_shell_cmd_trace_stdin(self):
init_config(build_options={'trace': True})

pattern = [
r"^ >> running command:",
r"^ >> running shell command:",
r"\techo hello",
r"\t\[started at: [0-9][0-9][0-9][0-9]-[0-9][0-9]-[0-9][0-9] [0-9][0-9]:[0-9][0-9]:[0-9][0-9]\]",
r"\t\[working dir: .*\]",
r"\t\[output saved to .*\]",
r"\techo hello",
r" >> command completed: exit 0, ran in .*",
]

Expand Down Expand Up @@ -707,8 +708,8 @@ def test_run_shell_cmd_trace_stdin(self):
self.assertEqual(res.output, 'hello')
self.assertEqual(res.exit_code, 0)
self.assertEqual(stderr, '')
pattern.insert(3, r"\t\[input: hello\]")
pattern[-2] = "\tcat"
pattern.insert(4, r"\t\[input: hello\]")
pattern[1] = "\tcat"
regex = re.compile('\n'.join(pattern))
self.assertTrue(regex.search(stdout), "Pattern '%s' found in: %s" % (regex.pattern, stdout))

Expand Down Expand Up @@ -909,7 +910,8 @@ def test_run_shell_cmd_cache(self):
# inject value into cache to check whether executing command again really returns cached value
with self.mocked_stdout_stderr():
cached_res = RunShellCmdResult(cmd=cmd, output="123456", exit_code=123, stderr=None,
work_dir='/test_ulimit', out_file='/tmp/foo.out', err_file=None)
work_dir='/test_ulimit', out_file='/tmp/foo.out', err_file=None,
thread_id=None)
run_shell_cmd.update_cache({(cmd, None): cached_res})
res = run_shell_cmd(cmd)
self.assertEqual(res.cmd, cmd)
Expand All @@ -928,7 +930,8 @@ def test_run_shell_cmd_cache(self):
# inject different output for cat with 'foo' as stdin to check whether cached value is used
with self.mocked_stdout_stderr():
cached_res = RunShellCmdResult(cmd=cmd, output="bar", exit_code=123, stderr=None,
work_dir='/test_cat', out_file='/tmp/cat.out', err_file=None)
work_dir='/test_cat', out_file='/tmp/cat.out', err_file=None,
thread_id=None)
run_shell_cmd.update_cache({(cmd, 'foo'): cached_res})
res = run_shell_cmd(cmd, stdin='foo')
self.assertEqual(res.cmd, cmd)
Expand Down Expand Up @@ -1006,7 +1009,7 @@ def test_run_shell_cmd_dry_run(self):
self.assertEqual(res.output, '')
self.assertEqual(res.stderr, None)
# check dry run output
expected = """ running command "somecommand foo 123 bar"\n"""
expected = """ running shell command "somecommand foo 123 bar"\n"""
self.assertIn(expected, stdout)

# check enabling 'hidden'
Expand All @@ -1029,7 +1032,7 @@ def test_run_shell_cmd_dry_run(self):
fail_on_error=False, in_dry_run=True)
stdout = self.get_stdout()
self.mock_stdout(False)
self.assertNotIn('running command "', stdout)
self.assertNotIn('running shell command "', stdout)
self.assertNotEqual(res.exit_code, 0)
self.assertEqual(res.output, 'done\n')
self.assertEqual(res.stderr, None)
Expand Down Expand Up @@ -1207,7 +1210,7 @@ def test_run_cmd_async(self):
"for i in $(seq 1 50)",
"do sleep 0.1",
"for j in $(seq 1000)",
"do echo foo",
"do echo foo${i}${j}",
"done",
"done",
"echo done",
Expand Down Expand Up @@ -1257,8 +1260,68 @@ def test_run_cmd_async(self):
res = check_async_cmd(*cmd_info, output=res['output'])
self.assertEqual(res['done'], True)
self.assertEqual(res['exit_code'], 0)
self.assertTrue(res['output'].startswith('start\n'))
self.assertTrue(res['output'].endswith('\ndone\n'))
self.assertEqual(len(res['output']), 435661)
self.assertTrue(res['output'].startswith('start\nfoo11\nfoo12\n'))
self.assertTrue('\nfoo49999\nfoo491000\nfoo501\n' in res['output'])
self.assertTrue(res['output'].endswith('\nfoo501000\ndone\n'))

def test_run_shell_cmd_async(self):
"""Test asynchronously running of a shell command via run_shell_cmd """

thread_pool = ThreadPoolExecutor()

os.environ['TEST'] = 'test123'
env = os.environ.copy()

test_cmd = "echo 'sleeping...'; sleep 2; echo $TEST"
task = thread_pool.submit(run_shell_cmd, test_cmd, hidden=True, asynchronous=True, env=env)

# change value of $TEST to check that command is completed with correct environment
os.environ['TEST'] = 'some_other_value'

# initial poll should result in None, since it takes a while for the command to complete
self.assertEqual(task.done(), False)

# wait until command is done
while not task.done():
time.sleep(1)
res = task.result()

self.assertEqual(res.exit_code, 0)
self.assertEqual(res.output, 'sleeping...\ntest123\n')

# check asynchronous running of failing command
error_test_cmd = "echo 'FAIL!' >&2; exit 123"
task = thread_pool.submit(run_shell_cmd, error_test_cmd, hidden=True, fail_on_error=False, asynchronous=True)
time.sleep(1)
res = task.result()
self.assertEqual(res.exit_code, 123)
self.assertEqual(res.output, "FAIL!\n")
self.assertTrue(res.thread_id)

# also test with a command that produces a lot of output,
# since that tends to lock up things unless we frequently grab some output...
verbose_test_cmd = ';'.join([
"echo start",
"for i in $(seq 1 50)",
"do sleep 0.1",
"for j in $(seq 1000)",
"do echo foo${i}${j}",
"done",
"done",
"echo done",
])
task = thread_pool.submit(run_shell_cmd, verbose_test_cmd, hidden=True, asynchronous=True)

while not task.done():
time.sleep(1)
res = task.result()

self.assertEqual(res.exit_code, 0)
self.assertEqual(len(res.output), 435661)
self.assertTrue(res.output.startswith('start\nfoo11\nfoo12\n'))
self.assertTrue('\nfoo49999\nfoo491000\nfoo501\n' in res.output)
self.assertTrue(res.output.endswith('\nfoo501000\ndone\n'))

def test_check_log_for_errors(self):
fd, logfile = tempfile.mkstemp(suffix='.log', prefix='eb-test-')
Expand Down Expand Up @@ -1373,7 +1436,7 @@ def post_run_shell_cmd_hook(cmd, *args, **kwargs):

def test_run_shell_cmd_with_hooks(self):
"""
Test running command with run_shell_cmd function with pre/post run_shell_cmd hooks in place.
Test running shell command with run_shell_cmd function with pre/post run_shell_cmd hooks in place.
"""
cwd = os.getcwd()

Expand Down
4 changes: 2 additions & 2 deletions test/framework/systemtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def mocked_run_shell_cmd(cmd, **kwargs):
}
if cmd in known_cmds:
return RunShellCmdResult(cmd=cmd, exit_code=0, output=known_cmds[cmd], stderr=None, work_dir=os.getcwd(),
out_file=None, err_file=None)
out_file=None, err_file=None, thread_id=None)
else:
return run_shell_cmd(cmd, **kwargs)

Expand Down Expand Up @@ -774,7 +774,7 @@ def test_gcc_version_darwin(self):
out = "Apple LLVM version 7.0.0 (clang-700.1.76)"
cwd = os.getcwd()
mocked_run_res = RunShellCmdResult(cmd="gcc --version", exit_code=0, output=out, stderr=None, work_dir=cwd,
out_file=None, err_file=None)
out_file=None, err_file=None, thread_id=None)
st.run_shell_cmd = lambda *args, **kwargs: mocked_run_res
self.assertEqual(get_gcc_version(), None)

Expand Down

0 comments on commit 7cba1dc

Please sign in to comment.