Skip to content

Commit

Permalink
Merge pull request #1067 from muhrin/daemon_megachange
Browse files Browse the repository at this point in the history
Daemon megachange
  • Loading branch information
sphuber authored Jan 19, 2018
2 parents c34d928 + 6b8ba82 commit 17278df
Show file tree
Hide file tree
Showing 20 changed files with 266 additions and 732 deletions.
2 changes: 1 addition & 1 deletion aiida/backends/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
'orm.data.frozendict': ['aiida.backends.tests.orm.data.frozendict'],
'orm.log': ['aiida.backends.tests.orm.log'],
'work.class_loader': ['aiida.backends.tests.work.class_loader'],
# 'work.daemon': ['aiida.backends.tests.work.daemon'],
'work.daemon': ['aiida.backends.tests.work.daemon'],
'work.persistence': ['aiida.backends.tests.work.persistence'],
'work.process': ['aiida.backends.tests.work.process'],
'work.processSpec': ['aiida.backends.tests.work.processSpec'],
Expand Down
185 changes: 1 addition & 184 deletions aiida/backends/tests/work/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,190 +8,7 @@
# For further information please visit http://www.aiida.net #
###########################################################################
from aiida.backends.testbase import AiidaTestCase
import tempfile
from shutil import rmtree
import unittest
import plum

from aiida.work.persistence import Persistence
from aiida.orm.calculation.job import JobCalculation
from aiida.orm.data.base import get_true_node
import aiida.work.daemon as daemon
from aiida.work.processes import Process
from aiida.work.launch import submit
from aiida.common.lang import override
from aiida.orm import load_node
import aiida.work.utils as util
from aiida.work.test_utils import DummyProcess, ExceptionProcess
import aiida.work.daemon as work_daemon
from aiida.work.utils import CalculationHeartbeat


#@unittest.skip("Rewriting daemon")
class ProcessEventsTester(Process):
EVENTS = ["create", "start", "run", "wait", "resume", "finish", "emitted",
"stop", "failed", ]

@classmethod
def define(cls, spec):
super(ProcessEventsTester, cls).define(spec)
for label in ["create", "start", "run", "wait", "resume",
"finish", "emitted", "stop"]:
spec.optional_output(label)

def __init__(self, inputs, pid, logger=None):
super(ProcessEventsTester, self).__init__(inputs, pid, logger)
self._emitted = False

@override
def on_create(self):
super(ProcessEventsTester, self).on_create()
self.out("create", get_true_node())

@override
def on_start(self):
super(ProcessEventsTester, self).on_start()
self.out("start", get_true_node())

@override
def on_run(self):
super(ProcessEventsTester, self).on_run()
self.out("run", get_true_node())

@override
def on_output_emitted(self, output_port, value, dynamic):
super(ProcessEventsTester, self).on_output_emitted(
output_port, value, dynamic)
if not self._emitted:
self._emitted = True
self.out("emitted", get_true_node())

@override
def on_wait(self, awaiting_uuid):
super(ProcessEventsTester, self).on_wait(awaiting_uuid)
self.out("wait", get_true_node())

@override
def on_resume(self):
super(ProcessEventsTester, self).on_resume()
self.out("resume", get_true_node())

@override
def on_finish(self):
super(ProcessEventsTester, self).on_finish()
self.out("finish", get_true_node())

@override
def on_stop(self):
super(ProcessEventsTester, self).on_stop()
self.out("stop", get_true_node())

@override
def _run(self):
return plum.Continue(self.finish)

def finish(self, wait_on):
pass

@unittest.skip("Rewriting daemon")
class FailCreateFromSavedStateProcess(DummyProcess):
"""
This class emulates a failure that occurs when loading the process from
a saved state.
"""

@override
def load_instance_state(self, saved_state, logger):
super(FailCreateFromSavedStateProcess, self).load_instance_state(saved_state)
raise RuntimeError()


@unittest.skip("Moving to new daemon")
class TestDaemon(AiidaTestCase):
def setUp(self):
self.assertEquals(len(util.ProcessStack.stack()), 0)

self.storedir = tempfile.mkdtemp()
self.storage = Persistence.create_from_basedir(self.storedir)

def tearDown(self):
self.assertEquals(len(util.ProcessStack.stack()), 0)
rmtree(self.storedir)

def test_submit(self):
# This call should create an entry in the database with a PK
rinfo = submit(DummyProcess)
self.assertIsNotNone(rinfo)
self.assertIsNotNone(load_node(pk=rinfo.pid))

