Skip to content

Commit

Permalink
WIP: Migrate Transport to use pydantic for configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
sphuber committed May 30, 2024
1 parent 06f8f4c commit aa42fd7
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 61 deletions.
110 changes: 55 additions & 55 deletions src/aiida/cmdline/commands/cmd_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

import click

from aiida.cmdline.commands.cmd_verdi import VerdiCommandGroup, verdi
from aiida.cmdline.commands.cmd_verdi import verdi
from aiida.cmdline.groups.dynamic import DynamicEntryPointCommandGroup
from aiida.cmdline.params import arguments, options
from aiida.cmdline.params.options.commands import computer as options_computer
from aiida.cmdline.utils import echo, echo_tabulate
from aiida.cmdline.utils.decorators import with_dbenv
from aiida.common.exceptions import EntryPointError, ValidationError
from aiida.plugins.entry_point import get_entry_point_names
from aiida.common.exceptions import ValidationError


@verdi.group('computer')
Expand Down Expand Up @@ -655,25 +655,40 @@ def _dry_run_callback(pks):
echo.echo_success(f'Computer `{label}` {"and all its associated nodes" if node_pks else ""} deleted.')


class LazyConfigureGroup(VerdiCommandGroup):
"""A click group that will lazily load the subcommands for each transport plugin."""
def configure_computer(ctx: click.Context, cls, non_interactive: bool, **kwargs): # pylint: disable=unused-argument
"""Configure a `Computer` instance."""
from aiida import orm

def list_commands(self, ctx):
subcommands = super().list_commands(ctx)
subcommands.extend(get_entry_point_names('aiida.transports'))
return subcommands
user = kwargs.pop('user', None) or orm.User.collection.get_default()
computer = kwargs.pop('computer')

def get_command(self, ctx, name):
from aiida.transports import cli as transport_cli
echo.echo_report(f'Configuring computer {computer.label} for user {user.email}.')
if not user.is_default:
echo.echo_report('Configuring different user, defaults may not be appropriate.')

try:
command = transport_cli.create_configure_cmd(name)
except EntryPointError:
command = super().get_command(ctx, name)
return command
computer.configure(user=user, **kwargs)
echo.echo_success(f'{computer.label} successfully configured for {user.email}')


@verdi_computer.group('configure', cls=LazyConfigureGroup)
def validate_transport(ctx, _, computer):
"""Validate that the transport of the computer matches that of the command."""
if computer.transport_type != ctx.command.name:
echo.echo_critical(
f'Transport of computer {computer.label} is `{computer.transport_type}` and not `{ctx.command.name}`.'
)
return computer


@verdi_computer.group(
'configure',
cls=DynamicEntryPointCommandGroup,
command=configure_computer,
entry_point_group='aiida.transports',
shared_options=[
options.USER(),
arguments.COMPUTER(callback=validate_transport),
],
)
def computer_configure():
"""Configure the transport for a computer and user."""

Expand All @@ -690,48 +705,33 @@ def computer_configure():
def computer_config_show(computer, user, defaults, as_option_string):
"""Show the current configuration for a computer."""
from aiida.common.escaping import escape_for_bash
from aiida.transports import cli as transport_cli

transport_cls = computer.get_transport_class()
option_list = [
param
for param in transport_cli.create_configure_cmd(computer.transport_type).params
if isinstance(param, click.core.Option)
]
option_list = [option for option in option_list if option.name in transport_cls.get_valid_auth_params()]
configuration = computer.get_configuration(user)
model = transport_cls.Model(**configuration)

if defaults:
config = {option.name: transport_cli.transport_option_default(option.name, computer) for option in option_list}
else:
config = computer.get_configuration(user)

option_items = []
if as_option_string:
for option in option_list:
t_opt = transport_cls.auth_options[option.name]
if config.get(option.name) or config.get(option.name) is False:
if t_opt.get('switch'):
option_value = (
option.opts[-1] if config.get(option.name) else f"--no-{option.name.replace('_', '-')}"
)
elif t_opt.get('is_flag'):
is_default = config.get(option.name) == transport_cli.transport_option_default(
option.name, computer
)
option_value = option.opts[-1] if is_default else ''
else:
option_value = f'{option.opts[-1]}={option.type(config[option.name])}'
option_items.append(option_value)
opt_string = ' '.join(option_items)
echo.echo(escape_for_bash(opt_string))
else:
table = []
for name in transport_cls.get_valid_auth_params():
if name in config:
table.append((f'* {name}', config[name]))
if not as_option_string:
echo_tabulate(list(model.model_dump().items()), tablefmt='plain')
return

