diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 8d9739e6d..e9e33e02c 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -34,6 +34,7 @@ Space, analysis_inputs, analysis_used_names_with_space, + calc_stack_effect, get_instructions, ) from .dispatch_functions import ( @@ -1754,7 +1755,7 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): ] + inputs_var # Collect all the to store variables. store_vars = [] - for stack_arg in self.stack._data: + for stack_arg in self.stack: store_vars.append(stack_arg) for name in inputs_name: store_vars.append(self.get_var(name)) @@ -1772,7 +1773,7 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): if_fn, if_fn.__code__.co_name ) insert_index = len(self._graph.pycode_gen._instructions) - 1 - for stack_arg in self.stack._data: + for stack_arg in self.stack: var_loader.load(stack_arg) for name in if_inputs: var_loader.load(self.get_var(name)) @@ -1789,7 +1790,7 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction): else_fn, else_fn.__code__.co_name ) jump_to = self._graph.pycode_gen._instructions[-1] - for stack_arg in self.stack._data: + for stack_arg in self.stack: var_loader.load(stack_arg) for name in else_inputs: var_loader.load(self.get_var(name)) @@ -1829,7 +1830,7 @@ def _break_graph_in_call( # gen call static fn opcode ret_vars = [ arg - for arg in self.stack._data + for arg in self.stack if isinstance(arg, (TensorVariable, ContainerVariable)) ] resume_input_name = analysis_inputs(self._instructions, index + 1) @@ -1841,7 +1842,7 @@ def _break_graph_in_call( # Collect all the to store variables. store_vars = [] - for stack_arg in self.stack._data: + for stack_arg in self.stack: store_vars.append(stack_arg) for name in resume_input_name: store_vars.append(self.get_var(name)) @@ -1853,30 +1854,35 @@ def _break_graph_in_call( self._graph.pycode_gen.gen_pop_top() # gen graph break call fn opcode - stack_effect = dis.stack_effect(instr.opcode, instr.arg) + stack_effect = calc_stack_effect(instr) pop_n = push_n - stack_effect - for i, stack_arg in enumerate(self.stack._data): + + for i, stack_arg in enumerate(self.stack): # Avoid passing NULL as a parameter to the resume function if ( isinstance(stack_arg, NullVariable) and i < len(self.stack) - pop_n ): self._graph.pycode_gen.gen_load_object( - NullVariable(), f'dummy_var{i}' + NullVariable(), f'dummy_var{i}', push_null=False ) else: var_loader.load(stack_arg) - self._graph.pycode_gen.add_pure_instructions([instr]) # gen call resume fn opcode + self._graph.pycode_gen.add_pure_instructions([instr]) self.stack.pop_n(pop_n) stack_size = len(self.stack) + push_n + resume_fn, _ = self._create_resume_fn(index + 1, stack_size) if resume_fn: self._graph.pycode_gen.gen_load_object( resume_fn, resume_fn.__code__.co_name ) - self._graph.pycode_gen.gen_rot_n(stack_size + 1) + # NOTE(zrr1999): We need to shift the resume_fn under its arguments. + # In Python 3.11+, NULL + resume_fn should be shifted together. + shift_n = 2 if sys.version_info >= (3, 11) else 1 + self._graph.pycode_gen.gen_shift_n(shift_n, stack_size + shift_n) for name in resume_input_name: var_loader.load(self.get_var(name)) self._graph.pycode_gen.gen_call_function( @@ -2001,9 +2007,7 @@ def _break_graph_in_for_loop( raise InnerError("Can not balance stack in loop body.") cur_instr = self._instructions[loop_body_start_idx] # do not consider jump instr - stack_effect = dis.stack_effect( - cur_instr.opcode, cur_instr.arg, jump=False - ) + stack_effect = calc_stack_effect(cur_instr, jump=False) curent_stack += stack_effect loop_body_start_idx += 1 if curent_stack == 0: @@ -2041,7 +2045,7 @@ def _break_graph_in_for_loop( ] ret_vars = [self.get_var(name) for name in ret_names] store_vars = [ret_vars[idx] for idx in range(len(ret_names))] - store_vars.extend(iter(self.stack._data)) + store_vars.extend(iter(self.stack)) var_loader = self._graph.start_compile_with_name_store( ret_vars, store_vars ) @@ -2106,7 +2110,7 @@ def _break_graph_in_for_loop( after_loop_fn, after_loop_fn.__code__.co_name ) - for stack_arg in self.stack._data: + for stack_arg in self.stack: var_loader.load(stack_arg) for name in fn_inputs: self._graph.pycode_gen.gen_load(name) diff --git a/sot/opcode_translator/executor/pycode_generator.py b/sot/opcode_translator/executor/pycode_generator.py index 9e24621ca..e12cc2d0b 100644 --- a/sot/opcode_translator/executor/pycode_generator.py +++ b/sot/opcode_translator/executor/pycode_generator.py @@ -4,7 +4,6 @@ from __future__ import annotations -import dis import sys import types from typing import TYPE_CHECKING @@ -25,6 +24,7 @@ ) from ..instruction_utils import ( analysis_inputs, + calc_stack_effect, gen_instr, get_instructions, instrs_info, @@ -383,11 +383,11 @@ def update_stacksize(lasti: int, nexti: int, stack_effect: int): idx + 1 < len(instructions) and instr.opname not in UNCONDITIONAL_JUMP ): - stack_effect = dis.stack_effect(instr.opcode, instr.arg, jump=False) + stack_effect = calc_stack_effect(instr, jump=False) update_stacksize(idx, idx + 1, stack_effect) if instr.opcode in opcode.hasjabs or instr.opcode in opcode.hasjrel: - stack_effect = dis.stack_effect(instr.opcode, instr.arg, jump=True) + stack_effect = calc_stack_effect(instr, jump=True) target_idx = instructions.index(instr.jump_to) update_stacksize(idx, target_idx, stack_effect) @@ -433,7 +433,9 @@ def gen_pycode(self) -> types.CodeType: ) return new_code - def gen_resume_fn_at(self, index: int, stack_size: int = 0): + def gen_resume_fn_at( + self, index: int, stack_size: int = 0 + ) -> tuple[None | types.FunctionType, OrderedSet[str]]: """ Generates a resume function at the specified index in the instruction list. @@ -662,7 +664,7 @@ def gen_load_global(self, name, push_null=False): idx |= 1 self._add_instr("LOAD_GLOBAL", arg=idx, argval=name) - def gen_load_object(self, obj, obj_name: str): + def gen_load_object(self, obj, obj_name: str, push_null: bool = True): """ Generate the bytecode for loading an object. @@ -673,7 +675,7 @@ def gen_load_object(self, obj, obj_name: str): if obj_name not in self._f_globals: self._f_globals[obj_name] = obj - self.gen_load_global(obj_name, push_null=True) + self.gen_load_global(obj_name, push_null=push_null) def gen_load_fast(self, name): """ @@ -814,6 +816,51 @@ def rot_n_fn(n): self._add_instr("CALL_FUNCTION_EX", arg=0) self.gen_unpack_sequence(n) + def gen_shift_n(self, s: int, n: int): + """ + Generate the bytecode for shifting the stack. + + Args: + s (int): Steps to shift. + n (int): The number of elements to shift. + """ + if s == 0 or n <= 1: + return + + # NOTE(zrr1999): right shift s steps is equal to left shift n-s steps + if abs(s) > n // 2: + new_s = s - n if s > 0 else s + n + self.gen_shift_n(new_s, n) + return + if s > 0: + # NOTE: s=1, n=3 [1,2,3,4,5] -> [1,2,5,3,4] + # s=2, n=3 [1,2,3,4,5] -> [1,2,4,5,3] + if s == 1: + self.gen_rot_n(n) + else: + self.gen_rot_n(n) + self.gen_shift_n(s - 1, n) + + else: # s < 0 + if sys.version_info >= (3, 11): + # NOTE: s=-1, n=3 [1,2,3,4,5] -> [1,2,4,5,3] + if s == -1: + for i in range(2, n + 1): + self._add_instr("SWAP", arg=i) + else: + self.gen_shift_n(-1, n) + self.gen_shift_n(s + 1, n) + else: + raise NotImplementedError( + "shift_n is not supported before python3.11" + ) + + def gen_swap(self, n): + if sys.version_info >= (3, 11): + self._add_instr("SWAP", arg=n) + else: + raise NotImplementedError("swap is not supported before python3.11") + def gen_return(self): self._add_instr("RETURN_VALUE") diff --git a/sot/opcode_translator/executor/variable_stack.py b/sot/opcode_translator/executor/variable_stack.py index 4a37006a2..3e45fff17 100644 --- a/sot/opcode_translator/executor/variable_stack.py +++ b/sot/opcode_translator/executor/variable_stack.py @@ -192,6 +192,9 @@ def top(self, value): assert len(self) > 0, "stack is empty" self.peek[1] = value + def __iter__(self): + return iter(self._data) + def __len__(self) -> int: return len(self._data) diff --git a/sot/opcode_translator/instruction_utils/__init__.py b/sot/opcode_translator/instruction_utils/__init__.py index bf6d34f0f..64f3ed4aa 100644 --- a/sot/opcode_translator/instruction_utils/__init__.py +++ b/sot/opcode_translator/instruction_utils/__init__.py @@ -1,6 +1,7 @@ from .instruction_utils import ( Instruction, calc_offset_from_bytecode_offset, + calc_stack_effect, convert_instruction, gen_instr, get_instructions, @@ -22,6 +23,7 @@ "analysis_inputs", "analysis_used_names_with_space", "calc_offset_from_bytecode_offset", + "calc_stack_effect", "Instruction", "convert_instruction", "gen_instr", diff --git a/sot/opcode_translator/instruction_utils/instruction_utils.py b/sot/opcode_translator/instruction_utils/instruction_utils.py index 6e1a391ca..dd5f5a8f5 100644 --- a/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -325,3 +325,25 @@ def instrs_info(instrs, mark=None, range=None): if idx == mark: ret[-1] = "\033[31m" + ret[-1] + "\033[0m" return ret + + +def calc_stack_effect(instr: Instruction, *, jump: bool | None = None) -> int: + """ + Gets the stack effect of the given instruction. In Python 3.11, the stack effect of `CALL` is -1, + refer to https://github.com/python/cpython/blob/3.11/Python/compile.c#L1123-L1124. + + Args: + instr: The instruction. + + Returns: + The stack effect of the instruction. + + """ + if sys.version_info[:2] == (3, 11): + if instr.opname == "PRECALL": + return 0 + elif instr.opname == "CALL": + # NOTE(zrr1999): push_n = 1, pop_n = oparg + 2, stack_effect = push_n - pop_n = -oparg-1 + assert instr.arg is not None + return -instr.arg - 1 + return dis.stack_effect(instr.opcode, instr.arg, jump=jump) diff --git a/tests/run_all.sh b/tests/run_all.sh index 3a2ced999..6255b4864 100644 --- a/tests/run_all.sh +++ b/tests/run_all.sh @@ -9,22 +9,12 @@ echo "IS_PY311:" $IS_PY311 failed_tests=() py311_skiped_tests=( - # ./test_01_basic.py There are some case need to be fixed - # ./test_04_list.py There are some case need to be fixed - # ./test_05_dict.py There are some case need to be fixed - # ./test_11_jumps.py There are some case need to be fixed ./test_12_for_loop.py - # ./test_14_operators.py There are some case need to be fixed ./test_15_slice.py - ./test_18_tensor_method.py ./test_19_closure.py ./test_21_global.py - ./test_break_graph.py - ./test_builtin_dispatch.py ./test_constant_graph.py ./test_enumerate.py - ./test_exception.py - ./test_execution_base.py ./test_guard_user_defined_fn.py ./test_inplace_api.py ./test_range.py @@ -32,7 +22,6 @@ py311_skiped_tests=( ./test_resnet50_backward.py # ./test_side_effects.py There are some case need to be fixed ./test_sir_rollback.py - ./test_str_format.py ./test_tensor_dtype_in_guard.py ) diff --git a/tests/test_01_basic.py b/tests/test_01_basic.py index 4101d9213..e1e5dca35 100644 --- a/tests/test_01_basic.py +++ b/tests/test_01_basic.py @@ -1,4 +1,3 @@ -import sys import unittest from test_case_base import TestCaseBase, strict_mode_guard @@ -20,9 +19,6 @@ def numpy_add(x, y): return out -@unittest.skipIf( - sys.version_info >= (3, 11), "Python 3.11+ is not supported yet." -) class TestNumpyAdd(TestCaseBase): @strict_mode_guard(0) def test_numpy_add(self): diff --git a/tests/test_04_list.py b/tests/test_04_list.py index 370fbeefc..e71a34957 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -5,7 +5,6 @@ from __future__ import annotations -import sys import unittest from test_case_base import TestCaseBase @@ -219,9 +218,6 @@ def test_list_basic(self): class TestListMethods(TestCaseBase): - @unittest.skipIf( - sys.version_info >= (3, 11), "Python 3.11+ not support breakgraph" - ) def test_list_setitem(self): self.assert_results_with_side_effects( list_setitem_tensor, 1, paddle.to_tensor(2) diff --git a/tests/test_05_dict.py b/tests/test_05_dict.py index f250d6fae..06be140ad 100644 --- a/tests/test_05_dict.py +++ b/tests/test_05_dict.py @@ -2,7 +2,6 @@ # BUILD_MAP (new) # BUILD_CONST_KEY_MAP (new) -import sys import unittest from test_case_base import TestCaseBase @@ -225,10 +224,6 @@ def test_dict_popitem(self): dict_popitem, 1, paddle.to_tensor(2) ) - @unittest.skipIf( - sys.version_info >= (3, 11), - "dict_construct_from_comprehension Python 3.11+ has some issues", - ) def test_construct(self): self.assert_results(dict_construct_from_dict) self.assert_results(dict_construct_from_list) diff --git a/tests/test_14_operators.py b/tests/test_14_operators.py index 667923953..4d1e0ff72 100644 --- a/tests/test_14_operators.py +++ b/tests/test_14_operators.py @@ -1,5 +1,4 @@ import operator -import sys import unittest from test_case_base import TestCaseBase @@ -273,10 +272,6 @@ def operator_pos(y: int): class TestExecutor(TestCaseBase): - @unittest.skipIf( - sys.version_info >= (3, 11), - "Python 3.11+ breakbreak occurred in unary_not", - ) def test_simple(self): a = paddle.to_tensor(1) b = paddle.to_tensor(True) @@ -319,9 +314,6 @@ def test_simple(self): self.assert_results(inplace_or, b, g) self.assert_results(inplace_xor, b, g) - @unittest.skipIf( - sys.version_info >= (3, 11), "Python 3.11+ breakbreak occurred in truth" - ) def test_operator_simple(self): self.assert_results(operator_add, 1, paddle.to_tensor(2)) self.assert_results(operator_mul, 1, paddle.to_tensor(2)) @@ -338,9 +330,6 @@ def test_operator_simple(self): self.assert_results(operator_not_in_, 12, [1, 2, 3]) self.assert_results(operator_not_in_, 12, [1, 2, 3]) - @unittest.skipIf( - sys.version_info >= (3, 11), "Python 3.11+ not support breakgraph" - ) def test_operator_list(self): self.assert_results(list_getitem, 1, paddle.to_tensor(2)) self.assert_results(list_getitem_slice, 1, paddle.to_tensor(2)) @@ -351,9 +340,6 @@ def test_operator_list(self): self.assert_results(list_delitem_int, 1, paddle.to_tensor(2)) self.assert_results(list_delitem_tensor, 1, paddle.to_tensor(2)) - @unittest.skipIf( - sys.version_info >= (3, 11), "Python 3.11+ not support breakgraph" - ) def test_operator_dict(self): self.assert_results(dict_getitem_int, 1, paddle.to_tensor(2)) self.assert_results(dict_getitem_tensor, 1, paddle.to_tensor(2)) @@ -364,9 +350,6 @@ def test_operator_dict(self): self.assert_results(dict_delitem_int, 1, paddle.to_tensor(2)) self.assert_results(dict_delitem_tensor, 1, paddle.to_tensor(2)) - @unittest.skipIf( - sys.version_info >= (3, 11), "Python 3.11+ not support breakgraph" - ) def test_operator_tuple(self): self.assert_results(tuple_getitem_int, 1, paddle.to_tensor(2)) self.assert_results(tuple_getitem_tensor, 1, paddle.to_tensor(2)) diff --git a/tests/test_side_effects.py b/tests/test_side_effects.py index dd6cfaebe..e2cedf037 100644 --- a/tests/test_side_effects.py +++ b/tests/test_side_effects.py @@ -235,7 +235,8 @@ def test_list_reverse(self): ) @unittest.skipIf( - sys.version_info >= (3, 11), "Python 3.11+ not support breakgraph" + sys.version_info >= (3, 11), + "Python 3.11+ not support for-loop breakgraph", ) def test_slice_in_for_loop(self): x = 2 @@ -247,7 +248,7 @@ def test_list_nested(self): @unittest.skipIf( - sys.version_info >= (3, 11), "Python 3.11+ not support breakgraph" + sys.version_info >= (3, 11), "Python 3.11+ not support for-loop breakgraph" ) class TestSliceAfterChange(TestCaseBase): def test_slice_list_after_change(self):