Skip to content

Commit

Permalink
Simplify loading of node class from type string (#2376)
Browse files Browse the repository at this point in the history
The type string of a `CalcJobNode` should now be just that of the
node class and not the sub class. Since the users are currently still
sub classing the node class and not a process class, we have to make
a special exception when generating the type string for the node.

Conversely, this type string, stored in the `type` column of the node,
should be used to load the `CalcJobNode` class when loading from
the database. The only classes that can legally exist in a database, and
therefore be loaded, are defined in `aiida-core`. Therefore, using the
entry point system to map the type string onto an actual ORM class is no
longer necessary. We rename the `aiida.plugins.loader.load_plugin`
function to the more correct `load_node_class`, which given a type string,
will return the corresponding ORM node sub class.

Note that the whole machinery around generating type and query strings
and loading the nodes based on them is still somewhat convoluted and
contains hacks for two reasons:

 1) Data is not yet moved within the `aiida.orm.node` sub module and as
    a result gets the `data.Data.` type string, which will not match the
    `node.Node.` type when sub classing in queries.

 2) CalcJobProcesses are defined by sub classing JobCalculation
    Until the user directly define a Process sub class that uses the
    `CalcJobNode` as its node class, exceptions will have to be made.

If these two issues are addressed, a lot of the code around type strings
can be simplified and cleaned up.
  • Loading branch information
sphuber authored Jan 10, 2019
1 parent dd0b1b9 commit 1e339c3
Show file tree
Hide file tree
Showing 20 changed files with 205 additions and 244 deletions.
98 changes: 43 additions & 55 deletions .ci/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,20 @@
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
import os

import subprocess
import sys
import time

from six.moves import range

from aiida.common.exceptions import NotExistent
from aiida.common import exceptions
from aiida.manage.caching import enable_caching
from aiida.daemon.client import get_daemon_client
from aiida.orm import Code, CalculationFactory, DataFactory
from aiida.orm.data.int import Int
from aiida.orm.data.str import Str
from aiida.orm.data.list import List
from aiida.orm.data.structure import StructureData
from aiida.orm.node.process import CalcJobNode
from aiida.work.launch import run_get_node, submit
from aiida.work.persistence import ObjectLoader
Expand All @@ -37,9 +36,9 @@
ParameterData = DataFactory('parameter')

codename = 'doubler@torquessh'
timeout_secs = 4 * 60 # 4 minutes
number_calculations = 15 # Number of calculations to submit
number_workchains = 8 # Number of workchains to submit
timeout_secs = 4 * 60 # 4 minutes
number_calculations = 15 # Number of calculations to submit
number_workchains = 8 # Number of workchains to submit


def print_daemon_log():
Expand Down Expand Up @@ -68,11 +67,11 @@ def jobs_have_finished(pks):
return not (False in finished_list)


def print_logshow(pk):
print("Output of 'verdi calculation logshow {}':".format(pk))
def print_report(pk):
print("Output of 'verdi process report {}':".format(pk))
try:
print(subprocess.check_output(
["verdi", "calculation", "logshow", "{}".format(pk)],
["verdi", "process", "report", "{}".format(pk)],
stderr=subprocess.STDOUT,
))
except subprocess.CalledProcessError as exception:
Expand All @@ -88,14 +87,14 @@ def validate_calculations(expected_results):
if not calc.is_finished_ok:
print('Calculation<{}> not finished ok: process_state<{}> exit_status<{}>'
.format(pk, calc.process_state, calc.exit_status))
print_logshow(pk)
print_report(pk)
valid = False

try:
actual_dict = calc.out.output_parameters.get_dict()
except (KeyError, AttributeError) as exception:
except exceptions.NotExistent:
print('Could not retrieve output_parameters node for Calculation<{}>'.format(pk))
print_logshow(pk)
print_report(pk)
valid = False

