Skip to content

Commit

Permalink
Fix hanging direct scheduler+ssh (#4735)
Browse files Browse the repository at this point in the history
* Fix hanging direct scheduler+ssh

The fix is very simple: in the ssh transport, to emulate 'chdir',
we keep the current directory in memory, and we prepend every command
with a `cd FOLDER_NAME && ACTUALCOMMAND`.

One could put `;` instead of `&&`, but then if the folder does not
exist the ACTUALCOMMAND would still be run in the wrong folder, which is
very bad (imagine you are removing files...).

Now, in general this is not a problem. However, the direct scheduler
inserts a complex-syntax bash command to run the command in the background
and immediately get the PID of that process without waiting.
When combined with SSH, this hangs until the whole process is completed, unless
the actual command is wrapped in brackets.

A simple way to check this is running these two commands, that reproduce
the issue with plain ssh, without paramiko:

This hangs for 5 seconds:
```
ssh localhost 'cd tmp && sleep 5 > /dev/null 2>&1 & echo $!'
```

This returns immediately, as we want:
```
ssh localhost 'cd tmp && ( sleep 5 > /dev/null 2>&1 & echo $! )'
```

Also, adding a regression test for the hanging direct+ssh combination
This test checks that submitting a long job over the direct scheduler
does not "hang" with any plugin.

Co-authored-by: Leopold Talirz <leopold.talirz@gmail.com>
  • Loading branch information
giovannipizzi and ltalirz authored Feb 18, 2021
1 parent d63c9c1 commit 652dee3
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 98 deletions.
2 changes: 1 addition & 1 deletion aiida/transports/plugins/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,7 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1): #

if self.getcwd() is not None:
escaped_folder = escape_for_bash(self.getcwd())
command_to_execute = (f'cd {escaped_folder} && {command}')
command_to_execute = (f'cd {escaped_folder} && ( {command} )')
else:
command_to_execute = command

Expand Down
190 changes: 93 additions & 97 deletions tests/transports/test_all_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,19 @@
Plugin specific tests will be written in the plugin itself.
"""
import io
import os
import random
import tempfile
import signal
import shutil
import string
import time
import unittest
import uuid

import psutil

from aiida.plugins import SchedulerFactory

# TODO : test for copy with pattern
# TODO : test for copy with/without patterns, overwriting folder
Expand All @@ -35,7 +47,6 @@ def get_all_custom_transports():
it was found)
"""
import importlib
import os

modulename = __name__.rpartition('.')[0]
this_full_fname = __file__
Expand Down Expand Up @@ -133,11 +144,6 @@ def test_makedirs(self, custom_transport):
"""
Verify the functioning of makedirs command
"""
# Imports required later
import random
import string
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
directory = 'temp_dir_test'
Expand Down Expand Up @@ -176,11 +182,6 @@ def test_rmtree(self, custom_transport):
"""
Verify the functioning of rmtree command
"""
# Imports required later
import random
import string
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
directory = 'temp_dir_test'
Expand Down Expand Up @@ -221,12 +222,6 @@ def test_listdir(self, custom_transport):
"""
create directories, verify listdir, delete a folder with subfolders
"""
# Imports required later
import tempfile
import random
import string
import os

with custom_transport as trans:
# We cannot use tempfile.mkdtemp because we're on a remote folder
location = trans.normalize(os.path.join('/', 'tmp'))
Expand Down Expand Up @@ -270,11 +265,6 @@ def test_listdir_withattributes(self, custom_transport):
"""
create directories, verify listdir_withattributes, delete a folder with subfolders
"""
# Imports required later
import tempfile
import random
import string
import os

