Skip to content

Commit

Permalink
Avoid repeat _bash_command and gotocomputer_command
Browse files Browse the repository at this point in the history
`bash_command` are used in three places and set by `use_login_shell`.
Encapsulate it as attribute `_bash_command_str` of `Transport`.

gotocompyter_command have same string for local and ssh transport
plugin, Encapsulate the string as `_gotocomputer_string` in `Transport`
class.
  • Loading branch information
unkcpz committed Jul 20, 2020
1 parent d88fc5e commit dd5afaa
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 36 deletions.
12 changes: 3 additions & 9 deletions aiida/transports/plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,10 +743,7 @@ def _exec_command_internal(self, command, **kwargs): # pylint: disable=unused-a

# Note: The outer shell will eat one level of escaping, while
# 'bash -l -c ...' will eat another. Thus, we need to escape again.
if self._use_login_shell:
bash_commmand = 'bash -l -c '
else:
bash_commmand = 'bash -c '
bash_commmand = self._bash_command_str + '-c '

command = bash_commmand + escape_for_bash(command)

Expand Down Expand Up @@ -807,11 +804,8 @@ def gotocomputer_command(self, remotedir):
:param str remotedir: the full path of the remote directory
"""
script = ' ; '.join([
'if [ -d {escaped_remotedir} ]', 'then cd {escaped_remotedir}', 'bash', "else echo ' ** The directory'",
"echo ' ** {remotedir}'", "echo ' ** seems to have been deleted, I logout...'", 'fi'
]).format(escaped_remotedir="'{}'".format(remotedir), remotedir=remotedir)
cmd = 'bash -c "{}"'.format(script)
connect_string = self._gotocomputer_string(remotedir)
cmd = 'bash -c {}'.format(connect_string)
return cmd

def rename(self, oldpath, newpath):
Expand Down
35 changes: 8 additions & 27 deletions aiida/transports/plugins/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,13 +310,6 @@ def _get_compress_suggestion_string(cls, computer): # pylint: disable=unused-ar
"""
return 'True'

@classmethod
def _get_use_login_shell_suggestion_string(cls, computer): # pylint: disable=unused-argument
"""
Return a suggestion for the specific field.
"""
return 'True'

@classmethod
def _get_load_system_host_keys_suggestion_string(cls, computer): # pylint: disable=unused-argument
"""
Expand Down Expand Up @@ -368,8 +361,6 @@ def __init__(self, *args, **kwargs):
Initialize the SshTransport class.
:param machine: the machine to connect to
:param use_login_shell: (optional, default True)
if False, do not use a login shell when executing command
:param load_system_host_keys: (optional, default False)
if False, do not load the system host keys
:param key_policy: (optional, default = paramiko.RejectPolicy())
Expand Down Expand Up @@ -1241,10 +1232,7 @@ def _exec_command_internal(self, command, combine_stderr=False, bufsize=-1): #

# Note: The default shell will eat one level of escaping, while
# 'bash -l -c ...' will eat another. Thus, we need to escape again.
if self._use_login_shell:
bash_commmand = 'bash -l -c '
else:
bash_commmand = 'bash -c '
bash_commmand = self._bash_command_str + '-c '

channel.exec_command(bash_commmand + escape_for_bash(command_to_execute))

Expand Down Expand Up @@ -1312,21 +1300,14 @@ def gotocomputer_command(self, remotedir):
further_params.append('-i {}'.format(escape_for_bash(self._connect_args['key_filename'])))

further_params_str = ' '.join(further_params)
# I use triple strings because I both have single and double quotes, but I still want everything in
# a single line
connect_string = (
"""ssh -t {machine} {further_params} "if [ -d {escaped_remotedir} ] ;"""
""" then cd {escaped_remotedir} ; bash -l ; else echo ' ** The directory' ; """
"""echo ' ** {remotedir}' ; echo ' ** seems to have been deleted, I logout...' ; fi" """.format(
further_params=further_params_str,
machine=self._machine,
escaped_remotedir="'{}'".format(remotedir),
remotedir=remotedir
)
)

