Skip to content

Commit

Permalink
Check permissions for ImportError (#37468)
Browse files Browse the repository at this point in the history
  • Loading branch information
jedcunningham authored Feb 20, 2024
1 parent 16d2671 commit d944eb0
Show file tree
Hide file tree
Showing 4 changed files with 314 additions and 21 deletions.
61 changes: 56 additions & 5 deletions airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,59 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence

from sqlalchemy import func, select

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.exceptions import NotFound, PermissionDenied
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
from airflow.api_connexion.schemas.error_schema import (
ImportErrorCollection,
import_error_collection_schema,
import_error_schema,
)
from airflow.auth.managers.models.resource_details import AccessView
from airflow.auth.managers.models.resource_details import AccessView, DagDetails
from airflow.models.dag import DagModel
from airflow.models.errors import ImportError as ImportErrorModel
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.www.extensions.init_auth_manager import get_auth_manager

if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.api_connexion.types import APIResponse
from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest


@security.requires_access_view(AccessView.IMPORT_ERRORS)
@provide_session
def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> APIResponse:
"""Get an import error."""
error = session.get(ImportErrorModel, import_error_id)

if error is None:
raise NotFound(
"Import error not found",
detail=f"The ImportError with import_error_id: `{import_error_id}` was not found",
)
session.expunge(error)

can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
if not can_read_all_dags:
readable_dag_ids = security.get_readable_dags()
file_dag_ids = {
dag_id[0]
for dag_id in session.query(DagModel.dag_id).filter(DagModel.fileloc == error.filename).all()
}

# Can the user read any DAGs in the file?
if not readable_dag_ids.intersection(file_dag_ids):
raise PermissionDenied(detail="You do not have read permission on any of the DAGs in the file")

# Check if user has read access to all the DAGs defined in the file
if not file_dag_ids.issubset(readable_dag_ids):
error.stacktrace = "REDACTED - you do not have read permission on all DAGs in the file"

return import_error_schema.dump(error)


Expand All @@ -65,10 +85,41 @@ def get_import_errors(
"""Get all import errors."""
to_replace = {"import_error_id": "id"}
allowed_filter_attrs = ["import_error_id", "timestamp", "filename"]
total_entries = session.scalars(func.count(ImportErrorModel.id)).one()
count_query = select(func.count(ImportErrorModel.id))
query = select(ImportErrorModel)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)

can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")

if not can_read_all_dags:
# if the user doesn't have access to all DAGs, only display errors from visible DAGs
readable_dag_ids = security.get_readable_dags()
dagfiles_subq = (
select(DagModel.fileloc).distinct().where(DagModel.dag_id.in_(readable_dag_ids)).subquery()
)
query = query.where(ImportErrorModel.filename.in_(dagfiles_subq))
count_query = count_query.where(ImportErrorModel.filename.in_(dagfiles_subq))

total_entries = session.scalars(count_query).one()
import_errors = session.scalars(query.offset(offset).limit(limit)).all()

if not can_read_all_dags:
for import_error in import_errors:
# Check if user has read access to all the DAGs defined in the file
file_dag_ids = (
session.query(DagModel.dag_id).filter(DagModel.fileloc == import_error.filename).all()
)
requests: Sequence[IsAuthorizedDagRequest] = [
{
"method": "GET",
"details": DagDetails(id=dag_id[0]),
}
for dag_id in file_dag_ids
]
if not get_auth_manager().batch_is_authorized_dag(requests):
session.expunge(import_error)
import_error.stacktrace = "REDACTED - you do not have read permission on all DAGs in the file"

return import_error_collection_schema.dump(
ImportErrorCollection(import_errors=import_errors, total_entries=total_entries)
)
51 changes: 38 additions & 13 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest
from airflow.models.dag import DAG
from airflow.models.operator import Operator

Expand Down Expand Up @@ -935,20 +936,44 @@ def index(self):

owner_links_dict = DagOwnerAttributes.get_all(session)

import_errors = select(errors.ImportError).order_by(errors.ImportError.id)

if not get_auth_manager().is_authorized_dag(method="GET"):
# if the user doesn't have access to all DAGs, only display errors from visible DAGs
import_errors = import_errors.join(
DagModel, DagModel.fileloc == errors.ImportError.filename
).where(DagModel.dag_id.in_(filter_dag_ids))
if get_auth_manager().is_authorized_view(access_view=AccessView.IMPORT_ERRORS):
import_errors = select(errors.ImportError).order_by(errors.ImportError.id)

can_read_all_dags = get_auth_manager().is_authorized_dag(method="GET")
if not can_read_all_dags:
# if the user doesn't have access to all DAGs, only display errors from visible DAGs
import_errors = import_errors.where(
errors.ImportError.filename.in_(
select(DagModel.fileloc)
.distinct()
.where(DagModel.dag_id.in_(filter_dag_ids))
.subquery()
)
)

import_errors = session.scalars(import_errors)
for import_error in import_errors:
flash(
f"Broken DAG: [{import_error.filename}] {import_error.stacktrace}",
"dag_import_error",
)
import_errors = session.scalars(import_errors)
for import_error in import_errors:
stacktrace = import_error.stacktrace
if not can_read_all_dags:
# Check if user has read access to all the DAGs defined in the file
file_dag_ids = (
session.query(DagModel.dag_id)
.filter(DagModel.fileloc == import_error.filename)
.all()
)
requests: Sequence[IsAuthorizedDagRequest] = [
{
"method": "GET",
"details": DagDetails(id=dag_id[0]),
}
for dag_id in file_dag_ids
]
if not get_auth_manager().batch_is_authorized_dag(requests):
stacktrace = "REDACTED - you do not have read permission on all DAGs in the file"
flash(
f"Broken DAG: [{import_error.filename}]\r{stacktrace}",
"dag_import_error",
)

from airflow.plugins_manager import import_errors as plugin_import_errors

Expand Down
162 changes: 159 additions & 3 deletions tests/api_connexion/endpoints/test_import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,19 @@
import pytest

from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from airflow.models.dag import DagModel
from airflow.models.errors import ImportError
from airflow.security import permissions
from airflow.utils import timezone
from airflow.utils.session import provide_session
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_import_errors
from tests.test_utils.db import clear_db_dags, clear_db_import_errors

pytestmark = pytest.mark.db_test

TEST_DAG_IDS = ["test_dag", "test_dag2"]


@pytest.fixture(scope="module")
def configured_app(minimal_app_for_api):
Expand All @@ -39,14 +42,34 @@ def configured_app(minimal_app_for_api):
app, # type:ignore
username="test",
role_name="Test",
permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], # type: ignore
permissions=[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR),
], # type: ignore
)
create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore
create_user(
app, # type:ignore
username="test_single_dag",
role_name="TestSingleDAG",
permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], # type: ignore
)
# For some reason, DAG level permissions are not synced when in the above list of perms,
# so do it manually here:
app.appbuilder.sm.bulk_sync_roles(
[
{
"role": "TestSingleDAG",
"perms": [(permissions.ACTION_CAN_READ, permissions.resource_name_for_dag(TEST_DAG_IDS[0]))],
}
]
)

