Skip to content

Commit

Permalink
fix: D401 lint issues in airflow core (apache#37274)
Browse files Browse the repository at this point in the history
  • Loading branch information
rawwar authored Feb 9, 2024
1 parent 7835fd2 commit ab9e2e1
Show file tree
Hide file tree
Showing 17 changed files with 42 additions and 56 deletions.
4 changes: 2 additions & 2 deletions airflow/auth/managers/utils/fab.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@


def get_fab_action_from_method_map():
"""Returns the map associating a method to a FAB action."""
"""Return the map associating a method to a FAB action."""
return _MAP_METHOD_NAME_TO_FAB_ACTION_NAME


def get_method_from_fab_action_map():
"""Returns the map associating a FAB action to a method."""
"""Return the map associating a FAB action to a method."""
return {
**{v: k for k, v in _MAP_METHOD_NAME_TO_FAB_ACTION_NAME.items()},
ACTION_CAN_ACCESS_MENU: "GET",
Expand Down
2 changes: 1 addition & 1 deletion airflow/decorators/bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def bash_task(
python_callable: Callable | None = None,
**kwargs,
) -> TaskDecorator:
"""Wraps a function into a BashOperator.
"""Wrap a function into a BashOperator.
Accepts kwargs for operator kwargs. Can be reused in a single DAG. This function is only used only used
during type checking or auto-completion.
Expand Down
2 changes: 1 addition & 1 deletion airflow/executors/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self):
self.fail_fast = conf.getboolean("debug", "fail_fast")

def execute_async(self, *args, **kwargs) -> None:
"""The method is replaced by custom trigger_task implementation."""
"""Replace the method with a custom trigger_task implementation."""

def sync(self) -> None:
task_succeeded = True
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,7 @@ def defer(
raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)

def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context):
"""This method is called when a deferred task is resumed."""
"""Call this method when a deferred task is resumed."""
# __fail__ is a special signal value for next_method that indicates
# this task was scheduled specifically to fail.
if next_method == "__fail__":
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def get_task_instances(
session: Session = NEW_SESSION,
) -> list[TI]:
"""
Returns the task instances for this dag run.
Return the task instances for this dag run.
Redirect to DagRun.fetch_task_instances method.
Keep this method because it is widely used across the code.
Expand Down Expand Up @@ -611,7 +611,7 @@ def fetch_task_instance(
map_index: int = -1,
) -> TI | TaskInstancePydantic | None:
"""
Returns the task instance specified by task_id for this dag run.
Return the task instance specified by task_id for this dag run.
:param dag_id: the DAG id
:param dag_run_id: the DAG run id
Expand Down
12 changes: 6 additions & 6 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def _refresh_from_db(
*, task_instance: TaskInstance | TaskInstancePydantic, session: Session, lock_for_update: bool = False
) -> None:
"""
Refreshes the task instance from the database based on the primary key.
Refresh the task instance from the database based on the primary key.
:param task_instance: the task instance
:param session: SQLAlchemy ORM Session
Expand Down Expand Up @@ -531,7 +531,7 @@ def _set_duration(*, task_instance: TaskInstance | TaskInstancePydantic) -> None

def _stats_tags(*, task_instance: TaskInstance | TaskInstancePydantic) -> dict[str, str]:
"""
Returns task instance tags.
Return task instance tags.
:param task_instance: the task instance
Expand Down Expand Up @@ -943,7 +943,7 @@ def _get_previous_dagrun(
session: Session | None = None,
) -> DagRun | None:
"""
The DagRun that ran before this task instance's DagRun.
Return the DagRun that ran prior to this task instance's DagRun.
:param task_instance: the task instance
:param state: If passed, it only take into account instances of a specific state.
Expand Down Expand Up @@ -983,7 +983,7 @@ def _get_previous_execution_date(
session: Session,
) -> pendulum.DateTime | None:
"""
The execution date from property previous_ti_success.
Get execution date from property previous_ti_success.
:param task_instance: the task instance
:param session: SQLAlchemy ORM Session
Expand Down Expand Up @@ -1178,7 +1178,7 @@ def _get_previous_ti(
state: DagRunState | None = None,
) -> TaskInstance | TaskInstancePydantic | None:
"""
The task instance for the task that ran before this task instance.
Get task instance for the task that ran before this task instance.
:param task_instance: the task instance
:param state: If passed, it only take into account instances of a specific state.
Expand Down Expand Up @@ -1436,7 +1436,7 @@ def try_number(self):
@try_number.expression
def try_number(cls):
"""
This is what will be used by SQLAlchemy when filtering on try_number.
Return the expression to be used by SQLAlchemy when filtering on try_number.
This is required because the override in the get_try_number function causes
try_number values to be off by one when listing tasks in the UI.
Expand Down
4 changes: 2 additions & 2 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def _prepare_venv(self, venv_path: Path) -> None:
)

