diff --git a/airflow_supervisor/airflow/ssh.py b/airflow_supervisor/airflow/ssh.py index eae5dc4..941df96 100644 --- a/airflow_supervisor/airflow/ssh.py +++ b/airflow_supervisor/airflow/ssh.py @@ -1,12 +1,11 @@ from shlex import quote -from typing import Dict, List, Optional, Union +from typing import Dict from airflow.models.dag import DAG from airflow.models.operator import Operator -from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.providers.ssh.operators.ssh import SSHOperator -from airflow_supervisor.config import SupervisorAirflowConfiguration +from airflow_supervisor.config import SupervisorSSHAirflowConfiguration from .common import SupervisorTaskStep from .local import Supervisor @@ -19,41 +18,32 @@ class SupervisorSSH(Supervisor): def __init__( self, dag: DAG, - cfg: SupervisorAirflowConfiguration, - command_prefix: str = "", - command_noescape: str = "", - ssh_hook: Optional[SSHHook] = None, - ssh_conn_id: Optional[str] = None, - remote_host: Optional[str] = None, - conn_timeout: Optional[int] = None, - cmd_timeout: Optional[int] = None, - environment: Optional[dict] = None, - get_pty: Optional[bool] = None, - banner_timeout: Optional[float] = None, - skip_on_exit_code: Optional[Union[int, List[int]]] = None, + cfg: SupervisorSSHAirflowConfiguration, **kwargs, ): - self._command_prefix = command_prefix - self._command_noescape = command_noescape + for attr in ("command_prefix", "command_noescape"): + if attr in kwargs: + setattr(self, f"_{attr}", kwargs.pop(attr)) + elif cfg and getattr(cfg, attr): + setattr(self, f"_{attr}", getattr(cfg, attr)) + self._ssh_operator_kwargs = {} - if ssh_hook: - self._ssh_operator_kwargs["ssh_hook"] = ssh_hook - if ssh_conn_id: - self._ssh_operator_kwargs["ssh_conn_id"] = ssh_conn_id - if remote_host: - self._ssh_operator_kwargs["remote_host"] = remote_host - if conn_timeout: - self._ssh_operator_kwargs["conn_timeout"] = conn_timeout - if cmd_timeout: - self._ssh_operator_kwargs["cmd_timeout"] = cmd_timeout - if environment: - self._ssh_operator_kwargs["environment"] = environment - if get_pty: - self._ssh_operator_kwargs["get_pty"] = get_pty - if banner_timeout: - self._ssh_operator_kwargs["banner_timeout"] = banner_timeout - if skip_on_exit_code: - self._ssh_operator_kwargs["skip_on_exit_code"] = skip_on_exit_code + for attr in ( + "ssh_hook", + "ssh_conn_id", + "remote_host", + "conn_timeout", + "cmd_timeout", + "environment", + "get_pty", + "banner_timeout", + "skip_on_exit_code", + ): + if attr in kwargs: + self._ssh_operator_kwargs[attr] = kwargs.pop(attr) + elif cfg and getattr(cfg, attr): + self._ssh_operator_kwargs[attr] = getattr(cfg, attr) + super().__init__(dag=dag, cfg=cfg, **kwargs) def get_base_operator_kwargs(self) -> Dict: diff --git a/airflow_supervisor/config/__init__.py b/airflow_supervisor/config/__init__.py index 45ff0df..b722dc1 100644 --- a/airflow_supervisor/config/__init__.py +++ b/airflow_supervisor/config/__init__.py @@ -7,7 +7,13 @@ from .inet_http_server import InetHttpServerConfiguration from .program import ProgramConfiguration from .rpcinterface import RpcInterfaceConfiguration -from .supervisor import SupervisorAirflowConfiguration, SupervisorConfiguration, load_airflow_config, load_config +from .supervisor import ( + SupervisorAirflowConfiguration, + SupervisorConfiguration, + SupervisorSSHAirflowConfiguration, + load_airflow_config, + load_config, +) from .supervisorctl import SupervisorctlConfiguration from .supervisord import SupervisordConfiguration from .unix_http_server import UnixHttpServerConfiguration diff --git a/airflow_supervisor/config/supervisor.py b/airflow_supervisor/config/supervisor.py index 73b01b7..f594c2b 100644 --- a/airflow_supervisor/config/supervisor.py +++ b/airflow_supervisor/config/supervisor.py @@ -6,11 +6,11 @@ from signal import SIGKILL, SIGTERM from subprocess import Popen from tempfile import gettempdir -from typing import Dict, Optional +from typing import Dict, List, Optional, Union from hydra import compose, initialize_config_dir from hydra.utils import instantiate -from pydantic import BaseModel, Field, PrivateAttr, model_validator +from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator from ..exceptions import ConfigNotFoundError from ..utils import _get_calling_dag @@ -29,6 +29,7 @@ __all__ = ( "SupervisorConfiguration", "SupervisorAirflowConfiguration", + "SupervisorSSHAirflowConfiguration", "load_config", "load_airflow_config", ) @@ -300,5 +301,31 @@ def _setup_airflow_defaults(self): return self +class SupervisorSSHAirflowConfiguration(SupervisorAirflowConfiguration): + command_prefix: Optional[str] = Field(default="") + command_noescape: Optional[str] = Field(default="") + + # SSH Kwargs + ssh_hook: Optional[object] = Field(default=None) + ssh_conn_id: Optional[str] = Field(default=None) + remote_host: Optional[str] = Field(default=None) + conn_timeout: Optional[int] = Field(default=None) + cmd_timeout: Optional[int] = Field(default=None) + environment: Optional[dict] = Field(default=None) + get_pty: Optional[bool] = Field(default=None) + banner_timeout: Optional[float] = Field(default=None) + skip_on_exit_code: Optional[Union[int, List[int]]] = Field(default=None) + + @field_validator("ssh_hook") + @classmethod + def _validate_ssh_hook(cls, v): + if v: + from airflow.providers.ssh.hooks.ssh import SSHHook + + assert isinstance(v, SSHHook) + return v + + load_config = SupervisorConfiguration.load load_airflow_config = SupervisorAirflowConfiguration.load +load_airflow_ssh_config = SupervisorSSHAirflowConfiguration.load