Skip to content

Commit

Permalink
add checks
Browse files Browse the repository at this point in the history
Signed-off-by: kalyanr <kalyan.ben10@live.com>
  • Loading branch information
rawwar committed Feb 4, 2024
1 parent caec4c7 commit f8a1ea9
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 11 deletions.
13 changes: 8 additions & 5 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import hashlib
import importlib
import importlib.machinery
import importlib.util
Expand Down Expand Up @@ -48,7 +47,12 @@
from airflow.utils import timezone
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.docs import get_docs_url
from airflow.utils.file import correct_maybe_zipped, list_py_file_paths, might_contain_dag
from airflow.utils.file import (
correct_maybe_zipped,
get_unique_dag_module_name,
list_py_file_paths,
might_contain_dag,
)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down Expand Up @@ -326,9 +330,8 @@ def _load_modules_from_file(self, filepath, safe_mode):
return []

self.log.debug("Importing %s", filepath)
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
org_mod_name = Path(filepath).stem
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"

mod_name = get_unique_dag_module_name(filepath)

if mod_name in sys.modules:
del sys.modules[mod_name]
Expand Down
27 changes: 21 additions & 6 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from airflow.operators.branch import BranchMixIn
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_merge
from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script
Expand Down Expand Up @@ -364,12 +365,7 @@ def __init__(
skip_on_exit_code: int | Container[int] | None = None,
**kwargs,
):
if (
not isinstance(python_callable, types.FunctionType)
or isinstance(python_callable, types.LambdaType)
and python_callable.__name__ == "<lambda>"
):
raise AirflowException("PythonVirtualenvOperator only supports functions for python_callable arg")
self._validate_python_callable(python_callable)
super().__init__(
python_callable=python_callable,
op_args=op_args,
Expand Down Expand Up @@ -403,6 +399,25 @@ def get_python_source(self):
"""Return the source of self.python_callable."""
return textwrap.dedent(inspect.getsource(self.python_callable))

def _validate_python_callable(self, python_callable):
"""Verifies if python_callable can be be used with the PythonVirtualenvOperator."""
if self.check_callable_in_dag_module(python_callable):
raise AirflowException(
"Functions defined within dag module are not supported for PythonVirtualenvOperator"
)

if (
not isinstance(python_callable, types.FunctionType)
or isinstance(python_callable, types.LambdaType)
and python_callable.__name__ == "<lambda>"
):
raise AirflowException("PythonVirtualenvOperator only supports functions for python_callable arg")

@staticmethod
def check_callable_in_dag_module(python_callable):
if get_unique_dag_module_name(inspect.getfile(python_callable)) == python_callable.__module__:
return True

def _write_args(self, file: Path):
if self.op_args or self.op_kwargs:
file.write_bytes(self.pickling_library.dumps({"args": self.op_args, "kwargs": self.op_kwargs}))
Expand Down
9 changes: 9 additions & 0 deletions airflow/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import ast
import hashlib
import logging
import os
import zipfile
Expand All @@ -33,6 +34,8 @@

log = logging.getLogger(__name__)

MODIFIED_DAG_MODULE_NAME = "unusual_prefix_{path_hash}_{module_name}"


class _IgnoreRule(Protocol):
"""Interface for ignore rules for structural subtyping."""
Expand Down Expand Up @@ -379,3 +382,9 @@ def iter_airflow_imports(file_path: str) -> Generator[str, None, None]:
for m in _find_imported_modules(parsed):
if m.startswith("airflow."):
yield m


def get_unique_dag_module_name(file_path):
path_hash = hashlib.sha1(file_path.encode("utf-8")).hexdigest()
org_mod_name = Path(file_path).stem
return MODIFIED_DAG_MODULE_NAME.format(path_hash=path_hash, module_name=org_mod_name)

0 comments on commit f8a1ea9

Please sign in to comment.