Skip to content

Commit

Permalink
[mypyc] Add lowered primitive for unsafe list get item op (#18136)
Browse files Browse the repository at this point in the history
This inlines the list get item op in loops like `for x in <list>`.

I estimated the impact using two microbenchmarks that iterate over
`list[int]` objects. One of them was 1.3x faster, while the other was
1.09x faster.

Since we now generate detailed IR for the op, instead of using a C
primitive function, this also opens up further IR optimization
opportunities in the future.
  • Loading branch information
JukkaL authored Dec 10, 2024
1 parent d920e6c commit 568648d
Show file tree
Hide file tree
Showing 12 changed files with 94 additions and 43 deletions.
12 changes: 9 additions & 3 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,14 @@ def load_module(self, name: str) -> Value:
def call_c(self, desc: CFunctionDescription, args: list[Value], line: int) -> Value:
return self.builder.call_c(desc, args, line)

def primitive_op(self, desc: PrimitiveDescription, args: list[Value], line: int) -> Value:
return self.builder.primitive_op(desc, args, line)
def primitive_op(
self,
desc: PrimitiveDescription,
args: list[Value],
line: int,
result_type: RType | None = None,
) -> Value:
return self.builder.primitive_op(desc, args, line, result_type)

def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int) -> Value:
return self.builder.int_op(type, lhs, rhs, op, line)
Expand Down Expand Up @@ -760,7 +766,7 @@ def process_sequence_assignment(
item = target.items[i]
index = self.builder.load_int(i)
if is_list_rprimitive(rvalue.type):
item_value = self.call_c(list_get_item_unsafe_op, [rvalue, index], line)
item_value = self.primitive_op(list_get_item_unsafe_op, [rvalue, index], line)
else:
item_value = self.builder.gen_method_call(
rvalue, "__getitem__", [index], item.type, line
Expand Down
2 changes: 1 addition & 1 deletion mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def unsafe_index(builder: IRBuilder, target: Value, index: Value, line: int) ->
# since we want to use __getitem__ if we don't have an unsafe version,
# so we just check manually.
if is_list_rprimitive(target.type):
return builder.call_c(list_get_item_unsafe_op, [target, index], line)
return builder.primitive_op(list_get_item_unsafe_op, [target, index], line)
else:
return builder.gen_method_call(target, "__getitem__", [index], None, line)

Expand Down
8 changes: 5 additions & 3 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,12 @@ def coerce_int_to_fixed_width(self, src: Value, target_type: RType, line: int) -
return res

def coerce_short_int_to_fixed_width(self, src: Value, target_type: RType, line: int) -> Value:
if is_int64_rprimitive(target_type):
if is_int64_rprimitive(target_type) or (
PLATFORM_SIZE == 4 and is_int32_rprimitive(target_type)
):
return self.int_op(target_type, src, Integer(1, target_type), IntOp.RIGHT_SHIFT, line)
# TODO: i32
assert False, (src.type, target_type)
# TODO: i32 on 64-bit platform
assert False, (src.type, target_type, PLATFORM_SIZE)

def coerce_fixed_width_to_int(self, src: Value, line: int) -> Value:
if (
Expand Down
30 changes: 29 additions & 1 deletion mypyc/lower/list_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from mypyc.common import PLATFORM_SIZE
from mypyc.ir.ops import GetElementPtr, Integer, IntOp, LoadMem, SetMem, Value
from mypyc.ir.ops import GetElementPtr, IncRef, Integer, IntOp, LoadMem, SetMem, Value
from mypyc.ir.rtypes import (
PyListObject,
c_pyssize_t_rprimitive,
Expand Down Expand Up @@ -43,3 +43,31 @@ def buf_init_item(builder: LowLevelIRBuilder, args: list[Value], line: int) -> V
def list_items(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
ob_item_ptr = builder.add(GetElementPtr(args[0], PyListObject, "ob_item", line))
return builder.add(LoadMem(pointer_rprimitive, ob_item_ptr, line))


def list_item_ptr(builder: LowLevelIRBuilder, obj: Value, index: Value, line: int) -> Value:
"""Get a pointer to a list item (index must be valid and non-negative).
Type of index must be c_pyssize_t_rprimitive, and obj must refer to a list object.
"""
# List items are represented as an array of pointers. Pointer to the item obj[index] is
# <pointer to first item> + index * <pointer size>.
items = list_items(builder, [obj], line)
delta = builder.add(
IntOp(
c_pyssize_t_rprimitive,
index,
Integer(PLATFORM_SIZE, c_pyssize_t_rprimitive),
IntOp.MUL,
)
)
return builder.add(IntOp(pointer_rprimitive, items, delta, IntOp.ADD))


@lower_primitive_op("list_get_item_unsafe")
def list_get_item_unsafe(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
index = builder.coerce(args[1], c_pyssize_t_rprimitive, line)
item_ptr = list_item_ptr(builder, args[0], index, line)
value = builder.add(LoadMem(object_rprimitive, item_ptr, line))
builder.add(IncRef(value))
return value
4 changes: 2 additions & 2 deletions mypyc/primitives/list_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@

# This is unsafe because it assumes that the index is a non-negative short integer
# that is in-bounds for the list.
list_get_item_unsafe_op = custom_op(
list_get_item_unsafe_op = custom_primitive_op(
name="list_get_item_unsafe",
arg_types=[list_rprimitive, short_int_rprimitive],
return_type=object_rprimitive,
c_function_name="CPyList_GetItemUnsafe",
error_kind=ERR_NEVER,
)

Expand Down
8 changes: 4 additions & 4 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -1874,7 +1874,7 @@ L1:
r9 = int_lt r6, r8
if r9 goto L2 else goto L8 :: bool
L2:
r10 = CPyList_GetItemUnsafe(r1, r6)
r10 = list_get_item_unsafe r1, r6
r11 = unbox(int, r10)
x = r11
r12 = int_ne x, 4
Expand Down Expand Up @@ -1938,7 +1938,7 @@ L1:
r9 = int_lt r6, r8
if r9 goto L2 else goto L8 :: bool
L2:
r10 = CPyList_GetItemUnsafe(r1, r6)
r10 = list_get_item_unsafe r1, r6
r11 = unbox(int, r10)
x = r11
r12 = int_ne x, 4
Expand Down Expand Up @@ -2000,7 +2000,7 @@ L1:
r3 = int_lt r0, r2
if r3 goto L2 else goto L4 :: bool
L2:
r4 = CPyList_GetItemUnsafe(l, r0)
r4 = list_get_item_unsafe l, r0
r5 = unbox(tuple[int, int, int], r4)
r6 = r5[0]
x = r6
Expand All @@ -2022,7 +2022,7 @@ L5:
r15 = int_lt r12, r14
if r15 goto L6 else goto L8 :: bool
L6:
r16 = CPyList_GetItemUnsafe(l, r12)
r16 = list_get_item_unsafe l, r12
r17 = unbox(tuple[int, int, int], r16)
r18 = r17[0]
x_2 = r18
Expand Down
10 changes: 5 additions & 5 deletions mypyc/test-data/irbuild-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ L1:
r5 = int_lt r2, r4
if r5 goto L2 else goto L4 :: bool
L2:
r6 = CPyList_GetItemUnsafe(source, r2)
r6 = list_get_item_unsafe source, r2
r7 = unbox(int, r6)
x = r7
r8 = CPyTagged_Add(x, 2)
Expand All @@ -362,7 +362,7 @@ L5:
r17 = int_lt r14, r16
if r17 goto L6 else goto L8 :: bool
L6:
r18 = CPyList_GetItemUnsafe(source, r14)
r18 = list_get_item_unsafe source, r14
r19 = unbox(int, r18)
x_2 = r19
r20 = CPyTagged_Add(x_2, 2)
Expand Down Expand Up @@ -403,7 +403,7 @@ L1:
r3 = int_lt r0, r2
if r3 goto L2 else goto L4 :: bool
L2:
r4 = CPyList_GetItemUnsafe(x, r0)
r4 = list_get_item_unsafe x, r0
r5 = unbox(int, r4)
i = r5
r6 = box(int, i)
Expand Down Expand Up @@ -476,7 +476,7 @@ L1:
r3 = int_lt r0, r2
if r3 goto L2 else goto L4 :: bool
L2:
r4 = CPyList_GetItemUnsafe(a, r0)
r4 = list_get_item_unsafe a, r0
r5 = cast(union[str, bytes], r4)
x = r5
L3:
Expand All @@ -502,7 +502,7 @@ L1:
r3 = int_lt r0, r2
if r3 goto L2 else goto L4 :: bool
L2:
r4 = CPyList_GetItemUnsafe(a, r0)
r4 = list_get_item_unsafe a, r0
r5 = cast(union[str, None], r4)
x = r5
L3:
Expand Down
4 changes: 2 additions & 2 deletions mypyc/test-data/irbuild-set.test
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ L1:
r9 = int_lt r6, r8
if r9 goto L2 else goto L4 :: bool
L2:
r10 = CPyList_GetItemUnsafe(tmp_list, r6)
r10 = list_get_item_unsafe tmp_list, r6
r11 = unbox(int, r10)
x = r11
r12 = f(x)
Expand Down Expand Up @@ -361,7 +361,7 @@ L1:
r13 = int_lt r10, r12
if r13 goto L2 else goto L6 :: bool
L2:
r14 = CPyList_GetItemUnsafe(tmp_list, r10)
r14 = list_get_item_unsafe tmp_list, r10
r15 = unbox(int, r14)
z = r15
r16 = int_lt z, 8
Expand Down
12 changes: 6 additions & 6 deletions mypyc/test-data/irbuild-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ L1:
r3 = int_lt r0, r2
if r3 goto L2 else goto L4 :: bool
L2:
r4 = CPyList_GetItemUnsafe(ls, r0)
r4 = list_get_item_unsafe ls, r0
r5 = unbox(int, r4)
x = r5
r6 = CPyTagged_Add(y, x)
Expand Down Expand Up @@ -594,8 +594,8 @@ def f(l, t):
L0:
r0 = CPySequence_CheckUnpackCount(l, 2)
r1 = r0 >= 0 :: signed
r2 = CPyList_GetItemUnsafe(l, 0)
r3 = CPyList_GetItemUnsafe(l, 2)
r2 = list_get_item_unsafe l, 0
r3 = list_get_item_unsafe l, 2
x = r2
r4 = unbox(int, r3)
y = r4
Expand Down Expand Up @@ -882,7 +882,7 @@ L1:
if r4 goto L2 else goto L4 :: bool
L2:
i = r0
r5 = CPyList_GetItemUnsafe(a, r1)
r5 = list_get_item_unsafe a, r1
r6 = unbox(int, r5)
x = r6
r7 = CPyTagged_Add(i, x)
Expand Down Expand Up @@ -961,7 +961,7 @@ L2:
r5 = PyIter_Next(r1)
if is_error(r5) goto L7 else goto L3
L3:
r6 = CPyList_GetItemUnsafe(a, r0)
r6 = list_get_item_unsafe a, r0
r7 = unbox(int, r6)
x = r7
r8 = unbox(bool, r5)
Expand Down Expand Up @@ -1015,7 +1015,7 @@ L3:
L4:
r8 = unbox(bool, r3)
x = r8
r9 = CPyList_GetItemUnsafe(b, r1)
r9 = list_get_item_unsafe b, r1
r10 = unbox(int, r9)
y = r10
x = 0
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/irbuild-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ L1:
r10 = int_lt r7, r9
if r10 goto L2 else goto L4 :: bool
L2:
r11 = CPyList_GetItemUnsafe(source, r7)
r11 = list_get_item_unsafe source, r7
r12 = unbox(int, r11)
x = r12
r13 = f(x)
Expand Down
38 changes: 24 additions & 14 deletions mypyc/test-data/lowering-int.test
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ L4:
L5:
return 4

[case testLowerIntForLoop]
[case testLowerIntForLoop_64bit]
from __future__ import annotations

def f(l: list[int]) -> None:
Expand All @@ -346,10 +346,14 @@ def f(l):
r2 :: native_int
r3 :: short_int
r4 :: bit
r5 :: object
r6, x :: int
r7 :: short_int
r8 :: None
r5 :: native_int
r6, r7 :: ptr
r8 :: native_int
r9 :: ptr
r10 :: object
r11, x :: int
r12 :: short_int
r13 :: None
L0:
r0 = 0
L1:
Expand All @@ -359,19 +363,25 @@ L1:
r4 = r0 < r3 :: signed
if r4 goto L2 else goto L5 :: bool
L2:
r5 = CPyList_GetItemUnsafe(l, r0)
r6 = unbox(int, r5)
dec_ref r5
if is_error(r6) goto L6 (error at f:4) else goto L3
r5 = r0 >> 1
r6 = get_element_ptr l ob_item :: PyListObject
r7 = load_mem r6 :: ptr*
r8 = r5 * 8
r9 = r7 + r8
r10 = load_mem r9 :: builtins.object*
inc_ref r10
r11 = unbox(int, r10)
dec_ref r10
if is_error(r11) goto L6 (error at f:4) else goto L3
L3:
x = r6
x = r11
dec_ref x :: int
L4:
r7 = r0 + 2
r0 = r7
r12 = r0 + 2
r0 = r12
goto L1
L5:
return 1
L6:
r8 = <error> :: None
return r8
r13 = <error> :: None
return r13
7 changes: 6 additions & 1 deletion mypyc/test/test_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MypycDataSuite,
assert_test_output,
build_ir_for_single_file,
infer_ir_build_options_from_test_name,
remove_comment_lines,
replace_word_size,
use_custom_builtins,
Expand All @@ -31,11 +32,15 @@ class TestLowering(MypycDataSuite):
base_path = test_temp_dir

def run_case(self, testcase: DataDrivenTestCase) -> None:
options = infer_ir_build_options_from_test_name(testcase.name)
if options is None:
# Skipped test case
return
with use_custom_builtins(os.path.join(self.data_prefix, ICODE_GEN_BUILTINS), testcase):
expected_output = remove_comment_lines(testcase.output)
expected_output = replace_word_size(expected_output)
try:
ir = build_ir_for_single_file(testcase.input)
ir = build_ir_for_single_file(testcase.input, options)
except CompileError as e:
actual = e.messages
else:
Expand Down

0 comments on commit 568648d

Please sign in to comment.