def simplify_attributes(data):
"""
Expand Down Expand Up @@ -340,11 +330,6 @@ def simplify_attributes(data):
@run_for_all_plugins
def test_dir_creation_deletion(self, custom_transport):
"""Test creating and deleting directories."""
# Imports required later
import random
import string
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
directory = 'temp_dir_test'
Expand All @@ -370,11 +355,6 @@ def test_dir_copy(self, custom_transport):
Verify if in the copy of a directory also the protection bits
are carried over
"""
# Imports required later
import random
import string
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
directory = 'temp_dir_test'
Expand Down Expand Up @@ -403,11 +383,6 @@ def test_dir_permissions_creation_modification(self, custom_transport): # pylin
verify if chmod raises IOError when trying to change bits on a
non-existing folder
"""
# Imports required later
import random
import string
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
directory = 'temp_dir_test'
Expand Down Expand Up @@ -460,11 +435,6 @@ def test_dir_reading_permissions(self, custom_transport):
Try to enter a directory with no read permissions.
Verify that the cwd has not changed after failed try.
"""
# Imports required later
import random
import string
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
directory = 'temp_dir_test'
Expand Down Expand Up @@ -503,8 +473,6 @@ def test_isfile_isdir_to_empty_string(self, custom_transport):
I check that isdir or isfile return False when executed on an
empty string
"""
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
transport.chdir(location)
Expand All @@ -517,8 +485,6 @@ def test_isfile_isdir_to_non_existing_string(self, custom_transport):
I check that isdir or isfile return False when executed on an
empty string
"""
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
transport.chdir(location)
Expand All @@ -535,8 +501,6 @@ def test_chdir_to_empty_string(self, custom_transport):
not change (this is a paramiko default behavior), but getcwd()
is still correctly defined.
"""
import os

with custom_transport as transport:
new_dir = transport.normalize(os.path.join('/', 'tmp'))
transport.chdir(new_dir)
Expand All @@ -555,10 +519,6 @@ class TestPutGetFile(unittest.TestCase):
@run_for_all_plugins
def test_put_and_get(self, custom_transport):
"""Test putting and getting files."""
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -605,10 +565,6 @@ def test_put_get_abs_path(self, custom_transport):
"""
test of exception for non existing files and abs path
"""
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -669,10 +625,6 @@ def test_put_get_empty_string(self, custom_transport):
test of exception put/get of empty strings
"""
# TODO : verify the correctness of \n at the end of a file
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -752,10 +704,6 @@ class TestPutGetTree(unittest.TestCase):
@run_for_all_plugins
def test_put_and_get(self, custom_transport):
"""Test putting and getting files."""
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -807,8 +755,6 @@ def test_put_and_get(self, custom_transport):
self.assertTrue('file.txt' in list_pushed_file)
self.assertTrue('file.txt' in list_retrieved_file)

import shutil

shutil.rmtree(local_subfolder)
shutil.rmtree(retrieved_subfolder)
transport.rmtree(remote_subfolder)
Expand All @@ -819,11 +765,6 @@ def test_put_and_get(self, custom_transport):
@run_for_all_plugins
def test_put_and_get_overwrite(self, custom_transport):
"""Test putting and getting files with overwrites."""
import os
import random
import shutil
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -877,10 +818,6 @@ def test_put_and_get_overwrite(self, custom_transport):
@run_for_all_plugins
def test_copy(self, custom_transport):
"""Test copying."""
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -952,10 +889,6 @@ def test_put(self, custom_transport):
# pylint: disable=too-many-statements
# exactly the same tests of copy, just with the put function
# and therefore the local path must be absolute
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -1033,11 +966,6 @@ def test_get(self, custom_transport):
# pylint: disable=too-many-statements
# exactly the same tests of copy, just with the put function
# and therefore the local path must be absolute
import os
import random
import shutil
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -1119,10 +1047,6 @@ def test_put_get_abs_path(self, custom_transport):
"""
test of exception for non existing files and abs path
"""
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -1194,10 +1118,6 @@ def test_put_get_empty_string(self, custom_transport):
test of exception put/get of empty strings
"""
# TODO : verify the correctness of \n at the end of a file
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -1263,9 +1183,6 @@ def test_put_get_empty_string(self, custom_transport):
@run_for_all_plugins
def test_gettree_nested_directory(self, custom_transport): # pylint: disable=no-self-use
"""Test `gettree` for a nested directory."""
import os
import tempfile

with tempfile.TemporaryDirectory() as dir_remote, tempfile.TemporaryDirectory() as dir_local:
content = b'dummy\ncontent'
filepath = os.path.join(dir_remote, 'sub', 'path', 'filename.txt')
Expand Down Expand Up @@ -1294,8 +1211,6 @@ def test_exec_pwd(self, custom_transport):
creation (which should be done by paramiko) and in the command
execution (done in this module, in the _exec_command_internal function).
"""
import os

