Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): BytecodeParser can now handle mixed/nested and/or control flow #10085

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 135 additions & 106 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dis
import sys
import warnings
from bisect import bisect_left
from collections import defaultdict
from dis import get_instructions
from inspect import signature
Expand All @@ -29,72 +30,66 @@ class StackValue(NamedTuple):
ApplyTarget: TypeAlias = Literal["expr", "frame", "series"]
StackEntry: TypeAlias = Union[str, StackValue]

_MIN_PY311 = sys.version_info >= (3, 11)


class OpNames:
BINARY = {
"BINARY_ADD": "+",
"BINARY_AND": "&",
"BINARY_FLOOR_DIVIDE": "//",
"BINARY_MODULO": "%",
"BINARY_MULTIPLY": "*",
"BINARY_OR": "|",
"BINARY_POWER": "**",
"BINARY_SUBTRACT": "-",
"BINARY_TRUE_DIVIDE": "/",
"BINARY_XOR": "^",
}
CALL = "CALL" if _MIN_PY311 else "CALL_*"
CONTROL_FLOW = (
{
"POP_JUMP_FORWARD_IF_FALSE": "&",
"POP_JUMP_FORWARD_IF_TRUE": "|",
"JUMP_IF_FALSE_OR_POP": "&",
"JUMP_IF_TRUE_OR_POP": "|",
}
if _MIN_PY311
else {
"POP_JUMP_IF_FALSE": "&",
"POP_JUMP_IF_TRUE": "|",
"JUMP_IF_FALSE_OR_POP": "&",
"JUMP_IF_TRUE_OR_POP": "|",
}
)
LOAD_VALUES = frozenset(("LOAD_CONST", "LOAD_DEREF", "LOAD_FAST", "LOAD_GLOBAL"))
LOAD = LOAD_VALUES | {"LOAD_METHOD", "LOAD_ATTR"}
SYNTHETIC = {
"POLARS_EXPRESSION": 1,
}
UNARY = {
"UNARY_NEGATIVE": "-",
"UNARY_POSITIVE": "+",
"UNARY_NOT": "~",
}
PARSEABLE_OPS = (
{"BINARY_OP", "COMPARE_OP", "CONTAINS_OP", "IS_OP"}
| set(UNARY)
| set(CONTROL_FLOW)
| set(SYNTHETIC)
| LOAD_VALUES
)
UNARY_VALUES = frozenset(UNARY.values())

# Note: in 3.11 individual binary opcodes were folded into a new BINARY_OP
_BINARY_OPCODES = {
"BINARY_ADD": "+",
"BINARY_AND": "&",
"BINARY_FLOOR_DIVIDE": "//",
"BINARY_MODULO": "%",
"BINARY_MULTIPLY": "*",
"BINARY_OR": "|",
"BINARY_POWER": "**",
"BINARY_SUBTRACT": "-",
"BINARY_TRUE_DIVIDE": "/",
"BINARY_XOR": "^",
}
_CONTROL_FLOW_OPCODES = {
# note: once we add additional JUMP op support, we'll need to disambiguate
# between and/or (what we currently support) and if/else (which we don't)
# "POP_JUMP_FORWARD_IF_FALSE": "?",
# "POP_JUMP_FORWARD_IF_TRUE": "?",
# "POP_JUMP_IF_FALSE": "?",
# "POP_JUMP_IF_TRUE": "?",
# "JUMP_FORWARD": "?",
"JUMP_IF_FALSE_OR_POP": "&",
"JUMP_IF_TRUE_OR_POP": "|",
}
_UNARY_OPCODES = {
"UNARY_NEGATIVE": "-",
"UNARY_POSITIVE": "+",
"UNARY_NOT": "~",
}
_SYNTHETIC_OPS = {
"POLARS_EXPRESSION": 1,
}
_LOAD_OPS = {
"LOAD_CONST",
"LOAD_DEREF",
"LOAD_FAST",
"LOAD_GLOBAL",
}
_SIMPLE_EXPR_OPS = {
"BINARY_OP",
"COMPARE_OP",
"CONTAINS_OP",
"IS_OP",
}
_SIMPLE_EXPR_OPS |= (
set(_UNARY_OPCODES) | set(_CONTROL_FLOW_OPCODES) | set(_SYNTHETIC_OPS) | _LOAD_OPS
)
_SIMPLE_FRAME_OPS = _SIMPLE_EXPR_OPS | {"BINARY_SUBSCR"}
MarcoGorelli marked this conversation as resolved.
Show resolved Hide resolved
_UNARY_OPCODE_VALUES = set(_UNARY_OPCODES.values())
_LOAD_OPS |= {"LOAD_METHOD", "LOAD_ATTR"}

