diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 01bfe0fb7f..0beaaaab9b 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -326,7 +326,7 @@ class AsyncAgentExecutorMixin: def execute(self: PythonTask, **kwargs) -> LiteralMap: ctx = FlyteContext.current_context() - ss = ctx.serialization_settings or SerializationSettings(ImageConfig()) + ss = ctx.serialization_settings or SerializationSettings(ImageConfig.auto_default_image()) output_prefix = ctx.file_access.get_random_remote_directory() self.resource_meta = None diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py index 4dcdf3174a..8aa7952134 100644 --- a/flytekit/extend/backend/utils.py +++ b/flytekit/extend/backend/utils.py @@ -20,9 +20,9 @@ def convert_to_flyte_phase(state: str) -> TaskExecution.Phase: Convert the state from the agent to the phase in flyte. """ state = state.lower() - if state in ["failed", "timeout", "timedout", "canceled", "skipped", "internal_error"]: + if state in ["failed", "timeout", "timedout", "canceled", "cancelled", "skipped", "internal_error"]: return TaskExecution.FAILED - elif state in ["done", "succeeded", "success"]: + elif state in ["done", "succeeded", "success", "completed"]: return TaskExecution.SUCCEEDED elif state in ["running", "terminating"]: return TaskExecution.RUNNING diff --git a/flytekit/extras/tasks/shell.py b/flytekit/extras/tasks/shell.py index 32ae33fcc7..296666c85e 100644 --- a/flytekit/extras/tasks/shell.py +++ b/flytekit/extras/tasks/shell.py @@ -250,7 +250,10 @@ def __init__( if task_config is not None: fully_qualified_class_name = task_config.__module__ + "." + task_config.__class__.__name__ - if not fully_qualified_class_name == "flytekitplugins.pod.task.Pod": + if fully_qualified_class_name not in [ + "flytekitplugins.pod.task.Pod", + "flytekitplugins.slurm.script.task.Slurm", + ]: raise ValueError("TaskConfig can either be empty - indicating simple container task or a PodConfig.") # Each instance of NotebookTask instantiates an underlying task with a dummy function that will only be used @@ -259,11 +262,14 @@ def __init__( # errors. # This seem like a hack. We should use a plugin_class that doesn't require a fake-function to make work. plugin_class = TaskPlugins.find_pythontask_plugin(type(task_config)) - self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func) - # Rename the internal task so that there are no conflicts at serialization time. Technically these internal - # tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities - # at serialization time. - self._config_task_instance._name = f"_bash.{name}" + if plugin_class.__name__ in ["SlurmShellTask"]: + self._config_task_instance = None + else: + self._config_task_instance = plugin_class(task_config=task_config, task_function=_dummy_task_func) + # Rename the internal task so that there are no conflicts at serialization time. Technically these internal + # tasks should not be serialized at all, but we don't currently have a mechanism for skipping Flyte entities + # at serialization time. + self._config_task_instance._name = f"_bash.{name}" self._script = script self._script_file = script_file self._debug = debug @@ -275,7 +281,9 @@ def __init__( super().__init__( name, task_config, - task_type=self._config_task_instance.task_type, + task_type=kwargs.pop("task_type") + if self._config_task_instance is None + else self._config_task_instance.task_type, interface=Interface(inputs=inputs, outputs=outputs), **kwargs, ) @@ -309,7 +317,10 @@ def script_file(self) -> typing.Optional[os.PathLike]: return self._script_file def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: - return self._config_task_instance.pre_execute(user_params) + if self._config_task_instance is None: + return user_params + else: + return self._config_task_instance.pre_execute(user_params) def execute(self, **kwargs) -> typing.Any: """ @@ -367,7 +378,10 @@ def execute(self, **kwargs) -> typing.Any: return None def post_execute(self, user_params: ExecutionParameters, rval: typing.Any) -> typing.Any: - return self._config_task_instance.post_execute(user_params, rval) + if self._config_task_instance is None: + return rval + else: + return self._config_task_instance.post_execute(user_params, rval) class RawShellTask(ShellTask): diff --git a/plugins/flytekit-slurm/README.md b/plugins/flytekit-slurm/README.md new file mode 100644 index 0000000000..af6596cf28 --- /dev/null +++ b/plugins/flytekit-slurm/README.md @@ -0,0 +1,5 @@ +# Flytekit Slurm Plugin + +The Slurm agent is designed to integrate Flyte workflows with Slurm-managed high-performance computing (HPC) clusters, enabling users to leverage Slurm's capability of compute resource allocation, scheduling, and monitoring. + +This [guide](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md) provides a concise overview of the design philosophy behind the Slurm agent and explains how to set up a local environment for testing the agent. diff --git a/plugins/flytekit-slurm/assets/basic_arch.png b/plugins/flytekit-slurm/assets/basic_arch.png new file mode 100644 index 0000000000..b1ee5d4771 Binary files /dev/null and b/plugins/flytekit-slurm/assets/basic_arch.png differ diff --git a/plugins/flytekit-slurm/assets/flyte_client.png b/plugins/flytekit-slurm/assets/flyte_client.png new file mode 100644 index 0000000000..454769bce5 Binary files /dev/null and b/plugins/flytekit-slurm/assets/flyte_client.png differ diff --git a/plugins/flytekit-slurm/assets/overview_v2.png b/plugins/flytekit-slurm/assets/overview_v2.png new file mode 100644 index 0000000000..c47caa1304 Binary files /dev/null and b/plugins/flytekit-slurm/assets/overview_v2.png differ diff --git a/plugins/flytekit-slurm/assets/remote_tiny_slurm_cluster.png b/plugins/flytekit-slurm/assets/remote_tiny_slurm_cluster.png new file mode 100644 index 0000000000..276b93f304 Binary files /dev/null and b/plugins/flytekit-slurm/assets/remote_tiny_slurm_cluster.png differ diff --git a/plugins/flytekit-slurm/assets/slurm_basic_result.png b/plugins/flytekit-slurm/assets/slurm_basic_result.png new file mode 100644 index 0000000000..4b15aeea51 Binary files /dev/null and b/plugins/flytekit-slurm/assets/slurm_basic_result.png differ diff --git a/plugins/flytekit-slurm/demo.md b/plugins/flytekit-slurm/demo.md new file mode 100644 index 0000000000..170632b19b --- /dev/null +++ b/plugins/flytekit-slurm/demo.md @@ -0,0 +1,264 @@ +# Slurm Agent Demo + +> Note: This document is still a work in progress, focusing on demonstrating the initial implementation. It will be updated and refined frequently until a stable version is ready. + +In this guide, we will briefly introduce how to setup an environment to test Slurm agent locally without running the backend service (e.g., flyte agent gRPC server). It covers both basic and advanced use cases: the basic use case involves executing a shell script directly, while the advanced use case enables running user-defined functions on a Slurm cluster. + +## Table of Content +* [Overview](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#overview) +* [Setup a Local Test Environment](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#setup-a-local-test-environment) + * [Flyte Client (Localhost)](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#flyte-client-localhost) + * [Remote Tiny Slurm Cluster](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#remote-tiny-slurm-cluster) + * [SSH Configuration](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#ssh-configuration) + * [(Optional) Setup Amazon S3 Bucket](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#optional-setup-amazon-s3-bucket) +* [Rich Use Cases](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#rich-use-cases) + * [`SlurmTask`](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#slurmtask) + * [`SlurmShellTask`](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#slurmshelltask) + * [`SlurmFunctionTask`](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#slurmfunctiontask) + +## Overview +Slurm agent on the highest level has three core methods to interact with a Slurm cluster: +1. `create`: Use `srun` or `sbatch` to run a job on a Slurm cluster +2. `get`: Use `scontrol show job ` to monitor the Slurm job state +3. `delete`: Use `scancel ` to cancel the Slurm job (this method is still under test) + +In the simplest form, Slurm agent supports directly running a batch script using `sbatch` on a Slurm cluster as shown below: + +![](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/assets/basic_arch.png) + +## Setup a Local Test Environment +Without running the backend service, we can setup an environment to test Slurm agent locally. The setup consists of two main components: a client (localhost) and a remote tiny Slurm cluster. Then, we need to configure SSH connection to facilitate communication between the two, which relies on `asyncssh`. Additionally, an S3-compatible object storage is needed for advanced use cases and we choose [Amazon S3](https://us-west-2.console.aws.amazon.com/s3/get-started?region=us-west-2&bucketType=general) for demonstration here. +> Note: A persistence layer (such as S3-compatible object storage) becomes essential as scenarios grow more complex, especially when integrating heterogeneous task types into a workflow in the future. + +### Flyte Client (Localhost) +1. Setup a local Flyte cluster following this [official guide](https://docs.flyte.org/en/latest/community/contribute/contribute_code.html#how-to-setup-dev-environment-for-flytekit) +2. Build a virtual environment (e.g., [poetry](https://python-poetry.org/), [conda](https://docs.conda.io/en/latest/)) and activate it +3. Clone Flytekit [repo](https://github.com/flyteorg/flytekit), checkout the Slurm agent [PR](https://github.com/flyteorg/flytekit/pull/3005/), and install Flytekit +``` +git clone https://github.com/flyteorg/flytekit.git +gh pr checkout 3005 +make setup && pip install -e . +``` +4. Install Flytekit Slurm agent +``` +cd plugins/flytekit-slurm/ +pip install -e . +``` + +### Remote Tiny Slurm Cluster +To simplify the setup process, we follow this [guide](https://github.com/JiangJiaWei1103/Slurm-101) to configure a single-host Slurm cluster, covering `slurmctld` (the central management daemon) and `slurmd` (the compute node daemon). + +After building a Slurm cluster, we need to install Flytekit and Slurm agent, just as what we did in the previous [section](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#flyte-client-localhost). +1. Build a virtual environment and activate it (we take `poetry` as an example): +``` +poetry new demo-env + +# For running a subshell with the virtual environment activated +poetry self add poetry-plugin-shell + +# Activate the virtual environment +poetry shell +``` +2. Clone Flytekit [repo](https://github.com/flyteorg/flytekit), checkout the Slurm agent [PR](https://github.com/flyteorg/flytekit/pull/3005/), and install Flytekit +``` +git clone https://github.com/flyteorg/flytekit.git +gh pr checkout 3005 +make setup && pip install -e . +``` +3. Install Flytekit Slurm agent +``` +cd plugins/flytekit-slurm/ +pip install -e . +``` + +### SSH Configuration +To facilitate communication between the Flyte client and the remote Slurm cluster, we setup SSH on the Flyte client side as follows: +1. Create a new authentication key pair +``` +ssh-keygen -t rsa -b 4096 +``` +2. Copy the public key into the remote Slurm cluster +``` +ssh-copy-id @ +``` +3. Enable key-based authentication +``` +# ~/.ssh/config +Host + HostName + Port + User + IdentityFile +``` +Then, run a sanity check to make sure we can connect to the Slurm cluster: +``` +ssh +``` +Simple and elegant! + +### (Optional) Setup Amazon S3 Bucket +For those interested in advanced use cases, in which user-defined functions are sent and executed on the Slurm cluster, an S3-compitable object storage becomes a necessary component. Following summarizes the setup process: +1. Click "Create bucket" button (marked in yellow) to create a bucket on this [page](https://us-west-2.console.aws.amazon.com/s3/get-started?region=us-west-2&bucketType=general) + * Give the cluster an unique name and leave other settings as default +2. Click the user on the top right corner and go to "Security credentials" +3. Create an access key and save it +4. Configure AWS access on **both** machines +``` +# ~/.aws/config +[default] +region= + +# ~/.aws/credentials +[default] +aws_access_key_id= +aws_secret_access_key= +``` + +Now, both machines have access to the Amazon S3 bucket. Perfect! + + +## Rich Use Cases +In this section, we will demonstrate three supported use cases, ranging from basic to advanced. + +### `SlurmTask` +In the simplest use case, we specify the path to the batch script that is already available on the cluster. + +Suppose we have a batch script as follows: +``` +#!/bin/bash + +echo "Hello AWS slurm, run a Flyte SlurmTask!" >> ./echo_aws.txt +``` + +We use the following python script to test Slurm agent on the [client](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md#flyte-client-localhost): +```python +import os + +from flytekit import workflow +from flytekitplugins.slurm import SlurmRemoteScript, SlurmTask + + +echo_job = SlurmTask( + name="", + task_config=SlurmRemoteScript( + slurm_host="", + batch_script_path="", + sbatch_conf={ + "partition": "debug", + "job-name": "tiny-slurm", + } + ) +) + + +@workflow +def wf() -> None: + echo_job() + + +if __name__ == "__main__": + from flytekit.clis.sdk_in_container import pyflyte + from click.testing import CliRunner + + runner = CliRunner() + path = os.path.realpath(__file__) + + print(f">>> LOCAL EXEC <<<") + result = runner.invoke(pyflyte.main, ["run", path, "wf"]) + print(result.output) +``` + +### `SlurmShellTask` +`SlurmShellTask` offers users the flexibility to define the content of shell scripts. Below is an example of creating a task that executes a Python script already present on the Slurm cluster: +```python +import os + +from flytekit import workflow +from flytekitplugins.slurm import Slurm, SlurmShellTask + + +shell_task = SlurmShellTask( + name="test-shell", + script="""#!/bin/bash +# We can define sbatch options here, but using sbatch_conf can be more neat +echo "Run a Flyte SlurmShellTask...\n" + +# Run a python script on Slurm +# Activate the virtual env first if any +python3 +""", + task_config=Slurm( + slurm_host="", + sbatch_conf={ + "partition": "debug", + "job-name": "tiny-slurm", + } + ), +) + + +@workflow +def wf() -> None: + shell_task() + + +if __name__ == "__main__": + from flytekit.clis.sdk_in_container import pyflyte + from click.testing import CliRunner + + runner = CliRunner() + path = os.path.realpath(__file__) + + print(f">>> LOCAL EXEC <<<") + result = runner.invoke(pyflyte.main, ["run", path, "wf"]) + print(result.output) +``` + +### `SlurmFunctionTask` +In the most advanced use case, `SlurmFunctionTask` allows users to define custom Python functions that are sent to and executed on the Slurm cluster. Following figure demonstrates the process of running a `SlurmFunctionTask`: + +![](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/assets/overview_v2.png) + +```python +import os + +from flytekit import task, workflow +from flytekitplugins.slurm import SlurmFunction + + +@task( + task_config=SlurmFunction( + slurm_host="", + sbatch_conf={ + "partition": "debug", + "job-name": "tiny-slurm", + } + ) +) +def plus_one(x: int) -> int: + return x + 1 + + +@task +def greet(year: int) -> str: + return f"Hello {year}!!!" + + +@workflow +def wf(x: int) -> str: + x = plus_one(x=x) + msg = greet(year=x) + return msg + + +if __name__ == "__main__": + from flytekit.clis.sdk_in_container import pyflyte + from click.testing import CliRunner + + runner = CliRunner() + path = os.path.realpath(__file__) + + print(f">>> LOCAL EXEC <<<") + result = runner.invoke(pyflyte.main, ["run", "--raw-output-data-prefix", "", path, "wf", "--x", 2024]) + print(result.output) +``` diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py b/plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py new file mode 100644 index 0000000000..75dc5ea9ff --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py @@ -0,0 +1,4 @@ +from .function.agent import SlurmFunctionAgent +from .function.task import SlurmFunction, SlurmFunctionTask +from .script.agent import SlurmScriptAgent +from .script.task import Slurm, SlurmRemoteScript, SlurmShellTask, SlurmTask diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py b/plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py new file mode 100644 index 0000000000..4050217c0a --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py @@ -0,0 +1,191 @@ +import tempfile +import uuid +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +from asyncssh import SSHClientConnection + +from flytekit import logger +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +from ..ssh_utils import ssh_connect + + +@dataclass +class SlurmJobMetadata(ResourceMeta): + """Slurm job metadata. + + Args: + job_id: Slurm job id. + ssh_config: Options of SSH client connection. For available options, please refer to + + """ + + job_id: str + ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]] + + +@dataclass +class SlurmCluster: + host: str + username: Optional[str] = None + + def __hash__(self): + return hash((self.host, self.username)) + + +class SlurmFunctionAgent(AsyncAgentBase): + name = "Slurm Function Agent" + + # SSH connection pool for multi-host environment + ssh_config_to_ssh_conn: Dict[SlurmCluster, SSHClientConnection] = {} + + def __init__(self) -> None: + super(SlurmFunctionAgent, self).__init__(task_type_name="slurm_fn", metadata_type=SlurmJobMetadata) + + async def create( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + **kwargs, + ) -> SlurmJobMetadata: + unique_script_name = f"/tmp/task_{uuid.uuid4().hex}.slurm" + + # Retrieve task config + ssh_config = task_template.custom["ssh_config"] + sbatch_conf = task_template.custom["sbatch_conf"] + script = task_template.custom["script"] + + # Construct command for Slurm cluster + cmd, script = _get_sbatch_cmd_and_script( + sbatch_conf=sbatch_conf, + entrypoint=" ".join(task_template.container.args), + script=script, + batch_script_path=unique_script_name, + ) + + logger.info("@@@ task_template.container.args:") + logger.info(task_template.container.args) + logger.info("@@@ Slurm Command: ") + logger.info(cmd) + logger.info("@@@ Batch script: ") + logger.info(script) + + # Run Slurm job + conn = await self._get_or_create_ssh_connection(ssh_config) + with tempfile.NamedTemporaryFile("w") as f: + f.write(script) + f.flush() + async with conn.start_sftp_client() as sftp: + await sftp.put(f.name, unique_script_name) + res = await conn.run(cmd, check=True) + + # Retrieve Slurm job id + job_id = res.stdout.split()[-1] + logger.info("@@@ create slurm job id: " + job_id) + + return SlurmJobMetadata(job_id=job_id, ssh_config=ssh_config) + + async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource: + ssh_config = resource_meta.ssh_config + conn = await self._get_or_create_ssh_connection(ssh_config) + job_res = await conn.run(f"scontrol show job {resource_meta.job_id}", check=True) + + # Determine the current flyte phase from Slurm job state + job_state = "running" + for o in job_res.stdout.split(" "): + if "JobState" in o: + job_state = o.split("=")[1].strip().lower() + elif "StdOut" in o: + stdout_path = o.split("=")[1].strip() + msg_res = await conn.run(f"cat {stdout_path}", check=True) + msg = msg_res.stdout + + logger.info("@@@ GET PHASE: ") + logger.info(str(job_state)) + cur_phase = convert_to_flyte_phase(job_state) + + return Resource(phase=cur_phase, message=msg) + + async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None: + conn = await self._get_or_create_ssh_connection(resource_meta.ssh_config) + _ = await conn.run(f"scancel {resource_meta.job_id}", check=True) + + async def _get_or_create_ssh_connection( + self, ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]] + ) -> SSHClientConnection: + """Get an existing SSH connection or create a new one if needed. + + Args: + ssh_config: SSH configuration dictionary. + + Returns: + An active SSH connection, either pre-existing or newly established. + """ + host = ssh_config.get("host") + username = ssh_config.get("username") + + ssh_cluster_config = SlurmCluster(host=host, username=username) + if self.ssh_config_to_ssh_conn.get(ssh_cluster_config) is None: + logger.info("ssh connection key not found, creating new connection") + conn = await ssh_connect(ssh_config=ssh_config) + self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn + else: + conn = self.ssh_config_to_ssh_conn[ssh_cluster_config] + try: + await conn.run("echo [TEST] SSH connection", check=True) + logger.info("re-using new connection") + except Exception as e: + logger.info(f"Re-establishing SSH connection due to error: {e}") + conn = await ssh_connect(ssh_config=ssh_config) + self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn + + return conn + + +def _get_sbatch_cmd_and_script( + sbatch_conf: Dict[str, str], + entrypoint: str, + script: Optional[str] = None, + batch_script_path: str = "/tmp/task.slurm", +) -> str: + """Construct the Slurm sbatch command and the batch script content. + + Flyte entrypoint, pyflyte-execute, is run within a bash shell process. + + Args: + sbatch_conf: Options of sbatch command. + entrypoint: Flyte entrypoint. + script: User-defined script where "{task.fn}" serves as a placeholder for the + task function execution. Users should insert "{task.fn}" at the desired + execution point within the script. If the script is not provided, the task + function will be executed directly. + batch_script_path: Absolute path of the batch script on Slurm cluster. + + Returns: + cmd: Slurm sbatch command. + """ + # Setup sbatch options + cmd = ["sbatch"] + for opt, val in sbatch_conf.items(): + cmd.extend([f"--{opt}", str(val)]) + + # Assign the batch script to run + cmd.append(batch_script_path) + + if script is None: + script = f"""#!/bin/bash -i + {entrypoint} + """ + else: + script = script.replace("{task.fn}", entrypoint) + + cmd = " ".join(cmd) + + return cmd, script + + +AgentRegistry.register(SlurmFunctionAgent()) diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py b/plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py new file mode 100644 index 0000000000..2dc2d7ce15 --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py @@ -0,0 +1,78 @@ +""" +Slurm task. +""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from flytekit import FlyteContextManager, PythonFunctionTask +from flytekit.configuration import SerializationSettings +from flytekit.extend import TaskPlugins +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.image_spec import ImageSpec + + +@dataclass +class SlurmFunction(object): + """Configure Slurm settings. Note that we focus on sbatch command now. + + Args: + ssh_config: Options of SSH client connection. For available options, please refer to + + sbatch_conf: Options of sbatch command. + script: User-defined script where "{task.fn}" serves as a placeholder for the + task function execution. Users should insert "{task.fn}" at the desired + execution point within the script. If the script is not provided, the task + function will be executed directly. + """ + + ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]] + sbatch_conf: Optional[Dict[str, str]] = None + script: Optional[str] = None + + def __post_init__(self): + assert self.ssh_config["host"] is not None, "'host' must be specified in ssh_config." + if self.sbatch_conf is None: + self.sbatch_conf = {} + + +class SlurmFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SlurmFunction]): + """ + Actual Plugin that transforms the local python code for execution within a slurm context... + """ + + _TASK_TYPE = "slurm_fn" + + def __init__( + self, + task_config: SlurmFunction, + task_function: Callable, + container_image: Optional[Union[str, ImageSpec]] = None, + **kwargs, + ): + super(SlurmFunctionTask, self).__init__( + task_config=task_config, + task_type=self._TASK_TYPE, + task_function=task_function, + container_image=container_image, + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + return { + "ssh_config": self.task_config.ssh_config, + "sbatch_conf": self.task_config.sbatch_conf, + "script": self.task_config.script, + } + + def execute(self, **kwargs) -> Any: + ctx = FlyteContextManager.current_context() + if ctx.execution_state and ctx.execution_state.is_local_execution(): + # Mimic the propeller's behavior in local agent test + return AsyncAgentExecutorMixin.execute(self, **kwargs) + else: + # Execute the task with a direct python function call + return PythonFunctionTask.execute(self, **kwargs) + + +TaskPlugins.register_pythontask_plugin(SlurmFunction, SlurmFunctionTask) diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/script/agent.py b/plugins/flytekit-slurm/flytekitplugins/slurm/script/agent.py new file mode 100644 index 0000000000..326664671d --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/script/agent.py @@ -0,0 +1,187 @@ +import tempfile +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +from asyncssh import SSHClientConnection + +import flytekit +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase +from flytekit.extras.tasks.shell import OutputLocation, _PythonFStringInterpolizer +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +from ..ssh_utils import ssh_connect + + +@dataclass +class SlurmJobMetadata(ResourceMeta): + """Slurm job metadata. + + Args: + job_id: Slurm job id. + ssh_config: Options of SSH client connection. For available options, please refer to + + outputs: A dictionary mapping from the output variable name to the output location. + """ + + job_id: str + ssh_config: Dict[str, Any] + outputs: Dict[str, str] + + +class SlurmScriptAgent(AsyncAgentBase): + name = "Slurm Script Agent" + + # SSH connection pool for multi-host environment + # _ssh_clients: Dict[str, SSHClientConnection] + _conn: Optional[SSHClientConnection] = None + + # Tmp remote path of the batch script + REMOTE_PATH = "/tmp/echo_shell.slurm" + + # Dummy script content + DUMMY_SCRIPT = "#!/bin/bash" + + def __init__(self) -> None: + super(SlurmScriptAgent, self).__init__(task_type_name="slurm", metadata_type=SlurmJobMetadata) + + async def create( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + **kwargs, + ) -> SlurmJobMetadata: + outputs = {} + + # Retrieve task config + ssh_config = task_template.custom["ssh_config"] + batch_script_args = task_template.custom["batch_script_args"] + sbatch_conf = task_template.custom["sbatch_conf"] + + # Construct sbatch command for Slurm cluster + upload_script = False + if "script" in task_template.custom: + script = task_template.custom["script"] + assert script != self.DUMMY_SCRIPT, "Please write the user-defined batch script content." + script, outputs = self._interpolate_script( + script, + input_literal_map=inputs, + python_input_types=task_template.custom["python_input_types"], + output_locs=task_template.custom["output_locs"], + ) + + batch_script_path = self.REMOTE_PATH + upload_script = True + else: + # Assume the batch script is already on Slurm + batch_script_path = task_template.custom["batch_script_path"] + cmd = _get_sbatch_cmd( + sbatch_conf=sbatch_conf, batch_script_path=batch_script_path, batch_script_args=batch_script_args + ) + + # Run Slurm job + conn = await ssh_connect(ssh_config=ssh_config) + if upload_script: + with tempfile.NamedTemporaryFile("w") as f: + f.write(script) + f.flush() + async with conn.start_sftp_client() as sftp: + await sftp.put(f.name, self.REMOTE_PATH) + res = await conn.run(cmd, check=True) + + # Retrieve Slurm job id + job_id = res.stdout.split()[-1] + + return SlurmJobMetadata(job_id=job_id, ssh_config=ssh_config, outputs=outputs) + + async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource: + conn = await ssh_connect(ssh_config=resource_meta.ssh_config) + job_res = await conn.run(f"scontrol show job {resource_meta.job_id}", check=True) + + # Determine the current flyte phase from Slurm job state + msg = "" + job_state = "running" + for o in job_res.stdout.split(" "): + if "JobState" in o: + job_state = o.split("=")[1].strip().lower() + elif "StdOut" in o: + stdout_path = o.split("=")[1].strip() + msg_res = await conn.run(f"cat {stdout_path}", check=True) + msg = msg_res.stdout + cur_phase = convert_to_flyte_phase(job_state) + + return Resource(phase=cur_phase, message=msg, outputs=resource_meta.outputs) + + async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None: + conn = await ssh_connect(ssh_config=resource_meta.ssh_config) + _ = await conn.run(f"scancel {resource_meta.job_id}", check=True) + + def _interpolate_script( + self, + script: str, + input_literal_map: Optional[LiteralMap] = None, + python_input_types: Optional[Dict[str, Type]] = None, + output_locs: Optional[List[OutputLocation]] = None, + ) -> Tuple[str, Dict[str, str]]: + """Interpolate the user-defined script with the specified input and output arguments. + + Args: + script: The user-defined script with placeholders for dynamic input and output values. + input_literal_map: The input literal map. + python_input_types: A dictionary of input names to types. + output_locs: Output locations. + + Returns: + A tuple (script, outputs), where script is the interpolated script, and outputs is a + dictionary mapping from the output variable name to the output location. + """ + input_kwargs = TypeEngine.literal_map_to_kwargs( + flytekit.current_context(), lm=input_literal_map, python_types=python_input_types + ) + interpolizer = _PythonFStringInterpolizer() + + # Interpolate output locations with input values + outputs = {} + if output_locs is not None: + for oloc in output_locs: + outputs[oloc.var] = interpolizer.interpolate(oloc.location, inputs=input_kwargs) + + # Interpolate the script + script = interpolizer.interpolate(script, inputs=input_kwargs, outputs=outputs) + + return script, outputs + + +def _get_sbatch_cmd(sbatch_conf: Dict[str, str], batch_script_path: str, batch_script_args: List[str] = None) -> str: + """Construct Slurm sbatch command. + + We assume all main scripts and dependencies are on Slurm cluster. + + Args: + sbatch_conf: Options of srun command. + batch_script_path: Absolute path of the batch script on Slurm cluster. + batch_script_args: Additional args for the batch script on Slurm cluster. + + Returns: + cmd: Slurm sbatch command. + """ + # Setup sbatch options + cmd = ["sbatch"] + for opt, val in sbatch_conf.items(): + cmd.extend([f"--{opt}", str(val)]) + + # Assign the batch script to run + cmd.append(batch_script_path) + + # Add args if present + if batch_script_args: + for arg in batch_script_args: + cmd.append(arg) + + cmd = " ".join(cmd) + return cmd + + +AgentRegistry.register(SlurmScriptAgent()) diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/script/task.py b/plugins/flytekit-slurm/flytekitplugins/slurm/script/task.py new file mode 100644 index 0000000000..24e7b0cf14 --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/script/task.py @@ -0,0 +1,110 @@ +""" +Slurm task. +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +from flytekit.configuration import SerializationSettings +from flytekit.core.base_task import PythonTask +from flytekit.core.interface import Interface +from flytekit.extend import TaskPlugins +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.extras.tasks.shell import OutputLocation, ShellTask + + +@dataclass +class Slurm(object): + """Configure Slurm settings. Note that we focus on sbatch command now. + + Compared with spark, please refer to https://api-docs.databricks.com/python/pyspark/latest/api/pyspark.SparkContext.html. + + Args: + ssh_config: Options of SSH client connection. For available options, please refer to + + sbatch_conf: Options of sbatch command. For available options, please refer to + https://slurm.schedmd.com/sbatch.html. + batch_script_args: Additional args for the batch script on Slurm cluster. + """ + + ssh_config: Dict[str, Any] + sbatch_conf: Optional[Dict[str, str]] = None + batch_script_args: Optional[List[str]] = None + + def __post_init__(self): + if self.sbatch_conf is None: + self.sbatch_conf = {} + + +# See https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses +@dataclass(kw_only=True) +class SlurmRemoteScript(Slurm): + """Encounter collision if Slurm is shared btw SlurmTask and SlurmShellTask.""" + + batch_script_path: str + + +class SlurmTask(AsyncAgentExecutorMixin, PythonTask[SlurmRemoteScript]): + _TASK_TYPE = "slurm" + + def __init__( + self, + name: str, + task_config: SlurmRemoteScript, + **kwargs, + ): + super(SlurmTask, self).__init__( + task_type=self._TASK_TYPE, + name=name, + task_config=task_config, + # Dummy interface, will support this after discussion + interface=Interface(None, None), + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + return { + "ssh_config": self.task_config.ssh_config, + "batch_script_path": self.task_config.batch_script_path, + "batch_script_args": self.task_config.batch_script_args, + "sbatch_conf": self.task_config.sbatch_conf, + } + + +class SlurmShellTask(AsyncAgentExecutorMixin, ShellTask[Slurm]): + _TASK_TYPE = "slurm" + + def __init__( + self, + name: str, + task_config: Slurm, + script: Optional[str] = None, + inputs: Optional[Dict[str, Type]] = None, + output_locs: Optional[List[OutputLocation]] = None, + **kwargs, + ): + self._inputs = inputs + + super(SlurmShellTask, self).__init__( + name, + task_config=task_config, + task_type=self._TASK_TYPE, + script=script, + inputs=inputs, + output_locs=output_locs, + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + return { + "ssh_config": self.task_config.ssh_config, + "batch_script_args": self.task_config.batch_script_args, + "sbatch_conf": self.task_config.sbatch_conf, + "script": self._script, + "python_input_types": self._inputs, + "output_locs": self._output_locs, + } + + +TaskPlugins.register_pythontask_plugin(SlurmRemoteScript, SlurmTask) +TaskPlugins.register_pythontask_plugin(Slurm, SlurmShellTask) diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/ssh_utils.py b/plugins/flytekit-slurm/flytekitplugins/slurm/ssh_utils.py new file mode 100644 index 0000000000..2067d8f229 --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/ssh_utils.py @@ -0,0 +1,127 @@ +""" +Utilities of asyncssh connections. +""" + +import os +import sys +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union + +import asyncssh +from asyncssh import SSHClientConnection + +from flytekit import logger +from flytekit.extend.backend.utils import get_agent_secret + +T = TypeVar("T", bound="SSHConfig") +SLURM_PRIVATE_KEY = "FLYTE_SLURM_PRIVATE_KEY" + + +@dataclass(frozen=True) +class SSHConfig: + """A customized version of SSHClientConnectionOptions, tailored to specific needs. + + This config is based on the official SSHClientConnectionOptions but includes + only a subset of options, with some fields adjusted to be optional or required. + For the official options, please refer to: + https://asyncssh.readthedocs.io/en/latest/api.html#asyncssh.SSHClientConnectionOptions + + Args: + host: The hostname or address to connect to. + username: The username to authenticate as on the server. + client_keys: File paths to private keys which will be used to authenticate the + client via public key authentication. The default value is not None since + client public key authentication is mandatory. + """ + + host: str + username: Optional[str] = None + client_keys: Union[str, List[str], Tuple[str, ...]] = () + + @classmethod + def from_dict(cls: Type[T], ssh_config: Dict[str, Any]) -> T: + return cls(**ssh_config) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + def __eq__(self, other): + if not isinstance(other, SSHConfig): + return False + return self.host == other.host and self.username == other.username and self.client_keys == other.client_keys + + +async def ssh_connect(ssh_config: Dict[str, Any]) -> SSHClientConnection: + """Make an SSH client connection. + + Args: + ssh_config: Options of SSH client connection defined in SSHConfig. + + Returns: + An SSH client connection object. + """ + # Validate ssh_config + ssh_config = SSHConfig.from_dict(ssh_config).to_dict() + ssh_config["known_hosts"] = None + + # Make the first SSH connection using either OpenSSH client config files or + # a user-defined private key. If using OpenSSH config, it will attempt to + # load settings from ~/.ssh/config. + try: + conn = await asyncssh.connect(**ssh_config) + return conn + except Exception as e: + logger.info( + "Failed to make an SSH connection using the default OpenSSH client config (~/.ssh/config) or " + f"the provided private keys. Error details:\n{e}" + ) + + try: + default_client_key = get_agent_secret(secret_key=SLURM_PRIVATE_KEY) + except ValueError: + logger.info("The secret for key FLYTE_SLURM_PRIVATE_KEY is not set.") + default_client_key = None + + if default_client_key is None and ssh_config.get("client_keys") == (): + raise ValueError( + "Both the secret for key FLYTE_SLURM_PRIVATE_KEY and ssh_config['private_key'] are missing. " + "At least one must be set." + ) + + client_keys = [] + if default_client_key is not None: + # Write the private key to a local path + # This may not be a good practice... + private_key_path = os.path.abspath("./slurm_private_key") + with open(private_key_path, "w") as f: + f.write(default_client_key) + client_keys.append(private_key_path) + + user_client_keys = ssh_config.get("client_keys") + if user_client_keys is not None: + client_keys.extend([user_client_keys] if isinstance(user_client_keys, str) else user_client_keys) + + ssh_config["client_keys"] = client_keys + logger.info(f"Updated SSH config: {ssh_config}") + try: + conn = await asyncssh.connect(**ssh_config) + return conn + except Exception as e: + logger.info( + "Failed to make an SSH connection using the provided private keys. Please verify your setup." + f"Error details:\n{e}" + ) + sys.exit(1) + + +if __name__ == "__main__": + import asyncio + + async def test_connect(): + conn = await ssh_connect({"host": "aws2", "username": "ubuntu"}) + res = await conn.run("echo [TEST] SSH connection", check=True) + out = res.stdout + + return out + + logger.info(asyncio.run(test_connect())) diff --git a/plugins/flytekit-slurm/setup.py b/plugins/flytekit-slurm/setup.py new file mode 100644 index 0000000000..2c338db47e --- /dev/null +++ b/plugins/flytekit-slurm/setup.py @@ -0,0 +1,40 @@ +from setuptools import setup + +PLUGIN_NAME = "slurm" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>1.13.8", "flyteidl>=1.11.0b1", "asyncssh"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package holds the Slurm plugins for flytekit", + namespace_packages=["flytekitplugins"], + packages=[ + f"flytekitplugins.{PLUGIN_NAME}", + f"flytekitplugins.{PLUGIN_NAME}.script", + f"flytekitplugins.{PLUGIN_NAME}.function", + ], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.9", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-slurm/tests/__init__.py b/plugins/flytekit-slurm/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-slurm/tests/test_slurm.py b/plugins/flytekit-slurm/tests/test_slurm.py new file mode 100644 index 0000000000..e69de29bb2