Skip to content

Commit

Permalink
Parity 1 (apache#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
superDong1998 authored Mar 20, 2024
2 parents 62fdc54 + a8ed0ca commit a72d6f8
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 6 deletions.
7 changes: 5 additions & 2 deletions frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def record_function(self,
if func in (min, max):
scalar = None
node = None
assert len(pargs) == 2
# NOTE: when pargs < 2, it should be a dynamic operation
assert len(pargs) <= 2
for i, obj in enumerate(pargs):
if isinstance(obj, (int, float)) and not dyn.contains(obj):
scalar = obj
Expand Down Expand Up @@ -1548,7 +1549,9 @@ def is_genexpr_func(self, func: Callable[..., Any]) -> bool:

def is_builtin_func(self, func: Callable[..., Any]) -> bool:
return func in (dict, tuple, set, list, hasattr, slice, range, len,
type, all, str.join, reversed, zip, iter, id, next)
type, all, str.join, reversed, zip, iter, id, next,
collections.OrderedDict, str.format, any, str,
str.split)

def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool:
print(dir(func))
Expand Down
4 changes: 4 additions & 0 deletions frontend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def get_root_module(func: Callable[..., Any]) -> str:
if module is None or 'torch.distributions' in module_str:
return ""
root_module = module_str.split('.')[0]
#NOTE: special cases in torchvision module, need to check whether this module is safe to record in graph
if hasattr(func, '__name__') and func.__name__ in (
'pad', 'resize') and root_module == 'torchvision':
return 'torch'
return root_module


Expand Down
4 changes: 4 additions & 0 deletions frontend/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def make_var_from_value(
extract_code_at_start: Optional[list[StorePos]] = None) -> Variable:
if extract_code_at_start is None:
extract_code_at_start = []
if type(value) == np.ndarray and value.size == 1:
return NumpyScalarVar.from_value(np.int64(value.tolist()),
need_guard_check, helper_functions,
fx_graph, extract_code_at_start)
if type(value) in ty2var:
return ty2var[type(value)].from_value(value, need_guard_check,
helper_functions, fx_graph,
Expand Down
6 changes: 5 additions & 1 deletion frontend/variables/dict_.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
items = []
for key, j in zip(self.value.keys(), range(len(self.vars))):
if isinstance(key, str):
key_part = f"'{key}'"
if "\n" not in key:
key_part = f"'{key}'"
else:
key_part = f"'{repr(key)}'"
key_part = key_part.strip("'")
else:
key_part = key
item = f'{key_part}: {name_in_graph_fn}_{j}'
Expand Down
2 changes: 1 addition & 1 deletion frontend/variables/list_.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, value: np.ndarray[Any, Any], need_guard_check: bool,
extract_code_at_start: list[StorePos]) -> None:
super().__init__(need_guard_check, value, extract_code_at_start)
self.value = value
self.length = len(value)
self.length = value.size
self.vars = []
self.obj_ids = []
for i, obj in enumerate(value):
Expand Down
4 changes: 2 additions & 2 deletions frontend/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def make_guard_inner(self, codegen: "GuardFnCodegen",
def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos,
codegen: "GraphFnCodegen", in_return: bool,
idx: int) -> None:
codegen.output(name_in_graph_fn, store_pos, f"{self.device}", in_return,
idx)
codegen.output(name_in_graph_fn, store_pos, f"'{self.device}'",
in_return, idx)

def as_fx_node(self) -> "NodeArgs":
return self.device
Expand Down

0 comments on commit a72d6f8

Please sign in to comment.