From b769edd5cff5db142705a64af4e4210c81261151 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 27 Jul 2023 00:35:40 +0400 Subject: [PATCH] feat(python): BytecodeParser can now handle mixed/nested `and/or` control flow (#10085) --- py-polars/polars/utils/udfs.py | 273 ++++++++++-------- py-polars/pyproject.toml | 1 + py-polars/tests/test_udfs.py | 34 +++ .../unit/operations/test_inefficient_apply.py | 1 - 4 files changed, 187 insertions(+), 122 deletions(-) diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 8d9090abfb6d..57d6ea4109e9 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -4,6 +4,7 @@ import dis import sys import warnings +from bisect import bisect_left from collections import defaultdict from dis import get_instructions from inspect import signature @@ -29,72 +30,66 @@ class StackValue(NamedTuple): ApplyTarget: TypeAlias = Literal["expr", "frame", "series"] StackEntry: TypeAlias = Union[str, StackValue] +_MIN_PY311 = sys.version_info >= (3, 11) + + +class OpNames: + BINARY = { + "BINARY_ADD": "+", + "BINARY_AND": "&", + "BINARY_FLOOR_DIVIDE": "//", + "BINARY_MODULO": "%", + "BINARY_MULTIPLY": "*", + "BINARY_OR": "|", + "BINARY_POWER": "**", + "BINARY_SUBTRACT": "-", + "BINARY_TRUE_DIVIDE": "/", + "BINARY_XOR": "^", + } + CALL = "CALL" if _MIN_PY311 else "CALL_*" + CONTROL_FLOW = ( + { + "POP_JUMP_FORWARD_IF_FALSE": "&", + "POP_JUMP_FORWARD_IF_TRUE": "|", + "JUMP_IF_FALSE_OR_POP": "&", + "JUMP_IF_TRUE_OR_POP": "|", + } + if _MIN_PY311 + else { + "POP_JUMP_IF_FALSE": "&", + "POP_JUMP_IF_TRUE": "|", + "JUMP_IF_FALSE_OR_POP": "&", + "JUMP_IF_TRUE_OR_POP": "|", + } + ) + LOAD_VALUES = frozenset(("LOAD_CONST", "LOAD_DEREF", "LOAD_FAST", "LOAD_GLOBAL")) + LOAD = LOAD_VALUES | {"LOAD_METHOD", "LOAD_ATTR"} + SYNTHETIC = { + "POLARS_EXPRESSION": 1, + } + UNARY = { + "UNARY_NEGATIVE": "-", + "UNARY_POSITIVE": "+", + "UNARY_NOT": "~", + } + PARSEABLE_OPS = ( + {"BINARY_OP", "COMPARE_OP", "CONTAINS_OP", "IS_OP"} + | set(UNARY) + | set(CONTROL_FLOW) + | set(SYNTHETIC) + | LOAD_VALUES + ) + UNARY_VALUES = frozenset(UNARY.values()) -# Note: in 3.11 individual binary opcodes were folded into a new BINARY_OP -_BINARY_OPCODES = { - "BINARY_ADD": "+", - "BINARY_AND": "&", - "BINARY_FLOOR_DIVIDE": "//", - "BINARY_MODULO": "%", - "BINARY_MULTIPLY": "*", - "BINARY_OR": "|", - "BINARY_POWER": "**", - "BINARY_SUBTRACT": "-", - "BINARY_TRUE_DIVIDE": "/", - "BINARY_XOR": "^", -} -_CONTROL_FLOW_OPCODES = { - # note: once we add additional JUMP op support, we'll need to disambiguate - # between and/or (what we currently support) and if/else (which we don't) - # "POP_JUMP_FORWARD_IF_FALSE": "?", - # "POP_JUMP_FORWARD_IF_TRUE": "?", - # "POP_JUMP_IF_FALSE": "?", - # "POP_JUMP_IF_TRUE": "?", - # "JUMP_FORWARD": "?", - "JUMP_IF_FALSE_OR_POP": "&", - "JUMP_IF_TRUE_OR_POP": "|", -} -_UNARY_OPCODES = { - "UNARY_NEGATIVE": "-", - "UNARY_POSITIVE": "+", - "UNARY_NOT": "~", -} -_SYNTHETIC_OPS = { - "POLARS_EXPRESSION": 1, -} -_LOAD_OPS = { - "LOAD_CONST", - "LOAD_DEREF", - "LOAD_FAST", - "LOAD_GLOBAL", -} -_SIMPLE_EXPR_OPS = { - "BINARY_OP", - "COMPARE_OP", - "CONTAINS_OP", - "IS_OP", -} -_SIMPLE_EXPR_OPS |= ( - set(_UNARY_OPCODES) | set(_CONTROL_FLOW_OPCODES) | set(_SYNTHETIC_OPS) | _LOAD_OPS -) -_SIMPLE_FRAME_OPS = _SIMPLE_EXPR_OPS | {"BINARY_SUBSCR"} -_UNARY_OPCODE_VALUES = set(_UNARY_OPCODES.values()) -_LOAD_OPS |= {"LOAD_METHOD", "LOAD_ATTR"} - -if sys.version_info < (3, 11): - _UPGRADE_BINARY_OPS = True - _CALL_OP = "CALL_*" -else: - _UPGRADE_BINARY_OPS = False - _CALL_OP = "CALL" # numpy functions that we can map to a native expression -_NUMPY_MODULE_ALIASES = {"np", "numpy"} -_NUMPY_FUNCTIONS = {"cbrt", "cos", "cosh", "sin", "sinh", "sqrt", "tan", "tanh"} - +_NUMPY_MODULE_ALIASES = frozenset(("np", "numpy")) +_NUMPY_FUNCTIONS = frozenset( + ("cbrt", "cos", "cosh", "sin", "sinh", "sqrt", "tan", "tanh") +) # python function that we can map to a native expression _PYTHON_CASTS_MAP = {"float": "Float64", "int": "Int64", "str": "Utf8"} -_PYTHON_METHOD_MAP = { +_PYTHON_METHODS_MAP = { "lower": "str.to_lowercase", "title": "str.to_titlecase", "upper": "str.to_uppercase", @@ -136,6 +131,53 @@ def _get_param_name(function: Callable[[Any], Any]) -> str | None: except ValueError: return None + def _inject_nesting( + self, + expression_blocks: dict[int, str], + logical_instructions: list[Instruction], + ) -> list[tuple[int, str]]: + """Inject nesting boundaries into expression blocks (as parentheses).""" + if logical_instructions: + # reconstruct nesting boundaries for mixed and/or ops by associating + # control flow jump offsets with their target expression blocks and + # injecting appropriate parentheses + combined_offset_idxs = set() + if len({inst.opname for inst in logical_instructions}) > 1: + block_offsets: list[int] = list(expression_blocks.keys()) + previous_logical_opname = "" + for i, inst in enumerate(logical_instructions): + # operator precedence means that we can combine logically connected + # 'and' blocks into one (depending on follow-on logic) and should + # parenthesise nested 'or' blocks + logical_op = OpNames.CONTROL_FLOW[inst.opname] + start = block_offsets[bisect_left(block_offsets, inst.offset) - 1] + if previous_logical_opname == ( + "POP_JUMP_FORWARD_IF_FALSE" + if _MIN_PY311 + else "POP_JUMP_IF_FALSE" + ): + # combine logical '&' blocks (and update start/block_offsets) + prev = block_offsets[bisect_left(block_offsets, start) - 1] + expression_blocks[prev] += f" & {expression_blocks.pop(start)}" + block_offsets = list(expression_blocks.keys()) + combined_offset_idxs.add(i - 1) + start = prev + + if logical_op == "|": + # parenthesise connected 'or' blocks + end = block_offsets[bisect_left(block_offsets, inst.argval) - 1] + if not (start == 0 and end == block_offsets[-1]): + expression_blocks[start] = "(" + expression_blocks[start] + expression_blocks[end] += ")" + + previous_logical_opname = inst.opname + + for i, inst in enumerate(logical_instructions): + if i not in combined_offset_idxs: + expression_blocks[inst.offset] = OpNames.CONTROL_FLOW[inst.opname] + + return sorted(expression_blocks.items()) + @property def apply_target(self) -> ApplyTarget: """The apply target, eg: one of 'expr', 'frame', or 'series'.""" @@ -154,21 +196,12 @@ def can_rewrite(self) -> bool: else: self._can_rewrite[self._apply_target] = False if self._rewritten_instructions and self._param_name is not None: - simple_ops = ( - _SIMPLE_FRAME_OPS - if self._apply_target == "frame" - else _SIMPLE_EXPR_OPS + self._can_rewrite[self._apply_target] = len( + self._rewritten_instructions + ) >= 2 and all( + inst.opname in OpNames.PARSEABLE_OPS + for inst in self._rewritten_instructions ) - if len(self._rewritten_instructions) >= 2 and all( - inst.opname in simple_ops for inst in self._rewritten_instructions - ): - # can (currently) only handle logical 'and'/'or' if they not mixed - logical_ops = { - _CONTROL_FLOW_OPCODES[inst.opname] - for inst in self._rewritten_instructions - if inst.opname in _CONTROL_FLOW_OPCODES - } - self._can_rewrite[self._apply_target] = len(logical_ops) <= 1 return self._can_rewrite[self._apply_target] @@ -196,50 +229,45 @@ def rewritten_instructions(self) -> list[Instruction]: """The rewritten bytecode instructions from the function we are parsing.""" return list(self._rewritten_instructions) - def to_expression(self, col: str) -> str | None: - """Translate postfix bytecode instructions to polars expression string.""" + def to_expression(self, col: str, as_repr: bool = True) -> str | None: + """Translate postfix bytecode instructions to polars expression/string.""" if not self.can_rewrite() or self._param_name is None: return None # decompose bytecode into logical 'and'/'or' expression blocks (if present) - logical_instruction_blocks = defaultdict(list) + control_flow_blocks = defaultdict(list) logical_instructions = [] jump_offset = 0 for idx, inst in enumerate(self._rewritten_instructions): - if inst.opname in _CONTROL_FLOW_OPCODES: + if inst.opname in OpNames.CONTROL_FLOW: jump_offset = self._rewritten_instructions[idx + 1].offset logical_instructions.append(inst) else: - logical_instruction_blocks[jump_offset].append(inst) - - # convert each logical block to a polars expression string - expression_strings = { - offset: InstructionTranslator( - instructions=ops, - apply_target=self._apply_target, - ).to_expression( - col=col, - param_name=self._param_name, - depth=int(bool(logical_instructions)), - ) - for offset, ops in logical_instruction_blocks.items() - } - - for inst in logical_instructions: - expression_strings[inst.offset] = InstructionTranslator.op(inst) - - # TODO: handle mixed 'and'/'or' blocks (e.g. `x > 0 AND (x > 0 OR (x == -1)`). - # (need to reconstruct the correct nesting boundaries to do this properly) - - exprs = sorted(expression_strings.items()) - polars_expr = " ".join(expr for _, expr in exprs) + control_flow_blocks[jump_offset].append(inst) + + # convert each block to a polars expression string + expression_blocks = self._inject_nesting( + { + offset: InstructionTranslator( + instructions=ops, + apply_target=self._apply_target, + ).to_expression( + col=col, + param_name=self._param_name, + depth=int(bool(logical_instructions)), + ) + for offset, ops in control_flow_blocks.items() + }, + logical_instructions, + ) + polars_expr = " ".join(expr for _offset, expr in expression_blocks) # note: if no 'pl.col' in the expression, it likely represents a compound # constant value (e.g. `lambda x: CONST + 123`), so we don't want to warn if "pl.col(" not in polars_expr: return None - return polars_expr + return polars_expr if as_repr else eval(polars_expr, globals()) def warn( self, @@ -300,16 +328,16 @@ def to_expression(self, col: str, param_name: str, depth: int) -> str: @classmethod def op(cls, inst: Instruction) -> str: """Convert bytecode instruction to suitable intermediate op string.""" - if inst.opname in _CONTROL_FLOW_OPCODES: - return _CONTROL_FLOW_OPCODES[inst.opname] + if inst.opname in OpNames.CONTROL_FLOW: + return OpNames.CONTROL_FLOW[inst.opname] elif inst.argrepr: return inst.argrepr elif inst.opname == "IS_OP": return "is not" if inst.argval else "is" elif inst.opname == "CONTAINS_OP": return "not in" if inst.argval else "in" - elif inst.opname in _UNARY_OPCODES: - return _UNARY_OPCODES[inst.opname] + elif inst.opname in OpNames.UNARY: + return OpNames.UNARY[inst.opname] else: raise AssertionError( "Unrecognised opname; please report a bug to https://github.com/pola-rs/polars/issues " @@ -324,7 +352,7 @@ def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str: op = value.operator e1 = cls._expr(value.left_operand, col, param_name, depth + 1) if value.operator_arity == 1: - if op not in _UNARY_OPCODE_VALUES: + if op not in OpNames.UNARY_VALUES: call = "" if op.endswith(")") else "()" return f"{e1}.{op}{call}" return f"{op}{e1}" @@ -358,7 +386,7 @@ def _to_intermediate_stack( for inst in instructions: stack.append( inst.argrepr - if inst.opname in _LOAD_OPS + if inst.opname in OpNames.LOAD else ( StackValue( operator=self.op(inst), @@ -367,8 +395,8 @@ def _to_intermediate_stack( right_operand=None, # type: ignore[arg-type] ) if ( - inst.opname in _UNARY_OPCODES - or _SYNTHETIC_OPS.get(inst.opname) == 1 + inst.opname in OpNames.UNARY + or OpNames.SYNTHETIC.get(inst.opname) == 1 ) else StackValue( operator=self.op(inst), @@ -418,7 +446,7 @@ def _matches( idx: int, *, opnames: list[str], - argvals: list[set[Any] | dict[Any, Any]] | None, + argvals: list[set[Any] | frozenset[Any] | dict[Any, Any]] | None, ) -> list[Instruction]: """ Check if a sequence of Instructions matches the specified ops/argvals. @@ -461,7 +489,7 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]: idx = 0 while idx < len(self._instructions): inst, increment = self._instructions[idx], 1 - if inst.opname not in _LOAD_OPS or not any( + if inst.opname not in OpNames.LOAD or not any( (increment := apply_rewrite(idx, updated_instructions)) for apply_rewrite in ( # add any other rewrite methods here @@ -480,7 +508,7 @@ def _rewrite_builtins( """Replace builtin function calls with a synthetic POLARS_EXPRESSION op.""" if matching_instructions := self._matches( idx, - opnames=["LOAD_GLOBAL", "LOAD_FAST", _CALL_OP], + opnames=["LOAD_GLOBAL", "LOAD_FAST", OpNames.CALL], argvals=[_PYTHON_CASTS_MAP], ): inst1, inst2 = matching_instructions[:2] @@ -503,8 +531,11 @@ def _rewrite_functions( """Replace numpy/json function calls with a synthetic POLARS_EXPRESSION op.""" if matching_instructions := self._matches( idx, - opnames=["LOAD_GLOBAL", "LOAD_*", "LOAD_*", _CALL_OP], - argvals=[_NUMPY_MODULE_ALIASES | {"json"}, _NUMPY_FUNCTIONS | {"loads"}], + opnames=["LOAD_GLOBAL", "LOAD_*", "LOAD_*", OpNames.CALL], + argvals=[ + _NUMPY_MODULE_ALIASES | {"json"}, + _NUMPY_FUNCTIONS | {"loads"}, + ], ): inst1, inst2, inst3 = matching_instructions[:3] expr_name = "str.json_extract" if inst1.argval == "json" else inst2.argval @@ -526,11 +557,11 @@ def _rewrite_methods( """Replace python method calls with synthetic POLARS_EXPRESSION op.""" if matching_instructions := self._matches( idx, - opnames=["LOAD_METHOD", _CALL_OP], - argvals=[_PYTHON_METHOD_MAP], + opnames=["LOAD_METHOD", OpNames.CALL], + argvals=[_PYTHON_METHODS_MAP], ): inst = matching_instructions[0] - expr_name = _PYTHON_METHOD_MAP[inst.argval] + expr_name = _PYTHON_METHODS_MAP[inst.argval] synthetic_call = inst._replace( opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name ) @@ -541,9 +572,9 @@ def _rewrite_methods( @staticmethod def _upgrade_instruction(inst: Instruction) -> Instruction: """Rewrite any older binary opcodes using py 3.11 'BINARY_OP' instead.""" - if _UPGRADE_BINARY_OPS and inst.opname in _BINARY_OPCODES: + if not _MIN_PY311 and inst.opname in OpNames.BINARY: inst = inst._replace( - argrepr=_BINARY_OPCODES[inst.opname], + argrepr=OpNames.BINARY[inst.opname], opname="BINARY_OP", ) return inst diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 30624fca245f..11bb61416009 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -162,6 +162,7 @@ strict = true "polars/datatypes.py" = ["B019"] "tests/**/*.py" = ["D100", "D103", "B018"] "polars/utils/show_versions.py" = ["D301"] +"polars/utils/udfs.py" = ["RUF012"] [tool.pytest.ini_options] addopts = [ diff --git a/py-polars/tests/test_udfs.py b/py-polars/tests/test_udfs.py index 7053a49371a1..b223c66286a4 100644 --- a/py-polars/tests/test_udfs.py +++ b/py-polars/tests/test_udfs.py @@ -23,7 +23,9 @@ # column_name, function, expected_suggestion TEST_CASES = [ + # --------------------------------------------- # numeric expr: math, comparison, logic ops + # --------------------------------------------- ("a", lambda x: x + 1 - (2 / 3), '(pl.col("a") + 1) - 0.6666666666666666'), ("a", lambda x: x // 1 % 2, '(pl.col("a") // 1) % 2'), ("a", lambda x: x & True, 'pl.col("a") & True'), @@ -68,14 +70,46 @@ lambda x: (float(x) * int(x)) // 2, '(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2', ), + # --------------------------------------------- + # logical 'and/or' (validate nesting levels) + # --------------------------------------------- + ( + "a", + lambda x: x > 1 or (x == 1 and x == 2), + '(pl.col("a") > 1) | (pl.col("a") == 1) & (pl.col("a") == 2)', + ), + ( + "a", + lambda x: (x > 1 or x == 1) and x == 2, + '((pl.col("a") > 1) | (pl.col("a") == 1)) & (pl.col("a") == 2)', + ), + ( + "a", + lambda x: x > 2 or x != 3 and x not in (0, 1, 4), + '(pl.col("a") > 2) | (pl.col("a") != 3) & ~pl.col("a").is_in((0, 1, 4))', + ), + ( + "a", + lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3, + '(pl.col("a") > 1) & (pl.col("a") != 2) | ((pl.col("a") % 2) == 0) & (pl.col("a") < 3)', + ), + ( + "a", + lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3, + '(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)', + ), + # --------------------------------------------- # string expr: case/cast ops + # --------------------------------------------- ("b", lambda x: str(x).title(), 'pl.col("b").cast(pl.Utf8).str.to_titlecase()'), ( "b", lambda x: x.lower() + ":" + x.upper() + ":" + x.title(), '(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()', ), + # --------------------------------------------- # json expr: load/extract + # --------------------------------------------- ("c", lambda x: json.loads(x), 'pl.col("c").str.json_extract()'), ] diff --git a/py-polars/tests/unit/operations/test_inefficient_apply.py b/py-polars/tests/unit/operations/test_inefficient_apply.py index 833efdd603b7..077afd50ca92 100644 --- a/py-polars/tests/unit/operations/test_inefficient_apply.py +++ b/py-polars/tests/unit/operations/test_inefficient_apply.py @@ -19,7 +19,6 @@ lambda x: x, lambda x, y: x + y, lambda x: x[0] + 1, - lambda x: x > 0 and (x < 100 or (x % 2 == 0)), ], ) def test_parse_invalid_function(func: Callable[[Any], Any]) -> None: