Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: allow for multiple hardware backends #1077

Conversation

Titus-von-Koeller
Copy link
Collaborator

Due to Tim's extensive feedback and request to implement an alternative approach based on his preferences, this is a (rough) draft of my understanding of his solution sketch.

This isn't complete, but I'm publishing it for early feedback: At this point especially from Tim, as I want to verify that I didn't misunderstand anything before getting feedback from everyone else.

It's late and I'll have to continue with the rest in the morning (meaning that this is far from complete).

Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -12,6 +12,7 @@
matmul_cublas,
mm_cublas,
)
from .backends import _backend as backend
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this renaming re-export should be here. Is there an excellent reason for lay users of the library to be able to do from bitsandbytes import backend?

from ._base import COOSparseTensor
from .nvidia import CudaBackend

_backend = CudaBackend(lib) if lib else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the initialization of the backend should happen on-demand, not eagerly and implicitly during import time. This is exacerbated by the re-export in __init__.py.

I think

_backend: Backend | None = None

def get_backend() -> Backend:
    if not _backend:
         _backend = CudaBackend()

would be the better API.

(Note that I also think lib should be something the backend knows about, not something that gets passed in.)

Comment on lines +4 to +18
class COOSparseTensor:
def __init__(self, rows, cols, nnz, rowidx, colidx, values):
assert rowidx.dtype == torch.int32
assert colidx.dtype == torch.int32
assert values.dtype == torch.float16
assert values.numel() == nnz
assert rowidx.numel() == nnz
assert colidx.numel() == nnz

self.rows = rows
self.cols = cols
self.nnz = nnz
self.rowidx = rowidx
self.colidx = colidx
self.values = values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it should be a @dataclasses.dataclass.



class COOSparseTensor:
def __init__(self, rows, cols, nnz, rowidx, colidx, values):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing types?

Comment on lines +6 to +11
assert rowidx.dtype == torch.int32
assert colidx.dtype == torch.int32
assert values.dtype == torch.float16
assert values.numel() == nnz
assert rowidx.numel() == nnz
assert colidx.numel() == nnz
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these are critical to happen during runtime, they shouldn't be asserts (but just if ...: raise ..., since someone may be running this library with python -O, which disables asserts running.

@@ -3,24 +3,63 @@
# This source code is licensed under the MIT license found in the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll refrain from reviewing this file much – as said on Slack, the reformatting should not happen in this PR, since it's now very hard to see which changes are functional and which were reformatting.

I opened #1081 to reformat everything in one go.

dtype2bytes = {}
dtype2bytes[torch.float32] = 4
dtype2bytes[torch.float16] = 2
dtype2bytes[torch.bfloat16] = 2
dtype2bytes[torch.uint8] = 1
dtype2bytes[torch.int8] = 1