yield minimal_app_for_api
yield app

delete_user(app, username="test") # type: ignore
delete_user(app, username="test_no_permissions") # type: ignore
delete_user(app, username="test_single_dag") # type: ignore


class TestBaseImportError:
Expand All @@ -58,9 +81,11 @@ def setup_attrs(self, configured_app) -> None:
self.client = self.app.test_client() # type:ignore

clear_db_import_errors()
clear_db_dags()

def teardown_method(self) -> None:
clear_db_import_errors()
clear_db_dags()

@staticmethod
def _normalize_import_errors(import_errors):
Expand Down Expand Up @@ -121,6 +146,72 @@ def test_should_raise_403_forbidden(self):
)
assert response.status_code == 403

def test_should_raise_403_forbidden_without_dag_read(self, session):
import_error = ImportError(
filename="Lorem_ipsum.py",
stacktrace="Lorem ipsum",
timestamp=timezone.parse(self.timestamp, timezone="UTC"),
)
session.add(import_error)
session.commit()

response = self.client.get(
f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
)

assert response.status_code == 403

def test_should_return_200_with_single_dag_read(self, session):
dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py")
session.add(dag_model)
import_error = ImportError(
filename="Lorem_ipsum.py",
stacktrace="Lorem ipsum",
timestamp=timezone.parse(self.timestamp, timezone="UTC"),
)
session.add(import_error)
session.commit()

