Skip to content

Commit

Permalink
refactor XComChecker (#21)
Browse files Browse the repository at this point in the history
- Split out functions and a check method from visit_module
- Cleaned up some of the function code to simplify/improve readability
  • Loading branch information
topherinternational authored Dec 22, 2023
1 parent 46abc00 commit 33e8730
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 57 deletions.
162 changes: 106 additions & 56 deletions src/pylint_airflow/checkers/xcom.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
"""Checks on Airflow XComs."""
"""Checks on Airflow XComs.
This module contains the XComChecker class and a collection of functions.
XComChecker contains only:
- Methods interfacing with the pylint checker API (i.e. `visit_<nodetype>()` methods)
- Methods that add pylint messages for rules violations (`check_<message>()`)
The module-level functions perform any work that isn't a pylint checker method or adding a message.
"""
from dataclasses import dataclass
from typing import Set, Dict, Tuple

import astroid
from pylint import checkers
Expand All @@ -13,9 +24,83 @@
"Return values from a python_callable function or execute() method are "
"automatically pushed as XCom.",
)
# TODO: add a check for pulling XComs that were never pushed
# TODO: make unused-xcom check sensitive to multiple pushes in the same callable
}


@dataclass
class PythonOperatorSpec:
"""Data class to hold the call (constructor) node for a PythonOperator construction and
the name of the function passed as the python_callable argument."""

python_operator_call_node: astroid.Call
python_callable_function_name: str


def get_task_ids_to_python_callable_specs(node: astroid.Module) -> Dict[str, PythonOperatorSpec]:
"""Fill this in"""
assign_nodes = node.nodes_of_class(astroid.Assign)
call_nodes = [assign.value for assign in assign_nodes if isinstance(assign.value, astroid.Call)]

# Store nodes containing python_callable arg as:
# {task_id: PythonOperatorSpec(call node, python_callable func name)}
task_ids_to_python_callable_specs = {}
for call_node in call_nodes:
if call_node.keywords:
task_id = ""
python_callable_function_name = ""
for keyword in call_node.keywords:
if keyword.arg == "python_callable" and isinstance(keyword.value, astroid.Name):
python_callable_function_name = keyword.value.name
elif keyword.arg == "task_id" and isinstance(keyword.value, astroid.Const):
task_id = keyword.value.value # TODO: support non-Const args

if python_callable_function_name:
task_ids_to_python_callable_specs[task_id] = PythonOperatorSpec(
call_node, python_callable_function_name
)

return task_ids_to_python_callable_specs


def get_xcoms_from_tasks(
node: astroid.Module, task_ids_to_python_callable_specs: Dict[str, PythonOperatorSpec]
) -> Tuple[Dict[str, PythonOperatorSpec], Set[str]]:
"""Now fetch the functions mentioned by python_callable args"""
xcoms_pushed = {}
xcoms_pulled_taskids = set()

for task_id, python_operator_spec in task_ids_to_python_callable_specs.items():
callable_func_name = python_operator_spec.python_callable_function_name
if callable_func_name == "<lambda>": # TODO support lambdas
continue

callable_func = node.getattr(callable_func_name)[0]

if not isinstance(callable_func, astroid.FunctionDef):
continue # Callable_func is str not FunctionDef when imported

# Check if the function returns any values
if any(isinstance(statement, astroid.Return) for statement in callable_func.body):
# Found a return statement
xcoms_pushed[task_id] = python_operator_spec

# Check if the function pulls any XComs
callable_func_calls = callable_func.nodes_of_class(astroid.Call)
for callable_func_call in callable_func_calls:
callable_func = callable_func_call.func
if (
isinstance(callable_func, astroid.Attribute)
and callable_func.attrname == "xcom_pull"
):
for keyword in callable_func_call.keywords:
if keyword.arg == "task_ids" and isinstance(keyword.value, astroid.Const):
xcoms_pulled_taskids.add(keyword.value.value)

return xcoms_pushed, xcoms_pulled_taskids


class XComChecker(checkers.BaseChecker):
"""Checks on Airflow XComs."""

