Skip to content

Commit

Permalink
Merge pull request #38 from esm-tools/feat/custom_funcs
Browse files Browse the repository at this point in the history
Script Steps
  • Loading branch information
pgierz authored Oct 8, 2024
2 parents 64b868c + 67c3281 commit fda5345
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 6 deletions.
11 changes: 6 additions & 5 deletions src/pymorize/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@
Pipeline of the data processing steps.
"""

import json
import os

import randomname
from prefect import flow
from prefect.tasks import Task
from prefect_dask import DaskTaskRunner

from .logging import logger
from .utils import get_callable_by_name
from .utils import get_callable, get_callable_by_name


class Pipeline:
Expand Down Expand Up @@ -87,6 +84,10 @@ def from_qualname_list(cls, qualnames: list, name=None):
[get_callable_by_name(name) for name in qualnames], name=name
)

@classmethod
def from_callable_strings(cls, step_strings: list, name=None):
return cls.from_list([get_callable(name) for name in step_strings], name=name)

@classmethod
def from_dict(cls, data):
if "uses" in data and "steps" in data:
Expand All @@ -95,7 +96,7 @@ def from_dict(cls, data):
# FIXME(PG): This is bad. What if I need to pass arguments to the constructor?
return get_callable_by_name(data["uses"])(name=data.get("name"))
if "steps" in data:
return cls.from_qualname_list(data["steps"], name=data.get("name"))
return cls.from_callable_strings(data["steps"], name=data.get("name"))
raise ValueError("Pipeline data must have 'uses' or 'steps' key")


Expand Down
66 changes: 65 additions & 1 deletion src/pymorize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Various utility functions needed around the package
"""

import importlib
import inspect
import time
from functools import partial
Expand All @@ -11,6 +12,28 @@
from .logging import logger


def get_callable(name):
"""Get a callable from a string
First, tries standard import, then tries entry points, then from script
"""
try:
return get_callable_by_name(name)
except (ImportError, AttributeError):
pass

try:
return get_entrypoint_by_name(name)
except ValueError:
pass

try:
return get_callable_by_script(name)
except ValueError:
pass

raise ValueError(f"Callable '{name}' not found")


def get_callable_by_name(name):
"""
Get a callable by its name.
Expand Down Expand Up @@ -48,7 +71,7 @@ def get_callable_by_name(name):
return getattr(module, callable_name)


def get_entrypoint_by_name(name, group="pymorize.rules"):
def get_entrypoint_by_name(name, group="pymorize.steps"):
"""
Get an entry point by its name.
Expand Down Expand Up @@ -161,6 +184,47 @@ def can_be_partialized(
return len(param_names) == 1 and param_names[0] == open_arg


def get_function_from_script(script_path: str, function_name: str):
"""
Get a function from a Python script.
This function takes the path to a Python script and the name of a function defined in that script,
and returns the actual function object. If the script does not exist or the function is not defined
in the script, this function will raise an ImportError.
Parameters
----------
script_path : str
The path to the Python script where the function is defined.
function_name : str
The name of the function to be retrieved.
Returns
-------
callable
The function object that corresponds to the given name in the specified script.
Raises
------
ImportError
If the script does not exist or the function is not defined in the script.
"""
logger.debug(f"Importing function '{function_name}' from script '{script_path}'")
spec = importlib.util.spec_from_file_location("script", script_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return getattr(module, function_name)


def get_callable_by_script(step_signature):
if not step_signature.startswith("script://"):
raise ValueError(f"Step signature '{step_signature}' is not a script step")
script_spec = step_signature.split("script://")[1]
script_path = script_spec.split(":")[0]
function_name = script_spec.split(":")[1]
return get_function_from_script(script_path, function_name)


def wait_for_workers(client, n_workers, timeout=600):
"""
Wait for a specific number of workers to be available.
Expand Down

0 comments on commit fda5345

Please sign in to comment.