Skip to content

Commit

Permalink
PseudoPotentialFamily: add the pseudo_type extra
Browse files Browse the repository at this point in the history
Since pseudopotential families can now optionally support many
pseudopotential formats, although each instance only supports one format
at a time, it should be easy to determine which format any particular
family hosts.

The `pseudo_type` property is added to `PseudoPotentialFamily` which
returns the entry point name of the `PseudoPotentialData` subclass that
is used by the pseudopotentials that it contains. For example, if the
family hosts only `UpfData` nodes, whose entry point string is defined
as `aiida.data:pseudo.upf`, it will return `pseudo.upf`. Note that the
entry point group `aiida.data` is stripped as it is redundant
information.

The `pseudo_type` is stored as an extra on the family instance and is
updated after the node contents are mutated, i.e. after `add_nodes`,
`remove_nodes` and `clear. Having this information stored as an extra
and not calculated each time on the fly, not only helps efficiency but
it also makes it queryable, which is important if one needs to find a
family with pseudos of a particular type.

The class does not provide a `pseudo_type` setter as this only ever
needs to be set by the class itself. Users can of course always change
the value through the `set_extra` method of the `Group` class, causing
inconsistency in the data, but there is nothing one can do to protect
against this.
  • Loading branch information
sphuber committed Dec 7, 2020
1 parent 558be1e commit 1ca1bc4
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
24 changes: 24 additions & 0 deletions aiida_pseudo/groups/family/pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class PseudoPotentialFamily(Group):
hosted by setting ``_pseudo_types`` to a tuple of ``PseudoPotentialData`` subclasses.
"""

_key_pseudo_type = '_pseudo_type'
_pseudo_types = (PseudoPotentialData,)
_pseudos = None

Expand Down Expand Up @@ -174,6 +175,26 @@ def create_from_folder(cls, dirpath, label, *, description='', pseudo_type=None,

return family

@property
def pseudo_type(self):
"""Return the type of the pseudopotentials that are hosted by this family.
:return: the pseudopotential type or ``None`` if none has been set yet.
"""
return self.get_extra(self._key_pseudo_type, None)

def update_pseudo_type(self):
"""Update the pseudo type, stored as an extra, based on the current nodes in the family."""
pseudo_types = {pseudo.__class__ for pseudo in self.pseudos.values()}

if pseudo_types:
assert len(pseudo_types) == 1, 'Family contains pseudopotential data nodes of various types.'
entry_point_name = tuple(pseudo_types)[0].get_entry_point_name()
else:
entry_point_name = None

self.set_extra(self._key_pseudo_type, entry_point_name)

def add_nodes(self, nodes):
"""Add a node or a set of nodes to the family.
Expand Down Expand Up @@ -203,6 +224,7 @@ def add_nodes(self, nodes):
pseudos[pseudo.element] = pseudo

self.pseudos.update(pseudos)
self.update_pseudo_type()

super().add_nodes(nodes)

Expand All @@ -218,11 +240,13 @@ def remove_nodes(self, nodes):

removed = [node.pk for node in nodes]
self._pseudos = {pseudo.element: pseudo for pseudo in self.pseudos.values() if pseudo.pk not in removed}
self.update_pseudo_type()

def clear(self):
"""Remove all the pseudopotentials from this family."""
super().clear()
self._pseudos = None
self.update_pseudo_type()

@property
def pseudos(self):
Expand Down
20 changes: 20 additions & 0 deletions tests/groups/family/test_pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@ class CustomFamily(PseudoPotentialFamily):
CustomFamily(label='custom')


@pytest.mark.usefixtures('clear_db')
def test_pseudo_type(get_pseudo_potential_data):
"""Test ``PseudoPotentialFamily.pseudo_type`` property."""
family = PseudoPotentialFamily(label='label').store()
assert family.pseudo_type is None

pseudo = get_pseudo_potential_data('Ar').store()
family.add_nodes((pseudo,))
assert family.pseudo_type == pseudo.get_entry_point_name()

family.clear()
assert family.pseudo_type is None

family.add_nodes((pseudo,))
assert family.pseudo_type == pseudo.get_entry_point_name()

family.remove_nodes(pseudo)
assert family.pseudo_type is None


@pytest.mark.usefixtures('clear_db')
def test_construct():
"""Test the construction of `PseudoPotentialFamily` works."""
Expand Down

0 comments on commit 1ca1bc4

Please sign in to comment.