Skip to content

Commit

Permalink
Add auto-complete support for CodeParamType and GroupParamType
Browse files Browse the repository at this point in the history
This will enable auto-completion for existing `Code` and `Group`
instances by their label for commands that have options or arguments
with the corresponding parameter type.
  • Loading branch information
sphuber committed Apr 15, 2020
1 parent 5479bbc commit 789fdee
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 104 deletions.
1 change: 0 additions & 1 deletion aiida/cmdline/params/types/choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def _click_choice(self):
"""
if self.__click_choice is None:
self.__click_choice = click.Choice(self._get_choices())
# self._get_choices = None
return self.__click_choice

@property
Expand Down
11 changes: 10 additions & 1 deletion aiida/cmdline/params/types/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Module to define the custom click type for code."""

import click

from aiida.cmdline.utils import decorators
from .identifier import IdentifierParamType


Expand Down Expand Up @@ -40,6 +41,14 @@ def orm_class_loader(self):
from aiida.orm.utils.loaders import CodeEntityLoader
return CodeEntityLoader

@decorators.with_dbenv()
def complete(self, ctx, incomplete): # pylint: disable=unused-argument
"""Return possible completions based on an incomplete value.
:returns: list of tuples of valid entry points (matching incomplete) and a description
"""
return [(option, '') for option, in self.orm_class_loader.get_options(incomplete, project='label')]

def convert(self, value, param, ctx):
code = super().convert(value, param, ctx)

Expand Down
12 changes: 10 additions & 2 deletions aiida/cmdline/params/types/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import click

from aiida.common.lang import type_check
from aiida.cmdline.utils.decorators import with_dbenv
from aiida.cmdline.utils import decorators

from .identifier import IdentifierParamType

Expand Down Expand Up @@ -56,7 +56,15 @@ def orm_class_loader(self):
from aiida.orm.utils.loaders import GroupEntityLoader
return GroupEntityLoader

@with_dbenv()
@decorators.with_dbenv()
def complete(self, ctx, incomplete): # pylint: disable=unused-argument
"""Return possible completions based on an incomplete value.
:returns: list of tuples of valid entry points (matching incomplete) and a description
"""
return [(option, '') for option, in self.orm_class_loader.get_options(incomplete, project='label')]

