Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT][3.12] replace POP_JUMP_{BACKWARD,FORWARD}_IF_{TRUE,FALSE} to POP_JUMP_IF_{TRUE,FALSE} #62155

Merged
merged 12 commits into from
Mar 4, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -1686,8 +1686,9 @@ def FOR_ITER(self, instr):

self._inline_call_for_loop(iterator, instr)
self._lasti = self.indexof(instr.jump_to)
next_instr = self._instructions[self._lasti]
self._lasti += int(next_instr.opname == 'END_FOR')
if sys.version_info >= (3, 12):
assert self._instructions[self._lasti].opname == "END_FOR"
self._lasti += 1
except BreakGraphError as e:
log(3, f"[BreakGraph] FOR_ITER sim for loop failed for: {e}\n")
if backup_iter_idx:
Expand Down Expand Up @@ -2060,10 +2061,17 @@ def create_after_loop_fn():
return None
pycode_gen = PyCodeGen(self._frame)
origin_instrs = get_instructions(pycode_gen._origin_code)
resume_fn_end_idx = loop_body_end_idx

# skip resume END_FOR in python3.12
if sys.version_info >= (3, 12):
assert origin_instrs[loop_body_end_idx].opname == "END_FOR"
resume_fn_end_idx += 1

pycode_gen.set_function_inputs(
after_loop_fn_inputs, stack_size=len(self.stack) - 1
)
pycode_gen.extend_instrs(origin_instrs[loop_body_end_idx:])
pycode_gen.extend_instrs(origin_instrs[resume_fn_end_idx:])
# the resume_fn contains return code, so we don't need set output here
# global vars are updated correctly, and need local vars will return
after_loop_fn = pycode_gen.create_function()
Expand Down Expand Up @@ -2127,8 +2135,13 @@ def create_after_loop_fn():
self._graph.pycode_gen.gen_jump(
for_iter, direction=JumpDirection.BACKWARD
)

if sys.version_info >= (3, 12):
end_for = self._graph.pycode_gen.add_instr("END_FOR")

nop = self._graph.pycode_gen.add_instr("NOP")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3.12 能否直接用 END_FOR 呢,就不需要先插一个 NOP 再插一个 END_FOR 了,这里的 NOP 应该是冗余的

变量名直接统一用 end_for 就好,也不需要两个版本一个用 end_for 一个用 nop

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

POP_JUMP_IF_FALSE并不能跳到END_FOR

        >>   90 FOR_ITER                19 (to 132)
             94 STORE_FAST               0 (i)
             96 LOAD_GLOBAL             11 (NULL + $resume_90@undefined_var_case_1_dc807)
            106 LOAD_FAST                0 (i)
            108 LOAD_FAST                1 (aaa)
            110 LOAD_CONST               5 (True)
            112 CALL                     3
            120 UNPACK_SEQUENCE          3
            124 STORE_FAST               1 (aaa)
            126 STORE_FAST               0 (i)
            128 POP_JUMP_IF_FALSE        2 (to 134)
            130 JUMP_BACKWARD           21 (to 90)
        >>  132 END_FOR
        >>  134 NOP

for_iter.jump_to = nop

for_iter.jump_to = end_for if sys.version_info >= (3, 12) else nop
jump_if_break.jump_to = nop

# 9. prepare inputs and call after_loop_fn
Expand Down Expand Up @@ -2198,6 +2211,8 @@ def create_inline_call_fn():
for_iter_instr, direction=JumpDirection.BACKWARD
)

if sys.version_info >= (3, 12):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我在测test_side_effects.py单测的时候,发现在3.12resume的FOR_ITER字节码会出现结束循环体的位置会比上一次情况往后多一步,导致这个单测出错;感觉和这里的END_FOR的生成有关,但是不理解为啥3.12要专门生成END_FOR,想请问下

Copy link
Member Author

@gouzil gouzil Feb 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的这个 PR 就是来解决这个问题的,具体表现情况为下面三种情况:

  • 生成跳转错误
