Skip to content

Commit

Permalink
fix str and ndarray size
Browse files Browse the repository at this point in the history
  • Loading branch information
superDong1998 committed Mar 19, 2024
1 parent b41e1cb commit a8ed0ca
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 2 deletions.
3 changes: 2 additions & 1 deletion frontend/guard_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,7 +1550,8 @@ def is_genexpr_func(self, func: Callable[..., Any]) -> bool:
def is_builtin_func(self, func: Callable[..., Any]) -> bool:
return func in (dict, tuple, set, list, hasattr, slice, range, len,
type, all, str.join, reversed, zip, iter, id, next,
collections.OrderedDict, str.format, any)
collections.OrderedDict, str.format, any, str,
str.split)

def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool:
print(dir(func))
Expand Down
4 changes: 4 additions & 0 deletions frontend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def get_root_module(func: Callable[..., Any]) -> str:
if module is None or 'torch.distributions' in module_str:
return ""
root_module = module_str.split('.')[0]
#NOTE: special cases in torchvision module, need to check whether this module is safe to record in graph
if hasattr(func, '__name__') and func.__name__ in (
'pad', 'resize') and root_module == 'torchvision':
return 'torch'
return root_module


Expand Down
4 changes: 4 additions & 0 deletions frontend/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def make_var_from_value(
extract_code_at_start: Optional[list[StorePos]] = None) -> Variable:
if extract_code_at_start is None:
extract_code_at_start = []
if type(value) == np.ndarray and value.size == 1:
return NumpyScalarVar.from_value(np.int64(value.tolist()),
need_guard_check, helper_functions,
fx_graph, extract_code_at_start)
if type(value) in ty2var:
return ty2var[type(value)].from_value(value, need_guard_check,
helper_functions, fx_graph,
Expand Down
2 changes: 1 addition & 1 deletion frontend/variables/list_.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, value: np.ndarray[Any, Any], need_guard_check: bool,
extract_code_at_start: list[StorePos]) -> None:
super().__init__(need_guard_check, value, extract_code_at_start)
self.value = value
self.length = len(value)
self.length = value.size
self.vars = []
self.obj_ids = []
for i, obj in enumerate(value):
Expand Down

0 comments on commit a8ed0ca

Please sign in to comment.