Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Compat][3.11] support call breakgraph #369

Merged
merged 16 commits into from
Sep 3, 2023
38 changes: 26 additions & 12 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction):
] + inputs_var
# Collect all the to store variables.
store_vars = []
for stack_arg in self.stack._data:
for stack_arg in self.stack:
store_vars.append(stack_arg)
for name in inputs_name:
store_vars.append(self.get_var(name))
Expand All @@ -1772,7 +1772,7 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction):
if_fn, if_fn.__code__.co_name
)
insert_index = len(self._graph.pycode_gen._instructions) - 1
for stack_arg in self.stack._data:
for stack_arg in self.stack:
var_loader.load(stack_arg)
for name in if_inputs:
var_loader.load(self.get_var(name))
Expand All @@ -1789,7 +1789,7 @@ def _break_graph_in_jump(self, result: VariableBase, instr: Instruction):
else_fn, else_fn.__code__.co_name
)
jump_to = self._graph.pycode_gen._instructions[-1]
for stack_arg in self.stack._data:
for stack_arg in self.stack:
var_loader.load(stack_arg)
for name in else_inputs:
var_loader.load(self.get_var(name))
Expand Down Expand Up @@ -1829,7 +1829,7 @@ def _break_graph_in_call(
# gen call static fn opcode
ret_vars = [
arg
for arg in self.stack._data
for arg in self.stack
if isinstance(arg, (TensorVariable, ContainerVariable))
]
resume_input_name = analysis_inputs(self._instructions, index + 1)
Expand All @@ -1841,7 +1841,7 @@ def _break_graph_in_call(

# Collect all the to store variables.
store_vars = []
for stack_arg in self.stack._data:
for stack_arg in self.stack:
store_vars.append(stack_arg)
for name in resume_input_name:
store_vars.append(self.get_var(name))
Expand All @@ -1853,30 +1853,44 @@ def _break_graph_in_call(
self._graph.pycode_gen.gen_pop_top()

# gen graph break call fn opcode
stack_effect = dis.stack_effect(instr.opcode, instr.arg)
if sys.version_info >= (3, 11) and instr.opname == "CALL":
assert instr.arg is not None
stack_effect = -instr.arg - 1
else:
stack_effect = dis.stack_effect(instr.opcode, instr.arg)
zrr1999 marked this conversation as resolved.
Show resolved Hide resolved
pop_n = push_n - stack_effect
for i, stack_arg in enumerate(self.stack._data):

for i, stack_arg in enumerate(self.stack):
# Avoid passing NULL as a parameter to the resume function
if (
isinstance(stack_arg, NullVariable)
and i < len(self.stack) - pop_n
):
self._graph.pycode_gen.gen_load_object(
NullVariable(), f'dummy_var{i}'
NullVariable(), f'dummy_var{i}', push_null=False
)
else:
var_loader.load(stack_arg)
self._graph.pycode_gen.add_pure_instructions([instr])

# gen call resume fn opcode
if sys.version_info >= (3, 11) and instr.opname == "CALL":
assert instr.arg is not None
self._graph.pycode_gen.gen_call_function(instr.arg)
else:
zrr1999 marked this conversation as resolved.
Show resolved Hide resolved
self._graph.pycode_gen.add_pure_instructions([instr])
self.stack.pop_n(pop_n)
stack_size = len(self.stack) + push_n

resume_fn, _ = self._create_resume_fn(index + 1, stack_size)
if resume_fn:
self._graph.pycode_gen.gen_load_object(
resume_fn, resume_fn.__code__.co_name
)
self._graph.pycode_gen.gen_rot_n(stack_size + 1)
if sys.version_info >= (3, 11):
zrr1999 marked this conversation as resolved.
Show resolved Hide resolved
self._graph.pycode_gen.gen_rot_n(stack_size + 2)
self._graph.pycode_gen.gen_rot_n(stack_size + 2)
else:
self._graph.pycode_gen.gen_rot_n(stack_size + 1)
for name in resume_input_name:
var_loader.load(self.get_var(name))
self._graph.pycode_gen.gen_call_function(
Expand Down Expand Up @@ -2041,7 +2055,7 @@ def _break_graph_in_for_loop(
]
ret_vars = [self.get_var(name) for name in ret_names]
store_vars = [ret_vars[idx] for idx in range(len(ret_names))]
store_vars.extend(iter(self.stack._data))
store_vars.extend(iter(self.stack))
var_loader = self._graph.start_compile_with_name_store(
ret_vars, store_vars
)
Expand Down Expand Up @@ -2106,7 +2120,7 @@ def _break_graph_in_for_loop(
after_loop_fn, after_loop_fn.__code__.co_name
)

for stack_arg in self.stack._data:
for stack_arg in self.stack:
var_loader.load(stack_arg)
for name in fn_inputs:
self._graph.pycode_gen.gen_load(name)
Expand Down
14 changes: 11 additions & 3 deletions sot/opcode_translator/executor/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,9 @@ def gen_pycode(self) -> types.CodeType:
)
return new_code

def gen_resume_fn_at(self, index: int, stack_size: int = 0):
def gen_resume_fn_at(
self, index: int, stack_size: int = 0
) -> tuple[None | types.FunctionType, OrderedSet]:
zrr1999 marked this conversation as resolved.
Show resolved Hide resolved
"""
Generates a resume function at the specified index in the instruction list.

Expand Down Expand Up @@ -662,7 +664,7 @@ def gen_load_global(self, name, push_null=False):
idx |= 1
self._add_instr("LOAD_GLOBAL", arg=idx, argval=name)

def gen_load_object(self, obj, obj_name: str):
def gen_load_object(self, obj, obj_name: str, push_null: bool = True):
"""
Generate the bytecode for loading an object.

Expand All @@ -673,7 +675,7 @@ def gen_load_object(self, obj, obj_name: str):

if obj_name not in self._f_globals:
self._f_globals[obj_name] = obj
self.gen_load_global(obj_name, push_null=True)
self.gen_load_global(obj_name, push_null=push_null)

def gen_load_fast(self, name):
"""
Expand Down Expand Up @@ -814,6 +816,12 @@ def rot_n_fn(n):
self._add_instr("CALL_FUNCTION_EX", arg=0)
self.gen_unpack_sequence(n)

def gen_swap(self, n):
if sys.version_info >= (3, 11):
self._add_instr("SWAP", arg=n)
else:
raise NotImplementedError("swap is not supported before python3.11")

def gen_return(self):
self._add_instr("RETURN_VALUE")

Expand Down
3 changes: 3 additions & 0 deletions sot/opcode_translator/executor/variable_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def top(self, value):
assert len(self) > 0, "stack is empty"
self.peek[1] = value

def __iter__(self):
return iter(self._data)

def __len__(self) -> int:
return len(self._data)

Expand Down
17 changes: 0 additions & 17 deletions tests/test_14_operators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import operator
import sys
import unittest

from test_case_base import TestCaseBase
Expand Down Expand Up @@ -273,10 +272,6 @@ def operator_pos(y: int):


class TestExecutor(TestCaseBase):
@unittest.skipIf(
sys.version_info >= (3, 11),
"Python 3.11+ breakbreak occurred in unary_not",
)
def test_simple(self):
a = paddle.to_tensor(1)
b = paddle.to_tensor(True)
Expand Down Expand Up @@ -319,9 +314,6 @@ def test_simple(self):
self.assert_results(inplace_or, b, g)
self.assert_results(inplace_xor, b, g)

@unittest.skipIf(
sys.version_info >= (3, 11), "Python 3.11+ breakbreak occurred in truth"
)
def test_operator_simple(self):
self.assert_results(operator_add, 1, paddle.to_tensor(2))
self.assert_results(operator_mul, 1, paddle.to_tensor(2))
Expand All @@ -338,9 +330,6 @@ def test_operator_simple(self):
self.assert_results(operator_not_in_, 12, [1, 2, 3])
self.assert_results(operator_not_in_, 12, [1, 2, 3])

@unittest.skipIf(
sys.version_info >= (3, 11), "Python 3.11+ not support breakgraph"
)
def test_operator_list(self):
self.assert_results(list_getitem, 1, paddle.to_tensor(2))
self.assert_results(list_getitem_slice, 1, paddle.to_tensor(2))
Expand All @@ -351,9 +340,6 @@ def test_operator_list(self):
self.assert_results(list_delitem_int, 1, paddle.to_tensor(2))
self.assert_results(list_delitem_tensor, 1, paddle.to_tensor(2))

@unittest.skipIf(
sys.version_info >= (3, 11), "Python 3.11+ not support breakgraph"
)
def test_operator_dict(self):
self.assert_results(dict_getitem_int, 1, paddle.to_tensor(2))
self.assert_results(dict_getitem_tensor, 1, paddle.to_tensor(2))
Expand All @@ -364,9 +350,6 @@ def test_operator_dict(self):
self.assert_results(dict_delitem_int, 1, paddle.to_tensor(2))
self.assert_results(dict_delitem_tensor, 1, paddle.to_tensor(2))

@unittest.skipIf(
sys.version_info >= (3, 11), "Python 3.11+ not support breakgraph"
)
def test_operator_tuple(self):
self.assert_results(tuple_getitem_int, 1, paddle.to_tensor(2))
self.assert_results(tuple_getitem_tensor, 1, paddle.to_tensor(2))
Expand Down