def _calculate_cache_hash(self) -> tuple[str, str]:
"""Helper to generate the hash of the cache folder to use.
"""Generate the hash of the cache folder to use.
The following factors are used as input for the hash:
- (sorted) list of requirements
Expand All @@ -666,7 +666,7 @@ def _calculate_cache_hash(self) -> tuple[str, str]:
return requirements_hash[:8], hash_text

def _ensure_venv_cache_exists(self, venv_cache_path: Path) -> Path:
"""Helper to ensure a valid virtual environment is set up and will create inplace."""
"""Ensure a valid virtual environment is set up and will create inplace."""
cache_hash, hash_data = self._calculate_cache_hash()
venv_path = venv_cache_path / f"venv-{cache_hash}"
self.log.info("Python virtual environment will be cached in %s", venv_path)
Expand Down
6 changes: 3 additions & 3 deletions airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,14 +171,14 @@ class AirflowPlugin:

@classmethod
def validate(cls):
"""Validates that plugin has a name."""
"""Validate if plugin has a name."""
if not cls.name:
raise AirflowPluginException("Your plugin needs a name.")

@classmethod
def on_load(cls, *args, **kwargs):
"""
Executed when the plugin is loaded; This method is only called once during runtime.
Execute when the plugin is loaded; This method is only called once during runtime.
:param args: If future arguments are passed in on call.
:param kwargs: If future arguments are passed in on call.
Expand Down Expand Up @@ -296,7 +296,7 @@ def load_providers_plugins():


def make_module(name: str, objects: list[Any]):
"""Creates new module."""
"""Create new module."""
if not objects:
return None
log.debug("Creating module %s", name)
Expand Down
26 changes: 13 additions & 13 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _read_schema_from_resources_or_local_file(filename: str) -> dict:


def _create_provider_info_schema_validator():
"""Creates JSON schema validator from the provider_info.schema.json."""
"""Create JSON schema validator from the provider_info.schema.json."""
import jsonschema

schema = _read_schema_from_resources_or_local_file("provider_info.schema.json")
Expand All @@ -156,7 +156,7 @@ def _create_provider_info_schema_validator():


def _create_customized_form_field_behaviours_schema_validator():
"""Creates JSON schema validator from the customized_form_field_behaviours.schema.json."""
"""Create JSON schema validator from the customized_form_field_behaviours.schema.json."""
import jsonschema

schema = _read_schema_from_resources_or_local_file("customized_form_field_behaviours.schema.json")
Expand Down Expand Up @@ -305,7 +305,7 @@ def _correctness_check(
provider_package: str, class_name: str, provider_info: ProviderInfo
) -> type[BaseHook] | None:
"""
Performs coherence check on provider classes.
Perform coherence check on provider classes.
For apache-airflow providers - it checks if it starts with appropriate package. For all providers
it tries to import the provider - checking that there are no exceptions during importing.
Expand Down Expand Up @@ -408,7 +408,7 @@ def initialization_stack_trace() -> str | None:
return ProvidersManager._initialization_stack_trace

def __init__(self):
"""Initializes the manager."""
"""Initialize the manager."""
super().__init__()
ProvidersManager._initialized = True
ProvidersManager._initialization_stack_trace = "".join(traceback.format_stack(inspect.currentframe()))
Expand Down Expand Up @@ -445,7 +445,7 @@ def __init__(self):
self._init_airflow_core_hooks()