def test_tick(self):
registry = ProcessRegistry()

rinfo = submit(ProcessEventsTester, _jobs_store=self.storage)
# Tick the engine a number of times or until there is no more work
i = 0
while daemon.launch_pending_jobs(self.storage):
self.assertLess(i, 10, "Engine not done after 10 ticks")
i += 1
self.assertTrue(registry.has_finished(rinfo.pid))

def test_multiple_processes(self):
submit(DummyProcess, _jobs_store=self.storage)
submit(ExceptionProcess, _jobs_store=self.storage)
submit(ExceptionProcess, _jobs_store=self.storage)
submit(DummyProcess, _jobs_store=self.storage)

self.assertFalse(daemon.launch_pending_jobs(self.storage))

def test_create_fail(self):
registry = ProcessRegistry()

dp_rinfo = submit(DummyProcess, _jobs_store=self.storage)
fail_rinfo = submit(FailCreateFromSavedStateProcess, _jobs_store=self.storage)

# Tick the engine a number of times or until there is no more work
i = 0
while daemon.launch_pending_jobs(self.storage):
self.assertLess(i, 10, "Engine not done after 10 ticks")
i += 1

self.assertTrue(registry.has_finished(dp_rinfo.pid))
self.assertFalse(registry.has_finished(fail_rinfo.pid))


@unittest.skip("Moving to new daemon")
class TestJobCalculationDaemon(AiidaTestCase):
def test_launch_pending_submitted(self):
num_at_start = len(work_daemon.get_all_pending_job_calculations())

# Create the calclation
calc_params = {
'computer': self.computer,
'resources': {'num_machines': 1,
'num_mpiprocs_per_machine': 1}
}
c = JobCalculation(**calc_params)
c.store()
c.submit()

self.assertIsNone(c.get_attr(CalculationHeartbeat.HEARTBEAT_EXPIRES, None))
pending = work_daemon.get_all_pending_job_calculations()
self.assertEqual(len(pending), num_at_start + 1)
self.assertIn(c.pk, [p.pk for p in pending])

def test_launch_pending_expired(self):
num_at_start = len(work_daemon.get_all_pending_job_calculations())

calc_params = {
'computer': self.computer,
'resources': {'num_machines': 1,
'num_mpiprocs_per_machine': 1}
}
c = JobCalculation(**calc_params)
c._set_attr(CalculationHeartbeat.HEARTBEAT_EXPIRES, 0)
c.store()
c.submit()

pending = work_daemon.get_all_pending_job_calculations()
self.assertEqual(len(pending), num_at_start + 1)
self.assertIn(c.pk, [p.pk for p in pending])
pass
15 changes: 3 additions & 12 deletions aiida/backends/tests/work/test_rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import aiida.work.test_utils as test_utils
from aiida.orm.data import base
import aiida.work as work
from aiida.work import rmq
from aiida.orm.calculation.work import WorkCalculation

__copyright__ = u"Copyright (c), This file is part of the AiiDA platform. For further information please visit http://www.aiida.net/. All rights reserved."
Expand All @@ -16,13 +17,13 @@
__version__ = "0.7.0"


class TestProcess(AiidaTestCase):
class TestProcessControl(AiidaTestCase):
"""
Test AiiDA's RabbitMQ functionalities.
"""

def setUp(self):
super(TestProcess, self).setUp()
super(TestProcessControl, self).setUp()
prefix = "{}.{}".format(self.__class__.__name__, uuid.uuid4())

self.loop = plum.new_event_loop()
Expand Down Expand Up @@ -104,16 +105,6 @@ def test_kill(self):
# TODO: Check kill message
self.assertTrue(result)

# def test_launch_and_get_status(self):
# a = base.Int(5)
# b = base.Int(10)
#
# calc_node = self.runner.submit(test_utils.AddProcess, a=a, b=b)
# self._wait_for_calc(calc_node)
# future = self.runner.rmq.request_status(calc_node.pk)
# result = plum.run_until_complete(future, self.loop)
# self.assertIsNotNone(result)

def _wait_for_calc(self, calc_node, timeout=5.):
def stop(*args):
self.loop.stop()
Expand Down
94 changes: 62 additions & 32 deletions aiida/cmdline/commands/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
import click
from functools import partial
import logging
from functools import partial

import click
from tabulate import tabulate

from aiida.cmdline.commands import work, verdi
from aiida.cmdline.baseclass import VerdiCommandWithSubcommands
from aiida.cmdline.commands import work, verdi
from aiida.utils.ascii_vis import print_tree_descending

CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
LIST_CMDLINE_PROJECT_CHOICES = ['id', 'ctime', 'label', 'uuid', 'descr', 'mtime', 'state', 'sealed']
Expand Down Expand Up @@ -179,7 +181,6 @@ def report(pk, levelname, order_by, indent_size, max_depth):

import itertools
from aiida.orm.backend import construct
from aiida.orm.log import OrderSpecifier, ASCENDING, DESCENDING
from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.calculation.work import WorkCalculation

Expand Down Expand Up @@ -333,69 +334,98 @@ def kill_old(pks):
def kill(pks):
from aiida import try_load_dbenv
try_load_dbenv()
import plum
from aiida import work

runner = work.get_runner()
control_panel = work.new_blocking_control_panel()

futures = []
for pk in pks:
future = runner.rmq.kill_process(pk)
future.add_done_callback(partial(_action_done, "pause", pk))
futures.append(future)

runner.run_until_complete(plum.gather(*futures))
try:
if control_panel.kill_process(pk):
click.echo("Killed '{}'".format(pk))
else:
click.echo("Problem killing '{}'".format(pk))
except (work.RemoteException, work.DeliveryFailed) as e:
print("Failed to kill '{}': {}".format(pk, e.message))


@work.command('pause', context_settings=CONTEXT_SETTINGS)
@click.argument('pks', nargs=-1, type=int)
def pause(pks):
from aiida import try_load_dbenv
try_load_dbenv()
import plum
from aiida import work

runner = work.get_runner()
control_panel = work.new_blocking_control_panel()

futures = []
for pk in pks:
future = runner.rmq.pause_process(pk)
future.add_done_callback(partial(_action_done, "pause", pk))
futures.append(future)

runner.run_until_complete(plum.gather(*futures))
try:
if control_panel.pause_process(pk):
click.echo("Paused '{}'".format(pk))
else:
click.echo("Problem pausing '{}'".format(pk))
except (work.RemoteException, work.DeliveryFailed) as e:
print("Failed to pause '{}': {}".format(pk, e.message))


@work.command('play', context_settings=CONTEXT_SETTINGS)
@click.argument('pks', nargs=-1, type=int)
def play(pks):
from aiida import try_load_dbenv
try_load_dbenv()
import plum
from aiida import work

runner = work.get_runner()
control_panel = work.new_blocking_control_panel()

futures = []
for pk in pks:
future = runner.rmq.play_process(pk)
future.add_done_callback(partial(_action_done, "play", pk))
futures.append(future)
try:
if control_panel.play_process(pk):
click.echo("Played '{}'".format(pk))
else:
click.echo("Problem playing '{}'".format(pk))
except (work.RemoteException, work.DeliveryFailed) as e:
print("Failed to play '{}': {}".format(pk, e.message))

runner.run_until_complete(plum.gather(*futures))

@work.command('status', context_settings=CONTEXT_SETTINGS)
@click.argument('pks', nargs=-1, type=int)
def status(pks):
from aiida import try_load_dbenv
try_load_dbenv()
import aiida.orm
from aiida.utils.ascii_vis import print_call_graph

for pk in pks:
calc_node = aiida.orm.load_node(pk)
print_call_graph(calc_node)

def _action_done(intent, pk, future):
if future.exception() is not None:
click.echo("Failed to {} process {}: {}".format(intent, pk, future.exception()))

def _create_status_info(calc_node):
status_line = _format_status_line(calc_node)
called = calc_node.called
if called:
return status_line, [_create_status_info(child) for child in called]
else:
return status_line


def _format_status_line(calc_node):
from aiida.orm.calculation.work import WorkCalculation
from aiida.orm.calculation.job import JobCalculation

if isinstance(calc_node, WorkCalculation):
label = calc_node.get_attr('_process_label')
state = calc_node.get_attr('process_state')
elif isinstance(calc_node, JobCalculation):
label = type(calc_node).__name__
state = str(calc_node.get_state())
else:
click.echo("{} {} OK".format(intent, pk))
raise TypeError("Unknown type")
return "{} <pk={}> [{}]".format(label, calc_node.pk, state)


def _build_query(projections=None, order_by=None, limit=None, past_days=None):
import datetime
from aiida.utils import timezone
from aiida.orm.mixins import Sealable
from aiida.orm.querybuilder import QueryBuilder
from aiida.orm.calculation.work import WorkCalculation

Expand Down
Loading

0 comments on commit 17278df

Please sign in to comment.