diff --git a/aiida/transports/plugins/ssh.py b/aiida/transports/plugins/ssh.py index 9f73bdbc14..5f1a765608 100644 --- a/aiida/transports/plugins/ssh.py +++ b/aiida/transports/plugins/ssh.py @@ -1265,7 +1265,7 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1): # if self.getcwd() is not None: escaped_folder = escape_for_bash(self.getcwd()) - command_to_execute = (f'cd {escaped_folder} && {command}') + command_to_execute = (f'cd {escaped_folder} && ( {command} )') else: command_to_execute = command diff --git a/tests/transports/test_all_plugins.py b/tests/transports/test_all_plugins.py index 2920126cd2..7470dd15c9 100644 --- a/tests/transports/test_all_plugins.py +++ b/tests/transports/test_all_plugins.py @@ -15,7 +15,19 @@ Plugin specific tests will be written in the plugin itself. """ import io +import os +import random +import tempfile +import signal +import shutil +import string +import time import unittest +import uuid + +import psutil + +from aiida.plugins import SchedulerFactory # TODO : test for copy with pattern # TODO : test for copy with/without patterns, overwriting folder @@ -35,7 +47,6 @@ def get_all_custom_transports(): it was found) """ import importlib - import os modulename = __name__.rpartition('.')[0] this_full_fname = __file__ @@ -133,11 +144,6 @@ def test_makedirs(self, custom_transport): """ Verify the functioning of makedirs command """ - # Imports required later - import random - import string - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' @@ -176,11 +182,6 @@ def test_rmtree(self, custom_transport): """ Verify the functioning of rmtree command """ - # Imports required later - import random - import string - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' @@ -221,12 +222,6 @@ def test_listdir(self, custom_transport): """ create directories, verify listdir, delete a folder with subfolders """ - # Imports required later - import tempfile - import random - import string - import os - with custom_transport as trans: # We cannot use tempfile.mkdtemp because we're on a remote folder location = trans.normalize(os.path.join('/', 'tmp')) @@ -270,11 +265,6 @@ def test_listdir_withattributes(self, custom_transport): """ create directories, verify listdir_withattributes, delete a folder with subfolders """ - # Imports required later - import tempfile - import random - import string - import os def simplify_attributes(data): """ @@ -340,11 +330,6 @@ def simplify_attributes(data): @run_for_all_plugins def test_dir_creation_deletion(self, custom_transport): """Test creating and deleting directories.""" - # Imports required later - import random - import string - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' @@ -370,11 +355,6 @@ def test_dir_copy(self, custom_transport): Verify if in the copy of a directory also the protection bits are carried over """ - # Imports required later - import random - import string - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' @@ -403,11 +383,6 @@ def test_dir_permissions_creation_modification(self, custom_transport): # pylin verify if chmod raises IOError when trying to change bits on a non-existing folder """ - # Imports required later - import random - import string - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' @@ -460,11 +435,6 @@ def test_dir_reading_permissions(self, custom_transport): Try to enter a directory with no read permissions. Verify that the cwd has not changed after failed try. """ - # Imports required later - import random - import string - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) directory = 'temp_dir_test' @@ -503,8 +473,6 @@ def test_isfile_isdir_to_empty_string(self, custom_transport): I check that isdir or isfile return False when executed on an empty string """ - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) transport.chdir(location) @@ -517,8 +485,6 @@ def test_isfile_isdir_to_non_existing_string(self, custom_transport): I check that isdir or isfile return False when executed on an empty string """ - import os - with custom_transport as transport: location = transport.normalize(os.path.join('/', 'tmp')) transport.chdir(location) @@ -535,8 +501,6 @@ def test_chdir_to_empty_string(self, custom_transport): not change (this is a paramiko default behavior), but getcwd() is still correctly defined. """ - import os - with custom_transport as transport: new_dir = transport.normalize(os.path.join('/', 'tmp')) transport.chdir(new_dir) @@ -555,10 +519,6 @@ class TestPutGetFile(unittest.TestCase): @run_for_all_plugins def test_put_and_get(self, custom_transport): """Test putting and getting files.""" - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -605,10 +565,6 @@ def test_put_get_abs_path(self, custom_transport): """ test of exception for non existing files and abs path """ - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -669,10 +625,6 @@ def test_put_get_empty_string(self, custom_transport): test of exception put/get of empty strings """ # TODO : verify the correctness of \n at the end of a file - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -752,10 +704,6 @@ class TestPutGetTree(unittest.TestCase): @run_for_all_plugins def test_put_and_get(self, custom_transport): """Test putting and getting files.""" - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -807,8 +755,6 @@ def test_put_and_get(self, custom_transport): self.assertTrue('file.txt' in list_pushed_file) self.assertTrue('file.txt' in list_retrieved_file) - import shutil - shutil.rmtree(local_subfolder) shutil.rmtree(retrieved_subfolder) transport.rmtree(remote_subfolder) @@ -819,11 +765,6 @@ def test_put_and_get(self, custom_transport): @run_for_all_plugins def test_put_and_get_overwrite(self, custom_transport): """Test putting and getting files with overwrites.""" - import os - import random - import shutil - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -877,10 +818,6 @@ def test_put_and_get_overwrite(self, custom_transport): @run_for_all_plugins def test_copy(self, custom_transport): """Test copying.""" - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -952,10 +889,6 @@ def test_put(self, custom_transport): # pylint: disable=too-many-statements # exactly the same tests of copy, just with the put function # and therefore the local path must be absolute - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -1033,11 +966,6 @@ def test_get(self, custom_transport): # pylint: disable=too-many-statements # exactly the same tests of copy, just with the put function # and therefore the local path must be absolute - import os - import random - import shutil - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -1119,10 +1047,6 @@ def test_put_get_abs_path(self, custom_transport): """ test of exception for non existing files and abs path """ - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -1194,10 +1118,6 @@ def test_put_get_empty_string(self, custom_transport): test of exception put/get of empty strings """ # TODO : verify the correctness of \n at the end of a file - import os - import random - import string - local_dir = os.path.join('/', 'tmp') remote_dir = local_dir directory = 'tmp_try' @@ -1263,9 +1183,6 @@ def test_put_get_empty_string(self, custom_transport): @run_for_all_plugins def test_gettree_nested_directory(self, custom_transport): # pylint: disable=no-self-use """Test `gettree` for a nested directory.""" - import os - import tempfile - with tempfile.TemporaryDirectory() as dir_remote, tempfile.TemporaryDirectory() as dir_local: content = b'dummy\ncontent' filepath = os.path.join(dir_remote, 'sub', 'path', 'filename.txt') @@ -1294,8 +1211,6 @@ def test_exec_pwd(self, custom_transport): creation (which should be done by paramiko) and in the command execution (done in this module, in the _exec_command_internal function). """ - import os - # Start value delete_at_end = False @@ -1365,3 +1280,84 @@ def test_exec_with_wrong_stdin(self, custom_transport): with custom_transport as transport: with self.assertRaises(ValueError): transport.exec_command_wait('cat', stdin=1) + + +class TestDirectScheduler(unittest.TestCase): + """ + Test how the direct scheduler works. + + While this is technically a scheduler test, I put it under the transport tests + because 1) in reality I am testing the interaction of each transport with the + direct scheduler; 2) the direct scheduler is always available; 3) I am reusing + the infrastructure to test on multiple transport plugins. + """ + + @run_for_all_plugins + def test_asynchronous_execution(self, custom_transport): + """Test that the execution of a long(ish) command via the direct scheduler does not block. + + This is a regression test for #3094, where running a long job on the direct scheduler + (via SSH) would lock the interpreter until the job was done. + """ + # Use a unique name, using a UUID, to avoid concurrent tests (or very rapid + # tests that follow each other) to overwrite the same destination + script_fname = f'sleep-submit-{uuid.uuid4().hex}-{custom_transport.__class__.__name__}.sh' + + scheduler = SchedulerFactory('direct')() + scheduler.set_transport(custom_transport) + with custom_transport as transport: + try: + with tempfile.NamedTemporaryFile() as tmpf: + # Put a submission script that sleeps 10 seconds + tmpf.write(b'#!/bin/bash\nsleep 10\n') + tmpf.flush() + + transport.chdir('/tmp') + transport.putfile(tmpf.name, script_fname) + + timestamp_before = time.time() + job_id_string = scheduler.submit_from_script('/tmp', script_fname) + + elapsed_time = time.time() - timestamp_before + # We want to get back control. If it takes < 5 seconds, it means that it is not blocking + # as the job is taking at least 10 seconds. I put 5 as the machine could be slow (including the + # SSH connection etc.) and I don't want to have false failures. + # Actually, if the time is short, it could mean also that the execution failed! + # So I double check later that the execution was successful. + self.assertTrue( + elapsed_time < 5, 'Getting back control after remote execution took more than 5 seconds! ' + 'Probably submission is blocking' + ) + + # Check that the job is still running + # Wait 0.2 more seconds, so that I don't do a super-quick check that might return True + # even if it's not sleeping + time.sleep(0.2) + # Check that the job is still running - IMPORTANT, I'm assuming that all transports actually act + # on the *same* local machine, and that the job_id is actually the process PID. + # This needs to be adapted if: + # - a new transport plugin is tested and this does not test the same machine + # - a new scheduler is used and does not use the process PID, or the job_id of the 'direct' scheduler + # is not anymore simply the job PID + job_id = int(job_id_string) + self.assertTrue( + psutil.pid_exists(job_id), 'The job is not there after a bit more than 1 second! Probably it failed' + ) + finally: + # Clean up by killing the remote job. + # This assumes it's on the same machine; if we add tests on a different machine, + # we need to call 'kill' via the transport instead. + # In reality it's not critical to remove it since it will end after 10 seconds of + # sleeping, but this might avoid warnings (e.g. ResourceWarning) + try: + os.kill(job_id, signal.SIGTERM) + except ProcessLookupError: + # If the process is already dead (or has never run), I just ignore the error + pass + + # Also remove the script + try: + transport.remove(f'/tmp/{script_fname}') + except FileNotFoundError: + # If the file wasn't even created, I just ignore this error + pass