From 7fbe67cb6273cf2bae4256cdbda284aeb89a9372 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 10 Aug 2023 20:42:30 +0200 Subject: [PATCH] `ArrayData`: Make `name` optional in `get_array` The `ArrayData` was designed to be able to store multiple numpy arrays. While useful, it forced users to be more verbose than necessary when only storing a single array as an explicit array name is always required: node = ArrayData() node.set_array('some_key', numpy.array([])) node.get_array('some_key') The `get_array` method is updated to allow `None` for the `name` argument as long as the node only stores a single array so that it can return the correct array unambiguously. This simplifies typical user code significantly: node = ArrayData(numpy.array([])) node.get_array() --- aiida/orm/nodes/data/array/array.py | 18 ++++++++++++++++-- docs/source/nitpick-exceptions | 1 + tests/orm/nodes/data/test_array.py | 18 ++++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/aiida/orm/nodes/data/array/array.py b/aiida/orm/nodes/data/array/array.py index 9d71908b77..602aa3d939 100644 --- a/aiida/orm/nodes/data/array/array.py +++ b/aiida/orm/nodes/data/array/array.py @@ -130,14 +130,28 @@ def get_iterarrays(self) -> Iterator[tuple[str, 'ndarray']]: for name in self.get_arraynames(): yield (name, self.get_array(name)) - def get_array(self, name: str) -> 'ndarray': + def get_array(self, name: str | None = None) -> 'ndarray': """ Return an array stored in the node - :param name: The name of the array to return. + :param name: The name of the array to return. The name can be omitted in case the node contains only a single + array, which will be returned in that case. If ``name`` is ``None`` and the node contains multiple arrays or + no arrays at all a ``ValueError`` is raised. + :raises ValueError: If ``name`` is ``None`` and the node contains more than one arrays or no arrays at all. """ import numpy + if name is None: + names = self.get_arraynames() + narrays = len(names) + + if narrays == 0: + raise ValueError('`name` not specified but the node contains no arrays.') + if narrays > 1: + raise ValueError('`name` not specified but the node contains multiple arrays.') + + name = names[0] + def get_array_from_file(self, name: str) -> 'ndarray': """Return the array stored in a .npy file""" filename = f'{name}.npy' diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 76bb64c9a4..60e0c89b89 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -130,6 +130,7 @@ py:class disk_objectstore.utils.LazyOpener py:class frozenset py:class numpy.bool_ +py:class ndarray py:class paramiko.proxy.ProxyCommand diff --git a/tests/orm/nodes/data/test_array.py b/tests/orm/nodes/data/test_array.py index 1d31788bdc..659e56048f 100644 --- a/tests/orm/nodes/data/test_array.py +++ b/tests/orm/nodes/data/test_array.py @@ -9,6 +9,7 @@ ########################################################################### """Tests for the :mod:`aiida.orm.nodes.data.array.array` module.""" import numpy +import pytest from aiida.orm import ArrayData, load_node @@ -43,3 +44,20 @@ def test_constructor(): assert sorted(node.get_arraynames()) == ['a', 'b'] assert (node.get_array('a') == arrays['a']).all() assert (node.get_array('b') == arrays['b']).all() + + +def test_get_array(): + """Test :meth:`aiida.orm.nodes.data.array.array.ArrayData:get_array`.""" + node = ArrayData() + with pytest.raises(ValueError, match='`name` not specified but the node contains no arrays.'): + node.get_array() + + node = ArrayData({'a': numpy.array([]), 'b': numpy.array([])}) + with pytest.raises(ValueError, match='`name` not specified but the node contains multiple arrays.'): + node.get_array() + + node = ArrayData({'a': numpy.array([1, 2])}) + assert (node.get_array() == numpy.array([1, 2])).all() + + node = ArrayData(numpy.array([1, 2])) + assert (node.get_array() == numpy.array([1, 2])).all()