From 37a9d23bcee8d24ddf1071b29fd5987ce8edd55e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tiziano=20M=C3=BCller?= Date: Wed, 28 Jul 2021 16:36:14 +0200 Subject: [PATCH] `DirectScheduler`: use `num_cores_per_mpiproc` if defined in resources If `num_cores_per_mpiproc` is specified in the job resources, the value will now be exported as the `OMP_NUM_THREADS` variable. Co-Authored-By: Sebastiaan Huber --- aiida/schedulers/plugins/direct.py | 18 ++++- tests/schedulers/test_direct.py | 126 +++++++++++------------------ 2 files changed, 63 insertions(+), 81 deletions(-) diff --git a/aiida/schedulers/plugins/direct.py b/aiida/schedulers/plugins/direct.py index 2f9ece8e38..09d1dd7155 100644 --- a/aiida/schedulers/plugins/direct.py +++ b/aiida/schedulers/plugins/direct.py @@ -150,19 +150,29 @@ def _get_submit_script_header(self, job_tmpl): if job_tmpl.custom_scheduler_commands: lines.append(job_tmpl.custom_scheduler_commands) + env_lines = [] + + if job_tmpl.job_resource and job_tmpl.job_resource.num_cores_per_mpiproc: + # since this was introduced after the environment injection below, + # it is intentionally put before it to avoid breaking current users script by overruling + # any explicit OMP_NUM_THREADS they may have set in their job_environment + env_lines.append(f'export OMP_NUM_THREADS={job_tmpl.job_resource.num_cores_per_mpiproc}') + # Job environment variables are to be set on one single line. # This is a tough job due to the escaping of commas, etc. # moreover, I am having issues making it work. # Therefore, I assume that this is bash and export variables by # and. - if job_tmpl.job_environment: - lines.append(empty_line) - lines.append('# ENVIRONMENT VARIABLES BEGIN ###') if not isinstance(job_tmpl.job_environment, dict): raise ValueError('If you provide job_environment, it must be a dictionary') for key, value in job_tmpl.job_environment.items(): - lines.append(f'export {key.strip()}={escape_for_bash(value)}') + env_lines.append(f'export {key.strip()}={escape_for_bash(value)}') + + if env_lines: + lines.append(empty_line) + lines.append('# ENVIRONMENT VARIABLES BEGIN ###') + lines += env_lines lines.append('# ENVIRONMENT VARIABLES END ###') lines.append(empty_line) diff --git a/tests/schedulers/test_direct.py b/tests/schedulers/test_direct.py index 1d9520bd50..f982a3cd5a 100644 --- a/tests/schedulers/test_direct.py +++ b/tests/schedulers/test_direct.py @@ -7,96 +7,68 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name,protected-access -"""Tests for the `DirectScheduler` plugin.""" -import unittest +# pylint: disable=redefined-outer-name +"""Tests for the ``DirectScheduler`` plugin.""" +import pytest -from aiida.schedulers.plugins.direct import DirectScheduler +from aiida.common.datastructures import CodeInfo, CodeRunMode from aiida.schedulers import SchedulerError +from aiida.schedulers.datastructures import JobTemplate +from aiida.schedulers.plugins.direct import DirectScheduler -# This was executed with ps -o pid,stat,user,time | tail -n +2 -mac_ps_output_str = """21259 S+ broeder 0:00.04 -87619 S+ broeder 0:00.44 -87634 S+ broeder 0:00.01 -87649 S+ broeder 0:00.02 -87664 S+ broeder 0:00.01 -87679 S+ broeder 0:00.01 -87694 S+ broeder 0:00.01 -87711 S+ broeder 0:00.01 -87726 S+ broeder 0:00.02 -87741 S+ broeder 0:00.01 -87756 S+ broeder 0:00.01 -87771 S+ broeder 0:00.01 -87787 S+ broeder 0:00.02 -87803 S+ broeder 0:00.01 -87818 S+ broeder 0:00.02 -87834 S+ broeder 0:00.02 -87849 S+ broeder 0:00.11 -87865 S+ broeder 0:00.02 -87880 S+ broeder 0:00.02 -15967 S+ broeder 0:00.05 -87910 S+ broeder 0:00.02 -87925 S+ broeder 0:00.02 -16814 S broeder 0:00.02 -24516 S+ broeder 0:00.06 -""" -linux_ps_output_str = """11354 Ss aiida 00:00:00 -11383 R+ aiida 00:00:00 -11384 S+ aiida 00:00:00 -""" - -wrong_output = """aaa""" - - -class TestParserGetJobList(unittest.TestCase): - """ - Tests to verify if teh function _parse_joblist_output behave correctly - The tests is done parsing a string defined above, to be used offline - """ - - def test_parse_mac_wrong(self): - """ - Test whether _parse_joblist can parse the qstat -f output - """ - scheduler = DirectScheduler() - - with self.assertRaises(SchedulerError): - scheduler._parse_joblist_output(retval=0, stdout=wrong_output, stderr='') - def test_parse_mac_joblist_output(self): - """ - Test whether _parse_joblist can parse the qstat -f output - """ - s = DirectScheduler() +@pytest.fixture +def scheduler(): + """Return an instance of the ``DirectScheduler``.""" + return DirectScheduler() - result = s._parse_joblist_output(retval=0, stdout=mac_ps_output_str, stderr='') - self.assertEqual(len(result), 24) - job_ids = [job.job_id for job in result] - self.assertIn('87849', job_ids) +@pytest.fixture +def template(): + """Return an instance of the ``JobTemplate`` with some required presets.""" + code_info = CodeInfo() + code_info.cmdline_params = [] - def test_parse_linux_joblist_output(self): - """ - Test whether _parse_joblist can parse the qstat -f output - """ - scheduler = DirectScheduler() + template = JobTemplate() + template.codes_info = [code_info] + template.codes_run_mode = CodeRunMode.SERIAL - result = scheduler._parse_joblist_output(retval=0, stdout=linux_ps_output_str, stderr='') - self.assertEqual(len(result), 3) + return template - job_ids = [job.job_id for job in result] - self.assertIn('11383', job_ids) +@pytest.mark.parametrize( + 'stdout', + ( + """21259 S+ broeder 0:00.04\n87619 S+ broeder 0:00.44\n87634 S+ broeder 0:00.01""", # MacOS + """11354 Ss aiida 00:00:00\n\n87619 R+ aiida 00:00:00\n11384 S+ aiida 00:00:00""", # Linux + ) +) +def test_parse_joblist_output(scheduler, stdout): + """Test the ``_parse_joblist_output`` for output taken from MacOS and Linux.""" + result = scheduler._parse_joblist_output(retval=0, stdout=stdout, stderr='') # pylint: disable=protected-access + assert len(result) == 3 + assert '87619' in [job.job_id for job in result] -def test_submit_script_rerunnable(aiida_caplog): - """Test that setting the `rerunnable` option gives a warning.""" - from aiida.schedulers.datastructures import JobTemplate - direct = DirectScheduler() - job_tmpl = JobTemplate() +def test_parse_joblist_output_incorrect(scheduler): + """Test the ``_parse_joblist_output`` for invalid output.""" + with pytest.raises(SchedulerError): + scheduler._parse_joblist_output(retval=0, stdout='aaa', stderr='') # pylint: disable=protected-access - job_tmpl.rerunnable = True - direct._get_submit_script_header(job_tmpl) +def test_submit_script_rerunnable(scheduler, template, aiida_caplog): + """Test that setting the ``rerunnable`` option gives a warning.""" + template.rerunnable = True + scheduler.get_submit_script(template) assert 'rerunnable' in aiida_caplog.text assert 'has no effect' in aiida_caplog.text + + +def test_submit_script_with_num_cores_per_mpiproc(scheduler, template): + """Test that passing ``num_cores_per_mpiproc`` in job resources results in ``OMP_NUM_THREADS`` being set.""" + num_cores_per_mpiproc = 24 + template.job_resource = scheduler.create_job_resource( + num_machines=1, tot_num_mpiprocs=1, num_cores_per_mpiproc=num_cores_per_mpiproc + ) + result = scheduler.get_submit_script(template) + assert f'export OMP_NUM_THREADS={num_cores_per_mpiproc}' in result