Skip to content

Commit

Permalink
Display explicit error in case UID has no actual username (#15212)
Browse files Browse the repository at this point in the history
Fixes #9963 : Don't require a current username

Previously, we used getpass.getuser() with no fallback, which errors out
if there is no username specified for the current UID (which happens a
lot more in environments like Docker & Kubernetes). This updates most
calls to use our own copy which has a fallback to return the UID as a
string if there is no username.

GitOrigin-RevId: 3e9e954d9ec5236cbbc6da2091b38e69c1b4c0c0
  • Loading branch information
Andrew Godwin authored and Cloud Composer Team committed Mar 10, 2022
1 parent 05ea318 commit e35958b
Show file tree
Hide file tree
Showing 14 changed files with 59 additions and 29 deletions.
4 changes: 2 additions & 2 deletions airflow/cli/commands/info_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Config sub-commands"""
import getpass
import locale
import logging
import os
Expand All @@ -33,6 +32,7 @@
from airflow.providers_manager import ProvidersManager
from airflow.typing_compat import Protocol
from airflow.utils.cli import suppress_logs_and_warning
from airflow.utils.platform import getuser
from airflow.version import version as airflow_version

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,7 +67,7 @@ class PiiAnonymizer(Anonymizer):

def __init__(self):
home_path = os.path.expanduser("~")
username = getpass.getuser()
username = getuser()
self._path_replacements = {home_path: "${HOME}", username: "${USER}"}

def process_path(self, value):
Expand Down
4 changes: 2 additions & 2 deletions airflow/jobs/base_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
#

import getpass
from time import sleep
from typing import Optional

Expand All @@ -37,6 +36,7 @@
from airflow.utils.helpers import convert_camel_to_snake
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser
from airflow.utils.session import create_session, provide_session
from airflow.utils.sqlalchemy import UtcDateTime
from airflow.utils.state import State
Expand Down Expand Up @@ -100,7 +100,7 @@ def __init__(self, executor=None, heartrate=None, *args, **kwargs):
self.latest_heartbeat = timezone.utcnow()
if heartrate is not None:
self.heartrate = heartrate
self.unixname = getpass.getuser()
self.unixname = getuser()
self.max_tis_per_query = conf.getint('scheduler', 'max_tis_per_query')
super().__init__(*args, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.
import contextlib
import getpass
import hashlib
import logging
import math
Expand Down Expand Up @@ -68,6 +67,7 @@
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.platform import getuser
from airflow.utils.session import provide_session
from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
from airflow.utils.state import State
Expand Down Expand Up @@ -327,7 +327,7 @@ def __init__(self, task, execution_date: datetime, state: Optional[str] = None):
self.execution_date = execution_date

self.try_number = 0
self.unixname = getpass.getuser()
self.unixname = getuser()
if state:
self.state = state
self.hostname = ''
Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/microsoft/winrm/hooks/winrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
# under the License.
#
"""Hook for winrm remote execution."""
import getpass
from typing import Optional

from winrm.protocol import Protocol

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook

try:
from airflow.utils.platform import getuser
except ImportError:
from getpass import getuser


# TODO: Fixme please - I have too complex implementation
# pylint: disable=too-many-instance-attributes,too-many-arguments,too-many-branches
Expand Down Expand Up @@ -201,7 +205,7 @@ def get_conn(self):
self.remote_host,
self.ssh_conn_id,
)
self.username = getpass.getuser()
self.username = getuser()

# If endpoint is not set, then build a standard wsman endpoint from host and port.
if not self.endpoint:
Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/ssh/hooks/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.
"""Hook for SSH connections."""
import getpass
import os
import warnings
from base64 import decodebytes
Expand All @@ -30,6 +29,11 @@
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook

try:
from airflow.utils.platform import getuser
except ImportError:
from getpass import getuser


class SSHHook(BaseHook): # pylint: disable=too-many-instance-attributes
"""
Expand Down Expand Up @@ -173,7 +177,7 @@ def __init__( # pylint: disable=too-many-statements
self.remote_host,
self.ssh_conn_id,
)
self.username = getpass.getuser()
self.username = getuser()

user_ssh_config_filename = os.path.expanduser('~/.ssh/config')
if os.path.isfile(user_ssh_config_filename):
Expand Down
4 changes: 2 additions & 2 deletions airflow/task/task_runner/base_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.
"""Base task runner"""
import getpass
import os
import subprocess
import threading
Expand All @@ -29,6 +28,7 @@
from airflow.utils.configuration import tmp_configuration_copy
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser

PYTHONPATH_VAR = 'PYTHONPATH'

Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(self, local_task_job):
# Add sudo commands to change user if we need to. Needed to handle SubDagOperator
# case using a SequentialExecutor.
self.log.debug("Planning to run as the %s user", self.run_as_user)
if self.run_as_user and (self.run_as_user != getpass.getuser()):
if self.run_as_user and (self.run_as_user != getuser()):
# We want to include any environment variables now, as we won't
# want to have to specify them in the sudo call - they would show
# up in `ps` that way! And run commands now, as the other user
Expand Down
4 changes: 2 additions & 2 deletions airflow/task/task_runner/cgroup_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"""Task runner for cgroup to run Airflow task"""

import datetime
import getpass
import os
import uuid

Expand All @@ -28,6 +27,7 @@

from airflow.task.task_runner.base_task_runner import BaseTaskRunner
from airflow.utils.operator_resources import Resources
from airflow.utils.platform import getuser
from airflow.utils.process_utils import reap_process_group


Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(self, local_task_job):
self.cpu_cgroup_name = None
self._created_cpu_cgroup = False
self._created_mem_cgroup = False
self._cur_user = getpass.getuser()
self._cur_user = getuser()

def _create_cgroup(self, path):
"""
Expand Down
5 changes: 2 additions & 3 deletions airflow/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#
"""Utilities module for cli"""
import functools
import getpass
import json
import logging
import os
Expand All @@ -35,7 +34,7 @@
from airflow import settings
from airflow.exceptions import AirflowException
from airflow.utils import cli_action_loggers
from airflow.utils.platform import is_terminal_support_colors
from airflow.utils.platform import getuser, is_terminal_support_colors
from airflow.utils.session import provide_session

T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
Expand Down Expand Up @@ -131,7 +130,7 @@ def _build_metrics(func_name, namespace):
'sub_command': func_name,
'start_datetime': datetime.utcnow(),
'full_command': f'{full_command}',
'user': getpass.getuser(),
'user': getuser(),
}

if not isinstance(namespace, Namespace):
Expand Down
23 changes: 23 additions & 0 deletions airflow/utils/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

"""Platform and system specific function."""
import getpass
import logging
import os
import pkgutil
Expand Down Expand Up @@ -57,3 +58,25 @@ def get_airflow_git_version():
log.debug(e)

return git_version


def getuser() -> str:
"""
Gets the username associated with the current user, or error with a nice
error message if there's no current user.
We don't want to fall back to os.getuid() because not having a username
probably means the rest of the user environment is wrong (e.g. no $HOME).
Explicit failure is better than silently trying to work badly.
"""
try:
return getpass.getuser()
except KeyError:
# Inner import to avoid circular import
from airflow.exceptions import AirflowConfigException

raise AirflowConfigException(
"The user that Airflow is running as has no username; you must run"
"Airflow as a full user, with a username and home directory, "
"in order for it to function properly."
)
8 changes: 4 additions & 4 deletions tests/api_connexion/endpoints/test_task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
# specific language governing permissions and limitations
# under the License.
import datetime as dt
import getpass
from unittest import mock

import pytest
from parameterized import parameterized

from airflow.models import DagBag, DagRun, SlaMiss, TaskInstance
from airflow.security import permissions
from airflow.utils.platform import getuser
from airflow.utils.session import provide_session
from airflow.utils.state import State
from airflow.utils.timezone import datetime
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_should_respond_200(self, session):
"state": "running",
"task_id": "print_the_context",
"try_number": 0,
"unixname": getpass.getuser(),
"unixname": getuser(),
}

def test_should_respond_200_with_task_state_in_removed(self, session):
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_should_respond_200_with_task_state_in_removed(self, session):
"state": "removed",
"task_id": "print_the_context",
"try_number": 0,
"unixname": getpass.getuser(),
"unixname": getuser(),
}

def test_should_respond_200_task_instance_with_sla(self, session):
Expand Down Expand Up @@ -238,7 +238,7 @@ def test_should_respond_200_task_instance_with_sla(self, session):
"state": "running",
"task_id": "print_the_context",
"try_number": 0,
"unixname": getpass.getuser(),
"unixname": getuser(),
}

def test_should_raises_401_unauthenticated(self):
Expand Down
6 changes: 3 additions & 3 deletions tests/api_connexion/schemas/test_task_instance_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.

import datetime as dt
import getpass
import unittest

import pytest
Expand All @@ -30,6 +29,7 @@
)
from airflow.models import DAG, SlaMiss, TaskInstance as TI
from airflow.operators.dummy import DummyOperator
from airflow.utils.platform import getuser
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State
from airflow.utils.timezone import datetime
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_task_instance_schema_without_sla(self, session):
"state": "running",
"task_id": "TEST_TASK_ID",
"try_number": 0,
"unixname": getpass.getuser(),
"unixname": getuser(),
}
assert serialized_ti == expected_json

Expand Down Expand Up @@ -133,7 +133,7 @@ def test_task_instance_schema_with_sla(self, session):
"state": "running",
"task_id": "TEST_TASK_ID",
"try_number": 0,
"unixname": getpass.getuser(),
"unixname": getuser(),
}
assert serialized_ti == expected_json

Expand Down
2 changes: 1 addition & 1 deletion tests/jobs/test_base_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_heartbeat_failed(self, mock_create_session):
@conf_vars({('scheduler', 'max_tis_per_query'): '100'})
@patch('airflow.jobs.base_job.ExecutorLoader.get_default_executor')
@patch('airflow.jobs.base_job.get_hostname')
@patch('airflow.jobs.base_job.getpass.getuser')
@patch('airflow.jobs.base_job.getuser')
def test_essential_attr(self, mock_getuser, mock_hostname, mock_default_executor):
mock_sequential_executor = SequentialExecutor()
mock_hostname.return_value = "test_hostname"
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/microsoft/winrm/hooks/test_winrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_get_conn_from_connection(self, mock_get_connection, mock_protocol):
send_cbt=str(connection.extra_dejson['send_cbt']).lower() == 'true',
)

@patch('airflow.providers.microsoft.winrm.hooks.winrm.getpass.getuser', return_value='user')
@patch('airflow.providers.microsoft.winrm.hooks.winrm.getuser', return_value='user')
@patch('airflow.providers.microsoft.winrm.hooks.winrm.Protocol')
def test_get_conn_no_username(self, mock_protocol, mock_getuser):
winrm_hook = WinRMHook(remote_host='host', password='password')
Expand Down
6 changes: 3 additions & 3 deletions tests/task/task_runner/test_standard_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import getpass
import logging
import os
import time
Expand All @@ -30,6 +29,7 @@
from airflow.models import TaskInstance as TI
from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
from airflow.utils import timezone
from airflow.utils.platform import getuser
from airflow.utils.state import State
from tests.test_utils.db import clear_db_runs

Expand Down Expand Up @@ -105,7 +105,7 @@ def test_start_and_terminate(self):
def test_start_and_terminate_run_as_user(self):
local_task_job = mock.Mock()
local_task_job.task_instance = mock.MagicMock()
local_task_job.task_instance.run_as_user = getpass.getuser()
local_task_job.task_instance.run_as_user = getuser()
local_task_job.task_instance.command_as_list.return_value = [
'airflow',
'tasks',
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_early_reap_exit(self, caplog):
# Set up mock task
local_task_job = mock.Mock()
local_task_job.task_instance = mock.MagicMock()
local_task_job.task_instance.run_as_user = getpass.getuser()
local_task_job.task_instance.run_as_user = getuser()
local_task_job.task_instance.command_as_list.return_value = [
'airflow',
'tasks',
Expand Down

0 comments on commit e35958b

Please sign in to comment.