Skip to content

Commit

Permalink
PseudoPotentialFamily: allow support of multiple pseudo formats
Browse files Browse the repository at this point in the history
Originally, each pseudopotential family could only specify a single
supported pseudopotential type, meaning a single subclass of the
`PseudoPotentialData` base class. However, this policy forced one to
create a new family subclass for each pseudopotential type that one
might want to create a family off. Semantically, this does not even make
sense, because the pseudopotential types are merely describing the file
format with which the pseudopotential is written to file. In principle,
any pseudo could be transformed from any format to any other format. A
family with a certain set of pseudopotentials in format A, could be
represented by the same pseudos but then in format B, without changing
the informational content. Forcing a different pseudo family class to be
used is nonsensical and unpractical.

The `PseudoPotentialFamily` class is updated such that the `_pseudo_type`
class attribute is renamed to `_pseudo_types` and now takes a tuple of
`PseudoPotentialData` subclasses. An instance of a family can be created
where the pseudopotentials are in any one of these formats. Note that a
family class can support multiple pseudopotential types, but any one
instance can only host pseudopotentials of a single type.
  • Loading branch information
sphuber committed Dec 7, 2020
1 parent 56e8f53 commit 558be1e
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 67 deletions.
116 changes: 70 additions & 46 deletions aiida_pseudo/groups/family/pseudo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
"""Subclass of `Group` that serves as a base class for representing pseudo potential families."""
"""Subclass of ``Group`` that serves as a base class for representing pseudo potential families."""
import os
import re
from typing import Union, List, Tuple, Mapping
Expand All @@ -20,11 +20,12 @@ class PseudoPotentialFamily(Group):
"""Group to represent a pseudo potential family.
This is a base class that provides most of the functionality but does not actually define what type of pseudo
potentials can be contained. Subclasses should define the `_pseudo_type` class attribute to the data type of the
pseudo potentials that are accepted. This *has* to be a subclass of `PseudoPotentialData`.
potentials can be contained. If ``_pseudo_types`` is not defined, any pseudo potential type is accepted in this
family, as long as it is a subclass of ``PseudoPotentialData``. Subclasses can limit which pseudo types can be
hosted by setting ``_pseudo_types`` to a tuple of ``PseudoPotentialData`` subclasses.
"""

_pseudo_type = PseudoPotentialData
_pseudo_types = (PseudoPotentialData,)
_pseudos = None

def __repr__(self):
Expand All @@ -36,38 +37,53 @@ def __str__(self):
return f'{self.__class__.__name__}<{self.label}>'

def __init__(self, *args, **kwargs):
"""Validate that the `_pseudo_type` class attribute is a subclass of `PseudoPotentialData`."""
if not issubclass(self._pseudo_type, PseudoPotentialData):
class_name = self._pseudo_type.__class__.__name__
raise RuntimeError(f'`{class_name}` is not a subclass of `PseudoPotentialData`.')
"""Validate that the ``_pseudo_types`` class attribute is a tuple of ``PseudoPotentialData`` subclasses."""
if not self._pseudo_types or not isinstance(self._pseudo_types, tuple) or any(
not issubclass(pseudo_type, PseudoPotentialData) for pseudo_type in self._pseudo_types
):
raise RuntimeError('`_pseudo_types` should be a tuple of `PseudoPotentialData` subclasses.')

super().__init__(*args, **kwargs)

@classproperty
def pseudo_type(cls): # pylint: disable=no-self-argument
"""Return the pseudo potential type that this family accepts.
def pseudo_types(cls): # pylint: disable=no-self-argument
"""Return the pseudo potential types that this family accepts.
:return: the subclass of ``PseudoPotentialData`` that this family hosts nodes of.
:return: the tuple of subclasses of ``PseudoPotentialData`` that this family can host nodes of. If it returns
``None``, that means all subclasses are supported.
"""
return cls._pseudo_type
return cls._pseudo_types

@classmethod
def parse_pseudos_from_directory(cls, dirpath):
def parse_pseudos_from_directory(cls, dirpath, pseudo_type=None):
"""Parse the pseudo potential files in the given directory into a list of data nodes.
.. note:: The directory pointed to by `dirpath` should only contain pseudo potential files. Optionally, it can
contain just a single directory, that contains all the pseudo potential files. If any other files are stored
in the basepath or the subdirectory, that cannot be successfully parsed as pseudo potential files the method
will raise a `ValueError`.
will raise a ``ValueError``.
:param dirpath: absolute path to a directory containing pseudo potentials.
:param pseudo_type: subclass of ``PseudoPotentialData`` to be used for the parsed pseudos. If not specified and
the family only defines a single supported pseudo type in ``_pseudo_types`` then that will be used otherwise
a ``ValueError`` is raised.
:return: list of data nodes
:raises ValueError: if `dirpath` is not a directory or contains anything other than files.
:raises ValueError: if `dirpath` contains multiple pseudo potentials for the same element.
:raises ParsingError: if the constructor of the pseudo type fails for one of the files in the `dirpath`.
:raises ValueError: if ``dirpath`` is not a directory or contains anything other than files.
:raises ValueError: if ``dirpath`` contains multiple pseudo potentials for the same element.
:raises ValueError: if ``pseudo_type`` is explicitly specified and is not supported by this family class.
:raises ValueError: if ``pseudo_type`` is not specified and the class supports more than one pseudo type.
:raises ParsingError: if the constructor of the pseudo type fails for one of the files in the ``dirpath``.
"""
from aiida.common.exceptions import ParsingError

if pseudo_type is None and len(cls._pseudo_types) > 1:
raise ValueError(f'`{cls}` supports more than one type, so `pseudo_type` needs to be explicitly passed.')

pseudo_type = pseudo_type or cls._pseudo_types[0]

if all(not issubclass(pseudo_type, supported_type) for supported_type in cls._pseudo_types):
raise ValueError(f'`{pseudo_type}` is not supported by `{cls}`.')

pseudos = []

if not os.path.isdir(dirpath):
Expand All @@ -84,22 +100,22 @@ def parse_pseudos_from_directory(cls, dirpath):
if not os.path.isfile(filepath):
raise ValueError(f'dirpath `{dirpath}` contains at least one entry that is not a file')

try:
with open(filepath, 'rb') as handle:
pseudo = cls._pseudo_type(handle, filename=filename)
except ParsingError as exception:
raise ParsingError(f'failed to parse `{filepath}`: {exception}') from exception
else:
if pseudo.element is None:
match = re.search(r'^([A-Za-z]{1,2})\.\w+', filename)
if match is None:
raise ParsingError(
f'`{cls._pseudo_type}` constructor did not define the element and could not parse a valid '
'element symbol from the filename `{filename}` either. It should have the format '
'`ELEMENT.EXTENSION`'
)
pseudo.element = match.group(1)
pseudos.append(pseudo)
with open(filepath, 'rb') as handle:
try:
pseudo = pseudo_type(handle, filename=filename)
except ParsingError as exception:
raise ParsingError(f'failed to parse `{filepath}`: {exception}') from exception

if pseudo.element is None:
match = re.search(r'^([A-Za-z]{1,2})\.\w+', filename)
if match is None:
raise ParsingError(
f'`{pseudo.__class__}` constructor did not define the element and could not parse a valid '
'element symbol from the filename `{filename}` either. It should have the format '
'`ELEMENT.EXTENSION`'
)
pseudo.element = match.group(1)
pseudos.append(pseudo)

if not pseudos:
raise ValueError(f'no pseudo potentials were parsed from `{dirpath}`')
Expand All @@ -112,15 +128,23 @@ def parse_pseudos_from_directory(cls, dirpath):
return pseudos

@classmethod
def create_from_folder(cls, dirpath, label, description='', deduplicate=True):
"""Create a new `PseudoPotentialFamily` from the pseudo potentials contained in a directory.
def create_from_folder(cls, dirpath, label, *, description='', pseudo_type=None, deduplicate=True):
"""Create a new ``PseudoPotentialFamily`` from the pseudo potentials contained in a directory.
:param dirpath: absolute path to the folder containing the UPF files.
:param label: label to give to the `PseudoPotentialFamily`, should not already exist.
:param label: label to give to the ``PseudoPotentialFamily``, should not already exist.
:param description: description to give to the family.
:param pseudo_type: subclass of ``PseudoPotentialData`` to be used for the parsed pseudos. If not specified and
the family only defines a single supported pseudo type in ``_pseudo_types`` then that will be used otherwise
a ``ValueError`` is raised.
:param deduplicate: if True, will scan database for existing pseudo potentials of same type and with the same
md5 checksum, and use that instead of the parsed one.
:raises ValueError: if a `PseudoPotentialFamily` already exists with the given name.
:raises ValueError: if a ``PseudoPotentialFamily`` already exists with the given name.
:raises ValueError: if ``dirpath`` is not a directory or contains anything other than files.
:raises ValueError: if ``dirpath`` contains multiple pseudo potentials for the same element.
:raises ValueError: if ``pseudo_type`` is explicitly specified and is not supported by this family class.
:raises ValueError: if ``pseudo_type`` is not specified and the class supports more than one pseudo type.
:raises ParsingError: if the constructor of the pseudo type fails for one of the files in the ``dirpath``.
"""
type_check(description, str, allow_none=True)