option_list = []

for key, value in model.model_dump().items():
if value is None or value == '':
continue

if model.model_fields[key].annotation is bool:
if value:
option_list.append(f'--{key.replace("_", "-")}')
else:
table.append((f'* {name}', '-'))
echo_tabulate(table, tablefmt='plain')
option_list.append(f'--no-{key.replace("_", "-")}')
else:
try:
option_list.append(f'--{key.replace("_", "-")}={value.value}')
except AttributeError:
option_list.append(f'--{key.replace("_", "-")}={value}')

echo.echo(escape_for_bash(' '.join(option_list)))


@verdi_computer.group('export')
Expand Down
8 changes: 6 additions & 2 deletions src/aiida/cmdline/groups/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import enum
import functools
import re
import typing as t
Expand Down Expand Up @@ -196,6 +197,9 @@ def list_options(self, entry_point: str) -> list:
for metadata_key, metadata_value in metadata.items():
options_spec[key][metadata_key] = metadata_value

if issubclass(field_type, enum.Enum):
options_spec[key]['type'] = click.Choice([e.value for e in field_type])

options_ordered = []

for name, spec in sorted(options_spec.items(), key=lambda x: x[1].get('priority', 0), reverse=True):
Expand All @@ -207,13 +211,13 @@ def list_options(self, entry_point: str) -> list:
@staticmethod
def create_option(name, spec: dict) -> t.Callable[[t.Any], t.Any]:
"""Create a click option from a name and a specification."""
is_flag = spec.pop('is_flag', False)
is_flag = spec.get('is_flag', False)
name_dashed = name.replace('_', '-')
option_name = f'--{name_dashed}/--no-{name_dashed}' if is_flag else f'--{name_dashed}'
option_short_name = spec.pop('short_name', None)
option_names = (option_short_name, option_name) if option_short_name else (option_name,)

kwargs = {'cls': spec.pop('option_cls', InteractiveOption), 'show_default': True, 'is_flag': is_flag, **spec}
kwargs = {'cls': spec.pop('option_cls', InteractiveOption), 'show_default': True, **spec}

# If the option is a flag with no default, make sure it is not prompted for, as that will force the user to
# specify it to be on or off, but cannot let it unspecified.
Expand Down
82 changes: 81 additions & 1 deletion src/aiida/transports/plugins/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,35 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Plugin for transport over SSH (and SFTP for file transfer)."""
from __future__ import annotations

import enum
import getpass
import glob
import io
import os
import re
from stat import S_ISDIR, S_ISREG

import click
from pydantic import ConfigDict

from aiida.cmdline.params import options
from aiida.cmdline.params.types.path import AbsolutePathOrEmptyParamType
from aiida.common.escaping import escape_for_bash
from aiida.common.pydantic import MetadataField

from ..transport import Transport, TransportInternalError

__all__ = ('parse_sshconfig', 'convert_to_bool', 'SshTransport')


class KeyPolicy(str, enum.Enum):
REJECT = 'RejectPolicy'
WARNING = 'WarningPolicy'
AUTOADD = 'AutoAddPolicy'


def parse_sshconfig(computername):
"""Return the ssh configuration for a given computer name.
Expand Down Expand Up @@ -64,6 +75,75 @@ def convert_to_bool(string):
class SshTransport(Transport):
"""Support connection, command execution and data transfer to remote computers via SSH+SFTP."""

class Model(Transport.Model):
"""Model describing required information to create an instance."""

model_config = ConfigDict(use_enum_values=True)

safe_interval: float = MetadataField(
30.0,
title='Connection cooldown time (s)',
description='Minimum time interval in seconds between opening new connections.',
)
username: str = MetadataField(
getpass.getuser(), title='User name', description='Login user name on the remote machine.'
)
port: int = MetadataField(22, title='Port number', description='Port number.', short_name='-P')
look_for_keys: bool = MetadataField(
True, title='Look for keys', description='Automatically look for private keys in the ~/.ssh folder.'
)
key_filename: str = MetadataField(
'',
title='SSH key file',
description='Absolute path to your private SSH key. Leave empty to use the path set in the SSH config.',
)
timeout: int = MetadataField(
60,
title='Connection timeout (s)',
description='Time in seconds to wait for connection before giving up. Leave empty to use default value.',
)
allow_agent: bool = MetadataField(
False, title='Allow ssh agent', description='Switch to allow or disallow using an SSH agent.'
)
proxy_jump: str = MetadataField(
'',
title='SSH proxy jump',
description='SSH proxy jump for tunneling through other SSH hosts. Use a comma-separated list of hosts of '
'the form [user@]host[:port]. If user or port are not specified for a host, the user & port values from '
'the target host are used. This option must be provided explicitly and is not parsed from the SSH config '
'file when left empty.',
)
proxy_command: str = MetadataField(
'',
title='SSH proxy command',
description='SSH proxy command for tunneling through a proxy server. For tunneling through another SSH '
'host, consider using the "SSH proxy jump" option instead! Leave empty to parse the proxy command from the '
'SSH config file.',
)
compress: bool = MetadataField(
True, title='Compress file transfers', description='Turn file transfer compression on or off.'
)
gss_auth: bool = MetadataField(
False, title='GSS auth', description='Enable when using GSS kerberos token to connect.'
)
gss_kex: bool = MetadataField(
False, title='GSS kex', description='GSS kex for kerberos, if not configured in SSH config file.'
)
gss_deleg_creds: bool = MetadataField(
False,
title='GSS deleg_creds',
description='GSS deleg_creds for kerberos, if not configured in SSH config file.',
)
gss_host: str = MetadataField(
'', title='GSS host', description='GSS host for kerberos, if not configured in SSH config file.'
)
load_system_host_keys: bool = MetadataField(
True, title='Load system host keys', description='Load system host keys from default SSH location.'
)
key_policy: KeyPolicy = MetadataField(
KeyPolicy.REJECT, title='Key policy', description='SSH key policy if host is not known.'
)

# Valid keywords accepted by the connect method of paramiko.SSHClient
# I disable 'password' and 'pkey' to avoid these data to get logged in the
# aiida log file.
Expand Down Expand Up @@ -235,8 +315,8 @@ def _get_username_suggestion_string(cls, computer):
"""Return a suggestion for the specific field."""
import getpass

config = parse_sshconfig(computer.hostname)
# Either the configured user in the .ssh/config, or the current username
config = parse_sshconfig(computer.hostname)
return str(config.get('user', getpass.getuser()))

@classmethod
Expand Down
18 changes: 18 additions & 0 deletions src/aiida/transports/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
import sys
from collections import OrderedDict

from pydantic import BaseModel

from aiida.common.exceptions import InternalError
from aiida.common.lang import classproperty
from aiida.common.pydantic import MetadataField

__all__ = ('Transport',)

Expand Down Expand Up @@ -76,6 +79,21 @@ class Transport(abc.ABC):
),
]

class Model(BaseModel):
"""Model describing required information to create an instance."""

use_login_shell: bool = MetadataField(
True,
title='Use login shell when executing command',
description='Not using a login shell can help suppress potential spurious text output that can prevent '
'AiiDA from parsing the output of commands, but may result in some startup files not being sourced.',
)
safe_interval: float = MetadataField(
0.0,
title='Connection cooldown time (s)',
description='Minimum time interval in seconds between opening new connections.',
)

def __init__(self, *args, **kwargs):
"""__init__ method of the Transport base class.
Expand Down
6 changes: 3 additions & 3 deletions tests/cmdline/commands/test_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def test_ssh_interactive(self):
# I just pass the first four arguments:
# the username, the port, look_for_keys, and the key_filename
# This testing also checks that an empty key_filename is ok
command_input = f"{remote_username}\n{port}\n{'yes' if look_for_keys else 'no'}\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n"
command_input = f"\n\n{remote_username}\n{port}\n{'yes' if look_for_keys else 'no'}\n\n\n\n\n\n\n\n\n\n\n\n\n\n"

result = self.cli_runner(computer_configure, ['core.ssh', comp.label], user_input=command_input)
assert comp.is_configured, result.output
Expand Down Expand Up @@ -497,15 +497,15 @@ def test_show(self):
result = self.cli_runner(computer_configure, ['show', comp.label])

result = self.cli_runner(computer_configure, ['show', comp.label, '--defaults'])
assert '* username' in result.output
assert 'username' in result.output

result = self.cli_runner(
computer_configure, ['show', comp.label, '--defaults', '--as-option-string'], suppress_warnings=True
)
assert '--username=' in result.output

config_cmd = ['core.ssh', comp.label, '--non-interactive']
config_cmd.extend(result.output.replace("'", '').split(' '))
config_cmd.extend(result.output.strip().replace("'", '').split(' '))
result_config = self.cli_runner(computer_configure, config_cmd, suppress_warnings=True)
assert comp.is_configured, result_config.output

Expand Down

0 comments on commit aa42fd7

Please sign in to comment.