[Resumed Function]: Inline call for loop function $resume_1@for_list_1_af1a0
 39           0 RESUME                   0
              2 LOAD_FAST                2 (object_168)
        >>    4 FOR_ITER                24 (to 56)

...
             54 END_FOR
        >>   56 LOAD_FAST                1 (i)
             58 LOAD_FAST                0 (x)
             60 LOAD_FAST                2 (object_168)
             62 BUILD_TUPLE              3
             64 RETURN_VALUE
  • END_FOR字节码不生成
[Resumed Function]: Inline call for loop function $resume_1@undefined_var_case_0_af1a0
269           0 RESUME                   0
              2 LOAD_FAST                2 (object_165)
        >>    4 FOR_ITER                38 (84)

...
             74 JUMP_FORWARD             1 (to 78)
             76 JUMP_FORWARD             2 (to 82)
        >>   78 NOP
             80 JUMP_BACKWARD           39 (to 4)
        >>   82 NOP
             86 LOAD_FAST                0 (i)
...
  • resume错误

在 Python3.12 下

 13           0 RESUME                   0

              2 END_FOR                                 # <-- 这里 resume 错误
              4 LOAD_FAST_CHECK          0 (zzz)
              6 LOAD_CONST               2 (1)
              8 BINARY_OP                0 (+)
             10 RETURN_VALUE

在 Python3.11 下

 13           0 RESUME                   0
              2 LOAD_FAST                0 (zzz)
              4 LOAD_CONST               2 (1)
              6 BINARY_OP                0 (+)

 14          10 STORE_FAST               0 (zzz)
             12 LOAD_FAST                0 (zzz)
             14 RETURN_VALUE

在 python3.12 _inline_call_for_loop理论上正常生成的字节码为

 39           0 RESUME                   0
              2 LOAD_FAST                2 (object_168)
        >>    4 FOR_ITER                27 (to 60)

 41           8 STORE_FAST               1 (i)
             10 LOAD_FAST                0 (x)
             12 LOAD_FAST                1 (i)
             14 BINARY_OP               13 (+=)

 43          18 STORE_FAST               0 (x)
             20 LOAD_FAST                0 (x)
             22 LOAD_CONST               2 (2)
             24 COMPARE_OP              68 (>)

 44          28 POP_JUMP_IF_FALSE        6 (to 42)
             30 LOAD_FAST                0 (x)
             32 LOAD_CONST               3 (1)
             34 BINARY_OP               13 (+=)
             38 STORE_FAST               0 (x)

 46          40 JUMP_FORWARD             7 (to 56)
        >>   42 LOAD_FAST                0 (x)
             44 LOAD_CONST               3 (1)
             46 BINARY_OP               23 (-=)
             50 STORE_FAST               0 (x)
             52 JUMP_FORWARD             1 (to 56)
             54 JUMP_FORWARD             2 (to 62)      # 无用
        >>   56 NOP
             58 JUMP_BACKWARD           28 (to 4)
        >>   60 END_FOR
        >>   62 NOP                                     
             64 LOAD_FAST                1 (i)
             66 LOAD_FAST                0 (x)
             68 LOAD_FAST                2 (object_168)
             70 BUILD_TUPLE              3
             72 RETURN_VALUE