Expand All @@ -28,62 +113,27 @@ def visit_module(self, node: astroid.Module):
XComs can be set (pushed) implicitly via return of a python_callable or
execute() of an operator. And explicitly by calling xcom_push().
Currently this only checks unused XComs from return value of a python_callable.
Currently, this only checks unused XComs from return value of a python_callable.
"""
# pylint: disable=too-many-locals,too-many-branches,too-many-nested-blocks
assign_nodes = [n for n in node.body if isinstance(n, astroid.Assign)]
call_nodes = [n.value for n in assign_nodes if isinstance(n.value, astroid.Call)]

# Store nodes containing python_callable arg as:
# {task_id: (call node, python_callable func name)}
python_callable_nodes = {}
for call_node in call_nodes:
if call_node.keywords:
task_id = ""
python_callable = ""
for keyword in call_node.keywords:
if keyword.arg == "python_callable":
python_callable = keyword.value.name
continue
if keyword.arg == "task_id":
task_id = keyword.value.value

if python_callable:
python_callable_nodes[task_id] = (call_node, python_callable)

# Now fetch the functions mentioned by python_callable args
xcoms_pushed = {}
xcoms_pulled_taskids = set()
for (task_id, (python_callable, callable_func_name)) in python_callable_nodes.items():
if callable_func_name != "<lambda>":
# TODO support lambdas
callable_func = node.getattr(callable_func_name)[0]

if isinstance(callable_func, astroid.FunctionDef):
# Callable_func is str not FunctionDef when imported
callable_func = node.getattr(callable_func_name)[0]

# Check if the function returns any values
if any(isinstance(n, astroid.Return) for n in callable_func.body):
# Found a return statement
xcoms_pushed[task_id] = (python_callable, callable_func_name)

# Check if the function pulls any XComs
callable_func_calls = callable_func.nodes_of_class(astroid.Call)
for callable_func_call in callable_func_calls:
if (
isinstance(callable_func_call.func, astroid.Attribute)
and callable_func_call.func.attrname == "xcom_pull"
):
for keyword in callable_func_call.keywords:
if keyword.arg == "task_ids":
xcoms_pulled_taskids.add(keyword.value.value)
python_callable_nodes = get_task_ids_to_python_callable_specs(node)
xcoms_pushed, xcoms_pulled_taskids = get_xcoms_from_tasks(node, python_callable_nodes)

self.check_unused_xcoms(xcoms_pushed, xcoms_pulled_taskids)

def check_unused_xcoms(
self, xcoms_pushed: Dict[str, PythonOperatorSpec], xcoms_pulled_taskids: Set[str]
):
"""Adds a message for every key in the xcoms_pushed dictionary that is not present in
xcoms_pulled_taskids. Note that this check does _not_ flag IDs in xcoms_pulled_taskids
that are not present in the xcoms_pushed dictionary."""
remainder = xcoms_pushed.keys() - xcoms_pulled_taskids
if remainder:
# There's a remainder in xcoms_pushed_taskids which should've been xcom_pulled.
for remainder_task_id in remainder:
python_callable, callable_func_name = xcoms_pushed[remainder_task_id]
self.add_message("unused-xcom", node=python_callable, args=callable_func_name)

# pylint: enable=too-many-locals,too-many-branches,too-many-nested-blocks
# There's a task_id in xcoms_pushed_taskids which should have been xcom_pull'd
sorted_remainder = sorted(list(remainder)) # guarantee repeatable ordering of messages
for remainder_task_id in sorted_remainder:
python_operator_spec = xcoms_pushed[remainder_task_id]
self.add_message(
"unused-xcom",
node=python_operator_spec.python_operator_call_node,
args=python_operator_spec.python_callable_function_name,
)
58 changes: 57 additions & 1 deletion tests/pylint_airflow/checkers/test_xcom.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Tests for the XCom checker."""
# pylint: disable=missing-function-docstring
"""Tests for the XCom checker and its helper functions."""

import astroid
from pylint.testutils import CheckerTestCase, MessageTest

import pylint_airflow
from pylint_airflow.checkers.xcom import PythonOperatorSpec


class TestXComChecker(CheckerTestCase):
Expand Down Expand Up @@ -55,3 +57,57 @@ def _pulltask():
ignore_position=True,
):
self.checker.visit_module(ast)


class TestCheckUnusedXComs(CheckerTestCase):
"""Tests for the XCom checker."""

CHECKER_CLASS = pylint_airflow.checkers.xcom.XComChecker

def test_empty_inputs_should_not_message(self):
test_xcoms_pushed = {}
test_xcoms_pulled_taskids = set()

with self.assertNoMessages():
self.checker.check_unused_xcoms(test_xcoms_pushed, test_xcoms_pulled_taskids)

def test_all_xcoms_used_should_not_message(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
pushtask = PythonOperator(task_id="pushtask", python_callable=_pushtask)
# further code omitted as not necessary for the test
"""
ast = astroid.parse(test_code)
push_call = ast.body[1].value

test_xcoms_pushed = {"pushtask": PythonOperatorSpec(push_call, "_pushtask")}
test_xcoms_pulled_taskids = {"pushtask"}

with self.assertNoMessages():
self.checker.check_unused_xcoms(test_xcoms_pushed, test_xcoms_pulled_taskids)

def test_xcoms_not_used_should_not_message(self):
test_code = """
from airflow.operators.python_operator import PythonOperator
pushtask_1 = PythonOperator(task_id="pushtask_1", python_callable=_pushtask_1)
pushtask_2 = PythonOperator(task_id="pushtask_2", python_callable=_pushtask_2)
# further code omitted as not necessary for the test
"""
ast = astroid.parse(test_code)
push_call_1 = ast.body[1].value
push_call_2 = ast.body[2].value

test_xcoms_pushed = {
"pushtask_1": PythonOperatorSpec(push_call_1, "_pushtask_1"),
"pushtask_2": PythonOperatorSpec(push_call_2, "_pushtask_2"),
}
test_xcoms_pulled_taskids = set()

with self.assertAddsMessages(
MessageTest(msg_id="unused-xcom", node=push_call_1, args="_pushtask_1"),
MessageTest(msg_id="unused-xcom", node=push_call_2, args="_pushtask_2"),
ignore_position=True,
):
self.checker.check_unused_xcoms(test_xcoms_pushed, test_xcoms_pulled_taskids)

0 comments on commit 33e8730

Please sign in to comment.