Skip to content

Commit

Permalink
PseudoPotentialFamily: override remove_nodes and clear (#29)
Browse files Browse the repository at this point in the history
These methods need to be overridden, such that after they have called
the implementation of the `Group` base class, the internal `_pseudos`
class attribute can be updated accordingly. For `clear` this is easy as
it is simply set to `None`, but for `remove_nodes` special care needs to
be taken. It needs to account for the fact that a single node or an
iterable of nodes can be passed as an argument, and that the base class
implementation will not raise if any of the nodes is not actually
contained in the group. So instead of doing the straight forward thing
of removing the nodes from `_pseudos` based on the elements, we remove
them based on the pk, which guarantees we are removing the correct nodes.
  • Loading branch information
sphuber authored Dec 7, 2020
1 parent d02286d commit e2dfb82
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
18 changes: 18 additions & 0 deletions aiida_pseudo/groups/family/pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,24 @@ def add_nodes(self, nodes):

super().add_nodes(nodes)

def remove_nodes(self, nodes):
"""Remove a pseudopotential or a set of pseudopotentials from the family.
:param nodes: a single or list of ``PseudoPotentialData`` instances or subclasses thereof.
"""
super().remove_nodes(nodes)

if not isinstance(nodes, (list, tuple)):
nodes = (nodes,)

removed = [node.pk for node in nodes]
self._pseudos = {pseudo.element: pseudo for pseudo in self.pseudos.values() if pseudo.pk not in removed}

def clear(self):
"""Remove all the pseudopotentials from this family."""
super().clear()
self._pseudos = None

@property
def pseudos(self):
"""Return the dictionary of pseudo potentials of this family indexed on the element symbol.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ max-line-length = 120

[tool.pylint.messages_control]
disable = [
'bad-continuation',
'duplicate-code',
'import-outside-toplevel',
]
45 changes: 45 additions & 0 deletions tests/groups/family/test_pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,51 @@ def test_add_nodes_duplicate_element(get_pseudo_family, get_pseudo_potential_dat
family.add_nodes(pseudo)


@pytest.mark.usefixtures('clear_db')
def test_remove_nodes(get_pseudo_family):
"""Test the ``PseudoPotentialFamily.remove_nodes`` method."""
elements = ('Ar', 'He', 'Kr')
family = get_pseudo_family(elements=elements)
pseudos = family.get_pseudos(elements=elements)

# Removing a single node
pseudo = pseudos.pop('Ar')
family.remove_nodes(pseudo)
assert family.pseudos == pseudos

# Removing multiple nodes
family.remove_nodes(list(pseudos.values()))
assert family.pseudos == {}


@pytest.mark.usefixtures('clear_db')
def test_remove_nodes_not_existing(get_pseudo_family, get_pseudo_potential_data):
"""Test the ``PseudoPotentialFamily.remove_nodes`` method works even when passing a non-existing pseudo.
The implementation of the ``remove_nodes`` method of the ``Group`` base class does not raise when passing a node
that is not contained within the group but will silently ignore it. Make sure that the corresponding element is not
accidentally still removed by the ``PseudoPotentialFamily.remove_nodes`` implementation.
"""
element = 'Ar'
family = get_pseudo_family(elements=(element,))
pseudo = get_pseudo_potential_data(element).store()

# The node ``pseudo`` is not actually contained within the family and so no pseudopotentials should be removed
family.remove_nodes(pseudo)
assert tuple(family.pseudos.keys()) == ('Ar',)


@pytest.mark.usefixtures('clear_db')
def test_clear(get_pseudo_family):
"""Test the ``PseudoPotentialFamily.clear`` method."""
family = get_pseudo_family(elements=('Ar', 'He', 'Kr'))
assert family.pseudos is not None

family.clear()
assert family.pseudos == {}
assert family.count() == 0


@pytest.mark.usefixtures('clear_db')
def test_pseudos(get_pseudo_potential_data):
"""Test the `PseudoPotentialFamily.pseudos` property."""
Expand Down

0 comments on commit e2dfb82

Please sign in to comment.