Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitizing ArrayData for REST-API #5613

Merged
merged 3 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion aiida/orm/nodes/data/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,11 @@ def _get_array_entries(self):
the value is the numpy array transformed into a list. This is so that
it can be transformed into a json object.
"""

array_dict = {}
for key, val in self.get_iterarrays():
array_dict[key] = val.tolist()

array_dict[key] = clean_array(val)
return array_dict

def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=unused-argument
Expand All @@ -222,3 +224,28 @@ def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=un
json_dict['comments'] = get_file_header(comment_char='')

return json.dumps(json_dict).encode('utf-8'), {}


def clean_array(array):
"""
Replacing np.nan and np.inf/-np.inf for Nones.

The function will also sanitize the array removing ``np.nan`` and ``np.inf``
for ``None`` of this way the resulting JSON is always valid.
Both ``np.nan`` and ``np.inf``/``-np.inf`` are set to None to be in
accordance with the
`ECMA-262 standard <https://www.ecma-international.org/publications-and-standards/standards/ecma-262/>`_.

:param array: input array to be cleaned
:return: cleaned list to be serialized
:rtype: list
"""
import numpy as np

output = np.reshape(
np.asarray([
entry if not np.isnan(entry) and not np.isinf(entry) else None for entry in array.flatten().tolist()
]), array.shape
)

return output.tolist()
5 changes: 5 additions & 0 deletions docs/source/howto/share_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ The AiiDA REST API allows to query your AiiDA database over HTTP(S) and returns
As of October 2020, the AiiDA REST API only supports ``GET`` methods (reading); in particular, it does *not* yet support workflow management.
This feature is, however, part of the `AiiDA roadmap <https://github.com/aiidateam/aiida-core/wiki/AiiDA-release-roadmap>`_.


.. note::
To ensure that when serving ``orm.ArrayData`` one always obtains a valid JSON compliant with the `ECMA-262 standard <https://www.ecma-international.org/publications-and-standards/standards/ecma-262/>`_, any ``np.nan``, ``np.inf`` and/or ``-np.inf`` entries will be replaced by ``None`` which will be rendered as ``null`` when getting the array via the API call.


.. _how-to:share:serve:launch:

Launching the REST API
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/rest_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ Querybuilder
Posts a query to the database. The content of the query is passed in a attached JSON file.

To use this endpoint, you need a http operator that allows to pass attachments.
We will demonstrate two options, the `HTTPie <https://httpie.io/>`_ (to use in the terminal) and the python library `Requests <https://docs.python-requests.org/en/latest/#>`_ (to use in python).
We will demonstrate two options, the `HTTPie <https://httpie.io/>`_ (to use in the terminal) and the python library `Requests <https://requests.readthedocs.io/en/latest/>`_ (to use in python).

Option 1: HTTPie

Expand Down
38 changes: 34 additions & 4 deletions tests/restapi/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
import json

from flask_cors.core import ACL_ORIGIN
import numpy as np
import pytest

from aiida import orm
from aiida.common.links import LinkType
from aiida.manage import get_manager
from aiida.orm.nodes.data.array.array import clean_array
from aiida.restapi.run_api import configure_api


Expand Down Expand Up @@ -141,6 +143,12 @@ def init_profile(self, aiida_profile_clean, aiida_localhost): # pylint: disable
computer = orm.Computer(**dummy_computer)
computer.store()

# Setting array data for the tests
array = orm.ArrayData()
array.set_array('array_clean', np.asarray([[4, 5, 7], [9, 5, 1], [3, 4, 4]]))
array.set_array('array_dirty', np.asarray([[4, 5, np.nan], [9, np.inf, -1 * np.inf], [np.nan, 4, 4]]))
array.store()

# Prepare typical REST responses
self.process_dummy_data()

Expand Down Expand Up @@ -205,6 +213,7 @@ def process_dummy_data(self):
'parameterdata': orm.Dict,
'structuredata': orm.StructureData,
'data': orm.Data,
'arraydata': orm.ArrayData,
}
for label, dataclass in data_types.items():
data = orm.QueryBuilder().append(dataclass, tag='data', project=data_projections).order_by({
Expand Down Expand Up @@ -300,7 +309,6 @@ def process_test(
if result_node_type is None and result_name is None:
result_node_type = entity_type
result_name = entity_type

url = self._url_prefix + url

with self.app.test_client() as client:
Expand All @@ -321,7 +329,6 @@ def process_test(
else:
from aiida.common.exceptions import InputValidationError
raise InputValidationError('Pass the expected range of the dummydata')

expected_node_uuids = [node['uuid'] for node in expected_data]
result_node_uuids = [node['uuid'] for node in response['data'][result_name]]
assert expected_node_uuids == result_node_uuids
Expand Down Expand Up @@ -745,7 +752,7 @@ def test_calculation_inputs(self):
self.process_test(
'nodes',
f'/nodes/{str(node_uuid)}/links/incoming?orderby=id',
expected_list_ids=[5, 3],
expected_list_ids=[6, 4],
uuid=node_uuid,
result_node_type='data',
result_name='incoming'
Expand All @@ -759,7 +766,7 @@ def test_calculation_input_filters(self):
self.process_test(
'nodes',
f"/nodes/{str(node_uuid)}/links/incoming?node_type=\"data.core.dict.Dict.\"",
expected_list_ids=[3],
expected_list_ids=[4],
uuid=node_uuid,
result_node_type='data',
result_name='incoming'
Expand Down Expand Up @@ -1347,3 +1354,26 @@ def test_querybuilder_project_implicit(self):
# All are Nodes, and all properties are projected, full_type should be present
assert 'full_type' in entity
assert 'attributes' in entity

def test_array_download(self):
"""
Test download of arraydata as a json file
"""
from aiida.orm import load_node

node_uuid = self.get_dummy_data()['arraydata'][0]['uuid']
url = f'{self.get_url_prefix()}/nodes/{node_uuid}/download?download_format=json&download=False'
with self.app.test_client() as client:
rv_obj = client.get(url)

data_json = json.loads(rv_obj.json['data']['download']['data'])

assert json.dumps(data_json, allow_nan=False)

data_array = load_node(node_uuid)
array_names = data_array.get_arraynames()
for name in array_names:
if not np.isnan(data_array.get_array(name)).any():
assert np.allclose(data_array.get_array(name), data_json[name])
else:
assert clean_array(data_array.get_array(name)) == data_json[name]