Skip to content

Commit

Permalink
Fix REST API attributes_filter and extras_filter (#3556)
Browse files Browse the repository at this point in the history
The REST API `attributes_filter` and `extras_filter` were not working
correctly when pagination is requested. Without pagination they return
attributes and extras dictionaries with the keys specified in
`attributes_filter` and `extras_filter` as expected. When pagination is
added they return instead the attributes keys as `attributes.key`, which
is not consistent with the expected behaviour.

Changed REST API version to 4.0.1.
  • Loading branch information
elsapassaro authored and sphuber committed Nov 22, 2019
1 parent 952282e commit 6a66c6b
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 19 deletions.
70 changes: 70 additions & 0 deletions aiida/backends/tests/test_restapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,76 @@ def test_structure_attributes_filter(self):
response = json.loads(rv_obj.data)
self.assertEqual(response['data']['nodes'][0]['attributes']['cell'], cell)

############### node attributes_filter with pagination #############
def test_node_attributes_filter_pagination(self):
"""
Check that node attributes specified in attributes_filter are
returned as a dictionary when pagination is set
"""
expected_attributes = ['resources', 'cell']
url = self.get_url_prefix() + '/nodes/page/1?perpage=10&attributes=true&attributes_filter=resources,cell'
with self.app.test_client() as client:
response_value = client.get(url)
response = json.loads(response_value.data)
self.assertNotEqual(len(response['data']['nodes']), 0)
for node in response['data']['nodes']:
self.assertIn('attributes', node)
self.assertNotIn('attributes.resources', node)
self.assertNotIn('attributes.cell', node)
self.assertEqual(len(node['attributes']), len(expected_attributes))
for attr in expected_attributes:
self.assertIn(attr, node['attributes'])

############### node get one attributes_filter with pagination #############
def test_node_single_attributes_filter(self):
"""
Check that when only one node attribute is specified in attributes_filter
only this attribute is returned as a dictionary when pagination is set
"""
expected_attribute = ['resources']
url = self.get_url_prefix() + '/nodes/page/1?perpage=10&attributes=true&attributes_filter=resources'
with self.app.test_client() as client:
response_value = client.get(url)
response = json.loads(response_value.data)
self.assertNotEqual(len(response['data']['nodes']), 0)
for node in response['data']['nodes']:
self.assertEqual(list(node['attributes'].keys()), expected_attribute)

############### node extras_filter with pagination #############
def test_node_extras_filter_pagination(self):
"""
Check that node extras specified in extras_filter are
returned as a dictionary when pagination is set
"""
expected_extras = ['extra1', 'extra2']
url = self.get_url_prefix() + '/nodes/page/1?perpage=10&extras=true&extras_filter=extra1,extra2'
with self.app.test_client() as client:
response_value = client.get(url)
response = json.loads(response_value.data)
self.assertNotEqual(len(response['data']['nodes']), 0)
for node in response['data']['nodes']:
self.assertIn('extras', node)
self.assertNotIn('extras.extra1', node)
self.assertNotIn('extras.extra2', node)
self.assertEqual(len(node['extras']), len(expected_extras))
for extra in expected_extras:
self.assertIn(extra, node['extras'])

############### node get one extras_filter with pagination #############
def test_node_single_extras_filter(self):
"""
Check that when only one node extra is specified in extras_filter
only this extra is returned as a dictionary when pagination is set
"""
expected_extra = ['extra2']
url = self.get_url_prefix() + '/nodes/page/1?perpage=10&extras=true&extras_filter=extra2'
with self.app.test_client() as client:
response_value = client.get(url)
response = json.loads(response_value.data)
self.assertNotEqual(len(response['data']['nodes']), 0)
for node in response['data']['nodes']:
self.assertEqual(list(node['extras'].keys()), expected_extra)

############### node full_type filter #############
def test_nodes_full_type_filter(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion aiida/restapi/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

##Version prefix for all the URLs
PREFIX = '/api/v4'
VERSION = '4.0.0'
VERSION = '4.0.1'
"""
Flask app configs.
Expand Down
36 changes: 18 additions & 18 deletions aiida/restapi/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,24 +328,6 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid-
## Retrieve results
results = self.trans.get_results()

if attributes_filter is not None and attributes:
for node in results['nodes']:
node['attributes'] = {}
if not isinstance(attributes_filter, list):
attributes_filter = [attributes_filter]
for attr in attributes_filter:
node['attributes'][str(attr)] = node['attributes.' + str(attr)]
del node['attributes.' + str(attr)]

if extras_filter is not None and extras:
for node in results['nodes']:
node['extras'] = {}
if not isinstance(extras_filter, list):
extras_filter = [extras_filter]
for extra in extras_filter:
node['extras'][str(extra)] = node['extras.' + str(extra)]
del node['extras.' + str(extra)]

if query_type == 'repo_contents' and results:
response = make_response(results)
response.headers['content-type'] = 'application/octet-stream'
Expand All @@ -366,6 +348,24 @@ def get(self, id=None, page=None): # pylint: disable=redefined-builtin,invalid-

headers = self.utils.build_headers(url=request.url, total_count=total_count)

if attributes_filter is not None and attributes:
for node in results['nodes']:
node['attributes'] = {}
if not isinstance(attributes_filter, list):
attributes_filter = [attributes_filter]
for attr in attributes_filter:
node['attributes'][str(attr)] = node['attributes.' + str(attr)]
del node['attributes.' + str(attr)]

if extras_filter is not None and extras:
for node in results['nodes']:
node['extras'] = {}
if not isinstance(extras_filter, list):
extras_filter = [extras_filter]
for extra in extras_filter:
node['extras'][str(extra)] = node['extras.' + str(extra)]
del node['extras.' + str(extra)]

## Build response
data = dict(
method=request.method,
Expand Down

0 comments on commit 6a66c6b

Please sign in to comment.