Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Add lowered primitive for unsafe list get item op #18136

Merged
merged 6 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,14 @@ def load_module(self, name: str) -> Value:
def call_c(self, desc: CFunctionDescription, args: list[Value], line: int) -> Value:
return self.builder.call_c(desc, args, line)

def primitive_op(self, desc: PrimitiveDescription, args: list[Value], line: int) -> Value:
return self.builder.primitive_op(desc, args, line)
def primitive_op(
self,
desc: PrimitiveDescription,
args: list[Value],
line: int,
result_type: RType | None = None,
) -> Value:
return self.builder.primitive_op(desc, args, line, result_type)

def int_op(self, type: RType, lhs: Value, rhs: Value, op: int, line: int) -> Value:
return self.builder.int_op(type, lhs, rhs, op, line)
Expand Down Expand Up @@ -760,7 +766,7 @@ def process_sequence_assignment(
item = target.items[i]
index = self.builder.load_int(i)
if is_list_rprimitive(rvalue.type):
item_value = self.call_c(list_get_item_unsafe_op, [rvalue, index], line)
item_value = self.primitive_op(list_get_item_unsafe_op, [rvalue, index], line)
else:
item_value = self.builder.gen_method_call(
rvalue, "__getitem__", [index], item.type, line
Expand Down
2 changes: 1 addition & 1 deletion mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def unsafe_index(builder: IRBuilder, target: Value, index: Value, line: int) ->
# since we want to use __getitem__ if we don't have an unsafe version,
# so we just check manually.
if is_list_rprimitive(target.type):
return builder.call_c(list_get_item_unsafe_op, [target, index], line)
return builder.primitive_op(list_get_item_unsafe_op, [target, index], line)
else:
return builder.gen_method_call(target, "__getitem__", [index], None, line)

Expand Down
8 changes: 5 additions & 3 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,12 @@ def coerce_int_to_fixed_width(self, src: Value, target_type: RType, line: int) -
return res

def coerce_short_int_to_fixed_width(self, src: Value, target_type: RType, line: int) -> Value:
if is_int64_rprimitive(target_type):
if is_int64_rprimitive(target_type) or (
PLATFORM_SIZE == 4 and is_int32_rprimitive(target_type)
):
return self.int_op(target_type, src, Integer(1, target_type), IntOp.RIGHT_SHIFT, line)
# TODO: i32
assert False, (src.type, target_type)
# TODO: i32 on 64-bit platform
assert False, (src.type, target_type, PLATFORM_SIZE)

def coerce_fixed_width_to_int(self, src: Value, line: int) -> Value:
if (
Expand Down
30 changes: 29 additions & 1 deletion mypyc/lower/list_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from mypyc.common import PLATFORM_SIZE
from mypyc.ir.ops import GetElementPtr, Integer, IntOp, LoadMem, SetMem, Value
from mypyc.ir.ops import GetElementPtr, IncRef, Integer, IntOp, LoadMem, SetMem, Value
from mypyc.ir.rtypes import (
PyListObject,
c_pyssize_t_rprimitive,
Expand Down Expand Up @@ -43,3 +43,31 @@ def buf_init_item(builder: LowLevelIRBuilder, args: list[Value], line: int) -> V
def list_items(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
ob_item_ptr = builder.add(GetElementPtr(args[0], PyListObject, "ob_item", line))
return builder.add(LoadMem(pointer_rprimitive, ob_item_ptr, line))


def list_item_ptr(builder: LowLevelIRBuilder, obj: Value, index: Value, line: int) -> Value:
"""Get a pointer to a list item (index must be valid and non-negative).

Type of index must be c_pyssize_t_rprimitive, and obj must refer to a list object.
"""
# List items are represented as an array of pointers. Pointer to the item obj[index] is
# <pointer to first item> + index * <pointer size>.
items = list_items(builder, [obj], line)
delta = builder.add(
IntOp(
c_pyssize_t_rprimitive,
index,
Integer(PLATFORM_SIZE, c_pyssize_t_rprimitive),
IntOp.MUL,
)
)
return builder.add(IntOp(pointer_rprimitive, items, delta, IntOp.ADD))


@lower_primitive_op("list_get_item_unsafe")
def list_get_item_unsafe(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
index = builder.coerce(args[1], c_pyssize_t_rprimitive, line)
item_ptr = list_item_ptr(builder, args[0], index, line)
value = builder.add(LoadMem(object_rprimitive, item_ptr, line))
builder.add(IncRef(value))
return value
4 changes: 2 additions & 2 deletions mypyc/primitives/list_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@

# This is unsafe because it assumes that the index is a non-negative short integer
# that is in-bounds for the list.
list_get_item_unsafe_op = custom_op(
list_get_item_unsafe_op = custom_primitive_op(
name="list_get_item_unsafe",
arg_types=[list_rprimitive, short_int_rprimitive],
return_type=object_rprimitive,
c_function_name="CPyList_GetItemUnsafe",
error_kind=ERR_NEVER,
)

Expand Down
8 changes: 4 additions & 4 deletions mypyc/test-data/irbuild-basic.test
Original file line number Diff line number Diff line change
Expand Up @@ -1874,7 +1874,7 @@ L1:
r9 = int_lt r6, r8
if r9 goto L2 else goto L8 :: bool
L2:
r10 = CPyList_GetItemUnsafe(r1, r6)
r10 = list_get_item_unsafe r1, r6
r11 = unbox(int, r10)
x = r11
r12 = int_ne x, 4
Expand Down Expand Up @@ -1938,7 +1938,7 @@ L1:
r9 = int_lt r6, r8
if r9 goto L2 else goto L8 :: bool
L2:
r10 = CPyList_GetItemUnsafe(r1, r6)
r10 = list_get_item_unsafe r1, r6
r11 = unbox(int, r10)
x = r11
r12 = int_ne x, 4
Expand Down Expand Up @@ -2000,7 +2000,7 @@ L1:
r3 = int_lt r0, r2
if r3 goto L2 else goto L4 :: bool
L2:
r4 = CPyList_GetItemUnsafe(l, r0)
r4 = list_get_item_unsafe l, r0
r5 = unbox(tuple[int, int, int], r4)
r6 = r5[0]
x = r6
Expand All @@ -2022,7 +2022,7 @@ L5:
r15 = int_lt r12, r14
if r15 goto L6 else goto L8 :: bool
L6:
r16 = CPyList_GetItemUnsafe(l, r12)
r16 = list_get_item_unsafe l, r12
r17 = unbox(tuple[int, int, int], r16)
r18 = r17[0]
x_2 = r18
Expand Down
10 changes: 5 additions & 5 deletions mypyc/test-data/irbuild-lists.test
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ L1:
r5 = int_lt r2, r4
if r5 goto L2 else goto L4 :: bool
L2:
r6 = CPyList_GetItemUnsafe(source, r2)
r6 = list_get_item_unsafe source, r2
r7 = unbox(int, r6)
x = r7
r8 = CPyTagged_Add(x, 2)
Expand All @@ -362,7 +362,7 @@ L5:
r17 = int_lt r14, r16
if r17 goto L6 else goto L8 :: bool
L6:
r18 = CPyList_GetItemUnsafe(source, r14)
r18 = list_get_item_unsafe source, r14
r19 = unbox(int, r18)
x_2 = r19
r20 = CPyTagged_Add(x_2, 2)
Expand Down Expand Up @@ -403,7 +403,7 @@ L1:
r3 = int_lt r0, r2
if r3 goto L2 else goto L4 :: bool
L2:
r4 = CPyList_GetItemUnsafe(x, r0)
r4 = list_get_item_unsafe x, r0
r5 = unbox(int, r4)
i = r5
r6 = box(int, i)
Expand Down Expand Up @@ -476,7 +476,7 @@ L1:
r3 = int_lt r0, r2
if r3 goto L2 else goto L4 :: bool
L2:
r4 = CPyList_GetItemUnsafe(a, r0)
r4 = list_get_item_unsafe a, r0
r5 = cast(union[str, bytes], r4)
x = r5
L3:
Expand All @@ -502,7 +502,7 @@ L1:
r3 = int_lt r0, r2
if r3 goto L2 else goto L4 :: bool
L2:
r4 = CPyList_GetItemUnsafe(a, r0)
r4 = list_get_item_unsafe a, r0
r5 = cast(union[str, None], r4)
x = r5
L3:
Expand Down
4 changes: 2 additions & 2 deletions mypyc/test-data/irbuild-set.test
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ L1:
r9 = int_lt r6, r8
if r9 goto L2 else goto L4 :: bool
L2:
r10 = CPyList_GetItemUnsafe(tmp_list, r6)
r10 = list_get_item_unsafe tmp_list, r6
r11 = unbox(int, r10)
x = r11
r12 = f(x)
Expand Down Expand Up @@ -361,7 +361,7 @@ L1:
r13 = int_lt r10, r12
if r13 goto L2 else goto L6 :: bool
L2:
r14 = CPyList_GetItemUnsafe(tmp_list, r10)
r14 = list_get_item_unsafe tmp_list, r10
r15 = unbox(int, r14)
z = r15
r16 = int_lt z, 8
Expand Down
12 changes: 6 additions & 6 deletions mypyc/test-data/irbuild-statements.test
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ L1:
r3 = int_lt r0, r2
if r3 goto L2 else goto L4 :: bool
L2:
r4 = CPyList_GetItemUnsafe(ls, r0)
r4 = list_get_item_unsafe ls, r0
r5 = unbox(int, r4)
x = r5
r6 = CPyTagged_Add(y, x)
Expand Down Expand Up @@ -594,8 +594,8 @@ def f(l, t):
L0:
r0 = CPySequence_CheckUnpackCount(l, 2)
r1 = r0 >= 0 :: signed
r2 = CPyList_GetItemUnsafe(l, 0)
r3 = CPyList_GetItemUnsafe(l, 2)
r2 = list_get_item_unsafe l, 0
r3 = list_get_item_unsafe l, 2
x = r2
r4 = unbox(int, r3)
y = r4
Expand Down Expand Up @@ -882,7 +882,7 @@ L1:
if r4 goto L2 else goto L4 :: bool
L2:
i = r0
r5 = CPyList_GetItemUnsafe(a, r1)
r5 = list_get_item_unsafe a, r1
r6 = unbox(int, r5)
x = r6
r7 = CPyTagged_Add(i, x)
Expand Down Expand Up @@ -961,7 +961,7 @@ L2:
r5 = PyIter_Next(r1)
if is_error(r5) goto L7 else goto L3
L3:
r6 = CPyList_GetItemUnsafe(a, r0)
r6 = list_get_item_unsafe a, r0
r7 = unbox(int, r6)
x = r7
r8 = unbox(bool, r5)
Expand Down Expand Up @@ -1015,7 +1015,7 @@ L3:
L4:
r8 = unbox(bool, r3)
x = r8
r9 = CPyList_GetItemUnsafe(b, r1)
r9 = list_get_item_unsafe b, r1
r10 = unbox(int, r9)
y = r10
x = 0
Expand Down
2 changes: 1 addition & 1 deletion mypyc/test-data/irbuild-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ L1:
r10 = int_lt r7, r9
if r10 goto L2 else goto L4 :: bool
L2:
r11 = CPyList_GetItemUnsafe(source, r7)
r11 = list_get_item_unsafe source, r7
r12 = unbox(int, r11)
x = r12
r13 = f(x)
Expand Down
38 changes: 24 additions & 14 deletions mypyc/test-data/lowering-int.test
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ L4:
L5:
return 4

[case testLowerIntForLoop]
[case testLowerIntForLoop_64bit]
from __future__ import annotations

def f(l: list[int]) -> None:
Expand All @@ -346,10 +346,14 @@ def f(l):
r2 :: native_int
r3 :: short_int
r4 :: bit
r5 :: object
r6, x :: int
r7 :: short_int
r8 :: None
r5 :: native_int
r6, r7 :: ptr
r8 :: native_int
r9 :: ptr
r10 :: object
r11, x :: int
r12 :: short_int
r13 :: None
L0:
r0 = 0
L1:
Expand All @@ -359,19 +363,25 @@ L1:
r4 = r0 < r3 :: signed
if r4 goto L2 else goto L5 :: bool
L2:
r5 = CPyList_GetItemUnsafe(l, r0)
r6 = unbox(int, r5)
dec_ref r5
if is_error(r6) goto L6 (error at f:4) else goto L3
r5 = r0 >> 1
r6 = get_element_ptr l ob_item :: PyListObject
r7 = load_mem r6 :: ptr*
r8 = r5 * 8
r9 = r7 + r8
r10 = load_mem r9 :: builtins.object*
inc_ref r10
r11 = unbox(int, r10)
dec_ref r10
if is_error(r11) goto L6 (error at f:4) else goto L3
L3:
x = r6
x = r11
dec_ref x :: int
L4:
r7 = r0 + 2
r0 = r7
r12 = r0 + 2
r0 = r12
goto L1
L5:
return 1
L6:
r8 = <error> :: None
return r8
r13 = <error> :: None
return r13
7 changes: 6 additions & 1 deletion mypyc/test/test_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MypycDataSuite,
assert_test_output,
build_ir_for_single_file,
infer_ir_build_options_from_test_name,
remove_comment_lines,
replace_word_size,
use_custom_builtins,
Expand All @@ -31,11 +32,15 @@ class TestLowering(MypycDataSuite):
base_path = test_temp_dir

def run_case(self, testcase: DataDrivenTestCase) -> None:
options = infer_ir_build_options_from_test_name(testcase.name)
if options is None:
# Skipped test case
return
with use_custom_builtins(os.path.join(self.data_prefix, ICODE_GEN_BUILTINS), testcase):
expected_output = remove_comment_lines(testcase.output)
expected_output = replace_word_size(expected_output)
try:
ir = build_ir_for_single_file(testcase.input)
ir = build_ir_for_single_file(testcase.input, options)
except CompileError as e:
actual = e.messages
else:
Expand Down
Loading