Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hanging direct scheduler+ssh #4735

Merged
2 changes: 1 addition & 1 deletion aiida/transports/plugins/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,7 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1): #

if self.getcwd() is not None:
escaped_folder = escape_for_bash(self.getcwd())
command_to_execute = (f'cd {escaped_folder} && {command}')
command_to_execute = (f'cd {escaped_folder} && ( {command} )')
else:
command_to_execute = command

Expand Down
190 changes: 93 additions & 97 deletions tests/transports/test_all_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,19 @@
Plugin specific tests will be written in the plugin itself.
"""
import io
import os
import random
import tempfile
import signal
import shutil
import string
import time
import unittest
import uuid

import psutil

from aiida.plugins import SchedulerFactory

# TODO : test for copy with pattern
# TODO : test for copy with/without patterns, overwriting folder
Expand All @@ -35,7 +47,6 @@ def get_all_custom_transports():
it was found)
"""
import importlib
import os

modulename = __name__.rpartition('.')[0]
this_full_fname = __file__
Expand Down Expand Up @@ -133,11 +144,6 @@ def test_makedirs(self, custom_transport):
"""
Verify the functioning of makedirs command
"""
# Imports required later
import random
import string
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
directory = 'temp_dir_test'
Expand Down Expand Up @@ -176,11 +182,6 @@ def test_rmtree(self, custom_transport):
"""
Verify the functioning of rmtree command
"""
# Imports required later
import random
import string
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
directory = 'temp_dir_test'
Expand Down Expand Up @@ -221,12 +222,6 @@ def test_listdir(self, custom_transport):
"""
create directories, verify listdir, delete a folder with subfolders
"""
# Imports required later
import tempfile
import random
import string
import os

with custom_transport as trans:
# We cannot use tempfile.mkdtemp because we're on a remote folder
location = trans.normalize(os.path.join('/', 'tmp'))
Expand Down Expand Up @@ -270,11 +265,6 @@ def test_listdir_withattributes(self, custom_transport):
"""
create directories, verify listdir_withattributes, delete a folder with subfolders
"""
# Imports required later
import tempfile
import random
import string
import os

def simplify_attributes(data):
"""
Expand Down Expand Up @@ -340,11 +330,6 @@ def simplify_attributes(data):
@run_for_all_plugins
def test_dir_creation_deletion(self, custom_transport):
"""Test creating and deleting directories."""
# Imports required later
import random
import string
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
directory = 'temp_dir_test'
Expand All @@ -370,11 +355,6 @@ def test_dir_copy(self, custom_transport):
Verify if in the copy of a directory also the protection bits
are carried over
"""
# Imports required later
import random
import string
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
directory = 'temp_dir_test'
Expand Down Expand Up @@ -403,11 +383,6 @@ def test_dir_permissions_creation_modification(self, custom_transport): # pylin
verify if chmod raises IOError when trying to change bits on a
non-existing folder
"""
# Imports required later
import random
import string
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
directory = 'temp_dir_test'
Expand Down Expand Up @@ -460,11 +435,6 @@ def test_dir_reading_permissions(self, custom_transport):
Try to enter a directory with no read permissions.
Verify that the cwd has not changed after failed try.
"""
# Imports required later
import random
import string
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
directory = 'temp_dir_test'
Expand Down Expand Up @@ -503,8 +473,6 @@ def test_isfile_isdir_to_empty_string(self, custom_transport):
I check that isdir or isfile return False when executed on an
empty string
"""
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
transport.chdir(location)
Expand All @@ -517,8 +485,6 @@ def test_isfile_isdir_to_non_existing_string(self, custom_transport):
I check that isdir or isfile return False when executed on an
empty string
"""
import os

with custom_transport as transport:
location = transport.normalize(os.path.join('/', 'tmp'))
transport.chdir(location)
Expand All @@ -535,8 +501,6 @@ def test_chdir_to_empty_string(self, custom_transport):
not change (this is a paramiko default behavior), but getcwd()
is still correctly defined.
"""
import os

with custom_transport as transport:
new_dir = transport.normalize(os.path.join('/', 'tmp'))
transport.chdir(new_dir)
Expand All @@ -555,10 +519,6 @@ class TestPutGetFile(unittest.TestCase):
@run_for_all_plugins
def test_put_and_get(self, custom_transport):
"""Test putting and getting files."""
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -605,10 +565,6 @@ def test_put_get_abs_path(self, custom_transport):
"""
test of exception for non existing files and abs path
"""
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -669,10 +625,6 @@ def test_put_get_empty_string(self, custom_transport):
test of exception put/get of empty strings
"""
# TODO : verify the correctness of \n at the end of a file
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -752,10 +704,6 @@ class TestPutGetTree(unittest.TestCase):
@run_for_all_plugins
def test_put_and_get(self, custom_transport):
"""Test putting and getting files."""
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -807,8 +755,6 @@ def test_put_and_get(self, custom_transport):
self.assertTrue('file.txt' in list_pushed_file)
self.assertTrue('file.txt' in list_retrieved_file)

import shutil

