diff --git a/aiida/orm/nodes/data/array/trajectory.py b/aiida/orm/nodes/data/array/trajectory.py index 70f29c252f..d1c0a7585e 100644 --- a/aiida/orm/nodes/data/array/trajectory.py +++ b/aiida/orm/nodes/data/array/trajectory.py @@ -10,7 +10,6 @@ """ AiiDA class to deal with crystal structure trajectories. """ - import collections.abc from .array import ArrayData @@ -319,14 +318,14 @@ def get_index_from_stepid(self, stepid): raise ValueError(f'{stepid} not among the stepids') def get_step_data(self, index): - r""" - Return a tuple with all information concerning - the stepid with given index (0 is the first step, 1 the second step - and so on). If you know only the step value, use the - :py:meth:`.get_index_from_stepid` method to get the - corresponding index. + """ + Return a tuple with all information concerning the stepid with given + index (0 is the first step, 1 the second step and so on). If you know + only the step value, use the :py:meth:`.get_index_from_stepid` method + to get the corresponding index. - If no velocities were specified, None is returned as the last element. + If no velocities, cells, or times were specified, None is returned as + the corresponding element. :return: A tuple in the format ``(stepid, time, cell, symbols, positions, velocities)``, @@ -352,6 +351,8 @@ def get_step_data(self, index): cells = self.get_cells() if cells is not None: cell = cells[index, :, :] + else: + cell = None return (self.get_stepids()[index], time, cell, self.symbols, self.get_positions()[index, :, :], vel) def get_step_structure(self, index, custom_kinds=None): @@ -376,6 +377,8 @@ def get_step_structure(self, index, custom_kinds=None): :py:class:`aiida.orm.nodes.data.structure.StructureData` nodes is used, meaning that the strings in the ``symbols`` array must be valid chemical symbols. + + :return: :py:class:`aiida.orm.nodes.data.structure.StructureData` node. """ from aiida.orm.nodes.data.structure import Kind, Site, StructureData @@ -477,9 +480,23 @@ def get_structure(self, store=False, **kwargs): .. versionadded:: 1.0 Renamed from _get_aiida_structure - :param converter: specify the converter. Default 'ase'. :param store: If True, intermediate calculation gets stored in the AiiDA database for record. Default False. + :param index: The index of the step that you want to retrieve, from + 0 to ``self.numsteps- 1``. + :param custom_kinds: (Optional) If passed must be a list of + :py:class:`aiida.orm.nodes.data.structure.Kind` objects. There must be one + kind object for each different string in the ``symbols`` array, with + ``kind.name`` set to this string. + If this parameter is omitted, the automatic kind generation of AiiDA + :py:class:`aiida.orm.nodes.data.structure.StructureData` nodes is used, + meaning that the strings in the ``symbols`` array must be valid + chemical symbols. + :param custom_cell: (Optional) The cell matrix of the structure. + If omitted, the cell will be read from the trajectory, if present, + otherwise the default cell of + :py:class:`aiida.orm.nodes.data.structure.StructureData` will be used. + :return: :py:class:`aiida.orm.nodes.data.structure.StructureData` node. """ from aiida.orm.nodes.data.dict import Dict diff --git a/aiida/tools/data/array/trajectory.py b/aiida/tools/data/array/trajectory.py index 62cbd4fe9d..73b3e48a96 100644 --- a/aiida/tools/data/array/trajectory.py +++ b/aiida/tools/data/array/trajectory.py @@ -15,9 +15,11 @@ @calcfunction def _get_aiida_structure_inline(trajectory, parameters): """ - Creates :py:class:`aiida.orm.nodes.data.structure.StructureData` using ASE. + CalcFunction to extract a :py:class:`aiida.orm.nodes.data.structure.StructureData` + from a `TrajectoryData`. - .. note:: requires ASE module. + :param parameters: A dictionary whose key-value pairs are passed as + additional kwargs to the :py:meth:``TrajectoryData.get_step_structure`` method. """ kwargs = {} if parameters is not None: diff --git a/tests/orm/nodes/data/test_trajectory.py b/tests/orm/nodes/data/test_trajectory.py index 7b55e93878..100b3bb1b4 100644 --- a/tests/orm/nodes/data/test_trajectory.py +++ b/tests/orm/nodes/data/test_trajectory.py @@ -1,20 +1,39 @@ # -*- coding: utf-8 -*- -# pylint: disable=no-self-use +# pylint: disable=no-self-use,redefined-outer-name """Tests for the `TrajectoryData` class.""" import numpy as np import pytest -from aiida.orm import TrajectoryData, load_node +from aiida.orm import StructureData, TrajectoryData, load_node + + +@pytest.fixture +def trajectory_data(): + """Return a dictionary of data to create a ``TrajectoryData``.""" + symbols = ['H'] * 5 + ['Cl'] * 5 + stepids = np.arange(1000, 3000, 10) + times = stepids * 0.01 + positions = np.arange(6000, dtype=float).reshape((200, 10, 3)) + velocities = -np.arange(6000, dtype=float).reshape((200, 10, 3)) + cell = [[[3., 0.1, 0.3], [-0.05, 3., -0.2], [0.02, -0.08, 3.]]] + cells = np.array(cell * 200) + np.arange(0, 0.2, 0.001)[:, np.newaxis, np.newaxis] + return { + 'symbols': symbols, + 'positions': positions, + 'stepids': stepids, + 'cells': cells, + 'times': times, + 'velocities': velocities + } class TestTrajectory: """Test for the `TrajectoryData` class.""" @pytest.fixture(autouse=True) - def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument + def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument """Initialize the profile.""" # pylint: disable=attribute-defined-outside-init - n_atoms = 5 n_steps = 30 @@ -74,3 +93,54 @@ def test_units(self): tjd2 = load_node(tjd.pk) assert tjd2.base.attributes.get('units|positions') == 'some_random_pos_unit' assert tjd2.base.attributes.get('units|times') == 'some_random_time_unit' + + def test_trajectory_get_index_from_stepid(self, trajectory_data): + """Test the ``get_index_from_stepid`` method.""" + trajectory = TrajectoryData() + trajectory.set_trajectory(**trajectory_data) + assert trajectory.get_index_from_stepid(1050) == 5 + + with pytest.raises(ValueError): + trajectory.get_index_from_stepid(2333) + + def test_trajectory_get_step_data(self, trajectory_data): + """Test the ``get_step_data`` method.""" + trajectory = TrajectoryData() + trajectory.set_trajectory(**trajectory_data) + stepid, time, cell, symbols, positions, velocities = trajectory.get_step_data(-2) + assert stepid == trajectory_data['stepids'][-2] + assert time == trajectory_data['times'][-2] + np.array_equal(cell, trajectory_data['cells'][-2, :, :]) + np.array_equal(symbols, trajectory_data['symbols']) + np.array_equal(positions, trajectory_data['positions'][-2, :, :]) + np.array_equal(velocities, trajectory_data['velocities'][-2, :, :]) + + def test_trajectory_get_step_data_empty(self, trajectory_data): + """Test the `get_step_data` method when some arrays are not defined.""" + trajectory = TrajectoryData() + trajectory.set_trajectory(symbols=trajectory_data['symbols'], positions=trajectory_data['positions']) + stepid, time, cell, symbols, positions, velocities = trajectory.get_step_data(3) + assert stepid == 3 + assert time is None + assert cell is None + assert np.array_equal(symbols, trajectory_data['symbols']) + assert np.array_equal(positions, trajectory_data['positions'][3, :, :]) + assert velocities is None + + def test_trajectory_get_step_structure(self, trajectory_data): + """Test the `get_step_structure` method.""" + trajectory = TrajectoryData() + trajectory.set_trajectory(**trajectory_data) + structure = trajectory.get_step_structure(50) + + expected = StructureData() + expected.cell = trajectory_data['cells'][50] + for symbol, position in zip(trajectory_data['symbols'], trajectory_data['positions'][50, :, :]): + expected.append_atom(symbols=symbol, position=position) + + structure.store() + expected.store() + assert structure.get_hash() == expected.get_hash() + + with pytest.raises(IndexError): + trajectory.get_step_structure(500)