Skip to content

Commit

Permalink
TrajectoryData: Fix bug in get_step_data
Browse files Browse the repository at this point in the history
The `get_step_data` would raise a `NameError` in case a `TrajectoryData`
does not define the `cells` array. The variable `cell` would be returned
but it would only be defined if `self.get_cells()` does not return
`None`.
  • Loading branch information
lorisercole authored and sphuber committed Nov 1, 2022
1 parent 798ded5 commit a869c18
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 15 deletions.
35 changes: 26 additions & 9 deletions aiida/orm/nodes/data/array/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""
AiiDA class to deal with crystal structure trajectories.
"""

import collections.abc

from .array import ArrayData
Expand Down Expand Up @@ -319,14 +318,14 @@ def get_index_from_stepid(self, stepid):
raise ValueError(f'{stepid} not among the stepids')

def get_step_data(self, index):
r"""
Return a tuple with all information concerning
the stepid with given index (0 is the first step, 1 the second step
and so on). If you know only the step value, use the
:py:meth:`.get_index_from_stepid` method to get the
corresponding index.
"""
Return a tuple with all information concerning the stepid with given
index (0 is the first step, 1 the second step and so on). If you know
only the step value, use the :py:meth:`.get_index_from_stepid` method
to get the corresponding index.
If no velocities were specified, None is returned as the last element.
If no velocities, cells, or times were specified, None is returned as
the corresponding element.
:return: A tuple in the format
``(stepid, time, cell, symbols, positions, velocities)``,
Expand All @@ -352,6 +351,8 @@ def get_step_data(self, index):
cells = self.get_cells()
if cells is not None:
cell = cells[index, :, :]
else:
cell = None
return (self.get_stepids()[index], time, cell, self.symbols, self.get_positions()[index, :, :], vel)

def get_step_structure(self, index, custom_kinds=None):
Expand All @@ -376,6 +377,8 @@ def get_step_structure(self, index, custom_kinds=None):
:py:class:`aiida.orm.nodes.data.structure.StructureData` nodes is used,
meaning that the strings in the ``symbols`` array must be valid
chemical symbols.
:return: :py:class:`aiida.orm.nodes.data.structure.StructureData` node.
"""
from aiida.orm.nodes.data.structure import Kind, Site, StructureData