# Start value
delete_at_end = False

Expand Down Expand Up @@ -1365,3 +1280,84 @@ def test_exec_with_wrong_stdin(self, custom_transport):
with custom_transport as transport:
with self.assertRaises(ValueError):
transport.exec_command_wait('cat', stdin=1)


class TestDirectScheduler(unittest.TestCase):
"""
Test how the direct scheduler works.
While this is technically a scheduler test, I put it under the transport tests
because 1) in reality I am testing the interaction of each transport with the
direct scheduler; 2) the direct scheduler is always available; 3) I am reusing
the infrastructure to test on multiple transport plugins.
"""

@run_for_all_plugins
def test_asynchronous_execution(self, custom_transport):
"""Test that the execution of a long(ish) command via the direct scheduler does not block.
This is a regression test for #3094, where running a long job on the direct scheduler
(via SSH) would lock the interpreter until the job was done.
"""
# Use a unique name, using a UUID, to avoid concurrent tests (or very rapid
# tests that follow each other) to overwrite the same destination
script_fname = f'sleep-submit-{uuid.uuid4().hex}-{custom_transport.__class__.__name__}.sh'

scheduler = SchedulerFactory('direct')()
scheduler.set_transport(custom_transport)
with custom_transport as transport:
try:
with tempfile.NamedTemporaryFile() as tmpf:
# Put a submission script that sleeps 10 seconds
tmpf.write(b'#!/bin/bash\nsleep 10\n')
tmpf.flush()

transport.chdir('/tmp')
transport.putfile(tmpf.name, script_fname)

timestamp_before = time.time()
job_id_string = scheduler.submit_from_script('/tmp', script_fname)

elapsed_time = time.time() - timestamp_before
# We want to get back control. If it takes < 5 seconds, it means that it is not blocking
# as the job is taking at least 10 seconds. I put 5 as the machine could be slow (including the
# SSH connection etc.) and I don't want to have false failures.
# Actually, if the time is short, it could mean also that the execution failed!
# So I double check later that the execution was successful.
self.assertTrue(
elapsed_time < 5, 'Getting back control after remote execution took more than 5 seconds! '
'Probably submission is blocking'
)

# Check that the job is still running
# Wait 0.2 more seconds, so that I don't do a super-quick check that might return True
# even if it's not sleeping
time.sleep(0.2)
# Check that the job is still running - IMPORTANT, I'm assuming that all transports actually act
# on the *same* local machine, and that the job_id is actually the process PID.
# This needs to be adapted if:
# - a new transport plugin is tested and this does not test the same machine
# - a new scheduler is used and does not use the process PID, or the job_id of the 'direct' scheduler
# is not anymore simply the job PID
job_id = int(job_id_string)
self.assertTrue(
psutil.pid_exists(job_id), 'The job is not there after a bit more than 1 second! Probably it failed'
)
finally:
# Clean up by killing the remote job.
# This assumes it's on the same machine; if we add tests on a different machine,
# we need to call 'kill' via the transport instead.
# In reality it's not critical to remove it since it will end after 10 seconds of
# sleeping, but this might avoid warnings (e.g. ResourceWarning)
try:
os.kill(job_id, signal.SIGTERM)
except ProcessLookupError:
# If the process is already dead (or has never run), I just ignore the error
pass

# Also remove the script
try:
transport.remove(f'/tmp/{script_fname}')
except FileNotFoundError:
# If the file wasn't even created, I just ignore this error
pass

0 comments on commit 652dee3

Please sign in to comment.