diff --git a/frontend/guard_tracker.py b/frontend/guard_tracker.py index 2b35af67a27e..ebe1fdc7663e 100644 --- a/frontend/guard_tracker.py +++ b/frontend/guard_tracker.py @@ -286,7 +286,8 @@ def record_function(self, if func in (min, max): scalar = None node = None - assert len(pargs) == 2 + # NOTE: when pargs < 2, it should be a dynamic operation + assert len(pargs) <= 2 for i, obj in enumerate(pargs): if isinstance(obj, (int, float)) and not dyn.contains(obj): scalar = obj @@ -1548,7 +1549,9 @@ def is_genexpr_func(self, func: Callable[..., Any]) -> bool: def is_builtin_func(self, func: Callable[..., Any]) -> bool: return func in (dict, tuple, set, list, hasattr, slice, range, len, - type, all, str.join, reversed, zip, iter, id, next) + type, all, str.join, reversed, zip, iter, id, next, + collections.OrderedDict, str.format, any, str, + str.split) def is_numpy_constant_func(self, func: Callable[..., Any]) -> bool: print(dir(func)) diff --git a/frontend/utils.py b/frontend/utils.py index 476bad81cac4..6ed1d6124874 100644 --- a/frontend/utils.py +++ b/frontend/utils.py @@ -156,6 +156,10 @@ def get_root_module(func: Callable[..., Any]) -> str: if module is None or 'torch.distributions' in module_str: return "" root_module = module_str.split('.')[0] + #NOTE: special cases in torchvision module, need to check whether this module is safe to record in graph + if hasattr(func, '__name__') and func.__name__ in ( + 'pad', 'resize') and root_module == 'torchvision': + return 'torch' return root_module diff --git a/frontend/variables/__init__.py b/frontend/variables/__init__.py index d0f48804c463..03fbecdc4d86 100644 --- a/frontend/variables/__init__.py +++ b/frontend/variables/__init__.py @@ -54,6 +54,10 @@ def make_var_from_value( extract_code_at_start: Optional[list[StorePos]] = None) -> Variable: if extract_code_at_start is None: extract_code_at_start = [] + if type(value) == np.ndarray and value.size == 1: + return NumpyScalarVar.from_value(np.int64(value.tolist()), + need_guard_check, helper_functions, + fx_graph, extract_code_at_start) if type(value) in ty2var: return ty2var[type(value)].from_value(value, need_guard_check, helper_functions, fx_graph, diff --git a/frontend/variables/dict_.py b/frontend/variables/dict_.py index 2c0b03c0ebaa..81d0c9ae81a8 100644 --- a/frontend/variables/dict_.py +++ b/frontend/variables/dict_.py @@ -76,7 +76,11 @@ def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, items = [] for key, j in zip(self.value.keys(), range(len(self.vars))): if isinstance(key, str): - key_part = f"'{key}'" + if "\n" not in key: + key_part = f"'{key}'" + else: + key_part = f"'{repr(key)}'" + key_part = key_part.strip("'") else: key_part = key item = f'{key_part}: {name_in_graph_fn}_{j}' diff --git a/frontend/variables/list_.py b/frontend/variables/list_.py index 9e7cc021881e..33907340843b 100644 --- a/frontend/variables/list_.py +++ b/frontend/variables/list_.py @@ -114,7 +114,7 @@ def __init__(self, value: np.ndarray[Any, Any], need_guard_check: bool, extract_code_at_start: list[StorePos]) -> None: super().__init__(need_guard_check, value, extract_code_at_start) self.value = value - self.length = len(value) + self.length = value.size self.vars = [] self.obj_ids = [] for i, obj in enumerate(value): diff --git a/frontend/variables/tensor.py b/frontend/variables/tensor.py index 4915636d3a1c..72c8da6253d8 100644 --- a/frontend/variables/tensor.py +++ b/frontend/variables/tensor.py @@ -238,8 +238,8 @@ def make_guard_inner(self, codegen: "GuardFnCodegen", def make_output_inner(self, name_in_graph_fn: str, store_pos: StorePos, codegen: "GraphFnCodegen", in_return: bool, idx: int) -> None: - codegen.output(name_in_graph_fn, store_pos, f"{self.device}", in_return, - idx) + codegen.output(name_in_graph_fn, store_pos, f"'{self.device}'", + in_return, idx) def as_fx_node(self) -> "NodeArgs": return self.device