Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLI: Add support for units to family show #97

Merged
merged 3 commits into from
May 12, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions aiida_pseudo/cli/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,26 @@ def cmd_family():
@cmd_family.command('show')
@arguments.PSEUDO_POTENTIAL_FAMILY()
@options.STRINGENCY()
@options.UNIT(default=None)
@options_core.RAW()
@decorators.with_dbenv()
def cmd_family_show(family, stringency, raw):
def cmd_family_show(family, stringency, unit, raw):
"""Show details of pseudo potential family."""
from tabulate import tabulate

if isinstance(family, RecommendedCutoffMixin):

if stringency is not None and stringency not in family.get_cutoff_stringencies():
raise click.BadParameter(f'`{stringency}` is not defined for family `{family}`.', param_hint='stringency')
try:
family.validate_stringency(stringency)
except ValueError as exception:
raise click.BadParameter(f'{exception}', param_hint="'-s' / '--stringency'")

headers = ['Element', 'Pseudo', 'MD5', 'Wavefunction (eV)', 'Charge density (eV)']
unit = unit or family.get_cutoffs_unit(stringency)

headers = ['Element', 'Pseudo', 'MD5', f'Wavefunction ({unit})', f'Charge density ({unit})']
rows = [[
pseudo.element, pseudo.filename, pseudo.md5,
*family.get_recommended_cutoffs(elements=pseudo.element, stringency=stringency)
*family.get_recommended_cutoffs(elements=pseudo.element, stringency=stringency, unit=unit)
] for pseudo in family.nodes]
else:
headers = ['Element', 'Pseudo', 'MD5']
Expand Down Expand Up @@ -76,11 +81,6 @@ def cmd_family_cutoffs_set(family, cutoffs, stringency, unit): # noqa: D301
if not isinstance(family, RecommendedCutoffMixin):
raise click.BadParameter(f'family `{family}` does not support recommended cutoffs to be set.')

try:
family.validate_cutoffs_unit(unit)
except ValueError as exception:
raise click.BadParameter(f'{exception}', param_hint='UNIT')

try:
data = json.load(cutoffs)
except ValueError as exception:
Expand Down
4 changes: 2 additions & 2 deletions aiida_pseudo/cli/params/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import click

from aiida.cmdline.params.options import OverridableOption
from .types import PseudoPotentialFamilyTypeParam, PseudoPotentialTypeParam
from .types import PseudoPotentialFamilyTypeParam, PseudoPotentialTypeParam, UnitParamType

