Skip to content

Commit

Permalink
implemented shell_task and basic unittests. Generated tasks do not wo…
Browse files Browse the repository at this point in the history
…rk as need to determine best way to map onto input_spec
  • Loading branch information
tclose committed May 15, 2023
1 parent 83c06e3 commit 0357fc8
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 76 deletions.
81 changes: 57 additions & 24 deletions pydra/mark/shell_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,45 @@ def shell_task(
"input_field arguments"
)
name = klass_or_name

if output_fields is None:
output_fields = {}
if bases is None:
bases = [pydra.engine.task.ShellCommandTask]
if input_bases is None:
input_bases = [pydra.engine.specs.ShellSpec]
if output_bases is None:
output_bases = [pydra.engine.specs.ShellOutSpec]
Inputs = type("Inputs", tuple(input_bases), input_fields)
Outputs = type("Outputs", tuple(output_bases), output_fields)

# Ensure bases are lists and can be modified
bases = list(bases) if bases is not None else []
input_bases = list(input_bases) if input_bases is not None else []
output_bases = list(output_bases) if output_bases is not None else []

# Ensure base classes included somewhere in MRO
def ensure_base_of(base_class: type, bases_list: list[type]):
if not any(issubclass(b, base_class) for b in bases_list):
bases_list.append(base_class)

ensure_base_of(pydra.engine.task.ShellCommandTask, bases)
ensure_base_of(pydra.engine.specs.ShellSpec, input_bases)
ensure_base_of(pydra.engine.specs.ShellOutSpec, output_bases)

def convert_to_attrs(fields: dict[str, dict[str, ty.Any]], attrs_func):
annotations = {}
attrs_dict = {"__annotations__": annotations}
for name, dct in fields.items():
kwargs = dict(dct) # copy to avoid modifying input to outer function
annotations[name] = kwargs.pop("type")
attrs_dict[name] = attrs_func(**kwargs)
return attrs_dict

