Skip to content

Commit

Permalink
[BE] Change _marked_safe_globals_list to set
Browse files Browse the repository at this point in the history
ghstack-source-id: fb6183a0733be1aa68c6a663609580f93235c7f4
Pull Request resolved: #139303
  • Loading branch information
mikaylagawarecki committed Oct 30, 2024
1 parent 483e52a commit 9e0a9e2
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions torch/_weights_only_unpickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,29 +83,26 @@
"nt",
]

_marked_safe_globals_list: List[Any] = []
_marked_safe_globals_set: Set[Any] = set()


def _add_safe_globals(safe_globals: List[Any]):
global _marked_safe_globals_list
_marked_safe_globals_list += safe_globals

global _marked_safe_globals_set
_marked_safe_globals_set = _marked_safe_globals_set.union(set(safe_globals))

def _get_safe_globals() -> List[Any]:
global _marked_safe_globals_list
return _marked_safe_globals_list
global _marked_safe_globals_set
return list(_marked_safe_globals_set)


def _clear_safe_globals():
global _marked_safe_globals_list
_marked_safe_globals_list = []
global _marked_safe_globals_set
_marked_safe_globals_set = set()


def _remove_safe_globals(globals_to_remove: List[Any]):
global _marked_safe_globals_list
_marked_safe_globals_list = list(
set(_marked_safe_globals_list) - set(globals_to_remove)
)
global _marked_safe_globals_set
_marked_safe_globals_set = _marked_safe_globals_set - set(globals_to_remove)


class _safe_globals:
Expand All @@ -128,7 +125,7 @@ def __exit__(self, type, value, tb):
# _get_allowed_globals due to the lru_cache
def _get_user_allowed_globals():
rc: Dict[str, Any] = {}
for f in _marked_safe_globals_list:
for f in _marked_safe_globals_set:
module, name = f.__module__, f.__name__
rc[f"{module}.{name}"] = f
return rc
Expand Down

0 comments on commit 9e0a9e2

Please sign in to comment.