Expand All @@ -131,19 +155,19 @@ def create_from_folder(cls, dirpath, label, description='', deduplicate=True):
else:
raise ValueError(f'the {cls.__name__} `{label}` already exists')

parsed_pseudos = cls.parse_pseudos_from_directory(dirpath)
parsed_pseudos = cls.parse_pseudos_from_directory(dirpath, pseudo_type)
family_pseudos = []

for pseudo in parsed_pseudos:
if deduplicate:
query = QueryBuilder()
query.append(cls.pseudo_type, subclassing=False, filters={'attributes.md5': pseudo.md5})
query.append(pseudo.__class__, subclassing=False, filters={'attributes.md5': pseudo.md5})
existing = query.first()
if existing:
pseudo = existing[0]
family_pseudos.append(pseudo)

# Only store the `Group` and the pseudo nodes now, such that we don't have to worry about the clean up in the
# Only store the ``Group`` and the pseudo nodes now, such that we don't have to worry about the clean up in the
# case that an exception is raised during creating them.
family.store()
family.add_nodes([pseudo.store() for pseudo in family_pseudos])
Expand All @@ -155,10 +179,10 @@ def add_nodes(self, nodes):
.. note: Each family instance can only contain a single pseudo potential for each element.
:param nodes: a single `Node` or a list of `Nodes` of type `PseudoPotentialFamily._pseudo_type`. Note that
subclasses of `_pseudo_type` are not accepted, only instances of that very type.
:param nodes: a single or list of ``Node`` instances of type that is in ``PseudoPotentialFamily._pseudo_types``.
:raises ModificationNotAllowed: if the family is not stored.
:raises TypeError: if nodes are not an instance or list of instance of `PseudoPotentialFamily._pseudo_type`.
:raises TypeError: if nodes are not an instance or list of instance of any of the classes listed by
``PseudoPotentialFamily._pseudo_types``.
:raises ValueError: if any of the nodes are not stored or their elements already exist in this family.
"""
if not self.is_stored:
Expand All @@ -167,8 +191,8 @@ def add_nodes(self, nodes):
if not isinstance(nodes, (list, tuple)):
nodes = [nodes]

if any([type(node) is not self._pseudo_type for node in nodes]): # pylint: disable=unidiomatic-typecheck
raise TypeError(f'only nodes of type `{self._pseudo_type}` can be added: {nodes}')
if any(not isinstance(node, self._pseudo_types) for node in nodes):
raise TypeError(f'only nodes of types `{self._pseudo_types}` can be added: {nodes}')

pseudos = {}

Expand Down Expand Up @@ -231,7 +255,7 @@ def get_pseudo(self, element):
except KeyError:
builder = QueryBuilder()
builder.append(self.__class__, filters={'id': self.pk}, tag='group')
builder.append(self._pseudo_type, filters={'attributes.element': element}, with_group='group')
builder.append(self._pseudo_types, filters={'attributes.element': element}, with_group='group')

try:
pseudo = builder.one()[0]
Expand Down
2 changes: 1 addition & 1 deletion aiida_pseudo/groups/family/sssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class SsspFamily(RecommendedCutoffMixin, PseudoPotentialFamily):
to contain the pseudo potentials and corresponding metadata of an official SSSP configuration.
"""

