diff --git a/aiida/transports/plugins/local.py b/aiida/transports/plugins/local.py index 0448187536..1a563e1f21 100644 --- a/aiida/transports/plugins/local.py +++ b/aiida/transports/plugins/local.py @@ -743,10 +743,7 @@ def _exec_command_internal(self, command, **kwargs): # pylint: disable=unused-a # Note: The outer shell will eat one level of escaping, while # 'bash -l -c ...' will eat another. Thus, we need to escape again. - if self._use_login_shell: - bash_commmand = 'bash -l -c ' - else: - bash_commmand = 'bash -c ' + bash_commmand = self._bash_command_str + '-c ' command = bash_commmand + escape_for_bash(command) @@ -807,11 +804,8 @@ def gotocomputer_command(self, remotedir): :param str remotedir: the full path of the remote directory """ - script = ' ; '.join([ - 'if [ -d {escaped_remotedir} ]', 'then cd {escaped_remotedir}', 'bash', "else echo ' ** The directory'", - "echo ' ** {remotedir}'", "echo ' ** seems to have been deleted, I logout...'", 'fi' - ]).format(escaped_remotedir="'{}'".format(remotedir), remotedir=remotedir) - cmd = 'bash -c "{}"'.format(script) + connect_string = self._gotocomputer_string(remotedir) + cmd = 'bash -c {}'.format(connect_string) return cmd def rename(self, oldpath, newpath): diff --git a/aiida/transports/plugins/ssh.py b/aiida/transports/plugins/ssh.py index efa723504c..1482d49b60 100644 --- a/aiida/transports/plugins/ssh.py +++ b/aiida/transports/plugins/ssh.py @@ -310,13 +310,6 @@ def _get_compress_suggestion_string(cls, computer): # pylint: disable=unused-ar """ return 'True' - @classmethod - def _get_use_login_shell_suggestion_string(cls, computer): # pylint: disable=unused-argument - """ - Return a suggestion for the specific field. - """ - return 'True' - @classmethod def _get_load_system_host_keys_suggestion_string(cls, computer): # pylint: disable=unused-argument """ @@ -368,8 +361,6 @@ def __init__(self, *args, **kwargs): Initialize the SshTransport class. :param machine: the machine to connect to - :param use_login_shell: (optional, default True) - if False, do not use a login shell when executing command :param load_system_host_keys: (optional, default False) if False, do not load the system host keys :param key_policy: (optional, default = paramiko.RejectPolicy()) @@ -1241,10 +1232,7 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1): # # Note: The default shell will eat one level of escaping, while # 'bash -l -c ...' will eat another. Thus, we need to escape again. - if self._use_login_shell: - bash_commmand = 'bash -l -c ' - else: - bash_commmand = 'bash -c ' + bash_commmand = self._bash_command_str + '-c ' channel.exec_command(bash_commmand + escape_for_bash(command_to_execute)) @@ -1312,21 +1300,14 @@ def gotocomputer_command(self, remotedir): further_params.append('-i {}'.format(escape_for_bash(self._connect_args['key_filename']))) further_params_str = ' '.join(further_params) - # I use triple strings because I both have single and double quotes, but I still want everything in - # a single line - connect_string = ( - """ssh -t {machine} {further_params} "if [ -d {escaped_remotedir} ] ;""" - """ then cd {escaped_remotedir} ; bash -l ; else echo ' ** The directory' ; """ - """echo ' ** {remotedir}' ; echo ' ** seems to have been deleted, I logout...' ; fi" """.format( - further_params=further_params_str, - machine=self._machine, - escaped_remotedir="'{}'".format(remotedir), - remotedir=remotedir - ) - ) - # print connect_string - return connect_string + connect_string = self._gotocomputer_string(remotedir) + cmd = 'ssh -t {machine} {further_params} {connect_string}'.format( + further_params=further_params_str, + machine=self._machine, + connect_string=connect_string, + ) + return cmd def symlink(self, remotesource, remotedestination): """ diff --git a/aiida/transports/transport.py b/aiida/transports/transport.py index 2e01dc939d..c51c577a42 100644 --- a/aiida/transports/transport.py +++ b/aiida/transports/transport.py @@ -77,10 +77,20 @@ class Transport(abc.ABC): def __init__(self, *args, **kwargs): # pylint: disable=unused-argument """ __init__ method of the Transport base class. + + :param safe_interval: (optional, default self._DEFAULT_SAFE_OPEN_INTERVAL) + Minimum time interval in seconds between opening new connections. + :param use_login_shell: (optional, default True) + if False, do not use a login shell when executing command """ from aiida.common import AIIDA_LOGGER self._safe_open_interval = kwargs.pop('safe_interval', self._DEFAULT_SAFE_OPEN_INTERVAL) self._use_login_shell = kwargs.pop('use_login_shell', True) + if self._use_login_shell: + self._bash_command_str = 'bash -l ' + else: + self._bash_command_str = 'bash ' + self._logger = AIIDA_LOGGER.getChild('transport').getChild(self.__class__.__name__) self._logger_extra = None self._is_open = False @@ -212,6 +222,13 @@ def _get_safe_interval_suggestion_string(cls, computer): # pylint: disable=unus """ return cls._DEFAULT_SAFE_OPEN_INTERVAL + @classmethod + def _get_use_login_shell_suggestion_string(cls, computer): # pylint: disable=unused-argument + """ + Return a suggestion for the specific field. + """ + return 'True' + @property def logger(self): """ @@ -775,6 +792,18 @@ def glob0(self, dirname, basename): def has_magic(self, string): return self._MAGIC_CHECK.search(string) is not None + def _gotocomputer_string(self, remotedir): + """command executed when goto computer.""" + connect_string = ( + """ "if [ -d {escaped_remotedir} ] ;""" + """ then cd {escaped_remotedir} ; {bash_command} ; else echo ' ** The directory' ; """ + """echo ' ** {remotedir}' ; echo ' ** seems to have been deleted, I logout...' ; fi" """.format( + bash_command=self._bash_command_str, escaped_remotedir="'{}'".format(remotedir), remotedir=remotedir + ) + ) + + return connect_string + class TransportInternalError(InternalError): """ diff --git a/tests/transports/test_local.py b/tests/transports/test_local.py index 35b3e247f1..fa88e00468 100644 --- a/tests/transports/test_local.py +++ b/tests/transports/test_local.py @@ -48,3 +48,16 @@ def test_basic(): """Test constructor.""" with LocalTransport(): pass + + +def test_gotocomputer(): + """Test gotocomputer""" + with LocalTransport() as transport: + cmd_str = transport.gotocomputer_command('/remote_dir/') + + expected_str = ( + """bash -c "if [ -d '/remote_dir/' ] ;""" + """ then cd '/remote_dir/' ; bash -l ; else echo ' ** The directory' ; """ + """echo ' ** /remote_dir/' ; echo ' ** seems to have been deleted, I logout...' ; fi" """ + ) + assert cmd_str == expected_str diff --git a/tests/transports/test_ssh.py b/tests/transports/test_ssh.py index 2b1e083cf3..6c8c713622 100644 --- a/tests/transports/test_ssh.py +++ b/tests/transports/test_ssh.py @@ -55,3 +55,16 @@ def test_no_host_key(self): # Reset logging level logging.disable(logging.NOTSET) + + +def test_gotocomputer(): + """Test gotocomputer""" + with SshTransport(machine='localhost', timeout=30, use_login_shell=False, key_policy='AutoAddPolicy') as transport: + cmd_str = transport.gotocomputer_command('/remote_dir/') + + expected_str = ( + """ssh -t localhost "if [ -d '/remote_dir/' ] ;""" + """ then cd '/remote_dir/' ; bash ; else echo ' ** The directory' ; """ + """echo ' ** /remote_dir/' ; echo ' ** seems to have been deleted, I logout...' ; fi" """ + ) + assert cmd_str == expected_str