Skip to content

Commit

Permalink
Update code for evaluation (apache#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 authored May 30, 2024
2 parents 45dd313 + af1072a commit a6418be
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 25 deletions.
12 changes: 11 additions & 1 deletion frontend/fx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,17 @@ def eager_due_to_inductor_bug(node: torch.fx.Node) -> bool:
else:
random_number = str(random.randint(0, 1000000))
folder_name = f'tmp/fx_module_{random_number}'

for node in gm.graph.nodes:
# to avoid error like
# interpolate(Tensor input, int? size=None, float[]? scale_factor=None, str mode="nearest", bool? align_corners=None, bool? recompute_scale_factor=None, bool antialias=False) -> Tensor:
# Expected a value of type 'Optional[List[float]]' for argument 'scale_factor' but instead found type 'int'.
if node.target == torch.nn.functional.interpolate and 'scale_factor' in node.kwargs:
new_dict = {k: v for k, v in node.kwargs.items()}
new_dict['scale_factor'] = float(new_dict['scale_factor'])
node.kwargs = new_dict
print(node.kwargs)

gm.recompile()
os.makedirs(folder_name, exist_ok=True)
gm.to_folder(folder_name)

Expand Down
49 changes: 28 additions & 21 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .object_table import ObjectTable
from .pycode_writer import new_name
from .pycode_generator import GraphFnCodegen, GuardFnCodegen
from .fx_graph import FxGraph, get_frame_root, is_leaf_module, NodeArgs
from .fx_graph import FxGraph, get_frame_root, is_leaf_module, NodeArgs, BaseArgumentTypes
from .bytecode_analysis import livevars_analysis, end_of_control_flow
from .variables.const import ClsByNamedTupleVar
from .variables.base import Variable
Expand Down Expand Up @@ -1458,7 +1458,6 @@ def make_sub_var(value: Any, fx_node: torch.fx.Node) -> None:
self.state.fx_graph,
partial.extract_code_at_start)
else:

var = make_var_fn(value, partial.need_guard_check,
self.state.objects.helper_functions,
self.state.fx_graph,
Expand Down Expand Up @@ -1593,10 +1592,12 @@ def is_builtin_func(self, func: Callable[..., Any]) -> bool:
collections.OrderedDict, str.format, any, str,
str.split, sorted)