在 python3.11 _inline_call_for_loop生成的字节码为

 39           0 RESUME                   0
              2 LOAD_FAST                2 (object_168)
        >>    4 FOR_ITER                27 (to 60)

 41           6 STORE_FAST               1 (i)
              8 LOAD_FAST                0 (x)
             10 LOAD_FAST                1 (i)
             12 BINARY_OP               13 (+=)

 43          16 STORE_FAST               0 (x)
             18 LOAD_FAST                0 (x)
             20 LOAD_CONST               2 (2)
             22 COMPARE_OP               4 (>)

 44          28 POP_JUMP_FORWARD_IF_FALSE     6 (to 42)
             30 LOAD_FAST                0 (x)
             32 LOAD_CONST               3 (1)
             34 BINARY_OP               13 (+=)
             38 STORE_FAST               0 (x)

 46          40 JUMP_FORWARD             7 (to 56)
        >>   42 LOAD_FAST                0 (x)
             44 LOAD_CONST               3 (1)
             46 BINARY_OP               23 (-=)
             50 STORE_FAST               0 (x)
             52 JUMP_FORWARD             1 (to 56)
             54 JUMP_FORWARD             2 (to 60)
        >>   56 NOP
             58 JUMP_BACKWARD           28 (to 4)
        >>   60 NOP
             62 LOAD_FAST                1 (i)
             64 LOAD_FAST                0 (x)
             66 LOAD_FAST                2 (object_168)
             68 BUILD_TUPLE              3
             70 RETURN_VALUE

我们可以看到FOR_ITER在 3.11 下会让 iterator 结束后跳出整个循环体内的字节码(也就是迭代器耗尽时),而在 3.12 下整个 for 是由 FOR_ITEREND_FOR 组成的,在大部分FOR_ITER及其超指令会在 FOR_ITER 字节码内跳过 END_FOR 字节码运行(仅做标识)。

下面的内容可能说的不对:

  • 关于为什么需要专门生成END_FOR(或许也可以直接拷贝):

因为需要与FOR_ITER对应。

# 2.2. copy main logic
pycode_gen.extend_instrs(origin_instrs[start_idx:end_idx+1])
  • 关于生成跳转错误

这个我目前是定位到 code gen 的 assemble 只生成了 opcodearg而没有其他属性(就是argval这些信息),当然也可能是我们在生成的问题

end_for = pycode_gen.add_instr("END_FOR")
nop_for_break = pycode_gen.add_instr("NOP")

# 2.4. relocate jumps
Expand All @@ -2212,6 +2227,8 @@ def create_inline_call_fn():
instr.jump_to = nop_for_break

jump.jump_to = for_iter_instr
if sys.version_info >= (3, 12):
for_iter_instr.jump_to = end_for

pycode_gen.set_function_outputs(output_var_names)
inline_call_fn = pycode_gen.create_function()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import contextlib
import inspect
import re
import sys
from typing import TYPE_CHECKING

from ...profiler import event_register
Expand Down Expand Up @@ -316,6 +317,9 @@ def FOR_ITER(self, instr: Instruction):
self.stack.pop()
assert isinstance(instr.jump_to, Instruction)
self._lasti = self.indexof(instr.jump_to)
if sys.version_info >= (3, 12):
assert self._instructions[self._lasti].opname == "END_FOR"
self._lasti += 1

