From 482d486c21561ff9b2007a86e7ed8bf80ab270a9 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 25 Jul 2023 18:25:01 +0400 Subject: [PATCH 1/4] feat(python): BytecodeParser can now handle mixed/nested `and/or` control flow --- py-polars/polars/utils/udfs.py | 84 +++++++++++++------ py-polars/tests/test_udfs.py | 26 ++++++ .../unit/operations/test_inefficient_apply.py | 1 - 3 files changed, 84 insertions(+), 27 deletions(-) diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 8d9090abfb6d..06ec30a8514c 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -4,6 +4,7 @@ import dis import sys import warnings +from bisect import bisect_left from collections import defaultdict from dis import get_instructions from inspect import signature @@ -46,11 +47,11 @@ class StackValue(NamedTuple): _CONTROL_FLOW_OPCODES = { # note: once we add additional JUMP op support, we'll need to disambiguate # between and/or (what we currently support) and if/else (which we don't) - # "POP_JUMP_FORWARD_IF_FALSE": "?", - # "POP_JUMP_FORWARD_IF_TRUE": "?", - # "POP_JUMP_IF_FALSE": "?", - # "POP_JUMP_IF_TRUE": "?", - # "JUMP_FORWARD": "?", + # "POP_JUMP_IF_FALSE": ..., + # "POP_JUMP_IF_TRUE": ..., + # "JUMP_FORWARD": ..., + "POP_JUMP_FORWARD_IF_FALSE": "&", + "POP_JUMP_FORWARD_IF_TRUE": "|", "JUMP_IF_FALSE_OR_POP": "&", "JUMP_IF_TRUE_OR_POP": "|", } @@ -159,16 +160,11 @@ def can_rewrite(self) -> bool: if self._apply_target == "frame" else _SIMPLE_EXPR_OPS ) - if len(self._rewritten_instructions) >= 2 and all( + self._can_rewrite[self._apply_target] = len( + self._rewritten_instructions + ) >= 2 and all( inst.opname in simple_ops for inst in self._rewritten_instructions - ): - # can (currently) only handle logical 'and'/'or' if they not mixed - logical_ops = { - _CONTROL_FLOW_OPCODES[inst.opname] - for inst in self._rewritten_instructions - if inst.opname in _CONTROL_FLOW_OPCODES - } - self._can_rewrite[self._apply_target] = len(logical_ops) <= 1 + ) return self._can_rewrite[self._apply_target] @@ -196,13 +192,13 @@ def rewritten_instructions(self) -> list[Instruction]: """The rewritten bytecode instructions from the function we are parsing.""" return list(self._rewritten_instructions) - def to_expression(self, col: str) -> str | None: - """Translate postfix bytecode instructions to polars expression string.""" + def to_expression(self, col: str, as_repr: bool = True) -> str | None: + """Translate postfix bytecode instructions to polars expression/string.""" if not self.can_rewrite() or self._param_name is None: return None # decompose bytecode into logical 'and'/'or' expression blocks (if present) - logical_instruction_blocks = defaultdict(list) + control_flow_blocks = defaultdict(list) logical_instructions = [] jump_offset = 0 for idx, inst in enumerate(self._rewritten_instructions): @@ -210,9 +206,9 @@ def to_expression(self, col: str) -> str | None: jump_offset = self._rewritten_instructions[idx + 1].offset logical_instructions.append(inst) else: - logical_instruction_blocks[jump_offset].append(inst) + control_flow_blocks[jump_offset].append(inst) - # convert each logical block to a polars expression string + # convert each block to a polars expression string expression_strings = { offset: InstructionTranslator( instructions=ops, @@ -222,14 +218,50 @@ def to_expression(self, col: str) -> str | None: param_name=self._param_name, depth=int(bool(logical_instructions)), ) - for offset, ops in logical_instruction_blocks.items() + for offset, ops in control_flow_blocks.items() } - for inst in logical_instructions: - expression_strings[inst.offset] = InstructionTranslator.op(inst) - - # TODO: handle mixed 'and'/'or' blocks (e.g. `x > 0 AND (x > 0 OR (x == -1)`). - # (need to reconstruct the correct nesting boundaries to do this properly) + if logical_instructions: + # reconstruct nesting boundaries for mixed and/or ops by associating + # control flow jump offsets with their target expression blocks and + # injecting appropriate parentheses + combined_offset_idxs = set() + if len({inst.opname for inst in logical_instructions}) > 1: + block_offsets: list[int] = list(expression_strings.keys()) + idx_min, idx_max = 0, len(block_offsets) - 1 + previous_logical_opname = "" + for i, inst in enumerate(logical_instructions): + # operator precedence means that we can combine logically connected + # 'and' blocks into one (depending on follow-on logic) and should + # parenthesise nested 'or' blocks + logical_op = _CONTROL_FLOW_OPCODES[inst.opname] + start = block_offsets[ + max(idx_min, bisect_left(block_offsets, inst.offset) - 1) + ] + if previous_logical_opname == "POP_JUMP_FORWARD_IF_FALSE": + prev = block_offsets[ + max(idx_min, bisect_left(block_offsets, start) - 1) + ] + expression_strings[ + prev + ] += f" & {expression_strings.pop(start)}" + block_offsets = list(expression_strings.keys()) + combined_offset_idxs.add(i - 1) + start = prev + + if logical_op == "|": + end = block_offsets[ + min(idx_max, bisect_left(block_offsets, inst.argval) - 1) + ] + if not (start == 0 and end == block_offsets[-1]): + expression_strings[start] = "(" + expression_strings[start] + expression_strings[end] += ")" + + previous_logical_opname = inst.opname + + for i, inst in enumerate(logical_instructions): + if i not in combined_offset_idxs: + expression_strings[inst.offset] = _CONTROL_FLOW_OPCODES[inst.opname] exprs = sorted(expression_strings.items()) polars_expr = " ".join(expr for _, expr in exprs) @@ -239,7 +271,7 @@ def to_expression(self, col: str) -> str | None: if "pl.col(" not in polars_expr: return None - return polars_expr + return polars_expr if as_repr else eval(polars_expr, globals()) def warn( self, diff --git a/py-polars/tests/test_udfs.py b/py-polars/tests/test_udfs.py index 7053a49371a1..1cd9e60d8b4a 100644 --- a/py-polars/tests/test_udfs.py +++ b/py-polars/tests/test_udfs.py @@ -68,6 +68,32 @@ lambda x: (float(x) * int(x)) // 2, '(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2', ), + # logical 'and/or' (validate nesting levels) + ( + "a", + lambda x: x > 1 or (x == 1 and x == 2), + '(pl.col("a") > 1) | (pl.col("a") == 1) & (pl.col("a") == 2)', + ), + ( + "a", + lambda x: (x > 1 or x == 1) and x == 2, + '((pl.col("a") > 1) | (pl.col("a") == 1)) & (pl.col("a") == 2)', + ), + ( + "a", + lambda x: x > 2 or x != 3 and x not in (0, 1, 4), + '(pl.col("a") > 2) | (pl.col("a") != 3) & ~pl.col("a").is_in((0, 1, 4))', + ), + ( + "a", + lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3, + '(pl.col("a") > 1) & (pl.col("a") != 2) | ((pl.col("a") % 2) == 0) & (pl.col("a") < 3)', + ), + ( + "a", + lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3, + '(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)', + ), # string expr: case/cast ops ("b", lambda x: str(x).title(), 'pl.col("b").cast(pl.Utf8).str.to_titlecase()'), ( diff --git a/py-polars/tests/unit/operations/test_inefficient_apply.py b/py-polars/tests/unit/operations/test_inefficient_apply.py index 833efdd603b7..077afd50ca92 100644 --- a/py-polars/tests/unit/operations/test_inefficient_apply.py +++ b/py-polars/tests/unit/operations/test_inefficient_apply.py @@ -19,7 +19,6 @@ lambda x: x, lambda x, y: x + y, lambda x: x[0] + 1, - lambda x: x > 0 and (x < 100 or (x % 2 == 0)), ], ) def test_parse_invalid_function(func: Callable[[Any], Any]) -> None: From 44f5cd309e065af71cb145948bf6da2362456dd8 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Wed, 26 Jul 2023 18:59:53 +0400 Subject: [PATCH 2/4] tidy-up OpCodes and support mixed and/or chains for python < py3.11 --- py-polars/polars/utils/udfs.py | 175 ++++++++++++++++----------------- py-polars/pyproject.toml | 1 + 2 files changed, 87 insertions(+), 89 deletions(-) diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 06ec30a8514c..79d2042e3ba2 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -30,72 +30,66 @@ class StackValue(NamedTuple): ApplyTarget: TypeAlias = Literal["expr", "frame", "series"] StackEntry: TypeAlias = Union[str, StackValue] +_MIN_PY311 = sys.version_info >= (3, 11) + + +class OpCodes: + BINARY = { + "BINARY_ADD": "+", + "BINARY_AND": "&", + "BINARY_FLOOR_DIVIDE": "//", + "BINARY_MODULO": "%", + "BINARY_MULTIPLY": "*", + "BINARY_OR": "|", + "BINARY_POWER": "**", + "BINARY_SUBTRACT": "-", + "BINARY_TRUE_DIVIDE": "/", + "BINARY_XOR": "^", + } + CALL = "CALL" if _MIN_PY311 else "CALL_*" + CONTROL_FLOW = ( + { + "POP_JUMP_FORWARD_IF_FALSE": "&", + "POP_JUMP_FORWARD_IF_TRUE": "|", + "JUMP_IF_FALSE_OR_POP": "&", + "JUMP_IF_TRUE_OR_POP": "|", + } + if _MIN_PY311 + else { + "POP_JUMP_IF_FALSE": "&", + "POP_JUMP_IF_TRUE": "|", + "JUMP_IF_FALSE_OR_POP": "&", + "JUMP_IF_TRUE_OR_POP": "|", + } + ) + LOAD_VALUES = frozenset(("LOAD_CONST", "LOAD_DEREF", "LOAD_FAST", "LOAD_GLOBAL")) + LOAD = LOAD_VALUES | {"LOAD_METHOD", "LOAD_ATTR"} + SYNTHETIC = { + "POLARS_EXPRESSION": 1, + } + UNARY = { + "UNARY_NEGATIVE": "-", + "UNARY_POSITIVE": "+", + "UNARY_NOT": "~", + } + PARSEABLE_OPS = ( + {"BINARY_OP", "COMPARE_OP", "CONTAINS_OP", "IS_OP"} + | set(UNARY) + | set(CONTROL_FLOW) + | set(SYNTHETIC) + | LOAD_VALUES + ) + UNARY_VALUES = frozenset(UNARY.values()) -# Note: in 3.11 individual binary opcodes were folded into a new BINARY_OP -_BINARY_OPCODES = { - "BINARY_ADD": "+", - "BINARY_AND": "&", - "BINARY_FLOOR_DIVIDE": "//", - "BINARY_MODULO": "%", - "BINARY_MULTIPLY": "*", - "BINARY_OR": "|", - "BINARY_POWER": "**", - "BINARY_SUBTRACT": "-", - "BINARY_TRUE_DIVIDE": "/", - "BINARY_XOR": "^", -} -_CONTROL_FLOW_OPCODES = { - # note: once we add additional JUMP op support, we'll need to disambiguate - # between and/or (what we currently support) and if/else (which we don't) - # "POP_JUMP_IF_FALSE": ..., - # "POP_JUMP_IF_TRUE": ..., - # "JUMP_FORWARD": ..., - "POP_JUMP_FORWARD_IF_FALSE": "&", - "POP_JUMP_FORWARD_IF_TRUE": "|", - "JUMP_IF_FALSE_OR_POP": "&", - "JUMP_IF_TRUE_OR_POP": "|", -} -_UNARY_OPCODES = { - "UNARY_NEGATIVE": "-", - "UNARY_POSITIVE": "+", - "UNARY_NOT": "~", -} -_SYNTHETIC_OPS = { - "POLARS_EXPRESSION": 1, -} -_LOAD_OPS = { - "LOAD_CONST", - "LOAD_DEREF", - "LOAD_FAST", - "LOAD_GLOBAL", -} -_SIMPLE_EXPR_OPS = { - "BINARY_OP", - "COMPARE_OP", - "CONTAINS_OP", - "IS_OP", -} -_SIMPLE_EXPR_OPS |= ( - set(_UNARY_OPCODES) | set(_CONTROL_FLOW_OPCODES) | set(_SYNTHETIC_OPS) | _LOAD_OPS -) -_SIMPLE_FRAME_OPS = _SIMPLE_EXPR_OPS | {"BINARY_SUBSCR"} -_UNARY_OPCODE_VALUES = set(_UNARY_OPCODES.values()) -_LOAD_OPS |= {"LOAD_METHOD", "LOAD_ATTR"} - -if sys.version_info < (3, 11): - _UPGRADE_BINARY_OPS = True - _CALL_OP = "CALL_*" -else: - _UPGRADE_BINARY_OPS = False - _CALL_OP = "CALL" # numpy functions that we can map to a native expression -_NUMPY_MODULE_ALIASES = {"np", "numpy"} -_NUMPY_FUNCTIONS = {"cbrt", "cos", "cosh", "sin", "sinh", "sqrt", "tan", "tanh"} - +_NUMPY_MODULE_ALIASES = frozenset(("np", "numpy")) +_NUMPY_FUNCTIONS = frozenset( + ("cbrt", "cos", "cosh", "sin", "sinh", "sqrt", "tan", "tanh") +) # python function that we can map to a native expression _PYTHON_CASTS_MAP = {"float": "Float64", "int": "Int64", "str": "Utf8"} -_PYTHON_METHOD_MAP = { +_PYTHON_METHODS_MAP = { "lower": "str.to_lowercase", "title": "str.to_titlecase", "upper": "str.to_uppercase", @@ -155,15 +149,11 @@ def can_rewrite(self) -> bool: else: self._can_rewrite[self._apply_target] = False if self._rewritten_instructions and self._param_name is not None: - simple_ops = ( - _SIMPLE_FRAME_OPS - if self._apply_target == "frame" - else _SIMPLE_EXPR_OPS - ) self._can_rewrite[self._apply_target] = len( self._rewritten_instructions ) >= 2 and all( - inst.opname in simple_ops for inst in self._rewritten_instructions + inst.opname in OpCodes.PARSEABLE_OPS + for inst in self._rewritten_instructions ) return self._can_rewrite[self._apply_target] @@ -202,7 +192,7 @@ def to_expression(self, col: str, as_repr: bool = True) -> str | None: logical_instructions = [] jump_offset = 0 for idx, inst in enumerate(self._rewritten_instructions): - if inst.opname in _CONTROL_FLOW_OPCODES: + if inst.opname in OpCodes.CONTROL_FLOW: jump_offset = self._rewritten_instructions[idx + 1].offset logical_instructions.append(inst) else: @@ -234,11 +224,15 @@ def to_expression(self, col: str, as_repr: bool = True) -> str | None: # operator precedence means that we can combine logically connected # 'and' blocks into one (depending on follow-on logic) and should # parenthesise nested 'or' blocks - logical_op = _CONTROL_FLOW_OPCODES[inst.opname] + logical_op = OpCodes.CONTROL_FLOW[inst.opname] start = block_offsets[ max(idx_min, bisect_left(block_offsets, inst.offset) - 1) ] - if previous_logical_opname == "POP_JUMP_FORWARD_IF_FALSE": + if previous_logical_opname == ( + "POP_JUMP_FORWARD_IF_FALSE" + if _MIN_PY311 + else "POP_JUMP_IF_FALSE" + ): prev = block_offsets[ max(idx_min, bisect_left(block_offsets, start) - 1) ] @@ -261,7 +255,7 @@ def to_expression(self, col: str, as_repr: bool = True) -> str | None: for i, inst in enumerate(logical_instructions): if i not in combined_offset_idxs: - expression_strings[inst.offset] = _CONTROL_FLOW_OPCODES[inst.opname] + expression_strings[inst.offset] = OpCodes.CONTROL_FLOW[inst.opname] exprs = sorted(expression_strings.items()) polars_expr = " ".join(expr for _, expr in exprs) @@ -332,16 +326,16 @@ def to_expression(self, col: str, param_name: str, depth: int) -> str: @classmethod def op(cls, inst: Instruction) -> str: """Convert bytecode instruction to suitable intermediate op string.""" - if inst.opname in _CONTROL_FLOW_OPCODES: - return _CONTROL_FLOW_OPCODES[inst.opname] + if inst.opname in OpCodes.CONTROL_FLOW: + return OpCodes.CONTROL_FLOW[inst.opname] elif inst.argrepr: return inst.argrepr elif inst.opname == "IS_OP": return "is not" if inst.argval else "is" elif inst.opname == "CONTAINS_OP": return "not in" if inst.argval else "in" - elif inst.opname in _UNARY_OPCODES: - return _UNARY_OPCODES[inst.opname] + elif inst.opname in OpCodes.UNARY: + return OpCodes.UNARY[inst.opname] else: raise AssertionError( "Unrecognised opname; please report a bug to https://github.com/pola-rs/polars/issues " @@ -356,7 +350,7 @@ def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str: op = value.operator e1 = cls._expr(value.left_operand, col, param_name, depth + 1) if value.operator_arity == 1: - if op not in _UNARY_OPCODE_VALUES: + if op not in OpCodes.UNARY_VALUES: call = "" if op.endswith(")") else "()" return f"{e1}.{op}{call}" return f"{op}{e1}" @@ -390,7 +384,7 @@ def _to_intermediate_stack( for inst in instructions: stack.append( inst.argrepr - if inst.opname in _LOAD_OPS + if inst.opname in OpCodes.LOAD else ( StackValue( operator=self.op(inst), @@ -399,8 +393,8 @@ def _to_intermediate_stack( right_operand=None, # type: ignore[arg-type] ) if ( - inst.opname in _UNARY_OPCODES - or _SYNTHETIC_OPS.get(inst.opname) == 1 + inst.opname in OpCodes.UNARY + or OpCodes.SYNTHETIC.get(inst.opname) == 1 ) else StackValue( operator=self.op(inst), @@ -450,7 +444,7 @@ def _matches( idx: int, *, opnames: list[str], - argvals: list[set[Any] | dict[Any, Any]] | None, + argvals: list[set[Any] | frozenset[Any] | dict[Any, Any]] | None, ) -> list[Instruction]: """ Check if a sequence of Instructions matches the specified ops/argvals. @@ -493,7 +487,7 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]: idx = 0 while idx < len(self._instructions): inst, increment = self._instructions[idx], 1 - if inst.opname not in _LOAD_OPS or not any( + if inst.opname not in OpCodes.LOAD or not any( (increment := apply_rewrite(idx, updated_instructions)) for apply_rewrite in ( # add any other rewrite methods here @@ -512,7 +506,7 @@ def _rewrite_builtins( """Replace builtin function calls with a synthetic POLARS_EXPRESSION op.""" if matching_instructions := self._matches( idx, - opnames=["LOAD_GLOBAL", "LOAD_FAST", _CALL_OP], + opnames=["LOAD_GLOBAL", "LOAD_FAST", OpCodes.CALL], argvals=[_PYTHON_CASTS_MAP], ): inst1, inst2 = matching_instructions[:2] @@ -535,8 +529,11 @@ def _rewrite_functions( """Replace numpy/json function calls with a synthetic POLARS_EXPRESSION op.""" if matching_instructions := self._matches( idx, - opnames=["LOAD_GLOBAL", "LOAD_*", "LOAD_*", _CALL_OP], - argvals=[_NUMPY_MODULE_ALIASES | {"json"}, _NUMPY_FUNCTIONS | {"loads"}], + opnames=["LOAD_GLOBAL", "LOAD_*", "LOAD_*", OpCodes.CALL], + argvals=[ + _NUMPY_MODULE_ALIASES | {"json"}, + _NUMPY_FUNCTIONS | {"loads"}, + ], ): inst1, inst2, inst3 = matching_instructions[:3] expr_name = "str.json_extract" if inst1.argval == "json" else inst2.argval @@ -558,11 +555,11 @@ def _rewrite_methods( """Replace python method calls with synthetic POLARS_EXPRESSION op.""" if matching_instructions := self._matches( idx, - opnames=["LOAD_METHOD", _CALL_OP], - argvals=[_PYTHON_METHOD_MAP], + opnames=["LOAD_METHOD", OpCodes.CALL], + argvals=[_PYTHON_METHODS_MAP], ): inst = matching_instructions[0] - expr_name = _PYTHON_METHOD_MAP[inst.argval] + expr_name = _PYTHON_METHODS_MAP[inst.argval] synthetic_call = inst._replace( opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name ) @@ -573,9 +570,9 @@ def _rewrite_methods( @staticmethod def _upgrade_instruction(inst: Instruction) -> Instruction: """Rewrite any older binary opcodes using py 3.11 'BINARY_OP' instead.""" - if _UPGRADE_BINARY_OPS and inst.opname in _BINARY_OPCODES: + if not _MIN_PY311 and inst.opname in OpCodes.BINARY: inst = inst._replace( - argrepr=_BINARY_OPCODES[inst.opname], + argrepr=OpCodes.BINARY[inst.opname], opname="BINARY_OP", ) return inst diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 30624fca245f..11bb61416009 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -162,6 +162,7 @@ strict = true "polars/datatypes.py" = ["B019"] "tests/**/*.py" = ["D100", "D103", "B018"] "polars/utils/show_versions.py" = ["D301"] +"polars/utils/udfs.py" = ["RUF012"] [tool.pytest.ini_options] addopts = [ From d108f3e16b72076f6b6cfa102f697618cff093dc Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Wed, 26 Jul 2023 19:11:27 +0400 Subject: [PATCH 3/4] rename OpCodes -> OpNames --- py-polars/polars/utils/udfs.py | 38 +++++++++++++++++----------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 79d2042e3ba2..86cddbe4d35b 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -33,7 +33,7 @@ class StackValue(NamedTuple): _MIN_PY311 = sys.version_info >= (3, 11) -class OpCodes: +class OpNames: BINARY = { "BINARY_ADD": "+", "BINARY_AND": "&", @@ -152,7 +152,7 @@ def can_rewrite(self) -> bool: self._can_rewrite[self._apply_target] = len( self._rewritten_instructions ) >= 2 and all( - inst.opname in OpCodes.PARSEABLE_OPS + inst.opname in OpNames.PARSEABLE_OPS for inst in self._rewritten_instructions ) @@ -192,7 +192,7 @@ def to_expression(self, col: str, as_repr: bool = True) -> str | None: logical_instructions = [] jump_offset = 0 for idx, inst in enumerate(self._rewritten_instructions): - if inst.opname in OpCodes.CONTROL_FLOW: + if inst.opname in OpNames.CONTROL_FLOW: jump_offset = self._rewritten_instructions[idx + 1].offset logical_instructions.append(inst) else: @@ -224,7 +224,7 @@ def to_expression(self, col: str, as_repr: bool = True) -> str | None: # operator precedence means that we can combine logically connected # 'and' blocks into one (depending on follow-on logic) and should # parenthesise nested 'or' blocks - logical_op = OpCodes.CONTROL_FLOW[inst.opname] + logical_op = OpNames.CONTROL_FLOW[inst.opname] start = block_offsets[ max(idx_min, bisect_left(block_offsets, inst.offset) - 1) ] @@ -255,7 +255,7 @@ def to_expression(self, col: str, as_repr: bool = True) -> str | None: for i, inst in enumerate(logical_instructions): if i not in combined_offset_idxs: - expression_strings[inst.offset] = OpCodes.CONTROL_FLOW[inst.opname] + expression_strings[inst.offset] = OpNames.CONTROL_FLOW[inst.opname] exprs = sorted(expression_strings.items()) polars_expr = " ".join(expr for _, expr in exprs) @@ -326,16 +326,16 @@ def to_expression(self, col: str, param_name: str, depth: int) -> str: @classmethod def op(cls, inst: Instruction) -> str: """Convert bytecode instruction to suitable intermediate op string.""" - if inst.opname in OpCodes.CONTROL_FLOW: - return OpCodes.CONTROL_FLOW[inst.opname] + if inst.opname in OpNames.CONTROL_FLOW: + return OpNames.CONTROL_FLOW[inst.opname] elif inst.argrepr: return inst.argrepr elif inst.opname == "IS_OP": return "is not" if inst.argval else "is" elif inst.opname == "CONTAINS_OP": return "not in" if inst.argval else "in" - elif inst.opname in OpCodes.UNARY: - return OpCodes.UNARY[inst.opname] + elif inst.opname in OpNames.UNARY: + return OpNames.UNARY[inst.opname] else: raise AssertionError( "Unrecognised opname; please report a bug to https://github.com/pola-rs/polars/issues " @@ -350,7 +350,7 @@ def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str: op = value.operator e1 = cls._expr(value.left_operand, col, param_name, depth + 1) if value.operator_arity == 1: - if op not in OpCodes.UNARY_VALUES: + if op not in OpNames.UNARY_VALUES: call = "" if op.endswith(")") else "()" return f"{e1}.{op}{call}" return f"{op}{e1}" @@ -384,7 +384,7 @@ def _to_intermediate_stack( for inst in instructions: stack.append( inst.argrepr - if inst.opname in OpCodes.LOAD + if inst.opname in OpNames.LOAD else ( StackValue( operator=self.op(inst), @@ -393,8 +393,8 @@ def _to_intermediate_stack( right_operand=None, # type: ignore[arg-type] ) if ( - inst.opname in OpCodes.UNARY - or OpCodes.SYNTHETIC.get(inst.opname) == 1 + inst.opname in OpNames.UNARY + or OpNames.SYNTHETIC.get(inst.opname) == 1 ) else StackValue( operator=self.op(inst), @@ -487,7 +487,7 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]: idx = 0 while idx < len(self._instructions): inst, increment = self._instructions[idx], 1 - if inst.opname not in OpCodes.LOAD or not any( + if inst.opname not in OpNames.LOAD or not any( (increment := apply_rewrite(idx, updated_instructions)) for apply_rewrite in ( # add any other rewrite methods here @@ -506,7 +506,7 @@ def _rewrite_builtins( """Replace builtin function calls with a synthetic POLARS_EXPRESSION op.""" if matching_instructions := self._matches( idx, - opnames=["LOAD_GLOBAL", "LOAD_FAST", OpCodes.CALL], + opnames=["LOAD_GLOBAL", "LOAD_FAST", OpNames.CALL], argvals=[_PYTHON_CASTS_MAP], ): inst1, inst2 = matching_instructions[:2] @@ -529,7 +529,7 @@ def _rewrite_functions( """Replace numpy/json function calls with a synthetic POLARS_EXPRESSION op.""" if matching_instructions := self._matches( idx, - opnames=["LOAD_GLOBAL", "LOAD_*", "LOAD_*", OpCodes.CALL], + opnames=["LOAD_GLOBAL", "LOAD_*", "LOAD_*", OpNames.CALL], argvals=[ _NUMPY_MODULE_ALIASES | {"json"}, _NUMPY_FUNCTIONS | {"loads"}, @@ -555,7 +555,7 @@ def _rewrite_methods( """Replace python method calls with synthetic POLARS_EXPRESSION op.""" if matching_instructions := self._matches( idx, - opnames=["LOAD_METHOD", OpCodes.CALL], + opnames=["LOAD_METHOD", OpNames.CALL], argvals=[_PYTHON_METHODS_MAP], ): inst = matching_instructions[0] @@ -570,9 +570,9 @@ def _rewrite_methods( @staticmethod def _upgrade_instruction(inst: Instruction) -> Instruction: """Rewrite any older binary opcodes using py 3.11 'BINARY_OP' instead.""" - if not _MIN_PY311 and inst.opname in OpCodes.BINARY: + if not _MIN_PY311 and inst.opname in OpNames.BINARY: inst = inst._replace( - argrepr=OpCodes.BINARY[inst.opname], + argrepr=OpNames.BINARY[inst.opname], opname="BINARY_OP", ) return inst From 3c0d7380dbd1b4512aaa75f0649ef36a4c0b3a22 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Wed, 26 Jul 2023 23:40:17 +0400 Subject: [PATCH 4/4] slightly simplify nesting code and move into its own private function --- py-polars/polars/utils/udfs.py | 122 +++++++++++++++++---------------- py-polars/tests/test_udfs.py | 8 +++ 2 files changed, 70 insertions(+), 60 deletions(-) diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 86cddbe4d35b..57d6ea4109e9 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -131,6 +131,53 @@ def _get_param_name(function: Callable[[Any], Any]) -> str | None: except ValueError: return None + def _inject_nesting( + self, + expression_blocks: dict[int, str], + logical_instructions: list[Instruction], + ) -> list[tuple[int, str]]: + """Inject nesting boundaries into expression blocks (as parentheses).""" + if logical_instructions: + # reconstruct nesting boundaries for mixed and/or ops by associating + # control flow jump offsets with their target expression blocks and + # injecting appropriate parentheses + combined_offset_idxs = set() + if len({inst.opname for inst in logical_instructions}) > 1: + block_offsets: list[int] = list(expression_blocks.keys()) + previous_logical_opname = "" + for i, inst in enumerate(logical_instructions): + # operator precedence means that we can combine logically connected + # 'and' blocks into one (depending on follow-on logic) and should + # parenthesise nested 'or' blocks + logical_op = OpNames.CONTROL_FLOW[inst.opname] + start = block_offsets[bisect_left(block_offsets, inst.offset) - 1] + if previous_logical_opname == ( + "POP_JUMP_FORWARD_IF_FALSE" + if _MIN_PY311 + else "POP_JUMP_IF_FALSE" + ): + # combine logical '&' blocks (and update start/block_offsets) + prev = block_offsets[bisect_left(block_offsets, start) - 1] + expression_blocks[prev] += f" & {expression_blocks.pop(start)}" + block_offsets = list(expression_blocks.keys()) + combined_offset_idxs.add(i - 1) + start = prev + + if logical_op == "|": + # parenthesise connected 'or' blocks + end = block_offsets[bisect_left(block_offsets, inst.argval) - 1] + if not (start == 0 and end == block_offsets[-1]): + expression_blocks[start] = "(" + expression_blocks[start] + expression_blocks[end] += ")" + + previous_logical_opname = inst.opname + + for i, inst in enumerate(logical_instructions): + if i not in combined_offset_idxs: + expression_blocks[inst.offset] = OpNames.CONTROL_FLOW[inst.opname] + + return sorted(expression_blocks.items()) + @property def apply_target(self) -> ApplyTarget: """The apply target, eg: one of 'expr', 'frame', or 'series'.""" @@ -199,66 +246,21 @@ def to_expression(self, col: str, as_repr: bool = True) -> str | None: control_flow_blocks[jump_offset].append(inst) # convert each block to a polars expression string - expression_strings = { - offset: InstructionTranslator( - instructions=ops, - apply_target=self._apply_target, - ).to_expression( - col=col, - param_name=self._param_name, - depth=int(bool(logical_instructions)), - ) - for offset, ops in control_flow_blocks.items() - } - - if logical_instructions: - # reconstruct nesting boundaries for mixed and/or ops by associating - # control flow jump offsets with their target expression blocks and - # injecting appropriate parentheses - combined_offset_idxs = set() - if len({inst.opname for inst in logical_instructions}) > 1: - block_offsets: list[int] = list(expression_strings.keys()) - idx_min, idx_max = 0, len(block_offsets) - 1 - previous_logical_opname = "" - for i, inst in enumerate(logical_instructions): - # operator precedence means that we can combine logically connected - # 'and' blocks into one (depending on follow-on logic) and should - # parenthesise nested 'or' blocks - logical_op = OpNames.CONTROL_FLOW[inst.opname] - start = block_offsets[ - max(idx_min, bisect_left(block_offsets, inst.offset) - 1) - ] - if previous_logical_opname == ( - "POP_JUMP_FORWARD_IF_FALSE" - if _MIN_PY311 - else "POP_JUMP_IF_FALSE" - ): - prev = block_offsets[ - max(idx_min, bisect_left(block_offsets, start) - 1) - ] - expression_strings[ - prev - ] += f" & {expression_strings.pop(start)}" - block_offsets = list(expression_strings.keys()) - combined_offset_idxs.add(i - 1) - start = prev - - if logical_op == "|": - end = block_offsets[ - min(idx_max, bisect_left(block_offsets, inst.argval) - 1) - ] - if not (start == 0 and end == block_offsets[-1]): - expression_strings[start] = "(" + expression_strings[start] - expression_strings[end] += ")" - - previous_logical_opname = inst.opname - - for i, inst in enumerate(logical_instructions): - if i not in combined_offset_idxs: - expression_strings[inst.offset] = OpNames.CONTROL_FLOW[inst.opname] - - exprs = sorted(expression_strings.items()) - polars_expr = " ".join(expr for _, expr in exprs) + expression_blocks = self._inject_nesting( + { + offset: InstructionTranslator( + instructions=ops, + apply_target=self._apply_target, + ).to_expression( + col=col, + param_name=self._param_name, + depth=int(bool(logical_instructions)), + ) + for offset, ops in control_flow_blocks.items() + }, + logical_instructions, + ) + polars_expr = " ".join(expr for _offset, expr in expression_blocks) # note: if no 'pl.col' in the expression, it likely represents a compound # constant value (e.g. `lambda x: CONST + 123`), so we don't want to warn diff --git a/py-polars/tests/test_udfs.py b/py-polars/tests/test_udfs.py index 1cd9e60d8b4a..b223c66286a4 100644 --- a/py-polars/tests/test_udfs.py +++ b/py-polars/tests/test_udfs.py @@ -23,7 +23,9 @@ # column_name, function, expected_suggestion TEST_CASES = [ + # --------------------------------------------- # numeric expr: math, comparison, logic ops + # --------------------------------------------- ("a", lambda x: x + 1 - (2 / 3), '(pl.col("a") + 1) - 0.6666666666666666'), ("a", lambda x: x // 1 % 2, '(pl.col("a") // 1) % 2'), ("a", lambda x: x & True, 'pl.col("a") & True'), @@ -68,7 +70,9 @@ lambda x: (float(x) * int(x)) // 2, '(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2', ), + # --------------------------------------------- # logical 'and/or' (validate nesting levels) + # --------------------------------------------- ( "a", lambda x: x > 1 or (x == 1 and x == 2), @@ -94,14 +98,18 @@ lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3, '(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)', ), + # --------------------------------------------- # string expr: case/cast ops + # --------------------------------------------- ("b", lambda x: str(x).title(), 'pl.col("b").cast(pl.Utf8).str.to_titlecase()'), ( "b", lambda x: x.lower() + ":" + x.upper() + ":" + x.title(), '(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()', ), + # --------------------------------------------- # json expr: load/extract + # --------------------------------------------- ("c", lambda x: json.loads(x), 'pl.col("c").str.json_extract()'), ]