Skip to content

Commit

Permalink
remove ref
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli committed Dec 22, 2021
1 parent 7658764 commit 5f44fd1
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def _wrapper(*args, **kwargs):
for p in processes:
p.join()
assert results.get(), "Distributed call failed."
_del_original_func(obj)

return _wrapper

Expand Down Expand Up @@ -521,6 +522,7 @@ def _wrapper(*args, **kwargs):
finally:
p.join()

_del_original_func(obj)
res = None
try:
res = results.get(block=False)
Expand All @@ -546,6 +548,15 @@ def _cache_original_func(obj) -> None:
_original_funcs[obj.__name__] = obj


def _del_original_func(obj):
"""pop the original function from cache."""
global _original_funcs
_original_funcs.pop(obj.__name__, None)
if torch.cuda.is_available(): # clean up the cached function
torch.cuda.synchronize()
torch.cuda.empty_cache()


def _call_original_func(name, module, *args, **kwargs):
if name not in _original_funcs:
_original_module = importlib.import_module(module) # reimport, refresh _original_funcs
Expand Down

0 comments on commit 5f44fd1

Please sign in to comment.