Skip to content

Commit

Permalink
chore(python): clean up bytecode parsing a bit (#13221)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Dec 27, 2023
1 parent 89c549f commit 4304ffe
Showing 1 changed file with 40 additions and 55 deletions.
95 changes: 40 additions & 55 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ def _get_all_caller_variables() -> dict[str, Any]:
class BytecodeParser:
"""Introspect UDF bytecode and determine if we can rewrite as native expression."""

_can_attempt_rewrite: dict[str, bool]
_map_target_name: str | None = None

def __init__(self, function: Callable[[Any], Any], map_target: MapTarget):
Expand All @@ -237,7 +236,6 @@ def __init__(self, function: Callable[[Any], Any], map_target: MapTarget):
# unavailable, like a bare numpy ufunc that isn't in a lambda/function)
original_instructions = iter([])

self._can_attempt_rewrite = {}
self._function = function
self._map_target = map_target
self._param_name = self._get_param_name(function)
Expand All @@ -251,13 +249,13 @@ def _get_param_name(function: Callable[[Any], Any]) -> str | None:
try:
# note: we do not parse/handle functions with > 1 params
sig = signature(function)
return (
next(iter(parameters.keys()))
if len(parameters := sig.parameters) == 1
else None
)
except ValueError:
return None
return (
next(iter(parameters.keys()))
if len(parameters := sig.parameters) == 1
else None
)

def _inject_nesting(
self,
Expand Down Expand Up @@ -323,30 +321,22 @@ def can_attempt_rewrite(self) -> bool:
guaranteed that using the equivalent bare constant value will return the
same output. (Hopefully nobody is writing lambdas like that anyway...)
"""
if (
can_attempt_rewrite := self._can_attempt_rewrite.get(self._map_target, None)
) is not None:
return can_attempt_rewrite
else:
self._can_attempt_rewrite[self._map_target] = False
if self._rewritten_instructions and self._param_name is not None:
self._can_attempt_rewrite[self._map_target] = (
# check minimum number of ops, ensuring all are parseable
len(self._rewritten_instructions) >= 2
and all(
inst.opname in OpNames.PARSEABLE_OPS
for inst in self._rewritten_instructions
)
# exclude constructs/functions with multiple RETURN_VALUE ops
and sum(
1
for inst in self.original_instructions
if inst.opname == "RETURN_VALUE"
)
== 1
)

return self._can_attempt_rewrite[self._map_target]
return (
self._param_name is not None
# check minimum number of ops, ensuring all are parseable
and len(self._rewritten_instructions) >= 2
and all(
inst.opname in OpNames.PARSEABLE_OPS
for inst in self._rewritten_instructions
)
# exclude constructs/functions with multiple RETURN_VALUE ops
and sum(
1
for inst in self.original_instructions
if inst.opname == "RETURN_VALUE"
)
== 1
)

def dis(self) -> None:
"""Print disassembled function bytecode."""
Expand Down Expand Up @@ -375,7 +365,7 @@ def rewritten_instructions(self) -> list[Instruction]:
def to_expression(self, col: str) -> str | None:
"""Translate postfix bytecode instructions to polars expression/string."""
self._map_target_name = None
if not self.can_attempt_rewrite() or self._param_name is None:
if self._param_name is None:
return None

# decompose bytecode into logical 'and'/'or' expression blocks (if present)
Expand All @@ -390,8 +380,8 @@ def to_expression(self, col: str) -> str | None:
control_flow_blocks[jump_offset].append(inst)

# convert each block to a polars expression string
caller_variables: dict[str, Any] = {}
try:
caller_variables: dict[str, Any] = {}
expression_strings = self._inject_nesting(
{
offset: InstructionTranslator(
Expand All @@ -407,10 +397,9 @@ def to_expression(self, col: str) -> str | None:
},
logical_instructions,
)
polars_expr = " ".join(expr for _offset, expr in expression_strings)
except NotImplementedError:
self._can_attempt_rewrite[self._map_target] = False
return None
polars_expr = " ".join(expr for _offset, expr in expression_strings)

# note: if no 'pl.col' in the expression, it likely represents a compound
# constant value (e.g. `lambda x: CONST + 123`), so we don't want to warn
Expand Down Expand Up @@ -826,26 +815,22 @@ def _is_stdlib_datetime(

def _is_raw_function(function: Callable[[Any], Any]) -> tuple[str, str]:
"""Identify translatable calls that aren't wrapped inside a lambda/function."""
try:
func_module = function.__class__.__module__
func_name = function.__name__

# numpy function calls
if func_module == "numpy" and func_name in _NUMPY_FUNCTIONS:
return "np", f"{func_name}()"

# python function calls
elif func_module == "builtins":
if func_name in _PYTHON_CASTS_MAP:
return "builtins", f"cast(pl.{_PYTHON_CASTS_MAP[func_name]})"
elif func_name == "loads":
import json # double-check since it is referenced via 'builtins'

if function is json.loads:
return "json", "str.json_decode()"

except AttributeError:
pass
func_module = function.__class__.__module__
func_name = function.__name__

# numpy function calls
if func_module == "numpy" and func_name in _NUMPY_FUNCTIONS:
return "np", f"{func_name}()"

# python function calls
elif func_module == "builtins":
if func_name in _PYTHON_CASTS_MAP:
return "builtins", f"cast(pl.{_PYTHON_CASTS_MAP[func_name]})"
elif func_name == "loads":
import json # double-check since it is referenced via 'builtins'

if function is json.loads:
return "json", "str.json_decode()"

return "", ""

Expand Down

0 comments on commit 4304ffe

Please sign in to comment.