From 3577ffee30da67b250e78e24c46bb413363f5fd3 Mon Sep 17 00:00:00 2001 From: Marnik Bercx Date: Sun, 9 May 2021 15:55:32 +0200 Subject: [PATCH] CLI: Add support for units to `family show` Although we now support specifying units for the recommended cutoffs, the `family show` method had not been adapted for this feature. Here we add the `-u/--unit` option to the `family show` method, making sure to show the correct unit in the table header as well as adapting the values in the columns by supplying the unit to the `get_recommended_cutoffs` method. There is also some refactoring regarding the validation of the stringencies. The `RecommendedCutoffMixin.validate_stringency` method was not being used, so we adapt the code for several methods with the stringency input to rely on this validation method. This way we have a consistent and succinct error message in case the user requests a stringency that has not been configured. --- aiida_pseudo/cli/family.py | 20 +++++----- aiida_pseudo/cli/params/options.py | 4 +- aiida_pseudo/cli/params/types.py | 31 +++++++++++++++ aiida_pseudo/groups/mixins/cutoffs.py | 54 +++++++++++++-------------- tests/cli/test_family.py | 50 ++++++++++++++++++++++++- tests/groups/mixins/test_cutoffs.py | 10 ++--- 6 files changed, 122 insertions(+), 47 deletions(-) diff --git a/aiida_pseudo/cli/family.py b/aiida_pseudo/cli/family.py index d40bb8a..46c3d39 100644 --- a/aiida_pseudo/cli/family.py +++ b/aiida_pseudo/cli/family.py @@ -20,21 +20,26 @@ def cmd_family(): @cmd_family.command('show') @arguments.PSEUDO_POTENTIAL_FAMILY() @options.STRINGENCY() +@options.UNIT(default=None) @options_core.RAW() @decorators.with_dbenv() -def cmd_family_show(family, stringency, raw): +def cmd_family_show(family, stringency, unit, raw): """Show details of pseudo potential family.""" from tabulate import tabulate if isinstance(family, RecommendedCutoffMixin): - if stringency is not None and stringency not in family.get_cutoff_stringencies(): - raise click.BadParameter(f'`{stringency}` is not defined for family `{family}`.', param_hint='stringency') + try: + family.validate_stringency(stringency) + except ValueError as exception: + raise click.BadParameter(f'{exception}', param_hint="'-s' / '--stringency'") - headers = ['Element', 'Pseudo', 'MD5', 'Wavefunction (eV)', 'Charge density (eV)'] + unit = unit or family.get_cutoffs_unit(stringency) + + headers = ['Element', 'Pseudo', 'MD5', f'Wavefunction ({unit})', f'Charge density ({unit})'] rows = [[ pseudo.element, pseudo.filename, pseudo.md5, - *family.get_recommended_cutoffs(elements=pseudo.element, stringency=stringency) + *family.get_recommended_cutoffs(elements=pseudo.element, stringency=stringency, unit=unit) ] for pseudo in family.nodes] else: headers = ['Element', 'Pseudo', 'MD5'] @@ -76,11 +81,6 @@ def cmd_family_cutoffs_set(family, cutoffs, stringency, unit): # noqa: D301 if not isinstance(family, RecommendedCutoffMixin): raise click.BadParameter(f'family `{family}` does not support recommended cutoffs to be set.') - try: - family.validate_cutoffs_unit(unit) - except ValueError as exception: - raise click.BadParameter(f'{exception}', param_hint='UNIT') - try: data = json.load(cutoffs) except ValueError as exception: diff --git a/aiida_pseudo/cli/params/options.py b/aiida_pseudo/cli/params/options.py index a7285bc..a158ab3 100644 --- a/aiida_pseudo/cli/params/options.py +++ b/aiida_pseudo/cli/params/options.py @@ -5,7 +5,7 @@ import click from aiida.cmdline.params.options import OverridableOption -from .types import PseudoPotentialFamilyTypeParam, PseudoPotentialTypeParam +from .types import PseudoPotentialFamilyTypeParam, PseudoPotentialTypeParam, UnitParamType __all__ = ( 'VERSION', 'FUNCTIONAL', 'RELATIVISTIC', 'PROTOCOL', 'PSEUDO_FORMAT', 'STRINGENCY', 'DEFAULT_STRINGENCY', @@ -88,7 +88,7 @@ UNIT = OverridableOption( '-u', '--unit', - type=click.STRING, + type=UnitParamType(quantity='energy'), required=False, default='eV', show_default=True, diff --git a/aiida_pseudo/cli/params/types.py b/aiida_pseudo/cli/params/types.py index 2f0b4c5..65b541b 100644 --- a/aiida_pseudo/cli/params/types.py +++ b/aiida_pseudo/cli/params/types.py @@ -9,6 +9,7 @@ from aiida.cmdline.params.types import GroupParamType from ..utils import attempt +from ...common.units import U __all__ = ('PseudoPotentialFamilyTypeParam', 'PseudoPotentialFamilyParam', 'PseudoPotentialTypeParam') @@ -128,3 +129,33 @@ def convert(self, value, param, ctx) -> typing.Union[pathlib.Path, bytes]: response = requests.get(value) response.raise_for_status() return response + + +class UnitParamType(click.ParamType): + """Parameter type .""" + + name = 'unit' + + def __init__(self, quantity: typing.Optional[typing.List[str]] = None, **kwargs): + """Construct the parameter. + + :param quantity: The corresponding quantity of the unit. + """ + super().__init__(**kwargs) + self.quantity = quantity + + def convert(self, value, _, __): + """Check if the provided unit is a valid energy unit. + + :raises: `click.BadParameter` if the provided unit is not a valid energy unit. + """ + try: + if value not in U: + raise ValueError(f'`{value}` is not a valid unit.') + + if not U.Quantity(1, value).check(f'[{self.quantity}]'): + raise ValueError(f'`{value}` is not a valid `{self.quantity}` unit.') + except ValueError as exception: + raise click.BadParameter(f'{exception}') from exception + + return value diff --git a/aiida_pseudo/groups/mixins/cutoffs.py b/aiida_pseudo/groups/mixins/cutoffs.py index 3d06241..6326ef9 100644 --- a/aiida_pseudo/groups/mixins/cutoffs.py +++ b/aiida_pseudo/groups/mixins/cutoffs.py @@ -2,6 +2,8 @@ """Mixin that adds support of recommended cutoffs to a ``Group`` subclass, using its extras.""" import warnings +from typing import Optional + from aiida.common.lang import type_check from aiida.plugins import DataFactory @@ -81,17 +83,23 @@ def validate_cutoffs_unit(unit: str) -> None: if not U.Quantity(1, unit).check('[energy]'): raise ValueError(f'`{unit}` is not a valid energy unit.') - def validate_stringency(self, stringency: str) -> None: + def validate_stringency(self, stringency: Optional[str]) -> None: """Validate a cutoff stringency. + Check if the stringency is defined for the family. If no stringency input is passed, the method checks if a + default stringency has been set. + :param stringency: the cutoff stringency to validate. - :raises ValueError: if stringency is None or the family does not define cutoffs for the specified stringency. + :raises ValueError: if default stringency has not been defined. + :raises ValueError: if the family does not define cutoffs for the specified stringency. """ if stringency is None: - raise ValueError('defining a stringency is required.') - - if stringency not in self.get_cutoff_stringencies(): - raise ValueError(f'stringency `{stringency}` is not defined for this family.') + self.get_default_stringency() + elif stringency not in self.get_cutoff_stringencies(): + raise ValueError( + f'stringency `{stringency}` is not one of the available cutoff stringencies for this family: ' + f'{self.get_cutoff_stringencies()}.' + ) def _get_cutoffs_dict(self) -> dict: """Return the cutoffs dictionary that maps the stringencies to the recommended cutoffs. @@ -125,12 +133,7 @@ def set_default_stringency(self, default_stringency: str) -> None: :raises ValueError: if the provided default stringency is not in the tuple of available cutoff stringencies for the pseudo family. """ - if default_stringency not in self.get_cutoff_stringencies(): - raise ValueError( - 'provided default stringency not in tuple of available cutoff stringencies: ' - f'{self.get_cutoff_stringencies()}.' - ) - + self.validate_stringency(default_stringency) self.set_extra(self._key_default_stringency, default_stringency) def get_cutoff_stringencies(self) -> tuple: @@ -182,7 +185,7 @@ def set_cutoffs(self, cutoffs: dict, stringency: str, unit: str = None) -> None: if len(cutoffs_dict) == 1: self.set_default_stringency(stringency) - def get_cutoffs(self, stringency=None) -> dict: + def get_cutoffs(self, stringency: str = None) -> dict: """Return a set of cutoffs for the given stringency. :param stringency: optional stringency for which to retrieve the cutoffs. If not specified, the default @@ -190,14 +193,11 @@ def get_cutoffs(self, stringency=None) -> dict: :raises ValueError: if no stringency is specified and no default stringency is defined for the family. :raises ValueError: if the requested stringency is not defined for this family. """ + self.validate_stringency(stringency) stringency = stringency or self.get_default_stringency() + return self._get_cutoffs_dict()[stringency] - try: - return self._get_cutoffs_dict()[stringency] - except KeyError as exception: - raise ValueError(f'stringency `{stringency}` is not defined for this family.') from exception - - def delete_cutoffs(self, stringency) -> None: + def delete_cutoffs(self, stringency: str) -> None: """Delete the recommended cutoffs for a specified stringency. .. note: If, after the cutoffs have been deleted, there is only one stringency defined for the pseudo family, @@ -207,8 +207,7 @@ def delete_cutoffs(self, stringency) -> None: :param stringency: stringency for which to delete the cutoffs. :raises ValueError: if the requested stringency is not defined for this family. """ - if stringency not in self.get_cutoff_stringencies(): - raise ValueError(f'stringency `{stringency}` is not defined for this family.') + self.validate_stringency(stringency) cutoffs_dict = self._get_cutoffs_dict() cutoffs_dict.pop(stringency) @@ -257,19 +256,18 @@ def get_cutoffs_unit(self, stringency: str = None) -> str: :raises ValueError: if no stringency is specified and no default stringency is defined for the family. :raises ValueError: if the requested stringency is not defined for this family. """ + self.validate_stringency(stringency) stringency = stringency or self.get_default_stringency() try: return self._get_cutoffs_unit_dict()[stringency] - except KeyError as exception: + except KeyError: # Workaround to deal with pseudo families installed in v0.5.0 - Set default unit in case it is not in extras - if stringency in self.get_cutoff_stringencies(): - cutoffs_unit_dict = self._get_cutoffs_unit_dict() - cutoffs_unit_dict[stringency] = 'eV' - self.set_extra(self._key_cutoffs_unit, cutoffs_unit_dict) - return 'eV' + cutoffs_unit_dict = self._get_cutoffs_unit_dict() + cutoffs_unit_dict[stringency] = 'eV' + self.set_extra(self._key_cutoffs_unit, cutoffs_unit_dict) + return 'eV' # End of workaround - raise ValueError(f'stringency `{stringency}` is not defined for this family.') from exception def get_recommended_cutoffs(self, *, elements=None, structure=None, stringency=None, unit=None): """Return tuple of recommended wavefunction and density cutoffs for the given elements or ``StructureData``. diff --git a/tests/cli/test_family.py b/tests/cli/test_family.py index c7ad79c..c6dcc06 100644 --- a/tests/cli/test_family.py +++ b/tests/cli/test_family.py @@ -66,7 +66,7 @@ def test_family_cutoffs_set_unit(run_cli_command, get_pseudo_family, generate_cu result = run_cli_command( cmd_family_cutoffs_set, [family.label, str(filepath), '-s', stringency, '-u', unit], raises=True ) - assert 'Error: Invalid value for UNIT:' in result.output + assert "Error: Invalid value for '-u' / '--unit': `GME stock` is not a valid unit." in result.output # Correct unit unit = 'hartree' @@ -95,7 +95,7 @@ def test_family_show_recommended_cutoffs(clear_db, run_cli_command, get_pseudo_f # Test the command prints an error for a non-existing stringency result = run_cli_command(cmd_family_show, [family.label, '--stringency', 'non-existing'], raises=True) - assert 'Invalid value for stringency: `non-existing` is not defined' in result.output + assert "Error: Invalid value for '-s' / '--stringency': stringency `non-existing` is not" in result.output # Test the command for default and explicit stringency for stringency in [None, 'high']: @@ -131,3 +131,49 @@ def test_family_show_raw(clear_db, run_cli_command, get_pseudo_family): for option in ['-r', '--raw']: result = run_cli_command(cmd_family_show, [option, family.label]) assert len(result.output_lines) == len(family.nodes) + + +def test_family_show_unit_default(clear_db, run_cli_command, get_pseudo_family): + """Test the `family show` command with default unit.""" + elements = ['Ar', 'Kr'] + cutoff_dict = {'normal': {'Ar': {'cutoff_wfc': 50, 'cutoff_rho': 200}, 'Kr': {'cutoff_wfc': 25, 'cutoff_rho': 100}}} + + family = get_pseudo_family(cls=CutoffsPseudoPotentialFamily, elements=elements, cutoffs_dict=cutoff_dict, unit='Ry') + + # Test the default unit (Ry) + result = run_cli_command(cmd_family_show, [family.label]) + + header_fields = result.output_lines[0].split() + unit = family.get_cutoffs_unit() + + assert header_fields[4] == f'({unit})' + assert header_fields[7] == f'({unit})' + + for index, element in enumerate(elements): + cutoffs = family.get_recommended_cutoffs(elements=element) + fields = result.output_lines[index + 2].split() + assert_almost_equal(cutoffs[0], float(fields[3])) + assert_almost_equal(cutoffs[1], float(fields[4])) + + +@pytest.mark.parametrize('unit', ['Ry', 'eV', 'hartree', 'aJ']) +def test_family_show_unit(clear_db, run_cli_command, get_pseudo_family, unit): + """Test the `-u/--unit` option.""" + elements = [ + 'Ar', + ] + cutoff_dict = {'normal': {'Ar': {'cutoff_wfc': 50, 'cutoff_rho': 200}}} + + family = get_pseudo_family(cls=CutoffsPseudoPotentialFamily, elements=elements, cutoffs_dict=cutoff_dict, unit='Ry') + + # Test both option strings and several units + for option in ['-u', '--unit']: + result = run_cli_command(cmd_family_show, [family.label, option, unit]) + cutoffs = family.get_recommended_cutoffs(elements='Ar', unit=unit) + header_fields = result.output_lines[0].split() + assert header_fields[4] == f'({unit})' + assert header_fields[4] == f'({unit})' + + values_fields = result.output_lines[2].split() + assert round(cutoffs[0], 1) == float(values_fields[3]) + assert round(cutoffs[1], 1) == float(values_fields[4]) diff --git a/tests/groups/mixins/test_cutoffs.py b/tests/groups/mixins/test_cutoffs.py index d4c202f..045e7c8 100644 --- a/tests/groups/mixins/test_cutoffs.py +++ b/tests/groups/mixins/test_cutoffs.py @@ -51,14 +51,14 @@ def test_validate_stringency(get_pseudo_family, generate_cutoffs): """Test the ``CutoffsPseudoPotentialFamily.validate_stringency`` method.""" family = get_pseudo_family(cls=CutoffsPseudoPotentialFamily) - with pytest.raises(ValueError, match=r'stringency `.*` is not defined for this family.'): + with pytest.raises(ValueError, match=r'stringency `.*` is not one of the available cutoff stringencies for this'): family.validate_stringency('default') cutoffs = generate_cutoffs(family) stringency = 'default' family.set_cutoffs(cutoffs, stringency) - with pytest.raises(ValueError, match=r'stringency `.*` is not defined for this family.'): + with pytest.raises(ValueError, match=r'stringency `.*` is not one of the available cutoff stringencies for this'): family.validate_stringency(stringency + 'non-existing') family.validate_stringency(stringency) @@ -91,7 +91,7 @@ def test_set_default_stringency(get_pseudo_family, generate_cutoffs_dict): assert family.get_default_stringency() == 'low' - with pytest.raises(ValueError, match='provided default stringency not in tuple of available cutoff stringencies:'): + with pytest.raises(ValueError, match=r'stringency `nonexistent` is not one of the available cutoff stringencies'): family.set_default_stringency('nonexistent') family.set_default_stringency('normal') @@ -215,7 +215,7 @@ def test_get_cutoffs(get_pseudo_family, generate_cutoffs): family.set_cutoffs(cutoffs, stringency) - with pytest.raises(ValueError, match=r'stringency `.*` is not defined for this family.'): + with pytest.raises(ValueError, match=r'stringency `.*` is not one of the available cutoff stringencies for this'): family.get_cutoffs('non-existing') assert family.get_cutoffs() == cutoffs @@ -314,7 +314,7 @@ def test_delete_cutoffs(get_pseudo_family, generate_cutoffs_dict): for stringency, cutoffs in generate_cutoffs_dict(family, stringencies).items(): family.set_cutoffs(cutoffs, stringency) - with pytest.raises(ValueError, match='stringency `nonexistent` is not defined for this family.'): + with pytest.raises(ValueError, match=r'stringency `nonexistent` is not one of the available cutoff stringencies'): family.delete_cutoffs('nonexistent') with pytest.warns(UserWarning, match='`low` was the default stringency of this family. Please set'):