__all__ = (
'VERSION', 'FUNCTIONAL', 'RELATIVISTIC', 'PROTOCOL', 'PSEUDO_FORMAT', 'STRINGENCY', 'DEFAULT_STRINGENCY',
Expand Down Expand Up @@ -88,7 +88,7 @@
UNIT = OverridableOption(
'-u',
'--unit',
type=click.STRING,
type=UnitParamType(quantity='energy'),
required=False,
default='eV',
show_default=True,
Expand Down
31 changes: 31 additions & 0 deletions aiida_pseudo/cli/params/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from aiida.cmdline.params.types import GroupParamType
from ..utils import attempt
from ...common.units import U

__all__ = ('PseudoPotentialFamilyTypeParam', 'PseudoPotentialFamilyParam', 'PseudoPotentialTypeParam')

Expand Down Expand Up @@ -128,3 +129,33 @@ def convert(self, value, param, ctx) -> typing.Union[pathlib.Path, bytes]:
response = requests.get(value)
response.raise_for_status()
return response


class UnitParamType(click.ParamType):
"""Parameter type ."""
mbercx marked this conversation as resolved.
Show resolved Hide resolved

name = 'unit'

def __init__(self, quantity: typing.Optional[typing.List[str]] = None, **kwargs):
"""Construct the parameter.

:param quantity: The corresponding quantity of the unit.
"""
super().__init__(**kwargs)
self.quantity = quantity

def convert(self, value, _, __):
"""Check if the provided unit is a valid energy unit.

:raises: `click.BadParameter` if the provided unit is not a valid energy unit.
mbercx marked this conversation as resolved.
Show resolved Hide resolved
"""
try:
if value not in U:
raise ValueError(f'`{value}` is not a valid unit.')

if not U.Quantity(1, value).check(f'[{self.quantity}]'):
raise ValueError(f'`{value}` is not a valid `{self.quantity}` unit.')
except ValueError as exception:
raise click.BadParameter(f'{exception}') from exception
mbercx marked this conversation as resolved.
Show resolved Hide resolved

return value
54 changes: 26 additions & 28 deletions aiida_pseudo/groups/mixins/cutoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""Mixin that adds support of recommended cutoffs to a ``Group`` subclass, using its extras."""
import warnings

from typing import Optional

from aiida.common.lang import type_check
from aiida.plugins import DataFactory

Expand Down Expand Up @@ -81,17 +83,23 @@ def validate_cutoffs_unit(unit: str) -> None:
if not U.Quantity(1, unit).check('[energy]'):
raise ValueError(f'`{unit}` is not a valid energy unit.')

def validate_stringency(self, stringency: str) -> None:
def validate_stringency(self, stringency: Optional[str]) -> None:
"""Validate a cutoff stringency.

Check if the stringency is defined for the family. If no stringency input is passed, the method checks if a
default stringency has been set.

:param stringency: the cutoff stringency to validate.
:raises ValueError: if stringency is None or the family does not define cutoffs for the specified stringency.
:raises ValueError: if default stringency has not been defined.
mbercx marked this conversation as resolved.
Show resolved Hide resolved
:raises ValueError: if the family does not define cutoffs for the specified stringency.
"""
if stringency is None:
mbercx marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError('defining a stringency is required.')

if stringency not in self.get_cutoff_stringencies():
raise ValueError(f'stringency `{stringency}` is not defined for this family.')
self.get_default_stringency()
elif stringency not in self.get_cutoff_stringencies():
raise ValueError(
f'stringency `{stringency}` is not one of the available cutoff stringencies for this family: '
f'{self.get_cutoff_stringencies()}.'
)

def _get_cutoffs_dict(self) -> dict:
"""Return the cutoffs dictionary that maps the stringencies to the recommended cutoffs.
Expand Down Expand Up @@ -125,12 +133,7 @@ def set_default_stringency(self, default_stringency: str) -> None:
:raises ValueError: if the provided default stringency is not in the tuple of available cutoff stringencies for
the pseudo family.
"""
if default_stringency not in self.get_cutoff_stringencies():
raise ValueError(
'provided default stringency not in tuple of available cutoff stringencies: '
f'{self.get_cutoff_stringencies()}.'
)

self.validate_stringency(default_stringency)
mbercx marked this conversation as resolved.
Show resolved Hide resolved
self.set_extra(self._key_default_stringency, default_stringency)

def get_cutoff_stringencies(self) -> tuple:
Expand Down Expand Up @@ -182,22 +185,19 @@ def set_cutoffs(self, cutoffs: dict, stringency: str, unit: str = None) -> None:
if len(cutoffs_dict) == 1:
self.set_default_stringency(stringency)

def get_cutoffs(self, stringency=None) -> dict:
def get_cutoffs(self, stringency: str = None) -> dict:
"""Return a set of cutoffs for the given stringency.

:param stringency: optional stringency for which to retrieve the cutoffs. If not specified, the default
stringency of the family is used.
:raises ValueError: if no stringency is specified and no default stringency is defined for the family.
:raises ValueError: if the requested stringency is not defined for this family.
"""
self.validate_stringency(stringency)
stringency = stringency or self.get_default_stringency()
return self._get_cutoffs_dict()[stringency]

try:
return self._get_cutoffs_dict()[stringency]
except KeyError as exception:
raise ValueError(f'stringency `{stringency}` is not defined for this family.') from exception

def delete_cutoffs(self, stringency) -> None:
def delete_cutoffs(self, stringency: str) -> None:
"""Delete the recommended cutoffs for a specified stringency.

.. note: If, after the cutoffs have been deleted, there is only one stringency defined for the pseudo family,
Expand All @@ -207,8 +207,7 @@ def delete_cutoffs(self, stringency) -> None:
:param stringency: stringency for which to delete the cutoffs.
:raises ValueError: if the requested stringency is not defined for this family.
"""
if stringency not in self.get_cutoff_stringencies():
raise ValueError(f'stringency `{stringency}` is not defined for this family.')
self.validate_stringency(stringency)

cutoffs_dict = self._get_cutoffs_dict()
cutoffs_dict.pop(stringency)
Expand Down Expand Up @@ -257,19 +256,18 @@ def get_cutoffs_unit(self, stringency: str = None) -> str:
:raises ValueError: if no stringency is specified and no default stringency is defined for the family.
:raises ValueError: if the requested stringency is not defined for this family.
"""
self.validate_stringency(stringency)
stringency = stringency or self.get_default_stringency()

try:
return self._get_cutoffs_unit_dict()[stringency]
except KeyError as exception:
except KeyError:
# Workaround to deal with pseudo families installed in v0.5.0 - Set default unit in case it is not in extras
if stringency in self.get_cutoff_stringencies():
sphuber marked this conversation as resolved.
Show resolved Hide resolved
cutoffs_unit_dict = self._get_cutoffs_unit_dict()
cutoffs_unit_dict[stringency] = 'eV'
self.set_extra(self._key_cutoffs_unit, cutoffs_unit_dict)
return 'eV'
cutoffs_unit_dict = self._get_cutoffs_unit_dict()
cutoffs_unit_dict[stringency] = 'eV'
self.set_extra(self._key_cutoffs_unit, cutoffs_unit_dict)
return 'eV'
# End of workaround
raise ValueError(f'stringency `{stringency}` is not defined for this family.') from exception

def get_recommended_cutoffs(self, *, elements=None, structure=None, stringency=None, unit=None):
"""Return tuple of recommended wavefunction and density cutoffs for the given elements or ``StructureData``.
Expand Down
50 changes: 48 additions & 2 deletions tests/cli/test_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_family_cutoffs_set_unit(run_cli_command, get_pseudo_family, generate_cu
result = run_cli_command(
cmd_family_cutoffs_set, [family.label, str(filepath), '-s', stringency, '-u', unit], raises=True
)
assert 'Error: Invalid value for UNIT:' in result.output
assert "Error: Invalid value for '-u' / '--unit': `GME stock` is not a valid unit." in result.output

# Correct unit
unit = 'hartree'
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_family_show_recommended_cutoffs(clear_db, run_cli_command, get_pseudo_f

# Test the command prints an error for a non-existing stringency
result = run_cli_command(cmd_family_show, [family.label, '--stringency', 'non-existing'], raises=True)
assert 'Invalid value for stringency: `non-existing` is not defined' in result.output
assert "Error: Invalid value for '-s' / '--stringency': stringency `non-existing` is not" in result.output

# Test the command for default and explicit stringency
for stringency in [None, 'high']:
Expand Down Expand Up @@ -131,3 +131,49 @@ def test_family_show_raw(clear_db, run_cli_command, get_pseudo_family):
for option in ['-r', '--raw']:
result = run_cli_command(cmd_family_show, [option, family.label])
assert len(result.output_lines) == len(family.nodes)


def test_family_show_unit_default(clear_db, run_cli_command, get_pseudo_family):
"""Test the `family show` command with default unit."""
elements = ['Ar', 'Kr']
cutoff_dict = {'normal': {'Ar': {'cutoff_wfc': 50, 'cutoff_rho': 200}, 'Kr': {'cutoff_wfc': 25, 'cutoff_rho': 100}}}

family = get_pseudo_family(cls=CutoffsPseudoPotentialFamily, elements=elements, cutoffs_dict=cutoff_dict, unit='Ry')

# Test the default unit (Ry)
result = run_cli_command(cmd_family_show, [family.label])

header_fields = result.output_lines[0].split()
unit = family.get_cutoffs_unit()

assert header_fields[4] == f'({unit})'
assert header_fields[7] == f'({unit})'

for index, element in enumerate(elements):
cutoffs = family.get_recommended_cutoffs(elements=element)
fields = result.output_lines[index + 2].split()
assert_almost_equal(cutoffs[0], float(fields[3]))
assert_almost_equal(cutoffs[1], float(fields[4]))


@pytest.mark.parametrize('unit', ['Ry', 'eV', 'hartree', 'aJ'])
def test_family_show_unit(clear_db, run_cli_command, get_pseudo_family, unit):
"""Test the `-u/--unit` option."""
elements = [
'Ar',
]
cutoff_dict = {'normal': {'Ar': {'cutoff_wfc': 50, 'cutoff_rho': 200}}}

family = get_pseudo_family(cls=CutoffsPseudoPotentialFamily, elements=elements, cutoffs_dict=cutoff_dict, unit='Ry')

# Test both option strings and several units
for option in ['-u', '--unit']:
result = run_cli_command(cmd_family_show, [family.label, option, unit])
cutoffs = family.get_recommended_cutoffs(elements='Ar', unit=unit)
header_fields = result.output_lines[0].split()
assert header_fields[4] == f'({unit})'
assert header_fields[4] == f'({unit})'

values_fields = result.output_lines[2].split()
assert round(cutoffs[0], 1) == float(values_fields[3])
assert round(cutoffs[1], 1) == float(values_fields[4])
10 changes: 5 additions & 5 deletions tests/groups/mixins/test_cutoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def test_validate_stringency(get_pseudo_family, generate_cutoffs):
"""Test the ``CutoffsPseudoPotentialFamily.validate_stringency`` method."""
family = get_pseudo_family(cls=CutoffsPseudoPotentialFamily)

with pytest.raises(ValueError, match=r'stringency `.*` is not defined for this family.'):
with pytest.raises(ValueError, match=r'stringency `.*` is not one of the available cutoff stringencies for this'):
family.validate_stringency('default')

cutoffs = generate_cutoffs(family)
stringency = 'default'
family.set_cutoffs(cutoffs, stringency)

with pytest.raises(ValueError, match=r'stringency `.*` is not defined for this family.'):
with pytest.raises(ValueError, match=r'stringency `.*` is not one of the available cutoff stringencies for this'):
family.validate_stringency(stringency + 'non-existing')

family.validate_stringency(stringency)
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_set_default_stringency(get_pseudo_family, generate_cutoffs_dict):

assert family.get_default_stringency() == 'low'

with pytest.raises(ValueError, match='provided default stringency not in tuple of available cutoff stringencies:'):
with pytest.raises(ValueError, match=r'stringency `nonexistent` is not one of the available cutoff stringencies'):
family.set_default_stringency('nonexistent')

family.set_default_stringency('normal')
Expand Down Expand Up @@ -215,7 +215,7 @@ def test_get_cutoffs(get_pseudo_family, generate_cutoffs):

family.set_cutoffs(cutoffs, stringency)

with pytest.raises(ValueError, match=r'stringency `.*` is not defined for this family.'):
with pytest.raises(ValueError, match=r'stringency `.*` is not one of the available cutoff stringencies for this'):
family.get_cutoffs('non-existing')

assert family.get_cutoffs() == cutoffs
Expand Down Expand Up @@ -314,7 +314,7 @@ def test_delete_cutoffs(get_pseudo_family, generate_cutoffs_dict):
for stringency, cutoffs in generate_cutoffs_dict(family, stringencies).items():
family.set_cutoffs(cutoffs, stringency)

with pytest.raises(ValueError, match='stringency `nonexistent` is not defined for this family.'):
with pytest.raises(ValueError, match=r'stringency `nonexistent` is not one of the available cutoff stringencies'):
family.delete_cutoffs('nonexistent')

with pytest.warns(UserWarning, match='`low` was the default stringency of this family. Please set'):
Expand Down