Skip to content

Commit

Permalink
Merge pull request #49 from airflow-laminar/tkp/sshargs
Browse files Browse the repository at this point in the history
Separate SSH config, pass through args to ssh operator
  • Loading branch information
timkpaine authored Dec 15, 2024
2 parents 4b8194f + 3424514 commit f0712dd
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 38 deletions.
60 changes: 25 additions & 35 deletions airflow_supervisor/airflow/ssh.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from shlex import quote
from typing import Dict, List, Optional, Union
from typing import Dict

from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.providers.ssh.operators.ssh import SSHOperator

from airflow_supervisor.config import SupervisorAirflowConfiguration
from airflow_supervisor.config import SupervisorSSHAirflowConfiguration

from .common import SupervisorTaskStep
from .local import Supervisor
Expand All @@ -19,41 +18,32 @@ class SupervisorSSH(Supervisor):
def __init__(
self,
dag: DAG,
cfg: SupervisorAirflowConfiguration,
command_prefix: str = "",
command_noescape: str = "",
ssh_hook: Optional[SSHHook] = None,
ssh_conn_id: Optional[str] = None,
remote_host: Optional[str] = None,
conn_timeout: Optional[int] = None,
cmd_timeout: Optional[int] = None,
environment: Optional[dict] = None,
get_pty: Optional[bool] = None,
banner_timeout: Optional[float] = None,
skip_on_exit_code: Optional[Union[int, List[int]]] = None,
cfg: SupervisorSSHAirflowConfiguration,
**kwargs,
):
self._command_prefix = command_prefix
self._command_noescape = command_noescape
for attr in ("command_prefix", "command_noescape"):
if attr in kwargs:
setattr(self, f"_{attr}", kwargs.pop(attr))
elif cfg and getattr(cfg, attr):
setattr(self, f"_{attr}", getattr(cfg, attr))

self._ssh_operator_kwargs = {}
if ssh_hook:
self._ssh_operator_kwargs["ssh_hook"] = ssh_hook
if ssh_conn_id:
self._ssh_operator_kwargs["ssh_conn_id"] = ssh_conn_id
if remote_host:
self._ssh_operator_kwargs["remote_host"] = remote_host
if conn_timeout:
self._ssh_operator_kwargs["conn_timeout"] = conn_timeout
if cmd_timeout:
self._ssh_operator_kwargs["cmd_timeout"] = cmd_timeout
if environment:
self._ssh_operator_kwargs["environment"] = environment
if get_pty:
self._ssh_operator_kwargs["get_pty"] = get_pty
if banner_timeout:
self._ssh_operator_kwargs["banner_timeout"] = banner_timeout
if skip_on_exit_code:
self._ssh_operator_kwargs["skip_on_exit_code"] = skip_on_exit_code
for attr in (
"ssh_hook",
"ssh_conn_id",
"remote_host",
"conn_timeout",
"cmd_timeout",
"environment",
"get_pty",
"banner_timeout",
"skip_on_exit_code",
):
if attr in kwargs:
self._ssh_operator_kwargs[attr] = kwargs.pop(attr)
elif cfg and getattr(cfg, attr):
self._ssh_operator_kwargs[attr] = getattr(cfg, attr)

super().__init__(dag=dag, cfg=cfg, **kwargs)

def get_base_operator_kwargs(self) -> Dict:
Expand Down
8 changes: 7 additions & 1 deletion airflow_supervisor/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from .inet_http_server import InetHttpServerConfiguration
from .program import ProgramConfiguration
from .rpcinterface import RpcInterfaceConfiguration
from .supervisor import SupervisorAirflowConfiguration, SupervisorConfiguration, load_airflow_config, load_config
from .supervisor import (
SupervisorAirflowConfiguration,
SupervisorConfiguration,
SupervisorSSHAirflowConfiguration,
load_airflow_config,
load_config,
)
from .supervisorctl import SupervisorctlConfiguration
from .supervisord import SupervisordConfiguration
from .unix_http_server import UnixHttpServerConfiguration
31 changes: 29 additions & 2 deletions airflow_supervisor/config/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from signal import SIGKILL, SIGTERM
from subprocess import Popen
from tempfile import gettempdir
from typing import Dict, Optional
from typing import Dict, List, Optional, Union

from hydra import compose, initialize_config_dir
from hydra.utils import instantiate
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator

from ..exceptions import ConfigNotFoundError
from ..utils import _get_calling_dag
Expand All @@ -29,6 +29,7 @@
__all__ = (
"SupervisorConfiguration",
"SupervisorAirflowConfiguration",
"SupervisorSSHAirflowConfiguration",
"load_config",
"load_airflow_config",
)
Expand Down Expand Up @@ -300,5 +301,31 @@ def _setup_airflow_defaults(self):
return self


class SupervisorSSHAirflowConfiguration(SupervisorAirflowConfiguration):
command_prefix: Optional[str] = Field(default="")
command_noescape: Optional[str] = Field(default="")

# SSH Kwargs
ssh_hook: Optional[object] = Field(default=None)
ssh_conn_id: Optional[str] = Field(default=None)
remote_host: Optional[str] = Field(default=None)
conn_timeout: Optional[int] = Field(default=None)
cmd_timeout: Optional[int] = Field(default=None)
environment: Optional[dict] = Field(default=None)
get_pty: Optional[bool] = Field(default=None)
banner_timeout: Optional[float] = Field(default=None)
skip_on_exit_code: Optional[Union[int, List[int]]] = Field(default=None)

@field_validator("ssh_hook")
@classmethod
def _validate_ssh_hook(cls, v):
if v:
from airflow.providers.ssh.hooks.ssh import SSHHook

assert isinstance(v, SSHHook)
return v


load_config = SupervisorConfiguration.load
load_airflow_config = SupervisorAirflowConfiguration.load
load_airflow_ssh_config = SupervisorSSHAirflowConfiguration.load

0 comments on commit f0712dd

Please sign in to comment.