def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool:
# print(dir(func))
if (hasattr(func, '__module__') and 'numpy' in func.__module__ and
'random' not in func.__module__):
def is_numpy_func(self, func: Callable[..., Any]) -> bool:
if get_root_module(func) == 'numpy':
return True
if hasattr(
func, '__module__'
) and func.__module__ is not None and 'numpy' in func.__module__:
return True
if type(func) == np.ufunc:
return True
Expand All @@ -1623,13 +1624,24 @@ def call_function(
if func == operator.is_ and args[1] is None: # is_none check
return
if func == enumerate:
assert len(args) == 1
assert len(kwargs) == 0
var = self.state.objects.get_or_none(args[0])
assert var is not None
vars = [
self.state.objects.get(a, allow_unexist_const=True)
for a in args
]
assert all(v is not None for v in vars)
poss: list[list[StorePos]] = []
for a, var in zip(args, vars):
if len(var.extract_code_at_start) > 0:
poss.append(var.extract_code_at_start)
elif isinstance(a, (int, float)):
poss.append([StoreConstant(a, id(a))])
pos_product: list[list[StorePos]] = list(
itertools.product(*poss)) # type: ignore
arg_ids = [id(a) for a in args]
new_store_pos: list[StorePos] = [
ExtractFromFunction([pos], [id(args[0])], func.__name__, func)
for pos in var.extract_code_at_start
ExtractFromFunction(p, arg_ids, func.__name__, func)
for p in pos_product
]
self.state.set_partial_var({
-1: [
Expand Down Expand Up @@ -1827,7 +1839,7 @@ def set_if_inplace_return() -> None:
"is_tracing", "is_scripting", "get_autocast_gpu_dtype",
"is_autocast_enabled", "ndimension", "get_enum",
"is_tensor", "is_complex", "is_contiguous", "stride",
"get_device"):
"get_device", "Size", "_output_padding"):
return
if hasattr(func, "__module__"
) and func.__module__ == 'torch.autograd.profiler':
Expand Down Expand Up @@ -1859,13 +1871,10 @@ def set_if_inplace_return() -> None:
]
})
return
elif get_root_module(func) == 'numpy' or has_ndarray_flag:
print("record numpy function in graph", func)
# self.state.record_function(func,
# args,
# kwargs,
# inplace_ref=inplace_ref,
# force_new_value=False)
elif self.is_numpy_func(func) or has_ndarray_flag:
if hasattr(func, '__self__') and isinstance(func.__self__,
np.random.RandomState):
raise ValueError("numpy random function")
self.state.set_partial_var({
-1: [
PartialVar(node=None,
Expand Down Expand Up @@ -1933,8 +1942,6 @@ def set_if_inplace_return() -> None:
return
elif len(args) > 0 and isinstance(args[0], torch.nn.ModuleList):
return
elif self.is_numpy_constant_func(func):
return
elif self.has_unknown_arg(args, kwargs):
print(
f"func is {func}, {is_user_defined_func(func)}, args: {args}, kwargs:{kwargs}"
Expand Down
2 changes: 2 additions & 0 deletions frontend/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def trace_func(frame: FrameType, event: str, arg: Any) -> None:
except Exception as e:
print("exception in trace_func:", e, type(e))
print(traceback.format_exc())
print("code stack:")
traceback.print_stack(f=frame, file=sys.stdout)
if get_config("enable_fallback"):
run_trace_func = False
for i in trackers:
Expand Down
11 changes: 10 additions & 1 deletion frontend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def get_root_module(func: Callable[..., Any]) -> str:
if hasattr(func, '__class__') and func.__class__ == np.ufunc:
return 'numpy'

if hasattr(func, '__self__') and isinstance(func.__self__,
np.random.RandomState):
return 'numpy'

module = inspect.getmodule(func)
module_str = ""
if module is not None:
Expand Down Expand Up @@ -197,6 +201,8 @@ def is_user_defined_func(func: Callable[..., Any]) -> bool:
if hasattr(func, '__self__'):
if isinstance(func.__self__, (torch.Tensor, random.Random)):
return False
elif isinstance(func.__self__, numpy.random.RandomState):
return False
elif isinstance(func.__self__, (list, tuple, set, dict, str)):
return False
elif isinstance(func.__self__, torch.nn.Sequential):
Expand All @@ -223,6 +229,7 @@ def is_user_defined_func(func: Callable[..., Any]) -> bool:
return False

root_module = get_root_module(func)
print("root module", func, "===is==", root_module, type(root_module))
if root_module == 'torch' and hasattr(
func, '__name__') and func.__name__ == '_call_impl':
return True
Expand Down Expand Up @@ -447,7 +454,9 @@ def call_user_defined_iterator(x: Any) -> bool:
return len(args) >= 1 and call_user_defined_iterator(args[0])
elif func == tuple:
return len(args) >= 1 and call_user_defined_iterator(
args[0]) and not isinstance(args[0], Generator)
args[0]) and not isinstance(
args[0],
Generator) # generator contains yield, which is not support yet
elif func == iter:
return len(args) >= 1 and is_user_defined_iter(args[0])
elif func == enumerate:
Expand Down
13 changes: 13 additions & 0 deletions test/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ def run_enumerate(x):
return s, enumerate(x)


def run_enumerate2(x):
s = 0
for i, v in enumerate(x, 2):
s += i * v
return s


def test_enumerate(caplog):
reset()
compiled_run_enumerate = compile(run_enumerate)
Expand All @@ -22,3 +29,9 @@ def test_enumerate(caplog):
expect_result = run_enumerate([1, 2, 3, 4, 5])
run_and_check(compiled_run_enumerate, [HIT], 1, caplog, expect_result,
[1, 2, 3, 4, 5])
compiled_run_enumerate2 = compile(run_enumerate2)
expect_result2 = run_enumerate2([1, 2, 3, 4, 5])
run_and_check(compiled_run_enumerate2, [MISS], 2, caplog, expect_result2,
[1, 2, 3, 4, 5])
run_and_check(compiled_run_enumerate2, [HIT], 2, caplog, expect_result2,
[1, 2, 3, 4, 5])
27 changes: 25 additions & 2 deletions test/test_model_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from common.checker import assert_equal, run_and_check_cache, run_and_check, HIT, MISS, ALL_MISS
from transformers import AutoTokenizer, AutoConfig, AutoModel
import torch
import os

# dynamo compile error in torch 2.0.1: torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised ValueError: Cannot view a tensor with shape torch.Size([1, 512, 12, 64]) and strides (393216, 64, 32768, 1) as a tensor with shape (1, 512, 768)!
# 10.4: dont know how to reproduce the bug
Expand Down Expand Up @@ -1160,13 +1161,12 @@ def get_model():
return model


def get_input(batch_size):
def get_input(batch_size, seq_len=256):
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# inputs = tokenizer("Hello world! Hello world! Hello world! Hello world! Hello world!", return_tensors="pt").to(device)
# assert len(inputs) == 3
# return (inputs['input_ids'], inputs['attention_mask'], inputs['token_type_ids']), {}
vocab_size = 50265
seq_len = 256
input_ids = torch.randint(0,
vocab_size, (batch_size, seq_len),
dtype=torch.int64).to(device)
Expand Down Expand Up @@ -1209,3 +1209,26 @@ def test_model_deberta_dyn(caplog):
**input_kwargs1)
run_and_check(compiled, [HIT], 1, caplog, expect2, *input_args2,
**input_kwargs2)


# need this command before the unit test:
# sed -i 's/py_all(a\.shape\[i\] for i in dims)/py_all(a.shape[i] > 0 for i in dims)/g' `python3 -c 'import torch; print(torch.__path__[0])'`/_refs/__init__.py
@pytest.mark.skipif(os.getenv('FORCE_RUN_SKIPPED_TEST') != '1',
reason="will affect other tests, run it solo")
@pytest.mark.model
def test_model_deberta_dynlen(caplog):
reset()
with enable_dyn_shape():
with torch.no_grad():
model = get_model().eval()
input_args1, input_kwargs1 = get_input(batch_size=2, seq_len=40)
input_args2, input_kwargs2 = get_input(batch_size=2, seq_len=48)
expect1 = model(*input_args1, **input_kwargs1)
expect2 = model(*input_args2, **input_kwargs2)
compiled = compile(model)
run_and_check(compiled, [ALL_MISS], 1, caplog, expect1,
*input_args1, **input_kwargs1)
run_and_check(compiled, [HIT], 1, caplog, expect1, *input_args1,
**input_kwargs1)
run_and_check(compiled, [HIT], 1, caplog, expect2, *input_args2,
**input_kwargs2)

0 comments on commit a6418be

Please sign in to comment.