else:
self._graph.remove_global_guarded_variable(iterator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ def gen_pop_jump(
direction: JumpDirection = JumpDirection.FORWARD,
suffix: PopJumpCond = PopJumpCond.NONE,
) -> Instruction:
if sys.version_info >= (3, 11):
if sys.version_info >= (3, 11) and sys.version_info < (3, 12):
return self.add_instr(
f"POP_JUMP_{direction.value}_IF_{suffix.value}", jump_to=jump_to
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import sys
from typing import TYPE_CHECKING

from paddle.jit.sot.utils import log, log_do

from ...utils import InnerError
from .instruction_utils import instrs_info
from .stack_analyse import StackAnalyser

if TYPE_CHECKING:
from .instruction_utils import Instruction


def apply_instr_pass(instrs, code_options):
def apply_instr_pass(instrs: list[Instruction], code_options):
log(4, f"[Opcode Pass]: Original New Code {code_options['co_name']}:\n")
log_do(4, lambda: print(instrs_info(instrs)))
supported_passes = (
supported_passes = [
remove_load_store_pass,
remove_duplicate_resume,
check_precall_followed_by_call,
)
]

if sys.version_info >= (3, 12):
supported_passes.append(check_for_iter_jump_to)

for instr_pass in supported_passes:
instr_pass(instrs, code_options)
Expand All @@ -38,7 +49,7 @@ def apply_instr_pass(instrs, code_options):
log_do(4, lambda: print(instrs_info(instrs)))


def find_stored_once_local_vars(instrs, code_options):
def find_stored_once_local_vars(instrs: list[Instruction], code_options):
"""
find out the local var names which is only stored once
"""
Expand All @@ -61,13 +72,13 @@ def find_stored_once_local_vars(instrs, code_options):
return stored_once


def find_loaded_once_local_vars(instrs, code_options):
def find_loaded_once_local_vars(instrs: list[Instruction], code_options):
"""
find out the local var names which is only stored once
"""
loaded_vars = {}
for instr in instrs:
if instr.opname == "LOAD_FAST":
if instr.opname in ["LOAD_FAST", "LOAD_FAST_CHECK"]:
if instr.argval in loaded_vars:
loaded_vars[instr.argval] += 1
else:
Expand All @@ -77,14 +88,14 @@ def find_loaded_once_local_vars(instrs, code_options):
return loaded_once


def find_related_local_opcodes(instrs, code_options):
def find_related_local_opcodes(instrs: list[Instruction], code_options):
"""
find out the opcode pairs consist with LOAD_FAST and STORE_FAST
find out the opcode pairs consist with LOAD_FAST and STORE_FAST and LOAD_FAST_CHECK
"""
stack = []
opcode_pairs = []
for instr in instrs:
if instr.opname == "LOAD_FAST":
if instr.opname in ["LOAD_FAST", "LOAD_FAST_CHECK"]:
stack.append(instr)
elif instr.opname == "STORE_FAST":
if len(stack) > 0 and stack[-1] is not None:
Expand All @@ -105,7 +116,7 @@ def find_related_local_opcodes(instrs, code_options):
return opcode_pairs


def remove_load_store_pass(instrs, code_options):
def remove_load_store_pass(instrs: list[Instruction], code_options):
"""
This question is extremely complex, so we just simplify it as
'remove renames which is between var names who only stored once'
Expand Down Expand Up @@ -158,7 +169,8 @@ def code_exist(opname, argval, instrs):
if a_name != b_name:
for instr in instrs:
if (
instr.opname in ("LOAD_FAST", "STORE_FAST")
instr.opname
in ("LOAD_FAST_CHECK", "LOAD_FAST", "STORE_FAST")
and instr.argval == b_name
):
instr.argval = a_name
Expand Down Expand Up @@ -211,7 +223,13 @@ def code_exist(opname, argval, instrs):
code_range = instrs[last_store_idx : instrs.index(store_b)]
if (
not code_exist("STORE_FAST", b_name, code_range)
and not code_exist("LOAD_FAST_CHECK", b_name, code_range)
and not code_exist("LOAD_FAST", b_name, code_range)
and not code_exist(
"LOAD_FAST_CHECK",
a_name,
instrs[instrs.index(store_b) :],
)
and not code_exist(
"LOAD_FAST", a_name, instrs[instrs.index(store_b) :]
)
Expand All @@ -222,7 +240,8 @@ def code_exist(opname, argval, instrs):
instrs.remove(store_b)
for instr in instrs[last_store_idx:]:
if (
instr.opname in ("LOAD_FAST", "STORE_FAST")
instr.opname
in ("LOAD_FAST_CHECK", "LOAD_FAST", "STORE_FAST")
and instr.argval == a_name
):
instr.argval = b_name
Expand All @@ -245,6 +264,7 @@ def code_exist(opname, argval, instrs):
and opcode2 not in jump_target
and opcode1.opname == "STORE_FAST"
and opcode2.opname == "LOAD_FAST"
and opcode2.opname == "LOAD_FAST_CHECK"
and opcode1.argval == opcode2.argval
and opcode1.argval in loaded_once
):
Expand All @@ -255,15 +275,15 @@ def code_exist(opname, argval, instrs):
idx += 1


def remove_duplicate_resume(instrs, code_options):
def remove_duplicate_resume(instrs: list[Instruction], code_options):
resumes = list(filter(lambda instr: instr.opname == "RESUME", instrs))
if not resumes:
return
for resume in resumes[1:]:
instrs.remove(resume)


def check_precall_followed_by_call(instrs, code_options):
def check_precall_followed_by_call(instrs: list[Instruction], code_options):
"""
PRECALL should be followed by CALL, otherwise it will cause a segmentation fault
"""
Expand All @@ -272,3 +292,14 @@ def check_precall_followed_by_call(instrs, code_options):
raise InnerError(
f"PRECALL is not followed by CALL in {code_options['co_name']}"
)


def check_for_iter_jump_to(instrs: list[Instruction], code_options):
"""
Check if the `jump_to` of FOR_ITER is END_FOR, in Python3.12+
"""
for instr in instrs:
if instr.opname == "FOR_ITER":
assert instr.jump_to is not None
if instr.jump_to.opname != "END_FOR":
raise InnerError("FOR_ITER jump_to is not END_FOR")
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
from typing import TYPE_CHECKING, Any

from ...utils import InnerError
from .opcode_info import ABS_JUMP, ALL_JUMP, REL_BWD_JUMP, REL_JUMP
from .opcode_info import (
ABS_JUMP,
ALL_JUMP,
PYOPCODE_CACHE_SIZE,
REL_BWD_JUMP,
REL_JUMP,
)

if TYPE_CHECKING:
import types
Expand Down Expand Up @@ -239,7 +245,8 @@ def relocate_jump_target(instructions: list[Instruction]) -> None:
if instr.opname in ABS_JUMP:
new_arg = jump_target
else: # instr.opname in REL_JUMP
new_arg = jump_target - instr.offset - 2
cache_size = PYOPCODE_CACHE_SIZE.get(instr.opname, 0)
new_arg = jump_target - (2 * cache_size) - instr.offset - 2
if instr.opname in REL_BWD_JUMP:
new_arg = -new_arg

Expand Down Expand Up @@ -315,12 +322,12 @@ def bind_ex_arg_with_instr(ex_arg, instr):
return modify_completed


def modify_vars(instructions, code_options):
def modify_vars(instructions: list[Instruction], code_options):
co_names = code_options['co_names']
co_varnames = code_options['co_varnames']
co_freevars = code_options['co_freevars']
for instrs in instructions:
if instrs.opname == 'LOAD_FAST' or instrs.opname == 'STORE_FAST':
if instrs.opname in ['LOAD_FAST', 'LOAD_FAST_CHECK', 'STORE_FAST']:
assert (
instrs.argval in co_varnames
), f"`{instrs.argval}` not in {co_varnames}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class PopJumpCond(Enum):
NOT_NONE = "NOT_NONE"


def get_pyopcode_cache_size() -> dict[str, int]:
def _get_pyopcode_cache_size() -> dict[str, int]:
if sys.version_info >= (3, 11) and sys.version_info < (3, 12):
# Cache for some opcodes, it's for Python 3.11+
# https://github.com/python/cpython/blob/3.11/Include/internal/pycore_opcode.h#L41-L53
Expand Down Expand Up @@ -87,4 +87,4 @@ def get_pyopcode_cache_size() -> dict[str, int]:
return {}


PYOPCODE_CACHE_SIZE = get_pyopcode_cache_size()
PYOPCODE_CACHE_SIZE = _get_pyopcode_cache_size()
5 changes: 0 additions & 5 deletions test/sot/skip_files_py312
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
./test_11_jumps.py
./test_12_for_loop.py
./test_builtin_zip.py
./test_inplace_api.py
./test_min_graph_size.py
./test_side_effects.py
./test_sot_cost_model.py
./test_sot_resnet.py
./test_sot_resnet50_backward.py