FIRST_CUDA_DEVICE = torch.device('cuda', index=0)
FIRST_CUDA_DEVICE = torch.device("cuda", index=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will probably need to move away from here in the future – after all, this file should end up having no mention of "cuda" at all.

assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
cblocksize = ct.c_int32(blocksize)
prev_device = pre_call(A.device)
code = code.to(A.device)
is_on_gpu([code, A, out, absmax])
if A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel()))
lib.cquantize_blockwise_fp32(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like these calls should also be delegated to a backend?

self.rowidx = rowidx
self.colidx = colidx
self.values = values
@deprecated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this (and the other similar trampolines) need to be deprecated? 🤔

"E731", # Do not use lambda
"F841", # Local assigned but not used (TODO: enable, these are likely bugs)
"RUF012", # Mutable class attribute annotations
"ISC001", # String concatination warning: may cause conflicts when used with the formatter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"ISC001", # String concatination warning: may cause conflicts when used with the formatter
"ISC001", # String concatenation warning: may cause conflicts when used with the formatter

... but this shouldn't be added anyway IMO, the single-line concatenations should be fixed (and Ruff can autofix them).

@rickardp
Copy link
Contributor

rickardp commented Feb 26, 2024

My main concern with this vs the original PR is the singleton backend. It feels like a fundamental architectural decision. I say this because I think in addition to the actual code changes, assumptions to this will be made in many places (as well as in code using this library, making backtracking on this decision "breaking").

Is it certain that this is not coming back to bite us in the end? Are we sure we do not want to support heterogeneous device setups?

I would gravitate towards suggesting following the torch device as was suggested in the other PR, but maybe I do not understand the cons of a multiple backends approach.

Edit: I see now this comment from @TimDettmers:

There is also currently the problem that the backend design breaks paged buffers that have a CPU device type but need to be executed on a CUDA device. I think this can be fixed by assigning a singleton backend upon initialization rather than dynamically checking the device type.

Could we solve this by sending in a device parameter instead of or as an optional complement to inferring from the device? Or is the concern a breaking API change here?

Comment on lines +24 to +33
def __new__(cls, lib=None):
if cls._instance is None:
if lib is None:
raise ValueError(
"A 'lib' binary must be provided during the first initialization of BackendInterface."
)
cls._instance = super().__new__(cls)
cls._instance.lib = (
lib # Set the binary name during the first and only instantiation
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not super clear what exactly lib is supposed to be, both from the lack of typing information and the naming. Plus, I'm not sure every backend would needs some kind of binary library implementation?

@matthewdouglas
Copy link
Member

Edit: I see now this comment from @TimDettmers:

There is also currently the problem that the backend design breaks paged buffers that have a CPU device type but need to be executed on a CUDA device. I think this can be fixed by assigning a singleton backend upon initialization rather than dynamically checking the device type.

Could we solve this by sending in a device parameter instead of or as an optional complement to inferring from the device? Or is the concern a breaking API change here?

I think there's already a need to track a device index, so this seems feasible to add device type, unless there's something I'm misunderstanding. E.g. I already see the signature void cprefetch(void *ptr, size_t bytes, int device) along with a page_deviceid property. Would we just be able to look at when tensor.is_paged to resolve that?

It may also just make sense to start out by assuming we're using a CUDA or (maybe ROCm?) device with paged optimizers. It doesn't make sense for MPS or CPU backends. Intel GPU is the question mark for me on that. But even then, is another option to globally select a backend just for this? I.e. GlobalOptimManager.get_instance().set_backend("ipex")

@rickardp
Copy link
Contributor

rickardp commented Mar 6, 2024

It doesn't make sense for MPS or CPU backends. Intel GPU is the question mark for me on that.

Paged optimizers is a pointless feature for all systems with true unified memory.

I.e. GlobalOptimManager.get_instance().set_backend("ipex")

On the other hand, even with NVIDIAs "virtual UMA", wouldn't it be possible to first allocate the tensor on the device it is meant to go to, then being able to access it from the CPU?

Or is the problem is that PyTorch does not support it which is why this would be needed? TBH I've never used these new CUDA features so I was just looking at the docs (https://developer.nvidia.com/blog/unified-memory-cuda-beginners/)

@TimDettmers
Copy link
Collaborator

Cross-posted reply to #898

  1. Device common abstraction interface

The four sub-interfaces look good to me. We probably want to create higher-level abstractions for 4-bit quantization/dequantization and some other functions should be deprecated such as percentile clipping, and the non-tensor-core 8-bit matrix multiplication.

It seems more people like the tensor-driven dispatch. I find it difficult to read and debug and it can lead to more accidental bugs compare to a singleton. There are also further problems. For example, for paged tensor, the device is cpu but it needs to be executed on a cuda device. We can create special branching functions for this case. But then we already have a special branch for AMD and paged tensors. A singleton would have no branching at all and bugs can be avoided it the user forgot to push the tensor to the correct device. I think a singleton is much superior from reducing bugs and enhancing the user experience.

For a tensor type dispatch, how would you handle the case where the user uses the wrong device? If a tensor is on the CPU, I guess it would silently be executed on the CPU (if supported). This would be slow without the user knowing that something is wrong and we have no way of detecting if something is wrong because we just dispatching based on device type.

  1. Device-specific functions/class

Some functions specific to the device would be specializations in the base class and would not be provided in the interface. So these should be "private" functions for the base class.

I think a runtime singleton would be overkill. The case that someone has multiple accelerators of different device type would be very rare (probably only developers have this). We should have a special manual hack for this rather than degrading the overall system for this use-case. So, a static one-time singleton would be preferable from my view.

  1. Device common tool functions/classes

Agreed. I am not sure if all device type would work the same way. For example, does get_ptr() generalize to all devices? It might be a special case though and if something is wrong we can fix it later. So I think moving everything to something like utils.py might work well.

I've noticed that the other implementation #1077 which has the same scope has just started to be reviewed. Have you decided which design will be applied? It would be nice if we could get your feedback soon, we would love to help make progress for both designs. Thx!

The PR #1107 was based on a discussion that I had with @Titus-von-Koeller. It was my preferred view at the time, but we should discuss what would be best for everyone and decide based on that.

Cross-posted reply to #1077

I think there's already a need to track a device index, so this seems feasible to add device type, unless there's something I'm misunderstanding. E.g. I already see the signature void cprefetch(void *ptr, size_t bytes, int device) along with a page_deviceid property. Would we just be able to look at when tensor.is_paged to resolve that?

I think this might be the way to go. If we use a singleton we do not need this, but for a dispatch scenario we need to do this and it is probably the best solution to do an if-else on tensor.is_paged.

On the other hand, even with NVIDIAs "virtual UMA", wouldn't it be possible to first allocate the tensor on the device it is meant to go to, then being able to access it from the CPU?

This would be more general, and we could even abstract it to disk -> unified memory architectures like Apple silicon. It requires more work, but would be a great feature for Apple silicon to work with larger models on smaller Macbooks and the iPhone etc. The main problem here is that it requires quite a bit of work, but I think it could well be worth it to create this abstraction as it is more general and quite powerful (it can be used beyond optimizers, for example also for matrix multiplication / weights).

My main concern with this vs the original PR is the singleton backend. It feels like a fundamental architectural decision. I say this because I think in addition to the actual code changes, assumptions to this will be made in many places (as well as in code using this library, making backtracking on this decision "breaking").. Is it certain that this is not coming back to bite us in the end? Are we sure we do not want to support heterogeneous device setups?

I think this is the main advantage of non-static singleton. I would optimize for the user-side rather than developer-side. A singleton can prevent a lot of accidental bugs and enhances the user-experience for cases where the user forgot to move to the correct device. This can become a problem in particular if we support CPU devices.

I think the main discussion point should be which architectural design will help with the best user experience. That is really what I want to optimize for with the design. Lets discuss.

PS: @Titus-von-Koeller had more details on the sub-interfaces and what will be deprecated and we will post soon.

@jiqing-feng
Copy link
Contributor

jiqing-feng commented Mar 19, 2024

Hi @Titus-von-Koeller . Can you share the user behavior here? Like how to assign the singleton or device.

If environmental parameters or some global variables can assign the singleton, how should I use assisted_decoding, which has 2 models, one running on the GPU and another running on the CPU?

If we can assign a singleton on the object like model.set_singleton(device), then what's the difference btw model.to(device).

Do you have any method that will not break any user behavior so it can be easily applied by users?

@akx
Copy link
Contributor

akx commented Mar 20, 2024

I would optimize for the user-side
[...]
A singleton can prevent a lot of accidental bugs

What accidental bugs are you referring to? The user would not be interacting with the backend object by hand in any regular use case.

@matthewdouglas
Copy link
Member

Not to be overly pedantic, but I wonder if we need to differentiate between backend and device, too. #747 has me thinking even more about this.

We have to consider what the PyTorch build supports, which devices are available, and which libraries/compute backend the user wants to (or should by default) use for those devices. I think there should be sensible defaults based on the hardware available, but maybe additional choices via a Python API and/or env variables for users to have more control if needed.

PyTorch is built for CUDA, ROCm, or MPS. We're not going to combine these.

Some sensible mutually exclusive default backends based on PyTorch build and library availability, in a priority order:

  1. macOS arm64, MPS
  2. Linux aarch64, CUDA
  3. Linux/Windows x86-64, CUDA or ROCm

We can do simple tests for those and pick at most one.

Next up is Intel GPUs (IPEX). That might depend on torch version, as it seems like torch 2.3.0+ will have torch.xpu built-in, while older will take from the ipex package. In this case we would load an IPEX backend, maybe with a SYCL binary to go with it from #747. It's not clear to me if this can coexist with CUDA/ROCm.

From there is XLA/PJRT (see #1119). My understanding there is limited, but I think AWS Tranium/Inferentia can be enabled through XLA and torch-neuronx.

For Intel Gaudi, the device type should end up being hpu, and we'd see a PyTorch package that is built for that, so I'm not sure that would ever be combined with other accelerators.

We should also always have some kind of minimal CPU implementation to fallback to, whether it's hyper-optimized or just naively implemented and missing some ops.

I think it is reasonable to detect the support for these devices and load any python modules, binaries, etc, and then make dispatch decisions based on type of device (mps, cuda, xpu, xla, hpu, cpu), device index, and some supplemental information, i.e. is_paged.

Then within some of these maybe there's further options based on build, env vars, hardware support, etc. Do we want to let you do something odd like run with torch + ROCm but invoke SYCL kernels? I'll try to elaborate soon, but in general I think there's a balance to be made here.

@Titus-von-Koeller
Copy link
Collaborator Author

Hey all,

just to be clear: I'm not super invested in the Singleton approach at all.

I quickly whipped up this PR solely to show my understanding of Tim's input and concretize it as a basis for discussion. For that, I feel it served its purpose and, personally, I'm not seeing any reason to keep two PRs open on this topic in parallel.

Therefore, I'm closing this PR here and am hoping to collaborate on #898 as put forward by Intel and iterate both in code and discussion together with the community there.

If there's still something useful in this branch here, we can consolidate those changes with #898.


Generally, in my talks about this with Tim last week, I was sure to be very insistent on addressing your feedback especially on the topic of tensor-driven dispatch. I felt that there was quite a strong preference throughout the community for this and I think there were very good arguments put forward. If the community decides together with us that this is the way to go, then I'm on board, too.

What is important to me is that we take it iteratively and build out the solution together, doing inclusive rounds of feedback while we're at it. With iterative I also mean that we should do so in a way that always tries to leave the code in a functional state.

An important topic related to this is setting up Github runners for CUDA, which I intend on focusing on in the coming week. It would be helpful to have Intel-powered runners as well. Related to this is the general need for improvement of the test suite, with a strong focus on eradicating flakiness and finding a representative set of relatively quickly executing tests that allow us to take a test-driven approach. We could simply create a pytest marker for selecting these, allowing us to quickly iterate.

cc @BenjaminBossan with interest to help us out on the test suite refactorings and, if you want, please also feel free to pitch in on these important architectural changes to enable different hardware backends.

In my understanding, the Singleton approach would be closer to what BNB already does atm and one concern that I do have is that a switch to the tensor-driven approach is quite a fundamental change to the code base leading to more work than intended as we'll be unintentionally breaking things. The existing code was based on this baseline assumption and we would be fundamentally changing this: For me the consequence of that are extensive changes throughout existing (functioning and battle-tested) code, needed to account for this.

Tim is barely available until end of May and since taking over maintenance, we're only slowly taking full ownership of the codebase. To me this means that I'm a bit afraid of such fundamental changes and I'm not yet certain where the test suite can be fully trusted and where it needs improvements (other than the known problem of flakiness).

This doesn't necessarily mean that we should refrain from a more extensive refactor. It would actually be a good opportunity to take account of everything that's there and refactor things towards more clarity, maintainability and, well, multi-backend support. It's just something that we should think of very consciously. I would be very happy for any thoughts of yours on how to make sure this transition/refactor could happen as reliably and smoothly as possible.

It is absolutely necessary that main always stays releasable and correct with a high degree of certainty. If there's a way to split things up in an iterative series of PRs that would be preferable, but this might be at odds with experimenting until we locked in a solution that is a good way forward. Any input here is also appreciated.

Another factor that comes into play is that we need to avoid performance regressions as well. This is not my area of expertise, yet. I would be much happier if we had a measure there, immediately making us aware of regressions in the (currently) CUDA reference implementation. I don't necessarily foresee any issues there, as the most performance critical code remains unchanged within the kernels, but I also don't like flying blind.

Please feel free to constructively criticize and correct me wherever you think that's useful. I really appreciate direct feedback and conscious discourse. If you prefer, you can also contact me in private on Slack.

I would also like to take this opportunity to thank all of you for your thought- and helpful comments and contributions so far! We're very grateful for having you so invested in the future of BNB and are looking forward to our shared next steps.

@Titus-von-Koeller
Copy link
Collaborator Author

Titus-von-Koeller commented Mar 21, 2024

Thanks @matthewdouglas for another very valuable and thought-through analysis. This is very helpful! I'll be referencing it in some of my next communications.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants