Skip to content

Commit

Permalink
def is bnb multi backend avaliable
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Sep 26, 2024
1 parent 63be237 commit 80e3293
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
9 changes: 2 additions & 7 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
get_mixed_precision_context_manager,
get_pretty_name,
is_bf16_available,
is_bitsandbytes_multi_backend_available,
is_deepspeed_available,
is_ipex_available,
is_lomo_available,
Expand Down Expand Up @@ -1443,13 +1444,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
"you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device()}` or `device_map={'':torch.xpu.current_device()}`"
)

bnb_multi_backends = False
try:
from transformers.utils import is_bitsandbytes_multi_backend_available
bnb_multi_backends = is_bitsandbytes_multi_backend_available()
except ImportError:
bnb_multi_backends = False
if ("cpu" in model_devices and not bnb_multi_backends) or "disk" in model_devices:
if ("cpu" in model_devices and not is_bitsandbytes_multi_backend_available()) or "disk" in model_devices:
raise ValueError(
"You can't train a model that has been loaded in 8-bit precision with CPU or disk offload. "
"If you want train the 8-bit model in CPU, please install bitsandbytes with multi-backend, see https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend"
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
is_8bit_bnb_available,
is_aim_available,
is_bf16_available,
is_bitsandbytes_multi_backend_available,
is_bnb_available,
is_boto3_available,
is_ccl_available,
Expand Down
8 changes: 8 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,14 @@ def is_bnb_available():
return _is_package_available("bitsandbytes")


def is_bitsandbytes_multi_backend_available():
if not is_bnb_available():
return False
import bitsandbytes as bnb

return "multi_backend" in getattr(bnb, "features", set())


def is_torchvision_available():
return _is_package_available("torchvision")

Expand Down

0 comments on commit 80e3293

Please sign in to comment.