Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable common device abstraction for 8bits/4bits #898

Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
314d5e0
init device abstraction
jianan-gu Nov 10, 2023
2e9550a
refinement
jianan-gu Nov 28, 2023
68fd024
device stepup
jianan-gu Nov 30, 2023
ba92680
Merge branch 'main' into enable_device_abstract
jianan-gu Nov 30, 2023
65b17a2
Update modules.py
jianan-gu Nov 30, 2023
c5044e0
Update bitsandbytes/functional.py
jianan-gu Dec 1, 2023
b23789a
add backends
jianan-gu Dec 3, 2023
b2a4d54
add quant to device when init weight paam
jianan-gu Dec 4, 2023
c44cf06
minor fix
jianan-gu Dec 4, 2023
365491a
mv cuda to common backends
jianan-gu Dec 4, 2023
4050fe3
format fix
jianan-gu Dec 4, 2023
30175d1
format fix
jianan-gu Dec 4, 2023
e17549e
use device.type
jianan-gu Dec 4, 2023
a53bc31
minor fix
jianan-gu Dec 4, 2023
80c598c
backend refinement
jianan-gu Dec 4, 2023
59facc8
minor fix
jianan-gu Dec 5, 2023
066d0dc
final refinement
jianan-gu Dec 5, 2023
e34c30e
Merge remote-tracking branch 'main/main' into upstream_device_abstrac…
jianan-gu Feb 5, 2024
cebd83c
refine backend register with base-backend
jianan-gu Feb 6, 2024
e0f2e18
Merge remote-tracking branch 'main/main' into upstream_device_abstrac…
jianan-gu Feb 6, 2024
d20c017
minor clean format
jianan-gu Feb 6, 2024
9f23308
Merge remote-tracking branch 'main/main' into upstream_device_abstrac…
jianan-gu Feb 7, 2024
b41c1c4
format in CI
jianan-gu Feb 7, 2024
1ab611e
minor fix for format
jianan-gu Feb 7, 2024
b933f9f
refactor base backend registering
jianan-gu Feb 7, 2024
8b4baaa
refine structures of backends
jianan-gu Feb 7, 2024
0905ad7
fix import issue
jianan-gu Feb 8, 2024
145a835
minor clean
jianan-gu Feb 8, 2024
d270832
fix CI python format
jianan-gu Feb 13, 2024
68e7859
fix py38 vers incompatibility from other PR
Titus-von-Koeller Feb 15, 2024
012b565
update pre-commit
Titus-von-Koeller Feb 16, 2024
8fa27f6
cuda.py: harmonize whitespace
Titus-von-Koeller Feb 16, 2024
2c04d48
delete dead code
Titus-von-Koeller Feb 16, 2024
c184655
fix whitespace
Titus-von-Koeller Feb 16, 2024
03b53d7
fix typo
Titus-von-Koeller Feb 16, 2024
ba7a162
remove exstraneous import
Titus-von-Koeller Feb 16, 2024
d162998
factor out ensure_backend_is_available, exc instead of assert
Titus-von-Koeller Feb 17, 2024
2cd9718
Remove minor device filter to avoid confusion
jianan-gu Feb 21, 2024
f26a4e6
Merge remote-tracking branch 'tim/multi-backend-refactor' into upstre…
jianan-gu Mar 28, 2024
adfb5e2
clean up device setup
jianan-gu Mar 28, 2024
6f08879
clean
jianan-gu Mar 28, 2024
a9e4548
fix utils
jianan-gu Mar 28, 2024
84f67d2
link QuantState in F.
jianan-gu Mar 28, 2024
9ff6c63
pre-commit run --all-files
Titus-von-Koeller Apr 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import cuda_setup, research, utils
from . import device_setup, research, utils
from .autograd._functions import (
MatmulLtState,
bmm_cublas,
Expand All @@ -17,7 +17,9 @@

if COMPILED_WITH_CUDA:
from .optim import adam

from .backends import register_backend
from .backends.cuda import CUDABackend
register_backend("cuda", CUDABackend())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this cause any problems? What if we have a backend that, upon initialization, makes assumptions about the hardware/system? I think this can work if the backend does not have any state. However, is it a realistic assumption if we think about other backends?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AIUI, this is only really here to keep things working at present and we could think about deferred initialization later.
Even in the preimage of this PR, bnb initializes a backend (the native library) at import time.

__pdoc__ = {
"libbitsandbytes": False,
"optim.optimizer.Optimizer8bit": False,
Expand Down
3 changes: 1 addition & 2 deletions bitsandbytes/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ def print_debug_info() -> None:

def main():
generate_bug_report_information()

from . import COMPILED_WITH_CUDA
from .cuda_setup.main import get_compute_capabilities
from .device_setup.cuda.main import get_compute_capabilities

print_header("OTHER")
print(f"COMPILED_WITH_CUDA = {COMPILED_WITH_CUDA}")
Expand Down
6 changes: 5 additions & 1 deletion bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ def supports_igemmlt(device: torch.device) -> bool:
nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series
if any(model_name in device_name for model_name in nvidia16_models):
return False # these devices are technically cuda 7.5-capable, but they lack tensor cores
if device.type == "cpu":
jianan-gu marked this conversation as resolved.
Show resolved Hide resolved
#TODO: will return True once CPU backend upstream the supports
return False

return True


Expand Down Expand Up @@ -564,7 +568,7 @@ def matmul(

def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None):
assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type == "cuda":
jianan-gu marked this conversation as resolved.
Show resolved Hide resolved
if A.shape[-1] % quant_state.blocksize != 0:
warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}')
return MatMul4Bit.apply(A, B, out, bias, quant_state)
Expand Down
9 changes: 9 additions & 0 deletions bitsandbytes/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Dict
import torch

from bitsandbytes.backends.base import Backend

backends: Dict[str, Backend] = {}
Copy link
Collaborator

@Titus-von-Koeller Titus-von-Koeller Feb 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if backends should be private?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should it be?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it intend to avoid the usage of an import of the backends from the bitsandbytes package (and we may make it _backends) ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does a user currently select a backend? Currently, only CUDA is supported, but should there be a function like backend.set_backend("cuda")?

I guess we would do this through the "device_setup" process. The question here is if we can automatically detect the device the user is running in all cases? I think the only exception is probably if a user has both an accelerated device and a CPU. I think having, for example, Apple silicon and a regular GPU will not really happen. Are there any other scenarios that we are missing here and we need to think about?

I think for now, it looks fine, but I want to make sure we are not missing anything. In terms of usability, the best designs often come from early thought rather than later corrections. So it makes sense we think a bit about this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TimDettmers To your point, I wonder if IPEX can be combined with CUDA/ROCm in such a way where as you mention, it's not clear what the user will want.

E.g. a situation where both torch.xpu.is_available() == true and torch.cuda.is_available() == true.

It's also my understanding that Intel GPU support may be upstreamed: pytorch/pytorch#114842

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jgong5 That's what's happening now in this PR 😁

return backends[A.device.type].dequantize_4bit(...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey all, so Tim thinks that the backend should only be initialized once and therefore implemented as Singleton: It
shouldn't be initialized multiple times through a dictionary.

According to him, there's no use-case to exchange the backend at runtime. The only potential use-case might be that of having both a CPU and GPU backend at the same time, but from what Tim says, this is sth that we currently don't need yet and shouldn't worry about.

Just forwarding his statement for the sake of furthering the discussion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From an engineering standpoint, I disagree with implementing it as a singleton (a class you can ever only initialize once). Doing that is more complex, a little non-Pythonic, and the current implementation has the same end result: there's a backend object that's only created once, and it's plugged into place in the backends dict.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From an engineering standpoint, I disagree with implementing it as a singleton (a class you can ever only initialize once). Doing that is more complex, a little non-Pythonic, and the current implementation has the same end result: there's a backend object that's only created once, and it's plugged into place in the backends dict.

I feel the same - no obvious benefit of constraining us with a single device. May I know what's the concern with dispatching device backend from the backend dict with the device on the tensor args? Dispatching according to the tensor's device type is something PyTorch ATen is also doing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it not make sense to try to stay with the device of the source/destination tensors rather than select and initialize the device once as a singleton?

If you have multiple GPUs for example and want to share the compute with them, wouldn't you want to do .to(some_device) then call BnB?


def register_backend(backend_name: str, backend_instance: Backend):
backends[backend_name.lower()] = backend_instance
133 changes: 133 additions & 0 deletions bitsandbytes/backends/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from abc import ABC, abstractmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the backend base contains too few functions or too many functions depending on the view. Currently, it does not provide abstractions for blockwise quantization, QLoRA-style double quantization, and 8-bit optimizers.

On the other hand, transform is a CUDA-specific function that probably does not need to be implemented by any other device.

This is definitely one thing that we need to discuss: what exact function do we abstract. We need to abstract everything that is needed by all devices and keep everything that is specific to CUDA in that particular backend.

from typing import Optional, Tuple

import torch

from bitsandbytes.utils import QuantState


class Backend(ABC):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot use an abstract base class here. This makes the interface too big to implement. We want people to be able to contribute sub-interfaces, for example, only implement the 4-bit functionality but not the 8-bit and 8-bit optimizer functionality. The intent of such a design it better captured by a base class that implements these functions with an NotImplementedError exception which have to be overridden by the backend.

I think the intend would be even clearer by having 4 backends: 4-bit, 8-bit, 8-bit optimizers, block-wise quantization. However, this will also introduce more bloat in terms of boilerplate and more classes. Not sure how to handle this and feedback would be appreciated. I think it might be better to have a single class and just highlight both as comments and in the documentation that not all functions need to be overridden for a solid contribution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now that the methods already throw a NotImplementedError. I think this is good already. So just removing the ABC would make it possible to implement sub-interfaces. I think to make the sub-interfaces clearer it would be great to have a NotImplementedError that shows the set of functions that need to be implemented. For example,

mm_dequant(...)
...
raise NotImplementedError("mm_dequant not implemented! \
This function is part of the 8-bit interface and it needs to be implemented along with: \
mm_dequant, igemmlt, extract_outliers, double_quant")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's okay to partially implement a backend, then sure, we can make it a concrete base class with NotImplementeds thrown around.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't partial implementations fall back to CPU?

"""Base class for devices backends that will implement their own 8bits and 4bits functions."""

@abstractmethod
def double_quant(
self,
A,
col_stats=None,
row_stats=None,
out_col=None,
out_row=None,
threshold=0.0,
):
raise NotImplementedError

@abstractmethod
def transform(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only needed for CUDA. This will probably not be needed for any other device. See the discussion above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was somewhat touched upon #898 (comment) – this PR doesn't yet move all of the CUDA-specific things into place, but I think that's fine and we can clean it up in near-future work...

self,
A,
to_order,
from_order="row",
out=None,
transpose=False,
state=None,
ld=None,
):
raise NotImplementedError

@abstractmethod
def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32):
raise NotImplementedError

@abstractmethod
def mm_dequant(
self,
A,
quant_state,
row_stats,
col_stats,
out=None,
new_row_stats=None,
new_col_stats=None,
bias=None,
):
raise NotImplementedError

@abstractmethod
def extract_outliers(self, A, SA, idx):
raise NotImplementedError

@abstractmethod
def quantize_4bit(
self,
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
compress_statistics=False,
quant_type="fp4",
quant_storage=torch.uint8,
) -> Tuple[torch.Tensor, QuantState]:
"""
Quantize tensor A in blocks of 4-bit values.

Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.

Parameters
----------
A : torch.Tensor
The input tensor.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
The output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}

Returns
-------
torch.Tensor:
Tensor with packed 4-bit values.
tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization.
"""
raise NotImplementedError

@abstractmethod
def dequantize_4bit(
self,
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
quant_type="fp4",
) -> torch.Tensor:
"""
Dequantizes FP4 blockwise quantized values.

Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.

Parameters
----------
A : torch.Tensor
The input tensor (packed 4-bit values).
quant_state : QuantState
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
Dequantized output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}


Returns
-------
torch.Tensor:
Dequantized tensor.
"""
raise NotImplementedError
Loading
Loading