Skip to content

Commit

Permalink
PseudoPotentialData: add the get_entry_point_name classmethod (#30)
Browse files Browse the repository at this point in the history
The method will return the entry point name under which the subclass of
`PseudoPotentialData` is registered. Note that the entry point name does
not include the entry point group. For example, the `UpfData` plugin is
registered under `aiida.data:pseudo.upf`, but the classmethod will
return the string `pseudo.upf`.
  • Loading branch information
sphuber authored Dec 7, 2020
1 parent e2dfb82 commit beb85e6
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
10 changes: 10 additions & 0 deletions aiida_pseudo/data/pseudo/pseudo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ class PseudoPotentialData(SingleFileData):
_key_element = 'element'
_key_md5 = 'md5'

@classmethod
def get_entry_point_name(cls):
"""Return the entry point name associated with this data class.
:return: the entry point name.
"""
from aiida.plugins.entry_point import get_entry_point_from_class
_, entry_point = get_entry_point_from_class(cls.__module__, cls.__name__)
return entry_point.name

@classmethod
def validate_element(cls, element: str):
"""Validate the given element symbol.
Expand Down
30 changes: 30 additions & 0 deletions tests/data/pseudo/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# pylint: disable=redefined-outer-name
"""Tests that are common to all data plugins in the :py:mod:`~aiida_pseudo.data.pseudo` module."""
import pytest

from aiida import plugins


def get_entry_point_names():
"""Return the registered entry point names for the given common workflow.
:param workflow: the name of the common workflow.
:param leaf: if True, only return the leaf of the entry point name, i.e., the name of plugin that implements it.
:return: list of entry points names.
"""
prefix = 'pseudo.'
entry_points_names = plugins.entry_point.get_entry_point_names('aiida.data')
return [name for name in entry_points_names if name.startswith(prefix)]


@pytest.fixture(scope='function', params=get_entry_point_names())
def entry_point_name(request):
"""Fixture that parametrizes over all the registered subclass implementations of ``PseudoPotentialData``."""
return request.param


def test_get_entry_point_name(entry_point_name):
"""Test the ``PseudoPotentialData.get_entry_point_name`` method."""
cls = plugins.DataFactory(entry_point_name)
assert cls.get_entry_point_name() == entry_point_name

0 comments on commit beb85e6

Please sign in to comment.