Skip to content

Commit

Permalink
fix: PythonVirtualenvOperator crashes if any python_callable function…
Browse files Browse the repository at this point in the history
… is defined in the same source as DAG (#37165)

---------

Signed-off-by: kalyanr <kalyan.ben10@live.com>
(cherry picked from commit e75522b)
  • Loading branch information
rawwar authored and jedcunningham committed Feb 9, 2024
1 parent c0c5ab4 commit f841e70
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 15 deletions.
12 changes: 7 additions & 5 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

import hashlib
import importlib
import importlib.machinery
import importlib.util
Expand Down Expand Up @@ -48,7 +47,12 @@
from airflow.utils import timezone
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.docs import get_docs_url
from airflow.utils.file import correct_maybe_zipped, list_py_file_paths, might_contain_dag
from airflow.utils.file import (
correct_maybe_zipped,
get_unique_dag_module_name,
list_py_file_paths,
might_contain_dag,
)
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down Expand Up @@ -326,9 +330,7 @@ def _load_modules_from_file(self, filepath, safe_mode):
return []

self.log.debug("Importing %s", filepath)
path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
org_mod_name = Path(filepath).stem
mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
mod_name = get_unique_dag_module_name(filepath)

if mod_name in sys.modules:
del sys.modules[mod_name]
Expand Down
23 changes: 15 additions & 8 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from airflow.operators.branch import BranchMixIn
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_merge
from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script
Expand Down Expand Up @@ -437,15 +438,21 @@ def _execute_python_callable_in_subprocess(self, python_path: Path):

self._write_args(input_path)
self._write_string_args(string_args_path)

jinja_context = {
"op_args": self.op_args,
"op_kwargs": op_kwargs,
"expect_airflow": self.expect_airflow,
"pickling_library": self.pickling_library.__name__,
"python_callable": self.python_callable.__name__,
"python_callable_source": self.get_python_source(),
}

if inspect.getfile(self.python_callable) == self.dag.fileloc:
jinja_context["modified_dag_module_name"] = get_unique_dag_module_name(self.dag.fileloc)

write_python_script(
jinja_context={
"op_args": self.op_args,
"op_kwargs": op_kwargs,
"expect_airflow": self.expect_airflow,
"pickling_library": self.pickling_library.__name__,
"python_callable": self.python_callable.__name__,
"python_callable_source": self.get_python_source(),
},
jinja_context=jinja_context,
filename=os.fspath(script_path),
render_template_as_native_obj=self.dag.render_template_as_native_obj,
)
Expand Down
12 changes: 12 additions & 0 deletions airflow/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import ast
import hashlib
import logging
import os
import zipfile
Expand All @@ -33,6 +34,8 @@

log = logging.getLogger(__name__)

MODIFIED_DAG_MODULE_NAME = "unusual_prefix_{path_hash}_{module_name}"


class _IgnoreRule(Protocol):
"""Interface for ignore rules for structural subtyping."""
Expand Down Expand Up @@ -379,3 +382,12 @@ def iter_airflow_imports(file_path: str) -> Generator[str, None, None]:
for m in _find_imported_modules(parsed):
if m.startswith("airflow."):
yield m


def get_unique_dag_module_name(file_path: str) -> str:
"""Returns a unique module name in the format unusual_prefix_{sha1 of module's file path}_{original module name}."""
if isinstance(file_path, str):
path_hash = hashlib.sha1(file_path.encode("utf-8")).hexdigest()
org_mod_name = Path(file_path).stem
return MODIFIED_DAG_MODULE_NAME.format(path_hash=path_hash, module_name=org_mod_name)
raise ValueError("file_path should be a string to generate unique module name")
19 changes: 17 additions & 2 deletions airflow/utils/python_virtualenv_script.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ if sys.version_info >= (3,6):
pass
{% endif %}

# Script
{{ python_callable_source }}

# monkey patching for the cases when python_callable is part of the dag module.
{% if modified_dag_module_name is defined %}

import types

{{ modified_dag_module_name }} = types.ModuleType("{{ modified_dag_module_name }}")

{{ modified_dag_module_name }}.{{ python_callable }} = {{ python_callable }}

sys.modules["{{modified_dag_module_name}}"] = {{modified_dag_module_name}}

{% endif%}

{% if op_args or op_kwargs %}
with open(sys.argv[1], "rb") as file:
arg_dict = {{ pickling_library }}.load(file)
Expand All @@ -47,8 +63,7 @@ with open(sys.argv[3], "r") as file:
virtualenv_string_args = list(map(lambda x: x.strip(), list(file)))
{% endif %}

# Script
{{ python_callable_source }}

try:
res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"])
except Exception as e:
Expand Down

0 comments on commit f841e70

Please sign in to comment.