def _init_airflow_core_hooks(self):
"""Initializes the hooks dict with default hooks from Airflow core."""
"""Initialize the hooks dict with default hooks from Airflow core."""
core_dummy_hooks = {
"generic": "Generic",
"email": "Email",
Expand Down Expand Up @@ -563,7 +563,7 @@ def initialize_providers_configuration(self):

def _initialize_providers_configuration(self):
"""
Internal method to initialize providers configuration information.
Initialize providers configuration information.
Should be used if we do not want to trigger caching for ``initialize_providers_configuration`` method.
In some cases we might want to make sure that the configuration is initialized, but we do not want
Expand Down Expand Up @@ -626,7 +626,7 @@ def _discover_all_providers_from_packages(self) -> None:

def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None:
"""
Finds all built-in airflow providers if airflow is run from the local sources.
Find all built-in airflow providers if airflow is run from the local sources.
It finds `provider.yaml` files for all such providers and registers the providers using those.
Expand Down Expand Up @@ -654,7 +654,7 @@ def _discover_all_airflow_builtin_providers_from_local_sources(self) -> None:

def _add_provider_info_from_local_source_files_on_path(self, path) -> None:
"""
Finds all the provider.yaml files in the directory specified.
Find all the provider.yaml files in the directory specified.
:param path: path where to look for provider.yaml files
"""
Expand All @@ -672,7 +672,7 @@ def _add_provider_info_from_local_source_files_on_path(self, path) -> None:

def _add_provider_info_from_local_source_file(self, path, package_name) -> None:
"""
Parses found provider.yaml file and adds found provider to the dictionary.
Parse found provider.yaml file and adds found provider to the dictionary.
:param path: full file path of the provider.yaml file
:param package_name: name of the package
Expand Down Expand Up @@ -1069,23 +1069,23 @@ def _add_customized_fields(self, package_name: str, hook_class: type, customized
)

def _discover_auth_managers(self) -> None:
"""Retrieves all auth managers defined in the providers."""
"""Retrieve all auth managers defined in the providers."""
for provider_package, provider in self._provider_dict.items():
if provider.data.get("auth-managers"):
for auth_manager_class_name in provider.data["auth-managers"]:
if _correctness_check(provider_package, auth_manager_class_name, provider):
self._auth_manager_class_name_set.add(auth_manager_class_name)

def _discover_notifications(self) -> None:
"""Retrieves all notifications defined in the providers."""
"""Retrieve all notifications defined in the providers."""
for provider_package, provider in self._provider_dict.items():
if provider.data.get("notifications"):
for notification_class_name in provider.data["notifications"]:
if _correctness_check(provider_package, notification_class_name, provider):
self._notification_info_set.add(notification_class_name)

def _discover_extra_links(self) -> None:
"""Retrieves all extra links defined in the providers."""
"""Retrieve all extra links defined in the providers."""
for provider_package, provider in self._provider_dict.items():
if provider.data.get("extra-links"):
for extra_link_class_name in provider.data["extra-links"]:
Expand Down Expand Up @@ -1149,7 +1149,7 @@ def _discover_plugins(self) -> None:

@provider_info_cache("triggers")
def initialize_providers_triggers(self):
"""Initialization of providers triggers."""
"""Initialize providers triggers."""
self.initialize_providers_list()
for provider_package, provider in self._provider_dict.items():
for trigger in provider.data.get("triggers", []):
Expand Down
6 changes: 3 additions & 3 deletions airflow/serialization/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,20 +288,20 @@ def _convert(old: dict) -> dict:


def _match(classname: str) -> bool:
"""Checks if the given classname matches a path pattern either using glob format or regexp format."""
"""Check if the given classname matches a path pattern either using glob format or regexp format."""
return _match_glob(classname) or _match_regexp(classname)


@functools.lru_cache(maxsize=None)
def _match_glob(classname: str):
"""Checks if the given classname matches a pattern from allowed_deserialization_classes using glob syntax."""
"""Check if the given classname matches a pattern from allowed_deserialization_classes using glob syntax."""
patterns = _get_patterns()
return any(fnmatch(classname, p.pattern) for p in patterns)


@functools.lru_cache(maxsize=None)
def _match_regexp(classname: str):
"""Checks if the given classname matches a pattern from allowed_deserialization_classes_regexp using regexp."""
"""Check if the given classname matches a pattern from allowed_deserialization_classes_regexp using regexp."""
patterns = _get_regexp_patterns()
return any(p.match(classname) is not None for p in patterns)

Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def iter_airflow_imports(file_path: str) -> Generator[str, None, None]:


def get_unique_dag_module_name(file_path: str) -> str:
"""Returns a unique module name in the format unusual_prefix_{sha1 of module's file path}_{original module name}."""
"""Return a unique module name in the format unusual_prefix_{sha1 of module's file path}_{original module name}."""
if isinstance(file_path, str):
path_hash = hashlib.sha1(file_path.encode("utf-8")).hexdigest()
org_mod_name = Path(file_path).stem
Expand Down
2 changes: 1 addition & 1 deletion airflow/utils/log/task_context_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _should_enable(self) -> bool:

@staticmethod
def _get_task_handler() -> FileTaskHandler | None:
"""Returns the task handler that supports task context logging."""
"""Return the task handler that supports task context logging."""
handlers = [
handler
for handler in logging.getLogger("airflow.task").handlers
Expand Down
2 changes: 2 additions & 0 deletions airflow/utils/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ def process(value):

def compare_values(self, x, y):
"""
Compare x and y using self.comparator if available. Else, use __eq__.
The TaskInstance.executor_config attribute is a pickled object that may contain kubernetes objects.
If the installed library version has changed since the object was originally pickled,
Expand Down
4 changes: 2 additions & 2 deletions airflow/www/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def has_access(permissions: Sequence[tuple[str, str]] | None = None) -> Callable

def has_access_with_pk(f):
"""
This decorator is used to check permissions on views.
Check permissions on views.
The implementation is very similar from
https://github.com/dpgaspar/Flask-AppBuilder/blob/c6fecdc551629e15467fde5d06b4437379d90592/flask_appbuilder/security/decorators.py#L134
Expand Down Expand Up @@ -345,5 +345,5 @@ def decorated(*args, **kwargs):


def has_access_view(access_view: AccessView = AccessView.WEBSITE) -> Callable[[T], T]:
"""Decorator that checks current user's permissions to access the website."""
"""Check current user's permissions to access the website."""
return _has_access_no_details(lambda: get_auth_manager().is_authorized_view(access_view=access_view))
2 changes: 1 addition & 1 deletion airflow/www/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@

@routes.route("/")
def index():
"""Main Airflow page."""
"""Return main Airflow page."""
return redirect(url_for("Airflow.index"))
2 changes: 1 addition & 1 deletion airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3994,7 +3994,7 @@ def delete(self, pk):
@expose("/action_post", methods=["POST"])
def action_post(self):
"""
Action method to handle multiple records selected from a list view.
Handle multiple records selected from a list view.
Same implementation as
https://github.com/dpgaspar/Flask-AppBuilder/blob/2c5763371b81cd679d88b9971ba5d1fc4d71d54b/flask_appbuilder/views.py#L677
Expand Down
16 changes: 0 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1350,22 +1350,6 @@ combine-as-imports = true
"tests/providers/elasticsearch/log/elasticmock/utilities/__init__.py" = ["E402"]

# All the modules which do not follow D401 yet, please remove as soon as it becomes compatible
"airflow/auth/managers/utils/fab.py" = ["D401"]
"airflow/decorators/bash.py" = ["D401"]
"airflow/executors/debug_executor.py" = ["D401"]
"airflow/models/baseoperator.py" = ["D401"]
"airflow/models/dagrun.py" = ["D401"]
"airflow/models/taskinstance.py" = ["D401"]
"airflow/operators/python.py" = ["D401"]
"airflow/plugins_manager.py" = ["D401"]
"airflow/providers_manager.py" = ["D401"]
"airflow/serialization/serde.py" = ["D401"]
"airflow/utils/log/task_context_logger.py" = ["D401"]
"airflow/utils/sqlalchemy.py" = ["D401"]
"airflow/www/auth.py" = ["D401"]
"airflow/www/blueprints.py" = ["D401"]
"airflow/www/views.py" = ["D401"]
"airflow/utils/file.py" = ["D401"]
"airflow/providers/airbyte/hooks/airbyte.py" = ["D401"]
"airflow/providers/airbyte/operators/airbyte.py" = ["D401"]
"airflow/providers/airbyte/sensors/airbyte.py" = ["D401"]
Expand Down

0 comments on commit ab9e2e1

Please sign in to comment.