if sys.version_info < (3, 11):
_UPGRADE_BINARY_OPS = True
_CALL_OP = "CALL_*"
else:
_UPGRADE_BINARY_OPS = False
_CALL_OP = "CALL"

# numpy functions that we can map to a native expression
_NUMPY_MODULE_ALIASES = {"np", "numpy"}
_NUMPY_FUNCTIONS = {"cbrt", "cos", "cosh", "sin", "sinh", "sqrt", "tan", "tanh"}

_NUMPY_MODULE_ALIASES = frozenset(("np", "numpy"))
_NUMPY_FUNCTIONS = frozenset(
("cbrt", "cos", "cosh", "sin", "sinh", "sqrt", "tan", "tanh")
)
# python function that we can map to a native expression
_PYTHON_CASTS_MAP = {"float": "Float64", "int": "Int64", "str": "Utf8"}
_PYTHON_METHOD_MAP = {
_PYTHON_METHODS_MAP = {
"lower": "str.to_lowercase",
"title": "str.to_titlecase",
"upper": "str.to_uppercase",
Expand Down Expand Up @@ -154,21 +149,12 @@ def can_rewrite(self) -> bool:
else:
self._can_rewrite[self._apply_target] = False
if self._rewritten_instructions and self._param_name is not None:
simple_ops = (
_SIMPLE_FRAME_OPS
if self._apply_target == "frame"
else _SIMPLE_EXPR_OPS
self._can_rewrite[self._apply_target] = len(
self._rewritten_instructions
) >= 2 and all(
inst.opname in OpNames.PARSEABLE_OPS
for inst in self._rewritten_instructions
)
if len(self._rewritten_instructions) >= 2 and all(
inst.opname in simple_ops for inst in self._rewritten_instructions
):
# can (currently) only handle logical 'and'/'or' if they not mixed
logical_ops = {
_CONTROL_FLOW_OPCODES[inst.opname]
for inst in self._rewritten_instructions
if inst.opname in _CONTROL_FLOW_OPCODES
}
self._can_rewrite[self._apply_target] = len(logical_ops) <= 1

return self._can_rewrite[self._apply_target]

Expand Down Expand Up @@ -196,23 +182,23 @@ def rewritten_instructions(self) -> list[Instruction]:
"""The rewritten bytecode instructions from the function we are parsing."""
return list(self._rewritten_instructions)

def to_expression(self, col: str) -> str | None:
"""Translate postfix bytecode instructions to polars expression string."""
def to_expression(self, col: str, as_repr: bool = True) -> str | None:
MarcoGorelli marked this conversation as resolved.
Show resolved Hide resolved
"""Translate postfix bytecode instructions to polars expression/string."""
if not self.can_rewrite() or self._param_name is None:
return None

# decompose bytecode into logical 'and'/'or' expression blocks (if present)
logical_instruction_blocks = defaultdict(list)
control_flow_blocks = defaultdict(list)
logical_instructions = []
jump_offset = 0
for idx, inst in enumerate(self._rewritten_instructions):
if inst.opname in _CONTROL_FLOW_OPCODES:
if inst.opname in OpNames.CONTROL_FLOW:
jump_offset = self._rewritten_instructions[idx + 1].offset
logical_instructions.append(inst)
else:
logical_instruction_blocks[jump_offset].append(inst)
control_flow_blocks[jump_offset].append(inst)

# convert each logical block to a polars expression string
# convert each block to a polars expression string
expression_strings = {
offset: InstructionTranslator(
instructions=ops,
Expand All @@ -222,14 +208,54 @@ def to_expression(self, col: str) -> str | None:
param_name=self._param_name,
depth=int(bool(logical_instructions)),
)
for offset, ops in logical_instruction_blocks.items()
for offset, ops in control_flow_blocks.items()
}

for inst in logical_instructions:
expression_strings[inst.offset] = InstructionTranslator.op(inst)

# TODO: handle mixed 'and'/'or' blocks (e.g. `x > 0 AND (x > 0 OR (x == -1)`).
# (need to reconstruct the correct nesting boundaries to do this properly)
if logical_instructions:
# reconstruct nesting boundaries for mixed and/or ops by associating
# control flow jump offsets with their target expression blocks and
# injecting appropriate parentheses
combined_offset_idxs = set()
if len({inst.opname for inst in logical_instructions}) > 1:
block_offsets: list[int] = list(expression_strings.keys())
idx_min, idx_max = 0, len(block_offsets) - 1
previous_logical_opname = ""
for i, inst in enumerate(logical_instructions):
# operator precedence means that we can combine logically connected
# 'and' blocks into one (depending on follow-on logic) and should
# parenthesise nested 'or' blocks
logical_op = OpNames.CONTROL_FLOW[inst.opname]
start = block_offsets[
max(idx_min, bisect_left(block_offsets, inst.offset) - 1)
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
]
if previous_logical_opname == (
"POP_JUMP_FORWARD_IF_FALSE"
if _MIN_PY311
else "POP_JUMP_IF_FALSE"
):
prev = block_offsets[
max(idx_min, bisect_left(block_offsets, start) - 1)
]
expression_strings[
prev
] += f" & {expression_strings.pop(start)}"
block_offsets = list(expression_strings.keys())
combined_offset_idxs.add(i - 1)
start = prev

if logical_op == "|":
end = block_offsets[
min(idx_max, bisect_left(block_offsets, inst.argval) - 1)
]
if not (start == 0 and end == block_offsets[-1]):
expression_strings[start] = "(" + expression_strings[start]
expression_strings[end] += ")"

previous_logical_opname = inst.opname

for i, inst in enumerate(logical_instructions):
if i not in combined_offset_idxs:
expression_strings[inst.offset] = OpNames.CONTROL_FLOW[inst.opname]

exprs = sorted(expression_strings.items())
polars_expr = " ".join(expr for _, expr in exprs)
Expand All @@ -239,7 +265,7 @@ def to_expression(self, col: str) -> str | None:
if "pl.col(" not in polars_expr:
return None

return polars_expr
return polars_expr if as_repr else eval(polars_expr, globals())

def warn(
self,
Expand Down Expand Up @@ -300,16 +326,16 @@ def to_expression(self, col: str, param_name: str, depth: int) -> str:
@classmethod
def op(cls, inst: Instruction) -> str:
"""Convert bytecode instruction to suitable intermediate op string."""
if inst.opname in _CONTROL_FLOW_OPCODES:
return _CONTROL_FLOW_OPCODES[inst.opname]
if inst.opname in OpNames.CONTROL_FLOW:
return OpNames.CONTROL_FLOW[inst.opname]
elif inst.argrepr:
return inst.argrepr
elif inst.opname == "IS_OP":
return "is not" if inst.argval else "is"
elif inst.opname == "CONTAINS_OP":
return "not in" if inst.argval else "in"
elif inst.opname in _UNARY_OPCODES:
return _UNARY_OPCODES[inst.opname]
elif inst.opname in OpNames.UNARY:
return OpNames.UNARY[inst.opname]
else:
raise AssertionError(
"Unrecognised opname; please report a bug to https://github.com/pola-rs/polars/issues "
Expand All @@ -324,7 +350,7 @@ def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str:
op = value.operator
e1 = cls._expr(value.left_operand, col, param_name, depth + 1)
if value.operator_arity == 1:
if op not in _UNARY_OPCODE_VALUES:
if op not in OpNames.UNARY_VALUES:
call = "" if op.endswith(")") else "()"
return f"{e1}.{op}{call}"
return f"{op}{e1}"
Expand Down Expand Up @@ -358,7 +384,7 @@ def _to_intermediate_stack(
for inst in instructions:
stack.append(
inst.argrepr
if inst.opname in _LOAD_OPS
if inst.opname in OpNames.LOAD
else (
StackValue(
operator=self.op(inst),
Expand All @@ -367,8 +393,8 @@ def _to_intermediate_stack(
right_operand=None, # type: ignore[arg-type]
)
if (
inst.opname in _UNARY_OPCODES
or _SYNTHETIC_OPS.get(inst.opname) == 1
inst.opname in OpNames.UNARY
or OpNames.SYNTHETIC.get(inst.opname) == 1
)
else StackValue(
operator=self.op(inst),
Expand Down Expand Up @@ -418,7 +444,7 @@ def _matches(
idx: int,
*,
opnames: list[str],
argvals: list[set[Any] | dict[Any, Any]] | None,
argvals: list[set[Any] | frozenset[Any] | dict[Any, Any]] | None,
) -> list[Instruction]:
"""
Check if a sequence of Instructions matches the specified ops/argvals.
Expand Down Expand Up @@ -461,7 +487,7 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]:
idx = 0
while idx < len(self._instructions):
inst, increment = self._instructions[idx], 1
if inst.opname not in _LOAD_OPS or not any(
if inst.opname not in OpNames.LOAD or not any(
(increment := apply_rewrite(idx, updated_instructions))
for apply_rewrite in (
# add any other rewrite methods here
Expand All @@ -480,7 +506,7 @@ def _rewrite_builtins(
"""Replace builtin function calls with a synthetic POLARS_EXPRESSION op."""
if matching_instructions := self._matches(
idx,
opnames=["LOAD_GLOBAL", "LOAD_FAST", _CALL_OP],
opnames=["LOAD_GLOBAL", "LOAD_FAST", OpNames.CALL],
argvals=[_PYTHON_CASTS_MAP],
):
inst1, inst2 = matching_instructions[:2]
Expand All @@ -503,8 +529,11 @@ def _rewrite_functions(
"""Replace numpy/json function calls with a synthetic POLARS_EXPRESSION op."""
if matching_instructions := self._matches(
idx,
opnames=["LOAD_GLOBAL", "LOAD_*", "LOAD_*", _CALL_OP],
argvals=[_NUMPY_MODULE_ALIASES | {"json"}, _NUMPY_FUNCTIONS | {"loads"}],
opnames=["LOAD_GLOBAL", "LOAD_*", "LOAD_*", OpNames.CALL],
argvals=[
_NUMPY_MODULE_ALIASES | {"json"},
_NUMPY_FUNCTIONS | {"loads"},
],
):
inst1, inst2, inst3 = matching_instructions[:3]
expr_name = "str.json_extract" if inst1.argval == "json" else inst2.argval
Expand All @@ -526,11 +555,11 @@ def _rewrite_methods(
"""Replace python method calls with synthetic POLARS_EXPRESSION op."""
if matching_instructions := self._matches(
idx,
opnames=["LOAD_METHOD", _CALL_OP],
argvals=[_PYTHON_METHOD_MAP],
opnames=["LOAD_METHOD", OpNames.CALL],
argvals=[_PYTHON_METHODS_MAP],
):
inst = matching_instructions[0]
expr_name = _PYTHON_METHOD_MAP[inst.argval]
expr_name = _PYTHON_METHODS_MAP[inst.argval]
synthetic_call = inst._replace(
opname="POLARS_EXPRESSION", argval=expr_name, argrepr=expr_name
)
Expand All @@ -541,9 +570,9 @@ def _rewrite_methods(
@staticmethod
def _upgrade_instruction(inst: Instruction) -> Instruction:
"""Rewrite any older binary opcodes using py 3.11 'BINARY_OP' instead."""
if _UPGRADE_BINARY_OPS and inst.opname in _BINARY_OPCODES:
if not _MIN_PY311 and inst.opname in OpNames.BINARY:
inst = inst._replace(
argrepr=_BINARY_OPCODES[inst.opname],
argrepr=OpNames.BINARY[inst.opname],
opname="BINARY_OP",
)
return inst
Expand Down
1 change: 1 addition & 0 deletions py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ strict = true
"polars/datatypes.py" = ["B019"]
"tests/**/*.py" = ["D100", "D103", "B018"]
"polars/utils/show_versions.py" = ["D301"]
"polars/utils/udfs.py" = ["RUF012"]

[tool.pytest.ini_options]
addopts = [
Expand Down
Loading