Skip to content

Commit

Permalink
CLI: Add support for units to family show
Browse files Browse the repository at this point in the history
Although we now support specifying units for the recommended cutoffs,
the `family show` method had not been adapted for this feature.

Here we add the `-u/--unit` option to the `family show` method, making
sure to show the correct unit in the table header as well as adapting
the values in the columns by supplying the unit to the
`get_recommended_cutoffs` method.

There is also some refactoring regarding the validation of the
stringencies. The `RecommendedCutoffMixin.validate_stringency` method
was not being used, so we adapt the code for several methods with the
stringency input to rely on this validation method. This way we have a
consistent and succinct error message in case the user requests a
stringency that has not been configured.
  • Loading branch information
mbercx committed May 9, 2021
1 parent 2ecf430 commit 3577ffe
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 47 deletions.
20 changes: 10 additions & 10 deletions aiida_pseudo/cli/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,26 @@ def cmd_family():
@cmd_family.command('show')
@arguments.PSEUDO_POTENTIAL_FAMILY()
@options.STRINGENCY()
@options.UNIT(default=None)
@options_core.RAW()
@decorators.with_dbenv()
def cmd_family_show(family, stringency, raw):
def cmd_family_show(family, stringency, unit, raw):
"""Show details of pseudo potential family."""
from tabulate import tabulate

if isinstance(family, RecommendedCutoffMixin):

if stringency is not None and stringency not in family.get_cutoff_stringencies():
raise click.BadParameter(f'`{stringency}` is not defined for family `{family}`.', param_hint='stringency')
try:
family.validate_stringency(stringency)
except ValueError as exception:
raise click.BadParameter(f'{exception}', param_hint="'-s' / '--stringency'")

headers = ['Element', 'Pseudo', 'MD5', 'Wavefunction (eV)', 'Charge density (eV)']
unit = unit or family.get_cutoffs_unit(stringency)

headers = ['Element', 'Pseudo', 'MD5', f'Wavefunction ({unit})', f'Charge density ({unit})']
rows = [[
pseudo.element, pseudo.filename, pseudo.md5,
*family.get_recommended_cutoffs(elements=pseudo.element, stringency=stringency)
*family.get_recommended_cutoffs(elements=pseudo.element, stringency=stringency, unit=unit)
] for pseudo in family.nodes]
else:
headers = ['Element', 'Pseudo', 'MD5']
Expand Down Expand Up @@ -76,11 +81,6 @@ def cmd_family_cutoffs_set(family, cutoffs, stringency, unit): # noqa: D301
if not isinstance(family, RecommendedCutoffMixin):
raise click.BadParameter(f'family `{family}` does not support recommended cutoffs to be set.')

try:
family.validate_cutoffs_unit(unit)
except ValueError as exception:
raise click.BadParameter(f'{exception}', param_hint='UNIT')

try:
data = json.load(cutoffs)
except ValueError as exception:
Expand Down
4 changes: 2 additions & 2 deletions aiida_pseudo/cli/params/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import click

from aiida.cmdline.params.options import OverridableOption
from .types import PseudoPotentialFamilyTypeParam, PseudoPotentialTypeParam
from .types import PseudoPotentialFamilyTypeParam, PseudoPotentialTypeParam, UnitParamType

