Skip to content

Commit

Permalink
apply lint to latest updates
Browse files Browse the repository at this point in the history
  • Loading branch information
gpetretto committed Jan 20, 2025
1 parent aeebffe commit cf6f479
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 55 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ docstring-code-format = true

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]
"**/tests/*" = ["INP001", "S101"]
"**/tests/*" = ["INP001", "S101", "SLF001"]

[tool.mypy]
ignore_missing_imports = true
Expand Down
9 changes: 3 additions & 6 deletions src/qtoolkit/io/pbs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

import re

from datetime import timedelta
from typing import ClassVar

from qtoolkit.core.data_objects import QJob, QJobInfo, QState, QSubState
Expand Down Expand Up @@ -93,8 +91,8 @@ class PBSIO(PBSIOBase):
CANCEL_CMD: str | None = "qdel"
system_name: str = "PBS"
default_unit: str = "mb"
power_labels: dict = {"kb": 0, "mb": 1, "gb": 2, "tb": 3}
_qresources_mapping: dict = {
power_labels: ClassVar[dict] = {"kb": 0, "mb": 1, "gb": 2, "tb": 3}
_qresources_mapping: ClassVar[dict] = {
"queue_name": "queue",
"job_name": "job_name",
"account": "account",
Expand Down Expand Up @@ -138,8 +136,7 @@ def _get_qstat_base_command(self) -> list[str]:
return ["qstat", "-f", "-w"]

def _get_job_cmd(self, job_id: str):
cmd = f"{' '.join(self._get_qstat_base_command())} {job_id}"
return cmd
return f"{' '.join(self._get_qstat_base_command())} {job_id}"

def _get_job_ids_flag(self, job_ids_str: str) -> str:
return job_ids_str
Expand Down
14 changes: 7 additions & 7 deletions src/qtoolkit/io/pbs_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
from abc import ABC
from datetime import timedelta
from typing import ClassVar

from qtoolkit.core.data_objects import (
CancelResult,
Expand All @@ -24,10 +25,10 @@ class PBSIOBase(BaseSchedulerIO, ABC):

SUBMIT_CMD: str | None = "qsub"
CANCEL_CMD: str | None = "qdel"
_qresources_mapping: dict
_qresources_mapping: ClassVar[dict]
system_name: str
default_unit: str
power_labels: dict
power_labels: ClassVar[dict]

def parse_submit_output(self, exit_code, stdout, stderr) -> SubmissionResult:
if isinstance(stdout, bytes):
Expand Down Expand Up @@ -128,21 +129,20 @@ def _convert_memory_str(self, memory: str | None) -> int | None:

try:
v = int(memory)
except ValueError:
raise OutputParsingError
except ValueError as exc:
raise OutputParsingError from exc

return v * (1024 ** power_labels[units.lower()])

@staticmethod
def _convert_time_to_str(time: int | float | timedelta) -> str:
def _convert_time_to_str(time: int | float | timedelta) -> str: # noqa: PYI041
if not isinstance(time, timedelta):
time = timedelta(seconds=time)

hours, remainder = divmod(int(time.total_seconds()), 3600)
minutes, seconds = divmod(remainder, 60)

time_str = f"{hours}:{minutes}:{seconds}"
return time_str
return f"{hours}:{minutes}:{seconds}"

def _convert_qresources(self, resources: QResources) -> dict:
header_dict = {}
Expand Down
55 changes: 29 additions & 26 deletions src/qtoolkit/io/sge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
import xml.dom.minidom
import xml.parsers.expat
from typing import ClassVar

from qtoolkit.core.data_objects import QJob, QJobInfo, QResources, QState, QSubState
from qtoolkit.core.exceptions import CommandFailedError, OutputParsingError
Expand Down Expand Up @@ -127,8 +128,8 @@ class SGEIO(PBSIOBase):
CANCEL_CMD: str | None = "qdel"
system_name: str = "SGE"
default_unit: str = "M"
power_labels: dict = {"k": 0, "m": 1, "g": 2, "t": 3}
_qresources_mapping: dict = {
power_labels: ClassVar[dict] = {"k": 0, "m": 1, "g": 2, "t": 3}
_qresources_mapping: ClassVar[dict] = {
"queue_name": "queue",
"job_name": "job_name",
"priority": "priority",
Expand Down Expand Up @@ -167,6 +168,8 @@ def _get_jobs_list_cmd(
return " ".join(command)

def parse_job_output(self, exit_code, stdout, stderr) -> QJob | None: # aiida style
# TODO at the moment the command for a single job is not available
# check if this should be removed as well.
if exit_code != 0:
msg = f"command {self.get_job_executable or 'qacct'} failed: {stderr}"
raise CommandFailedError(msg)
Expand Down Expand Up @@ -203,19 +206,19 @@ def parse_job_output(self, exit_code, stdout, stderr) -> QJob | None: # aiida s

# Check if stdout is in XML format
try:
xmldata = xml.dom.minidom.parseString(stdout)
xmldata = xml.dom.minidom.parseString(stdout) # noqa: S318
job_info = xmldata.getElementsByTagName("job_list")[0]
job_id = job_info.getElementsByTagName("JB_job_number")[
0
].firstChild.nodeValue
job_name = job_info.getElementsByTagName("JB_name")[0].firstChild.nodeValue
owner = job_info.getElementsByTagName("JB_owner")[0].firstChild.nodeValue
state = job_info.getElementsByTagName("state")[0].firstChild.nodeValue
].firstChild.nodeValue # type: ignore
job_name = job_info.getElementsByTagName("JB_name")[0].firstChild.nodeValue # type: ignore
owner = job_info.getElementsByTagName("JB_owner")[0].firstChild.nodeValue # type: ignore
state = job_info.getElementsByTagName("state")[0].firstChild.nodeValue # type: ignore
queue_name = job_info.getElementsByTagName("queue_name")[
0
].firstChild.nodeValue
slots = job_info.getElementsByTagName("slots")[0].firstChild.nodeValue
tasks = job_info.getElementsByTagName("tasks")[0].firstChild.nodeValue
].firstChild.nodeValue # type: ignore
slots = job_info.getElementsByTagName("slots")[0].firstChild.nodeValue # type: ignore
tasks = job_info.getElementsByTagName("tasks")[0].firstChild.nodeValue # type: ignore

sge_state = SGEState(state)
job_state = sge_state.qstate
Expand All @@ -242,32 +245,32 @@ def parse_job_output(self, exit_code, stdout, stderr) -> QJob | None: # aiida s
)
except Exception:
# Not XML, fallback to plain text
job_info = {}
job_info_dict: dict = {}
for line in stdout.split("\n"):
if ":" in line:
key, value = line.split(":", 1)
job_info[key.strip()] = value.strip()
job_info_dict[key.strip()] = value.strip()

try:
cpus = int(job_info.get("slots", 1))
nodes = int(job_info.get("tasks", 1))
cpus = int(job_info_dict.get("slots", 1))
nodes = int(job_info_dict.get("tasks", 1))
threads_per_process = int(cpus / nodes)
except ValueError:
cpus = None
nodes = None
threads_per_process = None

state_str = job_info.get("state")
state_str = job_info_dict.get("state")
sge_state = SGEState(state_str) if state_str else None
job_state = sge_state.qstate

return QJob(
name=job_info.get("job_name"),
job_id=job_info.get("job_id"),
name=job_info_dict.get("job_name"),
job_id=job_info_dict.get("job_id"),
state=job_state,
sub_state=sge_state,
account=job_info.get("owner"),
queue_name=job_info.get("queue_name"),
account=job_info_dict.get("owner"),
queue_name=job_info_dict.get("queue_name"),
info=QJobInfo(
nodes=nodes, cpus=cpus, threads_per_process=threads_per_process
),
Expand Down Expand Up @@ -306,9 +309,9 @@ def parse_jobs_list_output(self, exit_code, stdout, stderr) -> list[QJob]:
stderr = stderr.decode()

try:
xmldata = xml.dom.minidom.parseString(stdout)
except xml.parsers.expat.ExpatError:
raise OutputParsingError("XML parsing of stdout failed")
xmldata = xml.dom.minidom.parseString(stdout) # noqa: S318
except xml.parsers.expat.ExpatError as exc:
raise OutputParsingError("XML parsing of stdout failed") from exc

# Ensure <job_info> elements exist
# (==> xml file created via -u option,
Expand All @@ -331,10 +334,10 @@ def parse_jobs_list_output(self, exit_code, stdout, stderr) -> list[QJob]:

try:
sge_job_state = SGEState(job_state_string)
except ValueError:
except ValueError as exc:
raise OutputParsingError(
f"Unknown job state {job_state_string} for job id {qjob.job_id}"
)
) from exc

qjob.sub_state = sge_job_state
qjob.state = sge_job_state.qstate
Expand Down Expand Up @@ -378,8 +381,8 @@ def _convert_str_to_time(time_str: str | None) -> int | None:

try:
return int(hours) * 3600 + int(minutes) * 60 + int(seconds)
except ValueError:
raise OutputParsingError(f"Invalid time format: {time_str}")
except ValueError as exc:
raise OutputParsingError(f"Invalid time format: {time_str}") from exc

def _add_soft_walltime(self, header_dict: dict, resources: QResources):
header_dict["soft_walltime"] = self._convert_time_to_str(
Expand Down
4 changes: 2 additions & 2 deletions src/qtoolkit/io/shell.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar

from qtoolkit.core.data_objects import (
CancelResult,
Expand Down Expand Up @@ -248,7 +248,7 @@ def parse_jobs_list_output(self, exit_code, stdout, stderr) -> list[QJob]:

# helper attribute to match the values defined in QResources and
# the dictionary that should be passed to the template
_qresources_mapping = {
_qresources_mapping: ClassVar = {
"job_name": "job_name",
"output_filepath": "qout_path",
"error_filepath": "qerr_path",
Expand Down
1 change: 1 addition & 0 deletions src/qtoolkit/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def get_environment_setup(self, env_config) -> str:
if env_config:
env_setup = []
if "modules" in env_config:
env_setup.append("module purge")
env_setup += [f"module load {mod}" for mod in env_config["modules"]]
if "source_files" in env_config:
env_setup += [
Expand Down
10 changes: 5 additions & 5 deletions tests/io/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,17 +158,17 @@ def test_generate_header(self, scheduler):
):
scheduler.generate_header(res)

res = QResources(
nodes=4,
processes_per_node=16,
scheduler_kwargs={"option32": "xxx", "processes-per-node": "yyy"},
)
with pytest.raises(
ValueError,
match=r"The following keys are not present in the template: option32, processes-per-node. "
r"Check the template in .*MyScheduler.header_template.*'option3' or 'option2' or 'option1' "
r"instead of 'option32'. 'processes_per_node' or 'processes' instead of 'processes-per-node'",
):
res = QResources(
nodes=4,
processes_per_node=16,
scheduler_kwargs={"option32": "xxx", "processes-per-node": "yyy"},
)
scheduler.generate_header(res)

def test_generate_ids_list(self, scheduler):
Expand Down
6 changes: 2 additions & 4 deletions tests/io/test_pbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def maximalist_qresources_pbs():


class TestPBSState:
@pytest.mark.parametrize("sge_state", [s for s in PBSState])
@pytest.mark.parametrize("sge_state", list(PBSState))
def test_qstate(self, sge_state):
assert isinstance(sge_state.qstate, QState)

Expand Down Expand Up @@ -295,9 +295,7 @@ def test_submission_script(self, pbs_io, maximalist_qresources_pbs):
#PBS -o test_output_filepath
#PBS -e test_error_filepath
#PBS -p 1
ls -l""".split(
"\n"
)
ls -l""".split("\n")
)

def test_sanitize_options(self, pbs_io):
Expand Down
6 changes: 2 additions & 4 deletions tests/io/test_sge.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def sge_io():


class TestSGEState:
@pytest.mark.parametrize("sge_state", [s for s in SGEState])
@pytest.mark.parametrize("sge_state", list(SGEState))
def test_qstate(self, sge_state):
assert isinstance(sge_state.qstate, QState)
assert SGEState("hqw") == SGEState.HOLD
Expand Down Expand Up @@ -281,9 +281,7 @@ def test_submission_script(self, sge_io, maximalist_qresources):
#$ -o test_output_filepath
#$ -e test_error_filepath
#$ -p 1
ls -l""".split(
"\n"
)
ls -l""".split("\n")
)

def test_sanitize_options(self, sge_io):
Expand Down

0 comments on commit cf6f479

Please sign in to comment.