try:
Expand All @@ -120,7 +119,7 @@ def validate_workchains(expected_results):
try:
calc = load_node(pk)
actual_value = calc.out.output
except (NotExistent, AttributeError) as exception:
except (exceptions.NotExistent, AttributeError) as exception:
print("* UNABLE TO RETRIEVE VALUE for workchain pk={}: I expected {}, I got {}: {}"
.format(pk, expected_value, type(exception), exception))
valid = False
Expand All @@ -131,7 +130,7 @@ def validate_workchains(expected_results):
if this_valid and not calc.is_finished_ok:
print('Calculation<{}> not finished ok: process_state<{}> exit_status<{}>'
.format(pk, calc.process_state, calc.exit_status))
print_logshow(pk)
print_report(pk)
valid = False
this_valid = False

Expand All @@ -156,18 +155,18 @@ def validate_cached(cached_calcs):
if not calc.is_finished_ok:
print('Cached calculation<{}> not finished ok: process_state<{}> exit_status<{}>'
.format(calc.pk, calc.process_state, calc.exit_status))
print_logshow(calc.pk)
print_report(calc.pk)
valid = False

if '_aiida_cached_from' not in calc.extras() or calc.get_hash() != calc.get_extra('_aiida_hash'):
print('Cached calculation<{}> has invalid hash'.format(calc.pk))
print_logshow(calc.pk)
print_report(calc.pk)
valid = False