response = self.client.get(
f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
)

assert response.status_code == 200
response_data = response.json
response_data["import_error_id"] = 1
assert {
"filename": "Lorem_ipsum.py",
"import_error_id": 1,
"stack_trace": "Lorem ipsum",
"timestamp": "2020-06-10T12:00:00+00:00",
} == response_data

def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session):
for dag_id in TEST_DAG_IDS:
dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py")
session.add(dag_model)
import_error = ImportError(
filename="Lorem_ipsum.py",
stacktrace="Lorem ipsum",
timestamp=timezone.parse(self.timestamp, timezone="UTC"),
)
session.add(import_error)
session.commit()

response = self.client.get(
f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"}
)

assert response.status_code == 200
response_data = response.json
response_data["import_error_id"] = 1
assert {
"filename": "Lorem_ipsum.py",
"import_error_id": 1,
"stack_trace": "REDACTED - you do not have read permission on all DAGs in the file",
"timestamp": "2020-06-10T12:00:00+00:00",
} == response_data


class TestGetImportErrorsEndpoint(TestBaseImportError):
def test_get_import_errors(self, session):
Expand Down Expand Up @@ -231,6 +322,71 @@ def test_should_raises_401_unauthenticated(self, session):

assert_401(response)

def test_get_import_errors_single_dag(self, session):
for dag_id in TEST_DAG_IDS:
fake_filename = f"/tmp/{dag_id}.py"
dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename)
session.add(dag_model)
importerror = ImportError(
filename=fake_filename,
stacktrace="Lorem ipsum",
timestamp=timezone.parse(self.timestamp, timezone="UTC"),
)
session.add(importerror)
session.commit()

response = self.client.get(
"/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"}
)

assert response.status_code == 200
response_data = response.json
self._normalize_import_errors(response_data["import_errors"])
assert {
"import_errors": [
{
"filename": "/tmp/test_dag.py",
"import_error_id": 1,
"stack_trace": "Lorem ipsum",
"timestamp": "2020-06-10T12:00:00+00:00",
},
],
"total_entries": 1,
} == response_data

def test_get_import_errors_single_dag_in_dagfile(self, session):
for dag_id in TEST_DAG_IDS:
fake_filename = "/tmp/all_in_one.py"
dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename)
session.add(dag_model)

importerror = ImportError(
filename="/tmp/all_in_one.py",
stacktrace="Lorem ipsum",
timestamp=timezone.parse(self.timestamp, timezone="UTC"),
)
session.add(importerror)
session.commit()

response = self.client.get(
"/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"}
)

assert response.status_code == 200
response_data = response.json
self._normalize_import_errors(response_data["import_errors"])
assert {
"import_errors": [
{
"filename": "/tmp/all_in_one.py",
"import_error_id": 1,
"stack_trace": "REDACTED - you do not have read permission on all DAGs in the file",
"timestamp": "2020-06-10T12:00:00+00:00",
},
],
"total_entries": 1,
} == response_data


class TestGetImportErrorsEndpointPagination(TestBaseImportError):
@pytest.mark.parametrize(
Expand Down
Loading

0 comments on commit d944eb0

Please sign in to comment.