_pseudo_type = UpfData
_pseudo_types = (UpfData,)

label_template = 'SSSP/{version}/{functional}/{protocol}'
default_configuration = SsspConfiguration('1.1', 'PBE', 'efficiency')
Expand Down
3 changes: 2 additions & 1 deletion tests/cli/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Tests for the command `aiida-pseudo list`."""
from aiida_pseudo.cli import cmd_list
from aiida_pseudo.cli.list import PROJECTIONS_VALID
from aiida_pseudo.data.pseudo import UpfData
from aiida_pseudo.groups.family import PseudoPotentialFamily, SsspFamily


Expand Down Expand Up @@ -43,7 +44,7 @@ def test_list_project(clear_db, run_cli_command, get_pseudo_family):
def test_list_filter(clear_db, run_cli_command, get_pseudo_family):
"""Test the filtering option `-T`."""
family_base = get_pseudo_family(label='Pseudo potential family', cls=PseudoPotentialFamily)
family_sssp = get_pseudo_family(label='SSSP/1.0/PBE/efficiency', cls=SsspFamily)
family_sssp = get_pseudo_family(label='SSSP/1.0/PBE/efficiency', cls=SsspFamily, pseudo_type=UpfData)

assert PseudoPotentialFamily.objects.count() == 2

Expand Down
18 changes: 15 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,12 @@ def _get_pseudo_potential_data(element='Ar', entry_point=None) -> PseudoPotentia
def get_pseudo_family(tmpdir, filepath_pseudos):
"""Return a factory for a `PseudoPotentialFamily` instance."""

def _get_pseudo_family(label='family', cls=PseudoPotentialFamily, elements=None) -> PseudoPotentialFamily:
def _get_pseudo_family(
label='family',
cls=PseudoPotentialFamily,
pseudo_type=PseudoPotentialData,
elements=None
) -> PseudoPotentialFamily:
"""Return an instance of `PseudoPotentialFamily` or subclass containing the given elements.
:param elements: optional list of elements to include instead of all the available ones
Expand All @@ -128,12 +133,19 @@ def _get_pseudo_family(label='family', cls=PseudoPotentialFamily, elements=None)
if elements is not None:
elements = {re.sub('[0-9]+', '', element) for element in elements}

dirpath = filepath_pseudos('upf')
if pseudo_type is PseudoPotentialData:
# There is no actual pseudopotential file fixtures for the base class, so default back to `.upf` files
extension = 'upf'
else:
extension = pseudo_type.get_entry_point_name()[len('pseudo.'):]

dirpath = filepath_pseudos(extension)

for pseudo in os.listdir(dirpath):
if elements is None or any([pseudo.startswith(element) for element in elements]):
shutil.copyfile(os.path.join(dirpath, pseudo), os.path.join(str(tmpdir), pseudo))

return cls.create_from_folder(str(tmpdir), label)
return cls.create_from_folder(str(tmpdir), label, pseudo_type=pseudo_type)

return _get_pseudo_family

Expand Down
Loading

0 comments on commit 558be1e

Please sign in to comment.