# print connect_string
return connect_string
connect_string = self._gotocomputer_string(remotedir)
cmd = 'ssh -t {machine} {further_params} {connect_string}'.format(
further_params=further_params_str,
machine=self._machine,
connect_string=connect_string,
)
return cmd

def symlink(self, remotesource, remotedestination):
"""
Expand Down
29 changes: 29 additions & 0 deletions aiida/transports/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,20 @@ class Transport(abc.ABC):
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
"""
__init__ method of the Transport base class.
:param safe_interval: (optional, default self._DEFAULT_SAFE_OPEN_INTERVAL)
Minimum time interval in seconds between opening new connections.
:param use_login_shell: (optional, default True)
if False, do not use a login shell when executing command
"""
from aiida.common import AIIDA_LOGGER
self._safe_open_interval = kwargs.pop('safe_interval', self._DEFAULT_SAFE_OPEN_INTERVAL)
self._use_login_shell = kwargs.pop('use_login_shell', True)
if self._use_login_shell:
self._bash_command_str = 'bash -l '
else:
self._bash_command_str = 'bash '

self._logger = AIIDA_LOGGER.getChild('transport').getChild(self.__class__.__name__)
self._logger_extra = None
self._is_open = False
Expand Down Expand Up @@ -212,6 +222,13 @@ def _get_safe_interval_suggestion_string(cls, computer): # pylint: disable=unus
"""
return cls._DEFAULT_SAFE_OPEN_INTERVAL

@classmethod
def _get_use_login_shell_suggestion_string(cls, computer): # pylint: disable=unused-argument
"""
Return a suggestion for the specific field.
"""
return 'True'

@property
def logger(self):
"""
Expand Down Expand Up @@ -775,6 +792,18 @@ def glob0(self, dirname, basename):
def has_magic(self, string):
return self._MAGIC_CHECK.search(string) is not None

def _gotocomputer_string(self, remotedir):
"""command executed when goto computer."""
connect_string = (
""" "if [ -d {escaped_remotedir} ] ;"""
""" then cd {escaped_remotedir} ; {bash_command} ; else echo ' ** The directory' ; """
"""echo ' ** {remotedir}' ; echo ' ** seems to have been deleted, I logout...' ; fi" """.format(
bash_command=self._bash_command_str, escaped_remotedir="'{}'".format(remotedir), remotedir=remotedir
)
)

return connect_string


class TransportInternalError(InternalError):
"""
Expand Down
13 changes: 13 additions & 0 deletions tests/transports/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,16 @@ def test_basic():
"""Test constructor."""
with LocalTransport():
pass


def test_gotocomputer():
"""Test gotocomputer"""
with LocalTransport() as transport:
cmd_str = transport.gotocomputer_command('/remote_dir/')

expected_str = (
"""bash -c "if [ -d '/remote_dir/' ] ;"""
""" then cd '/remote_dir/' ; bash -l ; else echo ' ** The directory' ; """
"""echo ' ** /remote_dir/' ; echo ' ** seems to have been deleted, I logout...' ; fi" """
)
assert cmd_str == expected_str
13 changes: 13 additions & 0 deletions tests/transports/test_ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,16 @@ def test_no_host_key(self):

# Reset logging level
logging.disable(logging.NOTSET)


def test_gotocomputer():
"""Test gotocomputer"""
with SshTransport(machine='localhost', timeout=30, use_login_shell=False, key_policy='AutoAddPolicy') as transport:
cmd_str = transport.gotocomputer_command('/remote_dir/')

expected_str = (
"""ssh -t localhost "if [ -d '/remote_dir/' ] ;"""
""" then cd '/remote_dir/' ; bash ; else echo ' ** The directory' ; """
"""echo ' ** /remote_dir/' ; echo ' ** seems to have been deleted, I logout...' ; fi" """
)
assert cmd_str == expected_str

0 comments on commit dd5afaa

Please sign in to comment.