__all__ = (
'VERSION', 'FUNCTIONAL', 'RELATIVISTIC', 'PROTOCOL', 'PSEUDO_FORMAT', 'STRINGENCY', 'DEFAULT_STRINGENCY',
Expand Down Expand Up @@ -88,7 +88,7 @@
UNIT = OverridableOption(
'-u',
'--unit',
type=click.STRING,
type=UnitParamType(quantity='energy'),
required=False,
default='eV',
show_default=True,
Expand Down
31 changes: 31 additions & 0 deletions aiida_pseudo/cli/params/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from aiida.cmdline.params.types import GroupParamType
from ..utils import attempt
from ...common.units import U

__all__ = ('PseudoPotentialFamilyTypeParam', 'PseudoPotentialFamilyParam', 'PseudoPotentialTypeParam')

Expand Down Expand Up @@ -128,3 +129,33 @@ def convert(self, value, param, ctx) -> typing.Union[pathlib.Path, bytes]:
response = requests.get(value)
response.raise_for_status()
return response


class UnitParamType(click.ParamType):
"""Parameter type ."""

name = 'unit'

def __init__(self, quantity: typing.Optional[typing.List[str]] = None, **kwargs):
"""Construct the parameter.
:param quantity: The corresponding quantity of the unit.
"""
super().__init__(**kwargs)
self.quantity = quantity

def convert(self, value, _, __):
"""Check if the provided unit is a valid energy unit.
:raises: `click.BadParameter` if the provided unit is not a valid energy unit.
"""
try:
if value not in U:
raise ValueError(f'`{value}` is not a valid unit.')

if not U.Quantity(1, value).check(f'[{self.quantity}]'):
raise ValueError(f'`{value}` is not a valid `{self.quantity}` unit.')
except ValueError as exception:
raise click.BadParameter(f'{exception}') from exception

return value
54 changes: 26 additions & 28 deletions aiida_pseudo/groups/mixins/cutoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""Mixin that adds support of recommended cutoffs to a ``Group`` subclass, using its extras."""
import warnings

from typing import Optional

from aiida.common.lang import type_check
from aiida.plugins import DataFactory

Expand Down Expand Up @@ -81,17 +83,23 @@ def validate_cutoffs_unit(unit: str) -> None:
if not U.Quantity(1, unit).check('[energy]'):
raise ValueError(f'`{unit}` is not a valid energy unit.')

def validate_stringency(self, stringency: str) -> None:
def validate_stringency(self, stringency: Optional[str]) -> None:
"""Validate a cutoff stringency.
Check if the stringency is defined for the family. If no stringency input is passed, the method checks if a
default stringency has been set.
:param stringency: the cutoff stringency to validate.
:raises ValueError: if stringency is None or the family does not define cutoffs for the specified stringency.
:raises ValueError: if default stringency has not been defined.
:raises ValueError: if the family does not define cutoffs for the specified stringency.
"""
if stringency is None:
raise ValueError('defining a stringency is required.')

if stringency not in self.get_cutoff_stringencies():
raise ValueError(f'stringency `{stringency}` is not defined for this family.')
self.get_default_stringency()
elif stringency not in self.get_cutoff_stringencies():
raise ValueError(
f'stringency `{stringency}` is not one of the available cutoff stringencies for this family: '
f'{self.get_cutoff_stringencies()}.'
)

def _get_cutoffs_dict(self) -> dict:
"""Return the cutoffs dictionary that maps the stringencies to the recommended cutoffs.
Expand Down Expand Up @@ -125,12 +133,7 @@ def set_default_stringency(self, default_stringency: str) -> None:
:raises ValueError: if the provided default stringency is not in the tuple of available cutoff stringencies for
the pseudo family.
"""
if default_stringency not in self.get_cutoff_stringencies():
raise ValueError(
'provided default stringency not in tuple of available cutoff stringencies: '
f'{self.get_cutoff_stringencies()}.'
)

self.validate_stringency(default_stringency)
self.set_extra(self._key_default_stringency, default_stringency)

def get_cutoff_stringencies(self) -> tuple:
Expand Down Expand Up @@ -182,22 +185,19 @@ def set_cutoffs(self, cutoffs: dict, stringency: str, unit: str = None) -> None:
if len(cutoffs_dict) == 1:
self.set_default_stringency(stringency)

def get_cutoffs(self, stringency=None) -> dict:
def get_cutoffs(self, stringency: str = None) -> dict:
"""Return a set of cutoffs for the given stringency.
:param stringency: optional stringency for which to retrieve the cutoffs. If not specified, the default
stringency of the family is used.
:raises ValueError: if no stringency is specified and no default stringency is defined for the family.
:raises ValueError: if the requested stringency is not defined for this family.
"""
self.validate_stringency(stringency)
stringency = stringency or self.get_default_stringency()
return self._get_cutoffs_dict()[stringency]

try:
return self._get_cutoffs_dict()[stringency]
except KeyError as exception:
raise ValueError(f'stringency `{stringency}` is not defined for this family.') from exception

def delete_cutoffs(self, stringency) -> None:
def delete_cutoffs(self, stringency: str) -> None:
"""Delete the recommended cutoffs for a specified stringency.
.. note: If, after the cutoffs have been deleted, there is only one stringency defined for the pseudo family,
Expand All @@ -207,8 +207,7 @@ def delete_cutoffs(self, stringency) -> None:
:param stringency: stringency for which to delete the cutoffs.
:raises ValueError: if the requested stringency is not defined for this family.
"""
if stringency not in self.get_cutoff_stringencies():
raise ValueError(f'stringency `{stringency}` is not defined for this family.')
self.validate_stringency(stringency)

cutoffs_dict = self._get_cutoffs_dict()
cutoffs_dict.pop(stringency)
Expand Down Expand Up @@ -257,19 +256,18 @@ def get_cutoffs_unit(self, stringency: str = None) -> str:
:raises ValueError: if no stringency is specified and no default stringency is defined for the family.
:raises ValueError: if the requested stringency is not defined for this family.
"""
self.validate_stringency(stringency)
stringency = stringency or self.get_default_stringency()

try:
return self._get_cutoffs_unit_dict()[stringency]
except KeyError as exception:
except KeyError:
# Workaround to deal with pseudo families installed in v0.5.0 - Set default unit in case it is not in extras
if stringency in self.get_cutoff_stringencies():
cutoffs_unit_dict = self._get_cutoffs_unit_dict()
cutoffs_unit_dict[stringency] = 'eV'
self.set_extra(self._key_cutoffs_unit, cutoffs_unit_dict)
return 'eV'
cutoffs_unit_dict = self._get_cutoffs_unit_dict()
cutoffs_unit_dict[stringency] = 'eV'
self.set_extra(self._key_cutoffs_unit, cutoffs_unit_dict)
return 'eV'
# End of workaround
raise ValueError(f'stringency `{stringency}` is not defined for this family.') from exception

def get_recommended_cutoffs(self, *, elements=None, structure=None, stringency=None, unit=None):
"""Return tuple of recommended wavefunction and density cutoffs for the given elements or ``StructureData``.
Expand Down
50 changes: 48 additions & 2 deletions tests/cli/test_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_family_cutoffs_set_unit(run_cli_command, get_pseudo_family, generate_cu
result = run_cli_command(
cmd_family_cutoffs_set, [family.label, str(filepath), '-s', stringency, '-u', unit], raises=True
)
assert 'Error: Invalid value for UNIT:' in result.output
assert "Error: Invalid value for '-u' / '--unit': `GME stock` is not a valid unit." in result.output

# Correct unit
unit = 'hartree'
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_family_show_recommended_cutoffs(clear_db, run_cli_command, get_pseudo_f

# Test the command prints an error for a non-existing stringency
result = run_cli_command(cmd_family_show, [family.label, '--stringency', 'non-existing'], raises=True)
assert 'Invalid value for stringency: `non-existing` is not defined' in result.output
assert "Error: Invalid value for '-s' / '--stringency': stringency `non-existing` is not" in result.output

# Test the command for default and explicit stringency
for stringency in [None, 'high']:
Expand Down Expand Up @@ -131,3 +131,49 @@ def test_family_show_raw(clear_db, run_cli_command, get_pseudo_family):
for option in ['-r', '--raw']:
result = run_cli_command(cmd_family_show, [option, family.label])
assert len(result.output_lines) == len(family.nodes)


def test_family_show_unit_default(clear_db, run_cli_command, get_pseudo_family):
"""Test the `family show` command with default unit."""
elements = ['Ar', 'Kr']
cutoff_dict = {'normal': {'Ar': {'cutoff_wfc': 50, 'cutoff_rho': 200}, 'Kr': {'cutoff_wfc': 25, 'cutoff_rho': 100}}}

family = get_pseudo_family(cls=CutoffsPseudoPotentialFamily, elements=elements, cutoffs_dict=cutoff_dict, unit='Ry')

# Test the default unit (Ry)
result = run_cli_command(cmd_family_show, [family.label])

header_fields = result.output_lines[0].split()
unit = family.get_cutoffs_unit()

assert header_fields[4] == f'({unit})'
assert header_fields[7] == f'({unit})'

for index, element in enumerate(elements):
cutoffs = family.get_recommended_cutoffs(elements=element)
fields = result.output_lines[index + 2].split()
assert_almost_equal(cutoffs[0], float(fields[3]))
assert_almost_equal(cutoffs[1], float(fields[4]))


@pytest.mark.parametrize('unit', ['Ry', 'eV', 'hartree', 'aJ'])
def test_family_show_unit(clear_db, run_cli_command, get_pseudo_family, unit):
"""Test the `-u/--unit` option."""
elements = [
'Ar',
]
cutoff_dict = {'normal': {'Ar': {'cutoff_wfc': 50, 'cutoff_rho': 200}}}

family = get_pseudo_family(cls=CutoffsPseudoPotentialFamily, elements=elements, cutoffs_dict=cutoff_dict, unit='Ry')

# Test both option strings and several units
for option in ['-u', '--unit']:
result = run_cli_command(cmd_family_show, [family.label, option, unit])
cutoffs = family.get_recommended_cutoffs(elements='Ar', unit=unit)
header_fields = result.output_lines[0].split()
assert header_fields[4] == f'({unit})'
assert header_fields[4] == f'({unit})'

values_fields = result.output_lines[2].split()
assert round(cutoffs[0], 1) == float(values_fields[3])
assert round(cutoffs[1], 1) == float(values_fields[4])
10 changes: 5 additions & 5 deletions tests/groups/mixins/test_cutoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def test_validate_stringency(get_pseudo_family, generate_cutoffs):
"""Test the ``CutoffsPseudoPotentialFamily.validate_stringency`` method."""
family = get_pseudo_family(cls=CutoffsPseudoPotentialFamily)

with pytest.raises(ValueError, match=r'stringency `.*` is not defined for this family.'):
with pytest.raises(ValueError, match=r'stringency `.*` is not one of the available cutoff stringencies for this'):
family.validate_stringency('default')

cutoffs = generate_cutoffs(family)
stringency = 'default'
family.set_cutoffs(cutoffs, stringency)

with pytest.raises(ValueError, match=r'stringency `.*` is not defined for this family.'):
with pytest.raises(ValueError, match=r'stringency `.*` is not one of the available cutoff stringencies for this'):
family.validate_stringency(stringency + 'non-existing')

family.validate_stringency(stringency)
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_set_default_stringency(get_pseudo_family, generate_cutoffs_dict):

assert family.get_default_stringency() == 'low'

with pytest.raises(ValueError, match='provided default stringency not in tuple of available cutoff stringencies:'):
with pytest.raises(ValueError, match=r'stringency `nonexistent` is not one of the available cutoff stringencies'):
family.set_default_stringency('nonexistent')

family.set_default_stringency('normal')
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_get_cutoffs(get_pseudo_family, generate_cutoffs):

family.set_cutoffs(cutoffs, stringency)

with pytest.raises(ValueError, match=r'stringency `.*` is not defined for this family.'):
with pytest.raises(ValueError, match=r'stringency `.*` is not one of the available cutoff stringencies for this'):
family.get_cutoffs('non-existing')

assert family.get_cutoffs() == cutoffs
Expand Down Expand Up @@ -314,7 +314,7 @@ def test_delete_cutoffs(get_pseudo_family, generate_cutoffs_dict):
for stringency, cutoffs in generate_cutoffs_dict(family, stringencies).items():
family.set_cutoffs(cutoffs, stringency)

with pytest.raises(ValueError, match='stringency `nonexistent` is not defined for this family.'):
with pytest.raises(ValueError, match=r'stringency `nonexistent` is not one of the available cutoff stringencies'):
family.delete_cutoffs('nonexistent')

with pytest.warns(UserWarning, match='`low` was the default stringency of this family. Please set'):
Expand Down

0 comments on commit 3577ffe

Please sign in to comment.