if isinstance(calc, CalcJobNode):
if 'raw_input' not in calc.folder.get_content_list():
print("Cached calculation <{}> does not have a 'raw_input' folder".format(calc.pk))
print_logshow(calc.pk)
print_report(calc.pk)
valid = False
original_calc = load_node(calc.get_extra('_aiida_cached_from'))
if 'raw_input' not in original_calc.folder.get_content_list():
Expand All @@ -181,11 +180,11 @@ def validate_cached(cached_calcs):
def create_calculation(code, counter, inputval, use_cache=False):
parameters = ParameterData(dict={'value': inputval})
template = ParameterData(dict={
## The following line adds a significant sleep time.
## I set it to 1 second to speed up tests
## I keep it to a non-zero value because I want
## To test the case when AiiDA finds some calcs
## in a queued state
# The following line adds a significant sleep time.
# I set it to 1 second to speed up tests
# I keep it to a non-zero value because I want
# To test the case when AiiDA finds some calcs
# in a queued state
# 'cmdline_params': ["{}".format(counter % 3)], # Sleep time
'cmdline_params': ["1"],
'input_file_template': "{value}", # File just contains the value to double
Expand Down Expand Up @@ -220,6 +219,7 @@ def submit_calculation(code, counter, inputval):
print("[{}] calculation submitted.".format(counter))
return calc, expected_result


def launch_calculation(code, counter, inputval):
"""
Launch calculations to the daemon through the Process layer
Expand All @@ -229,6 +229,7 @@ def launch_calculation(code, counter, inputval):
print("[{}] launched calculation {}, pk={}".format(counter, calc.uuid, calc.dbnode.pk))
return calc, expected_result


def run_calculation(code, counter, inputval):
"""
Run a calculation through the Process layer.
Expand All @@ -238,6 +239,7 @@ def run_calculation(code, counter, inputval):
print("[{}] ran calculation {}, pk={}".format(counter, calc.uuid, calc.pk))
return calc, expected_result


def create_calculation_process(code, inputval):
"""
Create the process and inputs for a submitting / running a calculation.
Expand All @@ -247,18 +249,18 @@ def create_calculation_process(code, inputval):

parameters = ParameterData(dict={'value': inputval})
template = ParameterData(dict={
## The following line adds a significant sleep time.
## I set it to 1 second to speed up tests
## I keep it to a non-zero value because I want
## To test the case when AiiDA finds some calcs
## in a queued state
#'cmdline_params': ["{}".format(counter % 3)], # Sleep time
'cmdline_params': ["1"],
'input_file_template': "{value}", # File just contains the value to double
'input_file_name': 'value_to_double.txt',
'output_file_name': 'output.txt',
'retrieve_temporary_files': ['triple_value.tmp']
})
# The following line adds a significant sleep time.
# I set it to 1 second to speed up tests
# I keep it to a non-zero value because I want
# To test the case when AiiDA finds some calcs
# in a queued state
# 'cmdline_params': ["{}".format(counter % 3)], # Sleep time
'cmdline_params': ["1"],
'input_file_template': "{value}", # File just contains the value to double
'input_file_name': 'value_to_double.txt',
'output_file_name': 'output.txt',
'retrieve_temporary_files': ['triple_value.tmp']
})
options = {
'resources': {
'num_machines': 1
Expand All @@ -283,13 +285,15 @@ def create_calculation_process(code, inputval):
}
return process, inputs, expected_result


def create_cache_calc(code, counter, inputval):
calc, expected_result = create_calculation(
code=code, counter=counter, inputval=inputval, use_cache=True
)
print("[{}] created cached calculation.".format(counter))
return calc, expected_result


def main():
expected_results_calculations = {}
expected_results_workchains = {}
Expand Down Expand Up @@ -392,24 +396,6 @@ def main():
except subprocess.CalledProcessError as e:
print("Note: the command failed, message: {}".format(e))

print("Output of 'verdi calculation list -a':")
try:
print(subprocess.check_output(
["verdi", "calculation", "list", "-a"],
stderr=subprocess.STDOUT,
))
except subprocess.CalledProcessError as e:
print("Note: the command failed, message: {}".format(e))

print("Output of 'verdi work list':")
try:
print(subprocess.check_output(
['verdi', 'work', 'list', '-a', '-p1'],
stderr=subprocess.STDOUT,
))
except subprocess.CalledProcessError as e:
print("Note: the command failed, message: {}".format(e))

print("Output of 'verdi daemon status':")
try:
print(subprocess.check_output(
Expand Down Expand Up @@ -445,9 +431,11 @@ def main():
cached_calcs.append(calc)
expected_results_calculations[calc.pk] = expected_result

if (validate_calculations(expected_results_calculations)
and validate_workchains(expected_results_workchains)
and validate_cached(cached_calcs)):
if (
validate_calculations(expected_results_calculations) and
validate_workchains(expected_results_workchains) and
validate_cached(cached_calcs)
):
print_daemon_log()
print("")
print("OK, all calculations have the expected parsed result")
Expand Down
10 changes: 4 additions & 6 deletions aiida/backends/djsite/db/subtests/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ def test_clsf_django(self):
from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.data.structure import StructureData
from aiida.orm import Group, Node, Computer, Data
from aiida.common.exceptions import InputValidationError
from aiida.common.exceptions import DbContentError
qb = QueryBuilder()

with self.assertRaises(InputValidationError):
with self.assertRaises(DbContentError):
qb._get_ormclass(None, 'data')
with self.assertRaises(InputValidationError):
with self.assertRaises(DbContentError):
qb._get_ormclass(None, 'data.Data')
with self.assertRaises(InputValidationError):
with self.assertRaises(DbContentError):
qb._get_ormclass(None, '.')

for cls, clstype, query_type_string in (
Expand All @@ -47,9 +47,7 @@ def test_clsf_django(self):
StructureData._query_type_string)

for cls, clstype, query_type_string in (
qb._get_ormclass(Node, None),
qb._get_ormclass(DbNode, None),
qb._get_ormclass(None, '')
):
self.assertEqual(clstype, Node._plugin_type_string)
self.assertEqual(query_type_string, Node._query_type_string)
Expand Down
12 changes: 6 additions & 6 deletions aiida/backends/tests/cmdline/commands/test_rehash.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,48 +42,48 @@ def test_rehash(self):
expected_node_count = 5
options = []
result = self.cli_runner.invoke(cmd_rehash.rehash, options)
self.assertClickResultNoException(result)
self.assertTrue('{} nodes'.format(expected_node_count) in result.output)
self.assertIsNone(result.exception, result.output)

def test_rehash_bool(self):
"""Limiting the queryset by defining an entry point, in this case bool, should limit nodes to 2."""
expected_node_count = 2
options = ['-e', 'aiida.data:bool']
result = self.cli_runner.invoke(cmd_rehash.rehash, options)
self.assertClickResultNoException(result)
self.assertTrue('{} nodes'.format(expected_node_count) in result.output)
self.assertIsNone(result.exception, result.output)

def test_rehash_float(self):
"""Limiting the queryset by defining an entry point, in this case float, should limit nodes to 1."""
expected_node_count = 1
options = ['-e', 'aiida.data:float']
result = self.cli_runner.invoke(cmd_rehash.rehash, options)
self.assertClickResultNoException(result)
self.assertTrue('{} nodes'.format(expected_node_count) in result.output)
self.assertIsNone(result.exception, result.output)

def test_rehash_int(self):
"""Limiting the queryset by defining an entry point, in this case int, should limit nodes to 1."""
expected_node_count = 1
options = ['-e', 'aiida.data:int']
result = self.cli_runner.invoke(cmd_rehash.rehash, options)
self.assertClickResultNoException(result)
self.assertTrue('{} nodes'.format(expected_node_count) in result.output)
self.assertIsNone(result.exception, result.output)

def test_rehash_explicit_pk(self):
"""Limiting the queryset by defining explicit identifiers, should limit nodes to 2 in this example."""
expected_node_count = 2
options = [str(self.node_bool_true.pk), str(self.node_float.uuid)]
result = self.cli_runner.invoke(cmd_rehash.rehash, options)
self.assertClickResultNoException(result)
self.assertTrue('{} nodes'.format(expected_node_count) in result.output)
self.assertIsNone(result.exception, result.output)

def test_rehash_explicit_pk_and_entry_point(self):
"""Limiting the queryset by defining explicit identifiers and entry point, should limit nodes to 1."""
expected_node_count = 1
options = ['-e', 'aiida.data:bool', str(self.node_bool_true.pk), str(self.node_float.uuid)]
result = self.cli_runner.invoke(cmd_rehash.rehash, options)
self.assertClickResultNoException(result)
self.assertTrue('{} nodes'.format(expected_node_count) in result.output)
self.assertIsNone(result.exception, result.output)

def test_rehash_entry_point_no_matches(self):
"""Limiting the queryset by defining explicit entry point, with no nodes should exit with non-zero status."""
Expand Down
14 changes: 3 additions & 11 deletions aiida/backends/tests/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,6 @@ def test_with_subclasses(self):
a1 = CalcJobNode(**calc_params).store()
# To query only these nodes later
a1.set_extra(extra_name, True)
a2 = TemplateReplacerCalc(**calc_params).store()
# To query only these nodes later
a2.set_extra(extra_name, True)
a3 = Data().store()
a3.set_extra(extra_name, True)
a4 = ParameterData(dict={'a': 'b'}).store()
Expand All @@ -284,13 +281,13 @@ def test_with_subclasses(self):
results = [_ for [_] in qb.all()]
# a3, a4 should not be found because they are not CalcJobNodes.
# a6, a7 should not be found because they have not the attribute set.
self.assertEquals(set([i.pk for i in results]), set([a1.pk, a2.pk]))
self.assertEquals(set([i.pk for i in results]), set([a1.pk]))

# Same query, but by the generic Node class
qb = QueryBuilder()
qb.append(Node, filters={'extras': {'has_key': extra_name}})
results = [_ for [_] in qb.all()]
self.assertEquals(set([i.pk for i in results]), set([a1.pk, a2.pk, a3.pk, a4.pk]))
self.assertEquals(set([i.pk for i in results]), set([a1.pk, a3.pk, a4.pk]))

# Same query, but by the Data class
qb = QueryBuilder()
Expand All @@ -304,12 +301,6 @@ def test_with_subclasses(self):
results = [_ for [_] in qb.all()]
self.assertEquals(set([i.pk for i in results]), set([a4.pk]))

# Same query, but by the TemplateReplacerCalc subclass
qb = QueryBuilder()
qb.append(TemplateReplacerCalc, filters={'extras': {'has_key': extra_name}})
results = [_ for [_] in qb.all()]
self.assertEquals(set([i.pk for i in results]), set([a2.pk]))


class TestNodeBasic(AiidaTestCase):
"""
Expand Down Expand Up @@ -1492,6 +1483,7 @@ def test_load_node(self):
with self.assertRaises(NotExistent):
load_node(spec, sub_classes=(ArrayData,))

@unittest.skip('open issue JobCalculations cannot be stored')
def test_load_unknown_calculation_type(self):
"""
Test that the loader will choose a common calculation ancestor for an unknown data type.
Expand Down
Loading

0 comments on commit 1e339c3

Please sign in to comment.