Skip to content

Commit

Permalink
Fix slowdown on init with device_map="auto" (#2914)
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr authored Jul 4, 2024
1 parent 167cb5e commit 2471eac
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/accelerate/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@
from .imports import is_mlu_available, is_mps_available, is_npu_available, is_xpu_available


def clear_device_cache():
gc.collect()
def clear_device_cache(garbage_collection=False):
"""
Clears the device cache by calling `torch.{backend}.empty_cache`. Can also run `gc.collect()`, but do note that
this is a *considerable* slowdown and should be used sparingly.
"""
if garbage_collection:
gc.collect()

if is_xpu_available():
torch.xpu.empty_cache()
Expand Down Expand Up @@ -67,7 +72,7 @@ def release_memory(*objects):
objects = list(objects)
for i in range(len(objects)):
objects[i] = None
clear_device_cache()
clear_device_cache(garbage_collection=True)
return objects


Expand Down Expand Up @@ -123,7 +128,7 @@ def find_executable_batch_size(function: callable = None, starting_batch_size: i

def decorator(*args, **kwargs):
nonlocal batch_size
clear_device_cache()
clear_device_cache(garbage_collection=True)
params = list(inspect.signature(function).parameters.keys())
# Guard against user error
if len(params) < (len(args) + 1):
Expand All @@ -139,7 +144,7 @@ def decorator(*args, **kwargs):
return function(batch_size, *args, **kwargs)
except Exception as e:
if should_reduce_batch_size(e):
clear_device_cache()
clear_device_cache(garbage_collection=True)
batch_size //= 2
else:
raise
Expand Down

0 comments on commit 2471eac

Please sign in to comment.