Expand Down Expand Up @@ -477,9 +480,23 @@ def get_structure(self, store=False, **kwargs):
.. versionadded:: 1.0
Renamed from _get_aiida_structure
:param converter: specify the converter. Default 'ase'.
:param store: If True, intermediate calculation gets stored in the
AiiDA database for record. Default False.
:param index: The index of the step that you want to retrieve, from
0 to ``self.numsteps- 1``.
:param custom_kinds: (Optional) If passed must be a list of
:py:class:`aiida.orm.nodes.data.structure.Kind` objects. There must be one
kind object for each different string in the ``symbols`` array, with
``kind.name`` set to this string.
If this parameter is omitted, the automatic kind generation of AiiDA
:py:class:`aiida.orm.nodes.data.structure.StructureData` nodes is used,
meaning that the strings in the ``symbols`` array must be valid
chemical symbols.
:param custom_cell: (Optional) The cell matrix of the structure.
If omitted, the cell will be read from the trajectory, if present,
otherwise the default cell of
:py:class:`aiida.orm.nodes.data.structure.StructureData` will be used.
:return: :py:class:`aiida.orm.nodes.data.structure.StructureData` node.
"""
from aiida.orm.nodes.data.dict import Dict
Expand Down
6 changes: 4 additions & 2 deletions aiida/tools/data/array/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
@calcfunction
def _get_aiida_structure_inline(trajectory, parameters):
"""
Creates :py:class:`aiida.orm.nodes.data.structure.StructureData` using ASE.
CalcFunction to extract a :py:class:`aiida.orm.nodes.data.structure.StructureData`
from a `TrajectoryData`.
.. note:: requires ASE module.
:param parameters: A dictionary whose key-value pairs are passed as
additional kwargs to the :py:meth:``TrajectoryData.get_step_structure`` method.
"""
kwargs = {}
if parameters is not None:
Expand Down
78 changes: 74 additions & 4 deletions tests/orm/nodes/data/test_trajectory.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
# -*- coding: utf-8 -*-
# pylint: disable=no-self-use
# pylint: disable=no-self-use,redefined-outer-name
"""Tests for the `TrajectoryData` class."""
import numpy as np
import pytest

from aiida.orm import TrajectoryData, load_node
from aiida.orm import StructureData, TrajectoryData, load_node


@pytest.fixture
def trajectory_data():
"""Return a dictionary of data to create a ``TrajectoryData``."""
symbols = ['H'] * 5 + ['Cl'] * 5
stepids = np.arange(1000, 3000, 10)
times = stepids * 0.01
positions = np.arange(6000, dtype=float).reshape((200, 10, 3))
velocities = -np.arange(6000, dtype=float).reshape((200, 10, 3))
cell = [[[3., 0.1, 0.3], [-0.05, 3., -0.2], [0.02, -0.08, 3.]]]
cells = np.array(cell * 200) + np.arange(0, 0.2, 0.001)[:, np.newaxis, np.newaxis]
return {
'symbols': symbols,
'positions': positions,
'stepids': stepids,
'cells': cells,
'times': times,
'velocities': velocities
}


class TestTrajectory:
"""Test for the `TrajectoryData` class."""

@pytest.fixture(autouse=True)
def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable=unused-argument
def init_profile(self, aiida_profile_clean): # pylint: disable=unused-argument
"""Initialize the profile."""
# pylint: disable=attribute-defined-outside-init

n_atoms = 5
n_steps = 30

Expand Down Expand Up @@ -74,3 +93,54 @@ def test_units(self):
tjd2 = load_node(tjd.pk)
assert tjd2.base.attributes.get('units|positions') == 'some_random_pos_unit'
assert tjd2.base.attributes.get('units|times') == 'some_random_time_unit'

def test_trajectory_get_index_from_stepid(self, trajectory_data):
"""Test the ``get_index_from_stepid`` method."""
trajectory = TrajectoryData()
trajectory.set_trajectory(**trajectory_data)
assert trajectory.get_index_from_stepid(1050) == 5

with pytest.raises(ValueError):
trajectory.get_index_from_stepid(2333)

def test_trajectory_get_step_data(self, trajectory_data):
"""Test the ``get_step_data`` method."""
trajectory = TrajectoryData()
trajectory.set_trajectory(**trajectory_data)
stepid, time, cell, symbols, positions, velocities = trajectory.get_step_data(-2)
assert stepid == trajectory_data['stepids'][-2]
assert time == trajectory_data['times'][-2]
np.array_equal(cell, trajectory_data['cells'][-2, :, :])
np.array_equal(symbols, trajectory_data['symbols'])
np.array_equal(positions, trajectory_data['positions'][-2, :, :])
np.array_equal(velocities, trajectory_data['velocities'][-2, :, :])

def test_trajectory_get_step_data_empty(self, trajectory_data):
"""Test the `get_step_data` method when some arrays are not defined."""
trajectory = TrajectoryData()
trajectory.set_trajectory(symbols=trajectory_data['symbols'], positions=trajectory_data['positions'])
stepid, time, cell, symbols, positions, velocities = trajectory.get_step_data(3)
assert stepid == 3
assert time is None
assert cell is None
assert np.array_equal(symbols, trajectory_data['symbols'])
assert np.array_equal(positions, trajectory_data['positions'][3, :, :])
assert velocities is None

def test_trajectory_get_step_structure(self, trajectory_data):
"""Test the `get_step_structure` method."""
trajectory = TrajectoryData()
trajectory.set_trajectory(**trajectory_data)
structure = trajectory.get_step_structure(50)

expected = StructureData()
expected.cell = trajectory_data['cells'][50]
for symbol, position in zip(trajectory_data['symbols'], trajectory_data['positions'][50, :, :]):
expected.append_atom(symbols=symbol, position=position)

structure.store()
expected.store()
assert structure.get_hash() == expected.get_hash()

with pytest.raises(IndexError):
trajectory.get_step_structure(500)

0 comments on commit a869c18

Please sign in to comment.