shutil.rmtree(local_subfolder)
shutil.rmtree(retrieved_subfolder)
transport.rmtree(remote_subfolder)
Expand All @@ -819,11 +765,6 @@ def test_put_and_get(self, custom_transport):
@run_for_all_plugins
def test_put_and_get_overwrite(self, custom_transport):
"""Test putting and getting files with overwrites."""
import os
import random
import shutil
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -877,10 +818,6 @@ def test_put_and_get_overwrite(self, custom_transport):
@run_for_all_plugins
def test_copy(self, custom_transport):
"""Test copying."""
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -952,10 +889,6 @@ def test_put(self, custom_transport):
# pylint: disable=too-many-statements
# exactly the same tests of copy, just with the put function
# and therefore the local path must be absolute
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -1033,11 +966,6 @@ def test_get(self, custom_transport):
# pylint: disable=too-many-statements
# exactly the same tests of copy, just with the put function
# and therefore the local path must be absolute
import os
import random
import shutil
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -1119,10 +1047,6 @@ def test_put_get_abs_path(self, custom_transport):
"""
test of exception for non existing files and abs path
"""
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -1194,10 +1118,6 @@ def test_put_get_empty_string(self, custom_transport):
test of exception put/get of empty strings
"""
# TODO : verify the correctness of \n at the end of a file
import os
import random
import string

local_dir = os.path.join('/', 'tmp')
remote_dir = local_dir
directory = 'tmp_try'
Expand Down Expand Up @@ -1263,9 +1183,6 @@ def test_put_get_empty_string(self, custom_transport):
@run_for_all_plugins
def test_gettree_nested_directory(self, custom_transport): # pylint: disable=no-self-use
"""Test `gettree` for a nested directory."""
import os
import tempfile

with tempfile.TemporaryDirectory() as dir_remote, tempfile.TemporaryDirectory() as dir_local:
content = b'dummy\ncontent'
filepath = os.path.join(dir_remote, 'sub', 'path', 'filename.txt')
Expand Down Expand Up @@ -1294,8 +1211,6 @@ def test_exec_pwd(self, custom_transport):
creation (which should be done by paramiko) and in the command
execution (done in this module, in the _exec_command_internal function).
"""
import os

# Start value
delete_at_end = False

Expand Down Expand Up @@ -1365,3 +1280,84 @@ def test_exec_with_wrong_stdin(self, custom_transport):
with custom_transport as transport:
with self.assertRaises(ValueError):
transport.exec_command_wait('cat', stdin=1)


class TestDirectScheduler(unittest.TestCase):
"""
Test how the direct scheduler works.

While this is technically a scheduler test, I put it under the transport tests
because 1) in reality I am testing the interaction of each transport with the
direct scheduler; 2) the direct scheduler is always available; 3) I am reusing
the infrastructure to test on multiple transport plugins.
"""

@run_for_all_plugins
def test_asynchronous_execution(self, custom_transport):
"""Test that the execution of a long(ish) command via the direct scheduler does not block.

This is a regression test for #3094, where running a long job on the direct scheduler
(via SSH) would lock the interpreter until the job was done.
"""
# Use a unique name, using a UUID, to avoid concurrent tests (or very rapid
# tests that follow each other) to overwrite the same destination
script_fname = f'sleep-submit-{uuid.uuid4().hex}-{custom_transport.__class__.__name__}.sh'

scheduler = SchedulerFactory('direct')()
scheduler.set_transport(custom_transport)
with custom_transport as transport:
try:
with tempfile.NamedTemporaryFile() as tmpf:
# Put a submission script that sleeps 10 seconds
tmpf.write(b'#!/bin/bash\nsleep 10\n')
tmpf.flush()

transport.chdir('/tmp')
transport.putfile(tmpf.name, script_fname)

timestamp_before = time.time()
job_id_string = scheduler.submit_from_script('/tmp', script_fname)

elapsed_time = time.time() - timestamp_before
# We want to get back control. If it takes < 5 seconds, it means that it is not blocking
# as the job is taking at least 10 seconds. I put 5 as the machine could be slow (including the
# SSH connection etc.) and I don't want to have false failures.
# Actually, if the time is short, it could mean also that the execution failed!
# So I double check later that the execution was successful.
self.assertTrue(
elapsed_time < 5, 'Getting back control after remote execution took more than 5 seconds! '
'Probably submission is blocking'
)

# Check that the job is still running
# Wait 0.2 more seconds, so that I don't do a super-quick check that might return True
# even if it's not sleeping
time.sleep(0.2)
# Check that the job is still running - IMPORTANT, I'm assuming that all transports actually act
# on the *same* local machine, and that the job_id is actually the process PID.
# This needs to be adapted if:
# - a new transport plugin is tested and this does not test the same machine
# - a new scheduler is used and does not use the process PID, or the job_id of the 'direct' scheduler
# is not anymore simply the job PID
job_id = int(job_id_string)
self.assertTrue(
psutil.pid_exists(job_id), 'The job is not there after a bit more than 1 second! Probably it failed'
)
finally:
# Clean up by killing the remote job.
# This assumes it's on the same machine; if we add tests on a different machine,
# we need to call 'kill' via the transport instead.
# In reality it's not critical to remove it since it will end after 10 seconds of
# sleeping, but this might avoid warnings (e.g. ResourceWarning)
try:
os.kill(job_id, signal.SIGTERM)
except ProcessLookupError:
# If the process is already dead (or has never run), I just ignore the error
pass

# Also remove the script
try:
transport.remove(f'/tmp/{script_fname}')
except FileNotFoundError:
# If the file wasn't even created, I just ignore this error
pass