@decorators.with_dbenv()
def convert(self, value, param, ctx):
try:
group = super().convert(value, param, ctx)
Expand Down
8 changes: 6 additions & 2 deletions aiida/orm/utils/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,15 +456,19 @@ def _get_query_builder_label_identifier(cls, identifier, classes, operator='==',
:raises ValueError: if the identifier is invalid
:raises aiida.common.NotExistent: if the orm base class does not support a LABEL like identifier
"""
from aiida.common.escaping import escape_for_sql_like
from aiida.orm import Computer

try:
label, _, machinename = identifier.partition('@')
identifier, _, machinename = identifier.partition('@')
except AttributeError:
raise ValueError('the identifier needs to be a string')

if operator == 'like':
identifier = escape_for_sql_like(identifier) + '%'

builder = QueryBuilder()
builder.append(cls=classes, tag='code', project=project, filters={'label': {'==': label}})
builder.append(cls=classes, tag='code', project=project, filters={'label': {operator: identifier}})

if machinename:
builder.append(Computer, filters={'name': {'==': machinename}}, with_node='code')
Expand Down
212 changes: 114 additions & 98 deletions tests/cmdline/params/types/test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,109 +7,125 @@
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
# pylint: disable=redefined-outer-name,unused-variable,unused-argument
"""Tests for the `CodeParamType`."""

import click
import pytest

from aiida.backends.testbase import AiidaTestCase
from aiida.cmdline.params.types import CodeParamType
from aiida.orm import Code
from aiida.orm.utils.loaders import OrmEntityLoader


class TestCodeParamType(AiidaTestCase):
"""Tests for the `CodeParamType`."""

@classmethod
def setUpClass(cls, *args, **kwargs):
"""
Create some code to test the CodeParamType parameter type for the command line infrastructure
We create an initial code with a random name and then on purpose create two code with a name
that matches exactly the ID and UUID, respectively, of the first one. This allows us to test
the rules implemented to solve ambiguities that arise when determing the identifier type
"""
super().setUpClass(*args, **kwargs)

cls.param_base = CodeParamType()
cls.param_entry_point = CodeParamType(entry_point='arithmetic.add')
cls.entity_01 = Code(remote_computer_exec=(cls.computer, '/bin/true')).store()
cls.entity_02 = Code(remote_computer_exec=(cls.computer, '/bin/true'),
input_plugin_name='arithmetic.add').store()
cls.entity_03 = Code(remote_computer_exec=(cls.computer, '/bin/true'),
input_plugin_name='templatereplacer').store()

cls.entity_01.label = 'computer_01'
cls.entity_02.label = str(cls.entity_01.pk)
cls.entity_03.label = str(cls.entity_01.uuid)

def test_get_by_id(self):
"""
Verify that using the ID will retrieve the correct entity
"""
identifier = '{}'.format(self.entity_01.pk)
result = self.param_base.convert(identifier, None, None)
self.assertEqual(result.uuid, self.entity_01.uuid)

def test_get_by_uuid(self):
"""
Verify that using the UUID will retrieve the correct entity
"""
identifier = '{}'.format(self.entity_01.uuid)
result = self.param_base.convert(identifier, None, None)
self.assertEqual(result.uuid, self.entity_01.uuid)

def test_get_by_label(self):
"""
Verify that using the LABEL will retrieve the correct entity
"""
identifier = '{}'.format(self.entity_01.label)
result = self.param_base.convert(identifier, None, None)
self.assertEqual(result.uuid, self.entity_01.uuid)

def test_get_by_fullname(self):
"""
Verify that using the LABEL@machinename will retrieve the correct entity
"""
identifier = '{}@{}'.format(self.entity_01.label, self.computer.name) # pylint: disable=no-member
result = self.param_base.convert(identifier, None, None)
self.assertEqual(result.uuid, self.entity_01.uuid)

def test_ambiguous_label_pk(self):
"""
Situation: LABEL of entity_02 is exactly equal to ID of entity_01
Verify that using an ambiguous identifier gives precedence to the ID interpretation
Appending the special ambiguity breaker character will force the identifier to be treated as a LABEL
"""
identifier = '{}'.format(self.entity_02.label)
result = self.param_base.convert(identifier, None, None)
self.assertEqual(result.uuid, self.entity_01.uuid)

identifier = '{}{}'.format(self.entity_02.label, OrmEntityLoader.label_ambiguity_breaker)
result = self.param_base.convert(identifier, None, None)
self.assertEqual(result.uuid, self.entity_02.uuid)

def test_ambiguous_label_uuid(self):
"""
Situation: LABEL of entity_03 is exactly equal to UUID of entity_01
Verify that using an ambiguous identifier gives precedence to the UUID interpretation
Appending the special ambiguity breaker character will force the identifier to be treated as a LABEL
"""
identifier = '{}'.format(self.entity_03.label)
result = self.param_base.convert(identifier, None, None)
self.assertEqual(result.uuid, self.entity_01.uuid)

identifier = '{}{}'.format(self.entity_03.label, OrmEntityLoader.label_ambiguity_breaker)
result = self.param_base.convert(identifier, None, None)
self.assertEqual(result.uuid, self.entity_03.uuid)

def test_entry_point_validation(self):
"""Verify that when an `entry_point` is defined in the constructor, it is respected in the validation."""
identifier = '{}'.format(self.entity_02.pk)
result = self.param_entry_point.convert(identifier, None, None)
self.assertEqual(result.uuid, self.entity_02.uuid)

with self.assertRaises(click.BadParameter):
identifier = '{}'.format(self.entity_03.pk)
result = self.param_entry_point.convert(identifier, None, None)
@pytest.fixture
def parameter_type():
"""Return an instance of the `CodeParamType`."""
return CodeParamType()


@pytest.fixture
def setup_codes(clear_database_before_test, aiida_localhost):
"""Create some `Code` instances to test the `CodeParamType` parameter type for the command line infrastructure.
We create an initial code with a random name and then on purpose create two code with a name that matches exactly
the ID and UUID, respectively, of the first one. This allows us to test the rules implemented to solve ambiguities
that arise when determing the identifier type.
"""
entity_01 = Code(remote_computer_exec=(aiida_localhost, '/bin/true')).store()
entity_02 = Code(remote_computer_exec=(aiida_localhost, '/bin/true'), input_plugin_name='arithmetic.add').store()
entity_03 = Code(remote_computer_exec=(aiida_localhost, '/bin/true'), input_plugin_name='templatereplacer').store()

entity_01.label = 'computer_01'
entity_02.label = str(entity_01.pk)
entity_03.label = str(entity_01.uuid)

return entity_01, entity_02, entity_03


def test_get_by_id(setup_codes, parameter_type):
"""Verify that using the ID will retrieve the correct entity."""
entity_01, entity_02, entity_03 = setup_codes
identifier = '{}'.format(entity_01.pk)
result = parameter_type.convert(identifier, None, None)
assert result.uuid == entity_01.uuid


def test_get_by_uuid(setup_codes, parameter_type):
"""Verify that using the UUID will retrieve the correct entity."""
entity_01, entity_02, entity_03 = setup_codes
identifier = '{}'.format(entity_01.uuid)
result = parameter_type.convert(identifier, None, None)
assert result.uuid == entity_01.uuid


def test_get_by_label(setup_codes, parameter_type):
"""Verify that using the LABEL will retrieve the correct entity."""
entity_01, entity_02, entity_03 = setup_codes
identifier = '{}'.format(entity_01.label)
result = parameter_type.convert(identifier, None, None)
assert result.uuid == entity_01.uuid


def test_get_by_fullname(setup_codes, parameter_type):
"""Verify that using the LABEL@machinename will retrieve the correct entity."""
entity_01, entity_02, entity_03 = setup_codes
identifier = '{}@{}'.format(entity_01.label, entity_01.computer.name)
result = parameter_type.convert(identifier, None, None)
assert result.uuid == entity_01.uuid


def test_ambiguous_label_pk(setup_codes, parameter_type):
"""Situation: LABEL of entity_02 is exactly equal to ID of entity_01.
Verify that using an ambiguous identifier gives precedence to the ID interpretation
Appending the special ambiguity breaker character will force the identifier to be treated as a LABEL
"""
entity_01, entity_02, entity_03 = setup_codes
identifier = '{}'.format(entity_02.label)
result = parameter_type.convert(identifier, None, None)
assert result.uuid == entity_01.uuid

identifier = '{}{}'.format(entity_02.label, OrmEntityLoader.label_ambiguity_breaker)
result = parameter_type.convert(identifier, None, None)
assert result.uuid == entity_02.uuid


def test_ambiguous_label_uuid(setup_codes, parameter_type):
"""Situation: LABEL of entity_03 is exactly equal to UUID of entity_01.
Verify that using an ambiguous identifier gives precedence to the UUID interpretation
Appending the special ambiguity breaker character will force the identifier to be treated as a LABEL
"""
entity_01, entity_02, entity_03 = setup_codes
identifier = '{}'.format(entity_03.label)
result = parameter_type.convert(identifier, None, None)
assert result.uuid == entity_01.uuid

identifier = '{}{}'.format(entity_03.label, OrmEntityLoader.label_ambiguity_breaker)
result = parameter_type.convert(identifier, None, None)
assert result.uuid == entity_03.uuid


def test_entry_point_validation(setup_codes):
"""Verify that when an `entry_point` is defined in the constructor, it is respected in the validation."""
entity_01, entity_02, entity_03 = setup_codes
parameter_type = CodeParamType(entry_point='arithmetic.add')
identifier = '{}'.format(entity_02.pk)
result = parameter_type.convert(identifier, None, None)
assert result.uuid == entity_02.uuid

with pytest.raises(click.BadParameter):
identifier = '{}'.format(entity_03.pk)
result = parameter_type.convert(identifier, None, None)


def test_complete(setup_codes, parameter_type, aiida_localhost):
"""Test the `complete` method that provides auto-complete functionality."""
entity_01, entity_02, entity_03 = setup_codes
entity_04 = Code(label='xavier', remote_computer_exec=(aiida_localhost, '/bin/true')).store()

options = [item[0] for item in parameter_type.complete(None, '')]
assert sorted(options) == sorted([entity_01.label, entity_02.label, entity_03.label, entity_04.label])

options = [item[0] for item in parameter_type.complete(None, 'xa')]
assert sorted(options) == sorted([entity_04.label])
12 changes: 12 additions & 0 deletions tests/cmdline/params/types/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,15 @@ def test_sub_classes(setup_groups, sub_classes, expected):
results.append(True)

assert tuple(results) == expected


def test_complete(setup_groups, parameter_type):
"""Test the `complete` method that provides auto-complete functionality."""
entity_01, entity_02, entity_03 = setup_groups
entity_04 = Group(label='xavier').store()

options = [item[0] for item in parameter_type.complete(None, '')]
assert sorted(options) == sorted([entity_01.label, entity_02.label, entity_03.label, entity_04.label])

options = [item[0] for item in parameter_type.complete(None, 'xa')]
assert sorted(options) == sorted([entity_04.label])

0 comments on commit 789fdee

Please sign in to comment.