Inputs = attrs.define(kw_only=True, slots=False)(
type(
"Inputs", tuple(input_bases), convert_to_attrs(input_fields, shell_arg)
)
)
Outputs = attrs.define(kw_only=True, slots=False)(
type(
"Outputs",
tuple(output_bases),
convert_to_attrs(output_fields, shell_out),
)
)
else:
if (
executable,
Expand Down Expand Up @@ -96,39 +125,43 @@ def shell_task(
"Classes decorated by `shell_task` should contain an `Inputs` class attribute "
"specifying the inputs to the shell tool"
)
if not issubclass(Inputs, pydra.engine.specs.ShellSpec):
Inputs = type("Inputs", (Inputs, pydra.engine.specs.ShellSpec), {})

try:
Outputs = klass.Outputs
except KeyError:
Outputs = type("Outputs", (pydra.engine.specs.ShellOutSpec,))

Inputs = attrs.define(kw_only=True, slots=False)(Inputs)
Outputs = attrs.define(kw_only=True, slots=False)(Outputs)

if not issubclass(Inputs, pydra.engine.specs.ShellSpec):
Inputs = attrs.define(kw_only=True, slots=False)(
type("Inputs", (Inputs, pydra.engine.specs.ShellSpec), {})
)

if not issubclass(Outputs, pydra.engine.specs.ShellOutSpec):
Outputs = attrs.define(kw_only=True, slots=False)(
type("Outputs", (Outputs, pydra.engine.specs.ShellOutSpec), {})
)

bases = [klass]
if not issubclass(klass, pydra.engine.task.ShellCommandTask):
bases.append(pydra.engine.task.ShellCommandTask)

Inputs = attrs.define(kw_only=True, slots=False)(Inputs)
Outputs = attrs.define(kw_only=True, slots=False)(Outputs)

dct = {
"executable": executable,
"Inputs": Outputs,
"Outputs": Inputs,
"inputs": attrs.field(factory=Inputs),
"outputs": attrs.field(factory=Outputs),
"Inputs": Inputs,
"Outputs": Outputs,
"__annotations__": {
"executable": str,
"inputs": Inputs,
"outputs": Outputs,
"Inputs": type,
"Outputs": type,
},
}

return attrs.define(kw_only=True, slots=False)(
type(
name,
tuple(bases),
dct,
)
)
return type(name, tuple(bases), dct)


def shell_arg(
Expand Down
222 changes: 170 additions & 52 deletions pydra/mark/tests/test_shell_commands.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,187 @@
import os
import tempfile
from pathlib import Path
import attrs
import pydra.engine
from pathlib import Path
import pytest
import cloudpickle as cp
from pydra.mark import shell_task, shell_arg, shell_out


def test_shell_task_full():
@attrs.define(kw_only=True, slots=False)
class LsInputSpec(pydra.specs.ShellSpec):
directory: os.PathLike = shell_arg(
help_string="the directory to list the contents of",
argstr="",
mandatory=True,
)
hidden: bool = shell_arg(help_string=("display hidden FS objects"), argstr="-a")
long_format: bool = shell_arg(
help_string=(
"display properties of FS object, such as permissions, size and timestamps "
),
argstr="-l",
)
human_readable: bool = shell_arg(
help_string="display file sizes in human readable form",
argstr="-h",
requires=["long_format"],
)
complete_date: bool = shell_arg(
help_string="Show complete date in long format",
argstr="-T",
requires=["long_format"],
xor=["date_format_str"],
)
date_format_str: str = shell_arg(
help_string="format string for ",
argstr="-D",
requires=["long_format"],
xor=["complete_date"],
)
def list_entries(stdout):
return stdout.split("\n")[:-1]

def list_outputs(stdout):
return stdout.split("\n")[:-1]

@attrs.define(kw_only=True, slots=False)
class LsOutputSpec(pydra.specs.ShellOutSpec):
entries: list = shell_out(
help_string="list of entries returned by ls command", callable=list_outputs
)
@pytest.fixture
def tmpdir():
return Path(tempfile.mkdtemp())

class Ls(pydra.engine.ShellCommandTask):
"""Task definition for the `ls` command line tool"""

executable = "ls"
@pytest.fixture(params=["static", "dynamic"])
def Ls(request):
if request.param == "static":

input_spec = pydra.specs.SpecInfo(
name="LsInput",
bases=(LsInputSpec,),
)
@shell_task
class Ls:
executable = "ls"

output_spec = pydra.specs.SpecInfo(
name="LsOutput",
bases=(LsOutputSpec,),
class Inputs:
directory: os.PathLike = shell_arg(
help_string="the directory to list the contents of",
argstr="",
mandatory=True,
)
hidden: bool = shell_arg(
help_string=("display hidden FS objects"),
argstr="-a",
default=False,
)
long_format: bool = shell_arg(
help_string=(
"display properties of FS object, such as permissions, size and "
"timestamps "
),
default=False,
argstr="-l",
)
human_readable: bool = shell_arg(
help_string="display file sizes in human readable form",
argstr="-h",
default=False,
requires=["long_format"],
)
complete_date: bool = shell_arg(
help_string="Show complete date in long format",
argstr="-T",
default=False,
requires=["long_format"],
xor=["date_format_str"],
)
date_format_str: str = shell_arg(
help_string="format string for ",
argstr="-D",
default=None,
requires=["long_format"],
xor=["complete_date"],
)

class Outputs:
entries: list = shell_out(
help_string="list of entries returned by ls command",
callable=list_entries,
)

elif request.param == "dynamic":
Ls = shell_task(
"Ls",
executable="ls",
input_fields={
"directory": {
"type": os.PathLike,
"help_string": "the directory to list the contents of",
"argstr": "",
"mandatory": True,
},
"hidden": {
"type": bool,
"help_string": "display hidden FS objects",
"argstr": "-a",
},
"long_format": {
"type": bool,
"help_string": (
"display properties of FS object, such as permissions, size and "
"timestamps "
),
"argstr": "-l",
},
"human_readable": {
"type": bool,
"help_string": "display file sizes in human readable form",
"argstr": "-h",
"requires": ["long_format"],
},
"complete_date": {
"type": bool,
"help_string": "Show complete date in long format",
"argstr": "-T",
"requires": ["long_format"],
"xor": ["date_format_str"],
},
"date_format_str": {
"type": str,
"help_string": "format string for ",
"argstr": "-D",
"requires": ["long_format"],
"xor": ["complete_date"],
},
},
output_fields={
"entries": {
"type": list,
"help_string": "list of entries returned by ls command",
"callable": list_entries,
}
},
)

tmpdir = Path(tempfile.mkdtemp())
else:
assert False

return Ls


def test_shell_task_fields(Ls):
assert [a.name for a in attrs.fields(Ls.Inputs)] == [
"executable",
"args",
"directory",
"hidden",
"long_format",
"human_readable",
"complete_date",
"date_format_str",
]

assert [a.name for a in attrs.fields(Ls.Outputs)] == [
"return_code",
"stdout",
"stderr",
"entries",
]


def test_shell_task_pickle_roundtrip(Ls, tmpdir):
pkl_file = tmpdir / "ls.pkl"
with open(pkl_file, "wb") as f:
cp.dump(Ls, f)

with open(pkl_file, "rb") as f:
RereadLs = cp.load(f)

assert RereadLs is Ls


@pytest.mark.xfail(
reason=(
"Need to change relationship between Inputs/Outputs and input_spec/output_spec "
"for the task to run"
)
)
def test_shell_task_init(Ls, tmpdir):
inputs = Ls.Inputs(directory=tmpdir)
assert inputs.directory == tmpdir
assert not inputs.hidden
outputs = Ls.Outputs(entries=[])
assert outputs.entries == []


@pytest.mark.xfail(
reason=(
"Need to change relationship between Inputs/Outputs and input_spec/output_spec "
"for the task to run"
)
)
def test_shell_task_run(Ls, tmpdir):
Path.touch(tmpdir / "a")
Path.touch(tmpdir / "b")
Path.touch(tmpdir / "c")
Expand Down

0 comments on commit 0357fc8

Please sign in to comment.