Skip to content

Commit

Permalink
feat(python): detect and warn about usage of json.loads in conjunct…
Browse files Browse the repository at this point in the history
…ion with `apply`
  • Loading branch information
alexander-beedie committed Jul 21, 2023
1 parent bdc28d9 commit e8410f2
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 26 deletions.
56 changes: 34 additions & 22 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class StackValue(NamedTuple):
"lower": "str.to_lowercase",
"title": "str.to_titlecase",
"upper": "str.to_uppercase",
"loads": "str.json_extract",
}


Expand Down Expand Up @@ -403,28 +404,32 @@ def _apply_rules(self, instructions: Iterator[Instruction]) -> list[Instruction]
apply_rewrite(inst, updated_instructions)
for apply_rewrite in (
# add any other rewrite methods here
self._numpy_functions,
self._python_functions,
self._functions,
self._methods,
)
):
updated_instructions.append(inst)
return updated_instructions

def _numpy_functions(
self, inst: Instruction, instructions: list[Instruction]
) -> bool:
"""Replace numpy function calls with a synthetic POLARS_EXPRESSION op."""
if inst.opname == "LOAD_GLOBAL" and inst.argval in _NUMPY_MODULE_ALIASES:
def _functions(self, inst: Instruction, instructions: list[Instruction]) -> bool:
"""Replace numpy/json function calls with a synthetic POLARS_EXPRESSION op."""
if inst.opname == "LOAD_GLOBAL" and (
inst.argval in _NUMPY_MODULE_ALIASES or inst.argval == "json"
):
instruction_buffer = list(islice(self._instructions, 3))
if (
len(instruction_buffer) == 3
and instruction_buffer[0].argval in _NUMPY_FUNCTIONS
and (
instruction_buffer[0].argval == "loads"
or instruction_buffer[0].argval in _NUMPY_FUNCTIONS
)
and instruction_buffer[1].opname.startswith("LOAD_")
and instruction_buffer[2].opname.startswith("CALL")
):
# note: synthetic POLARS_EXPRESSION is mapped as a unary
# op, so we switch the instruction order on injection
expr_name = instruction_buffer[0].argval
expr_name = _PYFUNCTION_MAP.get(expr_name, expr_name)
offsets = inst.offset, instruction_buffer[1].offset
synthetic_call = inst._replace(
opname="POLARS_EXPRESSION",
Expand All @@ -440,9 +445,7 @@ def _numpy_functions(
return True
return False

def _python_functions(
self, inst: Instruction, instructions: list[Instruction]
) -> bool:
def _methods(self, inst: Instruction, instructions: list[Instruction]) -> bool:
"""Replace python method calls with synthetic POLARS_EXPRESSION op."""
if inst.opname == "LOAD_METHOD" and inst.argval in _PYFUNCTION_MAP:
if (
Expand Down Expand Up @@ -472,15 +475,18 @@ def _upgrade_instruction(inst: Instruction) -> Instruction:
return inst


def _is_raw_numpy_function(function: Callable[[Any], Any]) -> bool:
"""Identify numpy calls that are not wrapped in a lambda/function."""
def _is_raw_function(function: Callable[[Any], Any]) -> tuple[str, str]:
"""Identify translatable calls that aren't wrapped inside a lambda/function."""
try:
return (
function.__class__.__module__ == "numpy"
and function.__name__ in _NUMPY_FUNCTIONS
)
func_module = function.__class__.__module__
func_name = function.__name__
if func_module == "numpy" and func_name in _NUMPY_FUNCTIONS:
return "np", func_name
elif func_module in ("builtins", "json") and func_name == "loads":
return "json", "str.json_extract"
except AttributeError:
return False
pass
return "", ""


def warn_on_inefficient_apply(
Expand Down Expand Up @@ -512,10 +518,16 @@ def warn_on_inefficient_apply(
parser = BytecodeParser(function, apply_target)
if parser.can_rewrite():
parser.warn(col)
elif _is_raw_numpy_function(function):
fn = function.__name__
suggestion = f'pl.col("{col}").{fn}()'
parser.warn(col, suggestion_override=suggestion, func_name_override=f"np.{fn}")
else:
# handle bare numpy/json functions
module, func_name = _is_raw_function(function)
if module and func_name:
fn = function.__name__
parser.warn(
col,
suggestion_override=f'pl.col("{col}").{func_name}()',
func_name_override=f"{module}.{fn}",
)


__all__ = [
Expand Down
32 changes: 28 additions & 4 deletions py-polars/tests/unit/operations/test_inefficient_apply.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import json
from typing import Any, Callable

import numpy
Expand Down Expand Up @@ -31,7 +32,7 @@ def test_parse_invalid_function(func: Callable[[Any], Any]) -> None:
@pytest.mark.parametrize(
("col", "func"),
[
# numeric cols: math, comparison, logic ops
# numeric col: math, comparison, logic ops
("a", lambda x: x + 1 - (2 / 3)),
("a", lambda x: x // 1 % 2),
("a", lambda x: x & True),
Expand All @@ -51,9 +52,11 @@ def test_parse_invalid_function(func: Callable[[Any], Any]) -> None:
("a", lambda x: MY_CONSTANT + x),
("a", lambda x: 0 + numpy.cbrt(x)),
("a", lambda x: np.sin(x) + 1),
# string cols
# string col: case ops
("b", lambda x: x.title()),
("b", lambda x: x.lower() + ":" + x.upper()),
# json col: load/extract
("c", lambda x: json.loads(x)),
],
)
def test_parse_apply_functions(col: str, func: Callable[[Any], Any]) -> None:
Expand All @@ -68,6 +71,7 @@ def test_parse_apply_functions(col: str, func: Callable[[Any], Any]) -> None:
{
"a": [1, 2, 3],
"b": ["AB", "cd", "eF"],
"c": ['{"a": 1}', '{"b": 2}', '{"c": 3}'],
}
)
result = df.select(
Expand All @@ -81,9 +85,10 @@ def test_parse_apply_functions(col: str, func: Callable[[Any], Any]) -> None:
assert_frame_equal(result, expected)


def test_parse_apply_numpy_raw() -> None:
def test_parse_apply_raw_functions() -> None:
lf = pl.LazyFrame({"a": [1, 2, 3]})

# test bare numpy functions
for func_name in _NUMPY_FUNCTIONS:
func = getattr(numpy, func_name)

Expand All @@ -93,12 +98,31 @@ def test_parse_apply_numpy_raw() -> None:

# ...but we ARE still able to warn
with pytest.warns(
PolarsInefficientApplyWarning, match="In this case, you can replace"
PolarsInefficientApplyWarning,
match=rf"(?s)In this case, you can replace.*np\.{func_name}",
):
df1 = lf.select(pl.col("a").apply(func)).collect()
df2 = lf.select(getattr(pl.col("a"), func_name)()).collect()
assert_frame_equal(df1, df2)

# test bare json.loads
result_frames = []
with pytest.warns(
PolarsInefficientApplyWarning,
match=r"(?s)In this case, you can replace.*\.str\.json_extract",
):
for expr in (
pl.col("value").str.json_extract(),
pl.col("value").apply(json.loads),
):
result_frames.append(
pl.LazyFrame({"value": ['{"a":1, "b": true, "c": "xx"}', None]})
.select(extracted=expr)
.unnest("extracted")
.collect()
)
assert_frame_equal(*result_frames)


def test_parse_apply_miscellaneous() -> None:
# note: can also identify inefficient functions and methods as well as lambdas
Expand Down

0 comments on commit e8410f2

Please sign in to comment.