From 314d5e0f3d3bce0629aa5aae0b0cfdd17c170c28 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Fri, 10 Nov 2023 02:14:47 -0800 Subject: [PATCH 01/39] init device abstraction --- bitsandbytes/autograd/_functions.py | 26 +++++--- bitsandbytes/cextension.py | 20 +++--- bitsandbytes/functional.py | 94 +++++++++++++++++++++++++++-- bitsandbytes/nn/modules.py | 75 ++++++++++++++++++++--- 4 files changed, 179 insertions(+), 36 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index f8403cf24..7f2920bfb 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -223,6 +223,9 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: + if device is "cpu": + return True + """check if this device supports the optimized int8 kernel""" if torch.cuda.get_device_capability(device=device) < (7, 5): return False @@ -267,6 +270,7 @@ class MatmulLtState: idx = None is_training = True has_fp16_weights = True + memory_efficient_backward = False use_pool = False formatB = F.get_special_format_str() @@ -294,7 +298,8 @@ class MatMul8bitLt(torch.autograd.Function): @staticmethod def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): - using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt + device = A.device + using_igemmlt = supports_igemmlt(device) and not state.force_no_igemmlt # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: @@ -303,9 +308,9 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): ctx.B = B ctx.bias = bias if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=device) else: - return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) + return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=device) # 1. Quantize A # 2. Quantize B @@ -318,13 +323,14 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): state.outlier_pool = GlobalOutlierPooler.get_instance() # Cast A to fp16 - if A.dtype != torch.float16: + ctx.cast_dtype = torch.bfloat16 if device is "cpu" else torch.float16 + if A.dtype != ctx.cast_dtype: warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") # 1. Quantize A if len(A.shape) == 3: A = A.reshape(-1, A.shape[-1]) - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(ctx.cast_dtype), threshold=state.threshold) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -337,7 +343,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): else: if state.CxB is None and using_igemmlt: # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format + # we also need to convert it to the turing/ampere format if using cuda state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: if not state.has_fp16_weights and state.CxB is None and using_igemmlt: @@ -359,7 +365,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): state.SCB, state.SCBt, coo_tensorB, - ) = F.double_quant(B.to(torch.float16)) + ) = F.double_quant(B.to(ctx.cast_dtype)) if using_igemmlt: state.CxB, state.SB = F.transform(CB, to_order=formatB) else: @@ -399,7 +405,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): if using_igemmlt: C32A, SA = F.transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) - if bias is None or bias.dtype == torch.float16: + if bias is None or bias.dtype in [torch.float16, torch.bfloat16]: # we apply the fused bias here output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) output = output.to(A.dtype) @@ -458,7 +464,7 @@ def backward(ctx, grad_output): if len(grad_output.shape) == 3: grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(ctx.cast_dtype)) if req_gradB: CxAt, SAt = F.transform(CAt, formatB, transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) @@ -565,7 +571,7 @@ def matmul( def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): assert quant_state is not None - if A.numel() == A.shape[-1] and A.requires_grad == False: + if A.numel() == A.shape[-1] and A.requires_grad == False and A.device is "cuda": absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state if A.shape[-1] % blocksize != 0: warn(f'Some matrices hidden dimension is not a multiple of {blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}') diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index d52a6d607..42fe44387 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -7,13 +7,13 @@ from bitsandbytes.cuda_setup.main import CUDASetup +if torch.cuda.is_available(): + setup = CUDASetup.get_instance() + if setup.initialized != True: + setup.run_cuda_setup() -setup = CUDASetup.get_instance() -if setup.initialized != True: - setup.run_cuda_setup() - -lib = setup.lib -try: + lib = setup.lib + if lib is None and torch.cuda.is_available(): CUDASetup.get_instance().generate_instructions() CUDASetup.get_instance().print_log_stack() @@ -30,12 +30,10 @@ lib.get_cusparse.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p COMPILED_WITH_CUDA = True -except AttributeError as ex: - warn("The installed version of bitsandbytes was compiled without GPU support. " - "8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.") - COMPILED_WITH_CUDA = False - print(str(ex)) +else: + warn("The installed version of bitsandbytes was compiled without GPU support. Will" + "run with CPU support") # print the setup details after checking for errors so we do not print twice #if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 96f8ce4e6..b223f0703 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -772,7 +772,7 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') -def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: +def cuda_quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: """ Quantize tensor A in blocks of 4-bit values. @@ -1699,7 +1699,7 @@ def batched_igemm( return out -def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): +def cuda_igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeA = SA[0] shapeB = SB[0] dimsA = len(shapeA) @@ -1796,7 +1796,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): return out, Sout -def mm_dequant( +def cuda_mm_dequant( A, quant_state, row_stats, @@ -1980,7 +1980,7 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) -def double_quant( +def cuda_double_quant( A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 ): device = A.device @@ -2076,7 +2076,7 @@ def double_quant( return out_row, out_col, row_stats, col_stats, coo_tensor -def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): +def cuda_transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): prev_device = pre_call(A.device) if state is None: state = (A.shape, from_order) else: from_order = state[1] @@ -2372,7 +2372,7 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): return x.to(dtype) -def extract_outliers(A, SA, idx): +def cuda_extract_outliers(A, SA, idx): shapeA = SA[0] formatA = SA[1] assert formatA in ["col_turing", "col_ampere"] @@ -2402,3 +2402,85 @@ def pipeline_test(A, batch_size): out = torch.zeros_like(A) lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) return out + + +# 8 bits +def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): + if A.device is "cuda": + return cuda_double_quant(A=A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold) + elif A.device is "cpu": + pass + elif A.device is "xpu": + pass + else: + pass + +def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): + if A.device is "cuda": + cuda_transform(A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld) + elif A.device is "cpu": + pass + elif A.device is "xpu": + pass + else: + pass + +def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + if A.device is "cuda": + cuda_igemmlt(A, B, SA, SB, out=out, Sout=Sout, dtype=dtype) + elif A.device is "cpu": + pass + elif A.device is "xpu": + pass + else: + pass + +def mm_dequant( + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None +): + if A.device is "cuda": + cuda_mm_dequant(A, quant_state, row_stats, col_stats, out=out, new_row_stats=new_row_stats, new_col_stats=new_col_stats, bias=bias) + elif A.device is "cpu": + pass + elif A.device is "xpu": + pass + else: + pass + +# 4 bits +def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: + if A.device is "cuda": + cuda_quantize_4bit(A, absmax = absmax, out = out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type) + elif A.device is "cpu": + pass + elif A.device is "xpu": + pass + else: + pass + +def dequantize_4bit(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: + if A.device is "cuda": + cuda_dequantize_4bit(A, quant_state = quant_state, absmax = absmax, out = out, blocksize = blocksize, quant_type=quant_type) + elif A.device is "cpu": + pass + elif A.device is "xpu": + pass + else: + pass + +def extract_outliers(A, SA, idx): + if A.device is "cuda": + cuda_extract_outliers(A, SA, idx) + elif A.device is "cpu": + pass + elif A.device is "xpu": + pass + else: + pass \ No newline at end of file diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 3d34bb45f..d024ddd9e 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -152,6 +152,14 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, self.data = data return self + def cpu(self, device): + w = self.data.contiguous().half() + w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) + self.data = w_4bit + self.quant_state = quant_state + + return self + def cuda(self, device): w = self.data.contiguous().half().cuda(device) w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) @@ -160,6 +168,14 @@ def cuda(self, device): return self + def xpu(self, device): + w = self.data.contiguous().half().to("xpu") + w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) + self.data = w_4bit + self.quant_state = quant_state + + return self + @overload def to(self: T, device: Optional[Union[int, device]] = ..., dtype: Optional[Union[dtype, str]] = ..., non_blocking: bool = ...,) -> T: ... @@ -174,9 +190,15 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - - if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): - return self.cuda(device) + + if device is not None and device.type == "cpu": + return self.cpu(device) + + if (device is not None and device.type != "cpu" and self.data.device.type == "cpu"): + if device.type == "cuda": + return self.cuda(device) + elif device.type == "xpu": + return self.xpu(device) else: s = self.quant_state if s is not None: @@ -287,6 +309,39 @@ def __new__( data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) + + def cpu(self, device): + if self.has_fp16_weights: + return super() + else: + # we store the 8-bit rows-major weight + # we convert this weight to the turning/ampere weight during the first inference pass + B = self.data.contiguous().half() + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + del CBt + del SCBt + self.data = CB + setattr(self, "CB", CB) + setattr(self, "SCB", SCB) + + return self + + def cpu(self, device): + if self.has_fp16_weights: + return super().to("xpu") + else: + # we store the 8-bit rows-major weight + # we convert this weight to the turning/ampere weight during the first inference pass + B = self.data.contiguous().half().to("xpu") + CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) + del CBt + del SCBt + self.data = CB + setattr(self, "CB", CB) + setattr(self, "SCB", SCB) + + return self + def cuda(self, device): if self.has_fp16_weights: return super().cuda(device) @@ -325,12 +380,14 @@ def to(self, *args, **kwargs): *args, **kwargs ) - if ( - device is not None - and device.type == "cuda" - and self.data.device.type == "cpu" - ): - return self.cuda(device) + if device is not None and device.type == "cpu": + return self.cpu(device) + + if (device is not None and device.type != "cpu" and self.data.device.type == "cpu"): + if device.type == "cuda": + return self.cuda(device) + elif device.type == "xpu": + return self.xpu(device) else: new_param = Int8Params( super().to( From 2e9550aa0dcfa6f68dc345e2052f68899e6c8d20 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Tue, 28 Nov 2023 00:37:12 -0800 Subject: [PATCH 02/39] refinement --- bitsandbytes/autograd/_functions.py | 6 ++-- bitsandbytes/functional.py | 25 --------------- bitsandbytes/nn/modules.py | 48 +++++++---------------------- 3 files changed, 14 insertions(+), 65 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 7f2920bfb..21b814bf5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -322,10 +322,10 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() - # Cast A to fp16 + # Cast A to fp16 if not on CPU ctx.cast_dtype = torch.bfloat16 if device is "cpu" else torch.float16 if A.dtype != ctx.cast_dtype: - warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") + warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {ctx.cast_dtype} during quantization") # 1. Quantize A if len(A.shape) == 3: @@ -460,7 +460,7 @@ def backward(ctx, grad_output): # compute grad_bias first before changing grad_output dtype grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) - # Cast grad_output to fp16 + # Cast grad_output if len(grad_output.shape) == 3: grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b223f0703..c52fd0230 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2371,7 +2371,6 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): x += offset return x.to(dtype) - def cuda_extract_outliers(A, SA, idx): shapeA = SA[0] formatA = SA[1] @@ -2408,20 +2407,12 @@ def pipeline_test(A, batch_size): def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): if A.device is "cuda": return cuda_double_quant(A=A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold) - elif A.device is "cpu": - pass - elif A.device is "xpu": - pass else: pass def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): if A.device is "cuda": cuda_transform(A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld) - elif A.device is "cpu": - pass - elif A.device is "xpu": - pass else: pass @@ -2447,10 +2438,6 @@ def mm_dequant( ): if A.device is "cuda": cuda_mm_dequant(A, quant_state, row_stats, col_stats, out=out, new_row_stats=new_row_stats, new_col_stats=new_col_stats, bias=bias) - elif A.device is "cpu": - pass - elif A.device is "xpu": - pass else: pass @@ -2458,29 +2445,17 @@ def mm_dequant( def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: if A.device is "cuda": cuda_quantize_4bit(A, absmax = absmax, out = out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type) - elif A.device is "cpu": - pass - elif A.device is "xpu": - pass else: pass def dequantize_4bit(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: if A.device is "cuda": cuda_dequantize_4bit(A, quant_state = quant_state, absmax = absmax, out = out, blocksize = blocksize, quant_type=quant_type) - elif A.device is "cpu": - pass - elif A.device is "xpu": - pass else: pass def extract_outliers(A, SA, idx): if A.device is "cuda": cuda_extract_outliers(A, SA, idx) - elif A.device is "cpu": - pass - elif A.device is "xpu": - pass else: pass \ No newline at end of file diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index d024ddd9e..be90c829e 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -153,23 +153,17 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, return self def cpu(self, device): - w = self.data.contiguous().half() - w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) - self.data = w_4bit - self.quant_state = quant_state + warnings.warn("CPU Params4bit will be soon supported, return raw Params4bit for now") return self - def cuda(self, device): - w = self.data.contiguous().half().cuda(device) - w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) - self.data = w_4bit - self.quant_state = quant_state + def xpu(self, device): + warnings.warn("XPU Params4bit will be soon supported, return raw Params4bit for now") return self - def xpu(self, device): - w = self.data.contiguous().half().to("xpu") + def cuda(self, device): + w = self.data.contiguous().half().cuda(device) w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) self.data = w_4bit self.quant_state = quant_state @@ -311,34 +305,13 @@ def __new__( def cpu(self, device): - if self.has_fp16_weights: - return super() - else: - # we store the 8-bit rows-major weight - # we convert this weight to the turning/ampere weight during the first inference pass - B = self.data.contiguous().half() - CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) - del CBt - del SCBt - self.data = CB - setattr(self, "CB", CB) - setattr(self, "SCB", SCB) + warnings.warn("XPU Int8Params will be soon supported, return raw Int8Params for now") return self - def cpu(self, device): - if self.has_fp16_weights: - return super().to("xpu") - else: - # we store the 8-bit rows-major weight - # we convert this weight to the turning/ampere weight during the first inference pass - B = self.data.contiguous().half().to("xpu") - CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) - del CBt - del SCBt - self.data = CB - setattr(self, "CB", CB) - setattr(self, "SCB", SCB) + + def xpu(self, device): + warnings.warn("XPU Int8Params will be soon supported, return raw Int8Params for now") return self @@ -423,7 +396,8 @@ def __init__(self, input_features, output_features, bias=True, has_fp16_weights= self.index = index self.state.threshold = threshold - self.state.has_fp16_weights = has_fp16_weights + # fp16 not supports on CPU yet + self.state.has_fp16_weights = has_fp16_weights if device is not "cpu" else False self.state.memory_efficient_backward = memory_efficient_backward if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True From 68fd024206d1e1012083df131bff2fa34add665c Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 29 Nov 2023 23:53:34 -0800 Subject: [PATCH 03/39] device stepup --- bitsandbytes/__init__.py | 2 +- bitsandbytes/__main__.py | 4 +- bitsandbytes/autograd/_functions.py | 13 ++-- bitsandbytes/cextension.py | 22 +++--- .../cpu}/__init__.py | 0 bitsandbytes/device_setup/cpu/main.py | 40 ++++++++++ bitsandbytes/device_setup/cuda/__init__.py | 0 .../cuda}/env_vars.py | 0 .../{cuda_setup => device_setup/cuda}/main.py | 2 +- bitsandbytes/device_setup/xpu/__init__.py | 0 bitsandbytes/device_setup/xpu/main.py | 15 ++++ bitsandbytes/functional.py | 73 +++++++++++-------- bitsandbytes/nn/modules.py | 5 +- tests/test_cuda_setup_evaluator.py | 2 +- 14 files changed, 123 insertions(+), 55 deletions(-) rename bitsandbytes/{cuda_setup => device_setup/cpu}/__init__.py (100%) create mode 100644 bitsandbytes/device_setup/cpu/main.py create mode 100644 bitsandbytes/device_setup/cuda/__init__.py rename bitsandbytes/{cuda_setup => device_setup/cuda}/env_vars.py (100%) rename bitsandbytes/{cuda_setup => device_setup/cuda}/main.py (99%) create mode 100644 bitsandbytes/device_setup/xpu/__init__.py create mode 100644 bitsandbytes/device_setup/xpu/main.py diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index f35a3b582..1469e5a10 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import cuda_setup, utils, research +from . import device_setup, utils, research from .autograd._functions import ( MatmulLtState, bmm_cublas, diff --git a/bitsandbytes/__main__.py b/bitsandbytes/__main__.py index 523d02301..0626752dd 100644 --- a/bitsandbytes/__main__.py +++ b/bitsandbytes/__main__.py @@ -97,8 +97,8 @@ def print_debug_info() -> None: from . import COMPILED_WITH_CUDA, PACKAGE_GITHUB_URL -from .cuda_setup.env_vars import to_be_ignored -from .cuda_setup.main import get_compute_capabilities +from .device_setup.cuda.env_vars import to_be_ignored +from .device_setup.cuda.main import get_compute_capabilities print_header("OTHER") diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 21b814bf5..f99f87312 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -223,8 +223,6 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: - if device is "cpu": - return True """check if this device supports the optimized int8 kernel""" if torch.cuda.get_device_capability(device=device) < (7, 5): @@ -233,6 +231,10 @@ def supports_igemmlt(device: torch.device) -> bool: nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series if any(model_name in device_name for model_name in nvidia16_models): return False # these devices are technically cuda 7.5-capable, but they lack tensor cores + + if device == "cpu": + return False + return True @@ -291,7 +293,6 @@ def tile_indices(self): self._tile_indices = get_tile_inds(self.formatB, self.CxB.device) return self._tile_indices - class MatMul8bitLt(torch.autograd.Function): # forward is the same, but we added the fallback for pre-turing GPUs # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") @@ -322,8 +323,8 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance() - # Cast A to fp16 if not on CPU - ctx.cast_dtype = torch.bfloat16 if device is "cpu" else torch.float16 + # Cast A to fp16 + ctx.cast_dtype = torch.float16 if A.dtype != ctx.cast_dtype: warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {ctx.cast_dtype} during quantization") @@ -571,7 +572,7 @@ def matmul( def matmul_4bit(A: tensor, B: tensor, quant_state: List, out: tensor = None, bias=None): assert quant_state is not None - if A.numel() == A.shape[-1] and A.requires_grad == False and A.device is "cuda": + if A.numel() == A.shape[-1] and A.requires_grad == False and A.device == "cuda": absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state if A.shape[-1] % blocksize != 0: warn(f'Some matrices hidden dimension is not a multiple of {blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}') diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 42fe44387..76bfa6647 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -4,16 +4,15 @@ from pathlib import Path from warnings import warn +from bitsandbytes.device_setup.cuda.main import CUDASetup -from bitsandbytes.cuda_setup.main import CUDASetup +setup = CUDASetup.get_instance() +if setup.initialized != True: + setup.run_cuda_setup() -if torch.cuda.is_available(): - setup = CUDASetup.get_instance() - if setup.initialized != True: - setup.run_cuda_setup() +lib = setup.lib - lib = setup.lib - +try: if lib is None and torch.cuda.is_available(): CUDASetup.get_instance().generate_instructions() CUDASetup.get_instance().print_log_stack() @@ -30,10 +29,11 @@ lib.get_cusparse.restype = ct.c_void_p lib.cget_managed_ptr.restype = ct.c_void_p COMPILED_WITH_CUDA = True - -else: - warn("The installed version of bitsandbytes was compiled without GPU support. Will" - "run with CPU support") +except AttributeError as ex: + warn("The installed version of bitsandbytes was compiled without CUDA GPU support. " + "8-bit optimizers, 8-bit multiplication, and CUDA GPU quantization are unavailable.") + COMPILED_WITH_CUDA = False + print(str(ex)) # print the setup details after checking for errors so we do not print twice #if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': diff --git a/bitsandbytes/cuda_setup/__init__.py b/bitsandbytes/device_setup/cpu/__init__.py similarity index 100% rename from bitsandbytes/cuda_setup/__init__.py rename to bitsandbytes/device_setup/cpu/__init__.py diff --git a/bitsandbytes/device_setup/cpu/main.py b/bitsandbytes/device_setup/cpu/main.py new file mode 100644 index 000000000..8e44a6ae9 --- /dev/null +++ b/bitsandbytes/device_setup/cpu/main.py @@ -0,0 +1,40 @@ +from packaging import version +import importlib.metadata +from warnings import warn +def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: + # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + package_version = importlib.metadata.version(pkg_name) + package_exists = True + except importlib.metadata.PackageNotFoundError: + package_exists = False + if return_version: + return package_exists, package_version + else: + return package_exists + +_torch_version = "N/A" +_torch_available = False +_torch_available, _torch_version = _is_package_available("torch", return_version=True) +_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True) + +def is_ipex_cpu_available(): + def get_major_and_minor_from_version(full_version): + return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) + + if not _torch_available or not _ipex_available: + return False + + torch_major_and_minor = get_major_and_minor_from_version(_torch_version) + ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) + if torch_major_and_minor != ipex_major_and_minor: + warn( + f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," + f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." + "Refer to https://intel.github.io/intel-extension-for-pytorch/ for more details." + ) + return False + return True \ No newline at end of file diff --git a/bitsandbytes/device_setup/cuda/__init__.py b/bitsandbytes/device_setup/cuda/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/cuda_setup/env_vars.py b/bitsandbytes/device_setup/cuda/env_vars.py similarity index 100% rename from bitsandbytes/cuda_setup/env_vars.py rename to bitsandbytes/device_setup/cuda/env_vars.py diff --git a/bitsandbytes/cuda_setup/main.py b/bitsandbytes/device_setup/cuda/main.py similarity index 99% rename from bitsandbytes/cuda_setup/main.py rename to bitsandbytes/device_setup/cuda/main.py index f3edf4c73..fd639beb7 100644 --- a/bitsandbytes/cuda_setup/main.py +++ b/bitsandbytes/device_setup/cuda/main.py @@ -254,7 +254,7 @@ def determine_cuda_runtime_lib_path() -> Union[Path, None]: 1. active conda env 2. LD_LIBRARY_PATH 3. any other env vars, while ignoring those that - - are known to be unrelated (see `bnb.cuda_setup.env_vars.to_be_ignored`) + - are known to be unrelated (see `bnb.device_setup.cuda.env_vars.to_be_ignored`) - don't contain the path separator `/` If multiple libraries are found in part 3, we optimistically try one, diff --git a/bitsandbytes/device_setup/xpu/__init__.py b/bitsandbytes/device_setup/xpu/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bitsandbytes/device_setup/xpu/main.py b/bitsandbytes/device_setup/xpu/main.py new file mode 100644 index 000000000..f13e6cb2d --- /dev/null +++ b/bitsandbytes/device_setup/xpu/main.py @@ -0,0 +1,15 @@ +from .cpu.main import is_ipex_cpu_available +from warnings import warn + +def is_ipex_xpu_available(): + if is_ipex_cpu_available(): + import intel_extension_for_pytorch + else: + return False + + if torch.xpu.is_available(): + return True + else: + warn("The installed version of intel_extension_for_pytorch is not supporting XPU device, " + " or the XPU device is unavailable.") + return False diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c52fd0230..cfa8b91b0 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,9 +15,26 @@ from functools import reduce # Required in Python 3 from typing import Tuple from torch import Tensor +from warnings import warn +from .cextension import COMPILED_WITH_CUDA -from .cextension import COMPILED_WITH_CUDA, lib - +# CUDA specific lib +if COMPILED_WITH_CUDA: + from .cextension import lib + +from bitsandbytes.device_setup.cpu.main import is_ipex_cpu_available +from bitsandbytes.device_setup.xpu.main import is_ipex_xpu_available +if not is_ipex_cpu_available(): + warn( + "Intel Extension for PyTorch CPU/XPU supports are not available." + "Please refer to https://intel.github.io/intel-extension-for-pytorch/ for installation." + ) +else: + if not is_ipex_xpu_available(): + warn( + "Intel Extension for PyTorch CPU support is available, while XPU is not." + ) + import intel_extension_for_pytorch as ipex # math.prod not compatible with python < 3.8 def prod(iterable): @@ -2403,28 +2420,24 @@ def pipeline_test(A, batch_size): return out -# 8 bits +# 8 bits functions def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): - if A.device is "cuda": + if A.device == "cuda": return cuda_double_quant(A=A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold) else: - pass + raise RuntimeError("double_quant on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - if A.device is "cuda": - cuda_transform(A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld) + if A.device == "cuda": + return cuda_transform(A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld) else: - pass + raise RuntimeError("transform on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): - if A.device is "cuda": - cuda_igemmlt(A, B, SA, SB, out=out, Sout=Sout, dtype=dtype) - elif A.device is "cpu": - pass - elif A.device is "xpu": - pass + if A.device == "cuda": + return cuda_igemmlt(A, B, SA, SB, out=out, Sout=Sout, dtype=dtype) else: - pass + raise RuntimeError("igemmlt on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") def mm_dequant( A, @@ -2436,26 +2449,26 @@ def mm_dequant( new_col_stats=None, bias=None ): - if A.device is "cuda": + if A.device == "cuda": cuda_mm_dequant(A, quant_state, row_stats, col_stats, out=out, new_row_stats=new_row_stats, new_col_stats=new_col_stats, bias=bias) else: - pass + raise RuntimeError("mm_dequant on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") -# 4 bits -def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: - if A.device is "cuda": - cuda_quantize_4bit(A, absmax = absmax, out = out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type) +def extract_outliers(A, SA, idx): + if A.device == "cuda": + return cuda_extract_outliers(A, SA, idx) else: - pass + raise RuntimeError("extract_outliers on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") -def dequantize_4bit(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: - if A.device is "cuda": - cuda_dequantize_4bit(A, quant_state = quant_state, absmax = absmax, out = out, blocksize = blocksize, quant_type=quant_type) +# 4 bits functions +def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: + if A.device == "cuda": + return cuda_quantize_4bit(A, absmax = absmax, out = out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type) else: - pass + raise RuntimeError("quantize_4bit on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") -def extract_outliers(A, SA, idx): - if A.device is "cuda": - cuda_extract_outliers(A, SA, idx) +def dequantize_4bit(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: + if A.device == "cuda": + return cuda_dequantize_4bit(A, quant_state = quant_state, absmax = absmax, out = out, blocksize = blocksize, quant_type=quant_type) else: - pass \ No newline at end of file + raise RuntimeError("dequantize_4bit on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") \ No newline at end of file diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index be90c829e..96c3fd326 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -305,7 +305,7 @@ def __new__( def cpu(self, device): - warnings.warn("XPU Int8Params will be soon supported, return raw Int8Params for now") + warnings.warn("CPU Int8Params will be soon supported, return raw Int8Params for now") return self @@ -396,8 +396,7 @@ def __init__(self, input_features, output_features, bias=True, has_fp16_weights= self.index = index self.state.threshold = threshold - # fp16 not supports on CPU yet - self.state.has_fp16_weights = has_fp16_weights if device is not "cpu" else False + self.state.has_fp16_weights = has_fp16_weights self.state.memory_efficient_backward = memory_efficient_backward if threshold > 0.0 and not has_fp16_weights: self.state.use_pool = True diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index e875bcd2b..166fb9890 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -15,7 +15,7 @@ def test_manual_override(): os.environ['CUDA_VERSION']='122' assert str(manual_cuda_path) in os.environ['LD_LIBRARY_PATH'] import bitsandbytes as bnb - loaded_lib = bnb.cuda_setup.main.CUDASetup.get_instance().binary_name + loaded_lib = bnb.device_setup.cuda.main.CUDASetup.get_instance().binary_name assert loaded_lib == 'libbitsandbytes_cuda122.so' From 65b17a266d6175951e09978b760752c8fbc7cb05 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Thu, 30 Nov 2023 16:05:32 +0800 Subject: [PATCH 04/39] Update modules.py --- bitsandbytes/nn/modules.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 7c7c98375..d9a0e7434 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -161,8 +161,7 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], self.quant_state = QuantState.from_dict(qs_dict=quantized_stats, device=device) self.blocksize = self.quant_state.blocksize self.compress_statistics = self.quant_state.nested - - return self + self.quant_type = self.quant_state.quant_type def cpu(self, device): warnings.warn("CPU Params4bit will be soon supported, return raw Params4bit for now") From c5044e01aade4d39192efe6815cfdee446e83bfb Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Fri, 1 Dec 2023 17:11:36 +0800 Subject: [PATCH 05/39] Update bitsandbytes/functional.py Co-authored-by: Jiong Gong --- bitsandbytes/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index eab555a0b..69035f033 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2542,7 +2542,7 @@ def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, if A.device == "cuda": return cuda_double_quant(A=A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold) else: - raise RuntimeError("double_quant on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") + raise RuntimeError("double_quant is not supported on non-CUDA devices") def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): if A.device == "cuda": From b23789ab77d6c3703360be90aadf361d4c7a8239 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Sun, 3 Dec 2023 09:21:47 -0800 Subject: [PATCH 06/39] add backends --- bitsandbytes/autograd/_functions.py | 17 +- bitsandbytes/backends.py | 38 + .../device_setup/{cpu => }/__init__.py | 0 bitsandbytes/device_setup/cpu/main.py | 40 - bitsandbytes/device_setup/xpu/__init__.py | 0 bitsandbytes/device_setup/xpu/main.py | 15 - bitsandbytes/functional.py | 1050 ++++++++--------- 7 files changed, 533 insertions(+), 627 deletions(-) create mode 100644 bitsandbytes/backends.py rename bitsandbytes/device_setup/{cpu => }/__init__.py (100%) delete mode 100644 bitsandbytes/device_setup/cpu/main.py delete mode 100644 bitsandbytes/device_setup/xpu/__init__.py delete mode 100644 bitsandbytes/device_setup/xpu/main.py diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 548e6577c..757aafb4f 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -223,7 +223,6 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: - """check if this device supports the optimized int8 kernel""" if torch.cuda.get_device_capability(device=device) < (7, 5): return False @@ -231,8 +230,8 @@ def supports_igemmlt(device: torch.device) -> bool: nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series if any(model_name in device_name for model_name in nvidia16_models): return False # these devices are technically cuda 7.5-capable, but they lack tensor cores - if device == "cpu": + #TODO: will return True once CPU backend upstream the supports return False return True @@ -272,7 +271,6 @@ class MatmulLtState: idx = None is_training = True has_fp16_weights = True - memory_efficient_backward = False use_pool = False formatB = F.get_special_format_str() @@ -324,14 +322,13 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): state.outlier_pool = GlobalOutlierPooler.get_instance() # Cast A to fp16 - ctx.cast_dtype = torch.float16 - if A.dtype != ctx.cast_dtype: - warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {ctx.cast_dtype} during quantization") + if A.dtype != torch.float16: + warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") # 1. Quantize A if len(A.shape) == 3: A = A.reshape(-1, A.shape[-1]) - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(ctx.cast_dtype), threshold=state.threshold) + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) if state.threshold > 0.0 and coo_tensorA is not None: if state.has_fp16_weights: @@ -366,7 +363,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): state.SCB, state.SCBt, coo_tensorB, - ) = F.double_quant(B.to(ctx.cast_dtype)) + ) = F.double_quant(B.to(torch.float16)) if using_igemmlt: state.CxB, state.SB = F.transform(CB, to_order=formatB) else: @@ -461,11 +458,11 @@ def backward(ctx, grad_output): # compute grad_bias first before changing grad_output dtype grad_bias = grad_output.sum(0, dtype=ctx.dtype_bias) - # Cast grad_output + # Cast grad_output to fp16 if len(grad_output.shape) == 3: grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(ctx.cast_dtype)) + Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) if req_gradB: CxAt, SAt = F.transform(CAt, formatB, transpose=True) C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True) diff --git a/bitsandbytes/backends.py b/bitsandbytes/backends.py new file mode 100644 index 000000000..eb7bda484 --- /dev/null +++ b/bitsandbytes/backends.py @@ -0,0 +1,38 @@ +class Backends: + """ + An dict class for device backends that registered with 8bits and 4bits functions. + + The values of this device backends are lowercase strings, e.g., ``"cuda"``. They can + be accessed as attributes with key-value, e.g., ``Backends.device["cuda"]``. + + """ + + def __init__(self): + self.devices = {} + + @classmethod + def register_backend(backend_name: str, backend_class): + assert backend_name.lower() in { + "cpu", + "cuda", + "xpu", + }, "register device backend choices in [cpu, cuda, xpu]" + + # check 8bits or 4bits functionality, at least one is compelete + if ( + hasattr(backend_class, "double_quant") + and hasattr(backend_class, "transform") + and hasattr(backend_class, "igemmlt") + and hasattr(backend_class, "mm_dequant") + and hasattr(backend_class, "extract_outliers") + ): + self.devices[backend_name.lower()] = backend_class + + elif hasattr(backend_class, "quantize_4bit") and hasattr( + backend_class, "dequantize_4bit" + ): + self.devices[backend_name.lower()] = backend_classq + else: + assert ( + False + ), f"register device backend {backend_name.lower()} but its functionality is not compelete" diff --git a/bitsandbytes/device_setup/cpu/__init__.py b/bitsandbytes/device_setup/__init__.py similarity index 100% rename from bitsandbytes/device_setup/cpu/__init__.py rename to bitsandbytes/device_setup/__init__.py diff --git a/bitsandbytes/device_setup/cpu/main.py b/bitsandbytes/device_setup/cpu/main.py deleted file mode 100644 index 8e44a6ae9..000000000 --- a/bitsandbytes/device_setup/cpu/main.py +++ /dev/null @@ -1,40 +0,0 @@ -from packaging import version -import importlib.metadata -from warnings import warn -def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: - # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version - package_exists = importlib.util.find_spec(pkg_name) is not None - package_version = "N/A" - if package_exists: - try: - package_version = importlib.metadata.version(pkg_name) - package_exists = True - except importlib.metadata.PackageNotFoundError: - package_exists = False - if return_version: - return package_exists, package_version - else: - return package_exists - -_torch_version = "N/A" -_torch_available = False -_torch_available, _torch_version = _is_package_available("torch", return_version=True) -_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True) - -def is_ipex_cpu_available(): - def get_major_and_minor_from_version(full_version): - return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) - - if not _torch_available or not _ipex_available: - return False - - torch_major_and_minor = get_major_and_minor_from_version(_torch_version) - ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) - if torch_major_and_minor != ipex_major_and_minor: - warn( - f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," - f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." - "Refer to https://intel.github.io/intel-extension-for-pytorch/ for more details." - ) - return False - return True \ No newline at end of file diff --git a/bitsandbytes/device_setup/xpu/__init__.py b/bitsandbytes/device_setup/xpu/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bitsandbytes/device_setup/xpu/main.py b/bitsandbytes/device_setup/xpu/main.py deleted file mode 100644 index f13e6cb2d..000000000 --- a/bitsandbytes/device_setup/xpu/main.py +++ /dev/null @@ -1,15 +0,0 @@ -from .cpu.main import is_ipex_cpu_available -from warnings import warn - -def is_ipex_xpu_available(): - if is_ipex_cpu_available(): - import intel_extension_for_pytorch - else: - return False - - if torch.xpu.is_available(): - return True - else: - warn("The installed version of intel_extension_for_pytorch is not supporting XPU device, " - " or the XPU device is unavailable.") - return False diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 69035f033..c30e1b651 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -19,26 +19,12 @@ from warnings import warn from .cextension import COMPILED_WITH_CUDA from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict - +from backends import Backends # CUDA specific lib if COMPILED_WITH_CUDA: from .cextension import lib -from bitsandbytes.device_setup.cpu.main import is_ipex_cpu_available -from bitsandbytes.device_setup.xpu.main import is_ipex_xpu_available -if not is_ipex_cpu_available(): - warn( - "Intel Extension for PyTorch CPU/XPU supports are not available." - "Please refer to https://intel.github.io/intel-extension-for-pytorch/ for installation." - ) -else: - if not is_ipex_xpu_available(): - warn( - "Intel Extension for PyTorch CPU support is available, while XPU is not." - ) - import intel_extension_for_pytorch as ipex - # math.prod not compatible with python < 3.8 def prod(iterable): return reduce(operator.mul, iterable, 1) @@ -908,169 +894,12 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') -def cuda_quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: - """ - Quantize tensor A in blocks of 4-bit values. - - Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. - - Parameters - ---------- - A : torch.Tensor - The input tensor. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - The output tensor (8-bit). - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4} - - Returns - ------- - torch.Tensor: - The 8-bit tensor with packed 4-bit values. - tuple(torch.Tensor, torch.Size, torch.dtype, int): - The quantization state to undo the quantization. - """ - if A.device.type != 'cuda': - raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') - - n = A.numel() - input_shape = A.shape - - if absmax is None: - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - - - if out is None: - out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) - - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - - prev_device = pre_call(A.device) - is_on_gpu([A, out, absmax]) - - if A.dtype == torch.float32: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - else: - lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - elif A.dtype == torch.float16: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - else: - lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - elif A.dtype == torch.bfloat16: - if quant_type == 'fp4': - lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - else: - lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) - - code = get_4bit_type(quant_type, device=A.device) - - if compress_statistics: - offset = absmax.mean() - absmax -= offset - qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) - del absmax - state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2) - else: - state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, ) - - return out, state - def dequantize_fp4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') def dequantize_nf4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') -def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: - """ - Dequantizes FP4 blockwise quantized values. - - Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. - - Parameters - ---------- - A : torch.Tensor - The input 8-bit tensor (packed 4-bit values). - quant_state : QuantState - object with quantisation stats, incl. absmax values, original tensor shape and original dtype. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - Dequantized output tensor. - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4} - - - Returns - ------- - torch.Tensor: - Dequantized tensor. - """ - if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') - - if quant_state is None: - assert absmax is not None and out is not None - - quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type) - - else: - absmax = quant_state.absmax - - - if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: absmax = absmax.float() - - if out is None: - out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) - - n = out.numel() - - device = pre_call(A.device) - is_on_gpu([A, absmax, out]) - if out.dtype == torch.float32: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) - else: - lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) - elif out.dtype == torch.float16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) - else: - lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) - elif out.dtype == torch.bfloat16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) - else: - lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) - - is_transposed = (True if A.shape[0] == 1 else False) - if is_transposed: return out.t() - else: return out - - def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: if code is None: if "dynamic" not in name2qmap: @@ -1833,198 +1662,6 @@ def batched_igemm( return out -def cuda_igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): - shapeA = SA[0] - shapeB = SB[0] - dimsA = len(shapeA) - dimsB = len(shapeB) - assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' - if dimsA == 2: - m = shapeA[0] - elif dimsA == 3: - m = shapeA[0] * shapeA[1] - - rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' - - # if the tensor is empty, return a transformed empty tensor with the right dimensions - if shapeA[0] == 0 and dimsA == 2: - return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) - elif shapeA[1] == 0 and dimsA == 3: - return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) - - if dimsA == 2 and out is None: - out, Sout = get_transform_buffer( - (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" - ) - elif dimsA == 3 and out is None: - out, Sout = get_transform_buffer( - (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" - ) - - assert dimsB != 3, "len(B.shape)==3 not supported" - assert A.device.type == "cuda" - assert B.device.type == "cuda" - assert A.dtype == torch.int8 - assert B.dtype == torch.int8 - assert out.dtype == dtype - assert SA[1] == "col32" - assert SB[1] in ["col_turing", "col_ampere"] - assert Sout[1] == "col32" - assert ( - shapeA[-1] == shapeB[-1] - ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" - formatB = SB[1] - prev_device = A.device - torch.cuda.set_device(A.device) - - ptr = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - - k = shapeA[-1] - lda = ct.c_int32(m * 32) - if formatB == "col_turing": - # turing: tiles with rows filled up to multiple of 8 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) - else: - # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) - - ldc = ct.c_int32(m * 32) - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - - has_error = 0 - ptrRowScale = get_ptr(None) - is_on_gpu([A, B, out]) - if formatB == 'col_turing': - if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) - else: - has_error = lib.cigemmlt_turing_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) - elif formatB == "col_ampere": - if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) - else: - has_error = lib.cigemmlt_ampere_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) - - if has_error == 1: - print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') - raise Exception('cublasLt ran into an error!') - - torch.cuda.set_device(prev_device) - - return out, Sout - - -def cuda_mm_dequant( - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None -): - assert A.dtype == torch.int32 - if bias is not None: assert bias.dtype == torch.float16 - out_shape = quant_state[0] - if len(out_shape) == 3: - out_shape = (out_shape[0] * out_shape[1], out_shape[2]) - - if out is None: - out = torch.empty(out_shape, dtype=torch.float16, device=A.device) - if new_row_stats is None: - new_row_stats = torch.empty( - out_shape[0], dtype=torch.float32, device=A.device - ) - if new_col_stats is None: - new_col_stats = torch.empty( - out_shape[1], dtype=torch.float32, device=A.device - ) - assert ( - new_row_stats.shape[0] == row_stats.shape[0] - ), f"{new_row_stats.shape} vs {row_stats.shape}" - assert ( - new_col_stats.shape[0] == col_stats.shape[0] - ), f"{new_col_stats.shape} vs {col_stats.shape}" - - prev_device = pre_call(A.device) - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - ptrNewRowStats = get_ptr(new_row_stats) - ptrNewColStats = get_ptr(new_col_stats) - ptrBias = get_ptr(bias) - numRows = ct.c_int32(out_shape[0]) - numCols = ct.c_int32(out_shape[1]) - - is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) - lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) - post_call(prev_device) - - return out - - -def get_colrow_absmax( - A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 -): - assert A.dtype == torch.float16 - device = A.device - - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - - col_tiles = (cols + 255) // 256 - tiled_rows = ((rows + 15) // 16) * 16 - if row_stats is None: - row_stats = torch.empty( - (rows,), dtype=torch.float32, device=device - ).fill_(-50000.0) - if col_stats is None: - col_stats = torch.empty( - (cols,), dtype=torch.float32, device=device - ).fill_(-50000.0) - - if nnz_block_ptr is None and threshold > 0.0: - nnz_block_ptr = torch.zeros( - ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device - ) - - ptrA = get_ptr(A) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - ptrNnzrows = get_ptr(nnz_block_ptr) - rows = ct.c_int32(rows) - cols = ct.c_int32(cols) - - prev_device = pre_call(A.device) - is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) - lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) - post_call(prev_device) - - if threshold > 0.0: - nnz_block_ptr.cumsum_(0) - - return row_stats, col_stats, nnz_block_ptr class COOSparseTensor: @@ -2113,147 +1750,6 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): values = torch.zeros((nnz,), dtype=dtype, device=device) return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) - -def cuda_double_quant( - A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 -): - device = A.device - assert A.dtype == torch.half - assert device.type == "cuda" - prev_device = pre_call(A.device) - - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - - if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( - A, threshold=threshold - ) - - if out_col is None: - out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) - if out_row is None: - out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) - - coo_tensor = None - ptrA = get_ptr(A) - ptrColStats = get_ptr(col_stats) - ptrRowStats = get_ptr(row_stats) - ptrOutCol = get_ptr(out_col) - ptrOutRow = get_ptr(out_row) - - is_on_gpu([A, col_stats, row_stats, out_col, out_row]) - if threshold > 0.0: - nnz = nnz_row_ptr[-1].item() - if nnz > 0: - coo_tensor = coo_zeros( - A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device - ) - ptrRowIdx = get_ptr(coo_tensor.rowidx) - ptrColIdx = get_ptr(coo_tensor.colidx) - ptrVal = get_ptr(coo_tensor.values) - ptrRowPtr = get_ptr(nnz_row_ptr) - - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - ptrRowIdx, - ptrColIdx, - ptrVal, - ptrRowPtr, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - val, idx = torch.sort(coo_tensor.rowidx) - coo_tensor.rowidx = val - coo_tensor.colidx = coo_tensor.colidx[idx] - coo_tensor.values = coo_tensor.values[idx] - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(0.0), - ct.c_int32(rows), - ct.c_int32(cols), - ) - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - post_call(prev_device) - - return out_row, out_col, row_stats, col_stats, coo_tensor - - -def cuda_transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - prev_device = pre_call(A.device) - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: new_state = (state[0], to_order) # (shape, order) - - shape = state[0] - if len(shape) == 2: - dim1 = ct.c_int32(shape[0]) - dim2 = ct.c_int32(shape[1]) - else: - dim1 = ct.c_int32(shape[0] * shape[1]) - dim2 = ct.c_int32(shape[2]) - - is_on_gpu([A, out]) - if to_order == 'col32': - if transpose: - lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "col_turing": - if transpose: - lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "col_ampere": - if transpose: - lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "row": - if from_order == "col_turing": - lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) - elif from_order == "col_ampere": - lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) - else: - raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') - - post_call(prev_device) - - return out, new_state - - def spmm_coo(cooA, B, out=None): if out is None: out = torch.empty( @@ -2505,56 +2001,494 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): x += offset return x.to(dtype) -def cuda_extract_outliers(A, SA, idx): - shapeA = SA[0] - formatA = SA[1] - assert formatA in ["col_turing", "col_ampere"] - assert A.device.type == "cuda" +def pipeline_test(A, batch_size): + out = torch.zeros_like(A) + lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) + return out - out = torch.zeros( - (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device - ) +class CUDABackend: + @classmethod + def double_quant( + A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 + ): + device = A.device + assert A.dtype == torch.half + assert device.type == "cuda" + prev_device = pre_call(A.device) - idx_size = ct.c_int32(idx.numel()) - rows = ct.c_int32(shapeA[0]) - cols = ct.c_int32(shapeA[1]) - ptrA = get_ptr(A) - ptrIdx = get_ptr(idx) - ptrOut = get_ptr(out) + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] - prev_device = pre_call(A.device) - if formatA == 'col_turing': - lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - elif formatA == "col_ampere": - lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - post_call(prev_device) + if row_stats is None or col_stats is None: + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( + A, threshold=threshold + ) - return out + if out_col is None: + out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) + if out_row is None: + out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) + + coo_tensor = None + ptrA = get_ptr(A) + ptrColStats = get_ptr(col_stats) + ptrRowStats = get_ptr(row_stats) + ptrOutCol = get_ptr(out_col) + ptrOutRow = get_ptr(out_row) + + is_on_gpu([A, col_stats, row_stats, out_col, out_row]) + if threshold > 0.0: + nnz = nnz_row_ptr[-1].item() + if nnz > 0: + coo_tensor = coo_zeros( + A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device + ) + ptrRowIdx = get_ptr(coo_tensor.rowidx) + ptrColIdx = get_ptr(coo_tensor.colidx) + ptrVal = get_ptr(coo_tensor.values) + ptrRowPtr = get_ptr(nnz_row_ptr) + + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + ptrRowIdx, + ptrColIdx, + ptrVal, + ptrRowPtr, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) + val, idx = torch.sort(coo_tensor.rowidx) + coo_tensor.rowidx = val + coo_tensor.colidx = coo_tensor.colidx[idx] + coo_tensor.values = coo_tensor.values[idx] + else: + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(0.0), + ct.c_int32(rows), + ct.c_int32(cols), + ) + else: + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) + post_call(prev_device) -def pipeline_test(A, batch_size): - out = torch.zeros_like(A) - lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) - return out + return out_row, out_col, row_stats, col_stats, coo_tensor + @classmethod + def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): + prev_device = pre_call(A.device) + if state is None: state = (A.shape, from_order) + else: from_order = state[1] + if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) + else: new_state = (state[0], to_order) # (shape, order) + + shape = state[0] + if len(shape) == 2: + dim1 = ct.c_int32(shape[0]) + dim2 = ct.c_int32(shape[1]) + else: + dim1 = ct.c_int32(shape[0] * shape[1]) + dim2 = ct.c_int32(shape[2]) + is_on_gpu([A, out]) + if to_order == 'col32': + if transpose: + lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == "col_turing": + if transpose: + lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == "col_ampere": + if transpose: + lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == "row": + if from_order == "col_turing": + lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) + elif from_order == "col_ampere": + lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) + else: + raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') -# 8 bits functions -def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): - if A.device == "cuda": - return cuda_double_quant(A=A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold) + post_call(prev_device) + + return out, new_state + @classmethod + def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + shapeA = SA[0] + shapeB = SB[0] + dimsA = len(shapeA) + dimsB = len(shapeB) + assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' + if dimsA == 2: + m = shapeA[0] + elif dimsA == 3: + m = shapeA[0] * shapeA[1] + + rows = n = shapeB[0] + assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' + + # if the tensor is empty, return a transformed empty tensor with the right dimensions + if shapeA[0] == 0 and dimsA == 2: + return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) + elif shapeA[1] == 0 and dimsA == 3: + return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) + + if dimsA == 2 and out is None: + out, Sout = get_transform_buffer( + (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" + ) + elif dimsA == 3 and out is None: + out, Sout = get_transform_buffer( + (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" + ) + + assert dimsB != 3, "len(B.shape)==3 not supported" + assert A.device.type == "cuda" + assert B.device.type == "cuda" + assert A.dtype == torch.int8 + assert B.dtype == torch.int8 + assert out.dtype == dtype + assert SA[1] == "col32" + assert SB[1] in ["col_turing", "col_ampere"] + assert Sout[1] == "col32" + assert ( + shapeA[-1] == shapeB[-1] + ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" + formatB = SB[1] + prev_device = A.device + torch.cuda.set_device(A.device) + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + + k = shapeA[-1] + lda = ct.c_int32(m * 32) + if formatB == "col_turing": + # turing: tiles with rows filled up to multiple of 8 rows by 32 columns + # n = rows + ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) else: - raise RuntimeError("double_quant is not supported on non-CUDA devices") + # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns + # n = rows + ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) + + ldc = ct.c_int32(m * 32) + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + + has_error = 0 + ptrRowScale = get_ptr(None) + is_on_gpu([A, B, out]) + if formatB == 'col_turing': + if dtype == torch.int32: + has_error = lib.cigemmlt_turing_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + else: + has_error = lib.cigemmlt_turing_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + elif formatB == "col_ampere": + if dtype == torch.int32: + has_error = lib.cigemmlt_ampere_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + else: + has_error = lib.cigemmlt_ampere_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) -def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - if A.device == "cuda": - return cuda_transform(A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld) + if has_error == 1: + print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') + raise Exception('cublasLt ran into an error!') + + torch.cuda.set_device(prev_device) + + return out, Sout + @classmethod + def mm_dequant( + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None + ): + assert A.dtype == torch.int32 + if bias is not None: assert bias.dtype == torch.float16 + out_shape = quant_state[0] + if len(out_shape) == 3: + out_shape = (out_shape[0] * out_shape[1], out_shape[2]) + + if out is None: + out = torch.empty(out_shape, dtype=torch.float16, device=A.device) + if new_row_stats is None: + new_row_stats = torch.empty( + out_shape[0], dtype=torch.float32, device=A.device + ) + if new_col_stats is None: + new_col_stats = torch.empty( + out_shape[1], dtype=torch.float32, device=A.device + ) + assert ( + new_row_stats.shape[0] == row_stats.shape[0] + ), f"{new_row_stats.shape} vs {row_stats.shape}" + assert ( + new_col_stats.shape[0] == col_stats.shape[0] + ), f"{new_col_stats.shape} vs {col_stats.shape}" + + prev_device = pre_call(A.device) + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + ptrNewRowStats = get_ptr(new_row_stats) + ptrNewColStats = get_ptr(new_col_stats) + ptrBias = get_ptr(bias) + numRows = ct.c_int32(out_shape[0]) + numCols = ct.c_int32(out_shape[1]) + + is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) + lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) + post_call(prev_device) + + return out + @classmethod + def extract_outliers(A, SA, idx): + shapeA = SA[0] + formatA = SA[1] + assert formatA in ["col_turing", "col_ampere"] + assert A.device.type == "cuda" + + out = torch.zeros( + (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device + ) + + idx_size = ct.c_int32(idx.numel()) + rows = ct.c_int32(shapeA[0]) + cols = ct.c_int32(shapeA[1]) + ptrA = get_ptr(A) + ptrIdx = get_ptr(idx) + ptrOut = get_ptr(out) + + prev_device = pre_call(A.device) + if formatA == 'col_turing': + lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + elif formatA == "col_ampere": + lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + post_call(prev_device) + + return out + @classmethod + def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: + """ + Quantize tensor A in blocks of 4-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor (8-bit). + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype, int): + The quantization state to undo the quantization. + """ + if A.device.type != 'cuda': + raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + + n = A.numel() + input_shape = A.shape + + if absmax is None: + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + + + if out is None: + out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) + + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + + prev_device = pre_call(A.device) + is_on_gpu([A, out, absmax]) + + if A.dtype == torch.float32: + if quant_type == 'fp4': + lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + elif A.dtype == torch.float16: + if quant_type == 'fp4': + lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + elif A.dtype == torch.bfloat16: + if quant_type == 'fp4': + lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: + lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: - raise RuntimeError("transform on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) -def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): - if A.device == "cuda": - return cuda_igemmlt(A, B, SA, SB, out=out, Sout=Sout, dtype=dtype) + code = get_4bit_type(quant_type, device=A.device) + + if compress_statistics: + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) + del absmax + state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2) else: - raise RuntimeError("igemmlt on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") + state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, ) + + return out, state + @classmethod + def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: + """ + Dequantizes FP4 blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input 8-bit tensor (packed 4-bit values). + quant_state : QuantState + object with quantisation stats, incl. absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + if quant_type not in ['fp4', 'nf4']: + raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + + if quant_state is None: + assert absmax is not None and out is not None + + quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type) + + else: + absmax = quant_state.absmax + + + if quant_state.nested: + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: absmax = absmax.float() + + if out is None: + out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) + + n = out.numel() + + device = pre_call(A.device) + is_on_gpu([A, absmax, out]) + if out.dtype == torch.float32: + if quant_state.quant_type == 'fp4': + lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + else: + lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + elif out.dtype == torch.float16: + if quant_state.quant_type == 'fp4': + lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + else: + lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + elif out.dtype == torch.bfloat16: + if quant_state.quant_type == 'fp4': + lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + else: + lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) + + is_transposed = (True if A.shape[0] == 1 else False) + if is_transposed: return out.t() + else: return out + + +Backends.register_backend("cuda", CUDABackend) + +# 8 bits common functions +def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): + assert A.device in Backends.device, f"Device backend for {A.device} is not supported" + return Backends.device[A.device].double_quant(A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold) + +def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): + assert A.device in Backends.device, f"Device backend for {A.device} is not supported" + return Backends.device[A.device].transform(A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld) + +def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + assert A.device in Backends.device, f"Device backend for {A.device} is not supported" + return Backends.device[A.device].igemmlt(A, B, SA, SB, out=out, Sout=Sout, dtype=dtype) def mm_dequant( A, @@ -2566,26 +2500,18 @@ def mm_dequant( new_col_stats=None, bias=None ): - if A.device == "cuda": - cuda_mm_dequant(A, quant_state, row_stats, col_stats, out=out, new_row_stats=new_row_stats, new_col_stats=new_col_stats, bias=bias) - else: - raise RuntimeError("mm_dequant on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") + assert A.device in Backends.device, f"Device backend for {A.device} is not supported" + return Backends.device[A.device].mm_dequant(A, quant_state, row_stats, col_stats, out=out, new_row_stats=new_row_stats, new_col_stats=new_col_stats, bias=bias) def extract_outliers(A, SA, idx): - if A.device == "cuda": - return cuda_extract_outliers(A, SA, idx) - else: - raise RuntimeError("extract_outliers on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") + assert A.device in Backends.device, f"Device backend for {A.device} is not supported" + return Backends.device[A.device].extract_outliers(A, SA, idx) -# 4 bits functions -def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: - if A.device == "cuda": - return cuda_quantize_4bit(A, absmax = absmax, out = out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type) - else: - raise RuntimeError("quantize_4bit on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") +# 4 bits common functions +def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4'): + assert A.device in Backends.device, f"Device backend for {A.device} is not supported" + return Backends.device[A.device].quantize_4bit(A, absmax = absmax, out = out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type) -def dequantize_4bit(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: - if A.device == "cuda": - return cuda_dequantize_4bit(A, quant_state = quant_state, absmax = absmax, out = out, blocksize = blocksize, quant_type=quant_type) - else: - raise RuntimeError("dequantize_4bit on non-CUDA devices (CPU, XPU...) will be soon supported but not yet, aborting...") \ No newline at end of file +def dequantize_4bit(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4'): + assert A.device in Backends.device, f"Device backend for {A.device} is not supported" + return Backends.device[A.device].dequantize_4bit(A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type) \ No newline at end of file From b2a4d54e398a147b6e3ba797934bc023595b8d7a Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Sun, 3 Dec 2023 19:54:15 -0800 Subject: [PATCH 07/39] add quant to device when init weight paam --- bitsandbytes/backends.py | 11 +++------ bitsandbytes/nn/modules.py | 50 +++++++------------------------------- 2 files changed, 13 insertions(+), 48 deletions(-) diff --git a/bitsandbytes/backends.py b/bitsandbytes/backends.py index eb7bda484..69c2c458c 100644 --- a/bitsandbytes/backends.py +++ b/bitsandbytes/backends.py @@ -18,21 +18,18 @@ def register_backend(backend_name: str, backend_class): "xpu", }, "register device backend choices in [cpu, cuda, xpu]" - # check 8bits or 4bits functionality, at least one is compelete + # check 8bits and 4bits interfaces if ( hasattr(backend_class, "double_quant") and hasattr(backend_class, "transform") and hasattr(backend_class, "igemmlt") and hasattr(backend_class, "mm_dequant") and hasattr(backend_class, "extract_outliers") + and hasattr(backend_class, "quantize_4bit") + and hasattr(backend_class, "dequantize_4bit") ): self.devices[backend_name.lower()] = backend_class - - elif hasattr(backend_class, "quantize_4bit") and hasattr( - backend_class, "dequantize_4bit" - ): - self.devices[backend_name.lower()] = backend_classq else: assert ( False - ), f"register device backend {backend_name.lower()} but its functionality is not compelete" + ), f"register device backend {backend_name.lower()} but its interfaces are not compelete" diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index d9a0e7434..5de83e867 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -163,16 +163,8 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], self.compress_statistics = self.quant_state.nested self.quant_type = self.quant_state.quant_type - def cpu(self, device): - warnings.warn("CPU Params4bit will be soon supported, return raw Params4bit for now") - return self - - def xpu(self, device): - warnings.warn("XPU Params4bit will be soon supported, return raw Params4bit for now") - return self - - def cuda(self, device): - w = self.data.contiguous().half().cuda(device) + def quantize_to_device(self, device): + w = self.data.contiguous().half().to(device) w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) self.data = w_4bit self.quant_state = quant_state @@ -194,14 +186,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if device is not None and device.type == "cpu": - return self.cpu(device) - - if (device is not None and device.type != "cpu" and self.data.device.type == "cpu"): - if device.type == "cuda": - return self.cuda(device) - elif device.type == "xpu": - return self.xpu(device) + if (device is not None and self.data.device.type == "cpu"): + return self.quantize_to_device(device) else: if self.quant_state is not None: self.quant_state.to(device) @@ -309,25 +295,13 @@ def __new__( data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) - - def cpu(self, device): - warnings.warn("CPU Int8Params will be soon supported, return raw Int8Params for now") - - return self - - - def xpu(self, device): - warnings.warn("XPU Int8Params will be soon supported, return raw Int8Params for now") - - return self - - def cuda(self, device): + def quantize_to_device(self, device): if self.has_fp16_weights: - return super().cuda(device) + return super().to(device) else: # we store the 8-bit rows-major weight # we convert this weight to the turning/ampere weight during the first inference pass - B = self.data.contiguous().half().cuda(device) + B = self.data.contiguous().half().to(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt del SCBt @@ -359,14 +333,8 @@ def to(self, *args, **kwargs): *args, **kwargs ) - if device is not None and device.type == "cpu": - return self.cpu(device) - - if (device is not None and device.type != "cpu" and self.data.device.type == "cpu"): - if device.type == "cuda": - return self.cuda(device) - elif device.type == "xpu": - return self.xpu(device) + if (device is not None and self.data.device.type == "cpu"): + return self.quantize_to_device(device) else: new_param = Int8Params( super().to( From c44cf065cc38c898c048ab461c3162879a815406 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Sun, 3 Dec 2023 22:15:15 -0800 Subject: [PATCH 08/39] minor fix --- bitsandbytes/functional.py | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c30e1b651..95295d6a1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -933,6 +933,50 @@ def dequantize( out = dequantize_no_absmax(A, state[1], out) return out * state[0] +def get_colrow_absmax( + A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 +): + assert A.dtype == torch.float16 + device = A.device + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + col_tiles = (cols + 255) // 256 + tiled_rows = ((rows + 15) // 16) * 16 + if row_stats is None: + row_stats = torch.empty( + (rows,), dtype=torch.float32, device=device + ).fill_(-50000.0) + if col_stats is None: + col_stats = torch.empty( + (cols,), dtype=torch.float32, device=device + ).fill_(-50000.0) + + if nnz_block_ptr is None and threshold > 0.0: + nnz_block_ptr = torch.zeros( + ((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device + ) + + ptrA = get_ptr(A) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + ptrNnzrows = get_ptr(nnz_block_ptr) + rows = ct.c_int32(rows) + cols = ct.c_int32(cols) + + prev_device = pre_call(A.device) + is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) + lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols) + post_call(prev_device) + + if threshold > 0.0: + nnz_block_ptr.cumsum_(0) + + return row_stats, col_stats, nnz_block_ptr def quantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: ''' From 365491a573e340f4c589694537c8b463290375a4 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 4 Dec 2023 00:41:57 -0800 Subject: [PATCH 09/39] mv cuda to common backends --- bitsandbytes/backends.py | 663 ++++++++++++++++++++++++++++++++++++- bitsandbytes/functional.py | 564 +++---------------------------- 2 files changed, 706 insertions(+), 521 deletions(-) diff --git a/bitsandbytes/backends.py b/bitsandbytes/backends.py index 69c2c458c..e2899e840 100644 --- a/bitsandbytes/backends.py +++ b/bitsandbytes/backends.py @@ -1,3 +1,21 @@ +import torch +from torch import Tensor +from bitsandbytes.functional import ( + pre_call, + post_call, + get_colrow_absmax, + get_ptr, + is_on_gpu, + coo_zeros, + get_transform_buffer, + prod, + get_4bit_type, + quantize_blockwise, + dequantize_blockwise, +) +from bitsandbytes.functional import CUBLAS_Context, QuantState + + class Backends: """ An dict class for device backends that registered with 8bits and 4bits functions. @@ -7,11 +25,10 @@ class Backends: """ - def __init__(self): - self.devices = {} + devices = {} @classmethod - def register_backend(backend_name: str, backend_class): + def register_backend(self, backend_name: str, backend_class): assert backend_name.lower() in { "cpu", "cuda", @@ -28,8 +45,646 @@ def register_backend(backend_name: str, backend_class): and hasattr(backend_class, "quantize_4bit") and hasattr(backend_class, "dequantize_4bit") ): - self.devices[backend_name.lower()] = backend_class + Backends.devices[backend_name.lower()] = backend_class else: assert ( False ), f"register device backend {backend_name.lower()} but its interfaces are not compelete" + + +class CUDABackend: + @classmethod + def double_quant( + A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 + ): + device = A.device + assert A.dtype == torch.half + assert device.type == "cuda" + prev_device = pre_call(A.device) + + cols = A.shape[-1] + if len(A.shape) == 3: + rows = A.shape[0] * A.shape[1] + else: + rows = A.shape[0] + + if row_stats is None or col_stats is None: + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( + A, threshold=threshold + ) + + if out_col is None: + out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) + if out_row is None: + out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) + + coo_tensor = None + ptrA = get_ptr(A) + ptrColStats = get_ptr(col_stats) + ptrRowStats = get_ptr(row_stats) + ptrOutCol = get_ptr(out_col) + ptrOutRow = get_ptr(out_row) + + is_on_gpu([A, col_stats, row_stats, out_col, out_row]) + if threshold > 0.0: + nnz = nnz_row_ptr[-1].item() + if nnz > 0: + coo_tensor = coo_zeros( + A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device + ) + ptrRowIdx = get_ptr(coo_tensor.rowidx) + ptrColIdx = get_ptr(coo_tensor.colidx) + ptrVal = get_ptr(coo_tensor.values) + ptrRowPtr = get_ptr(nnz_row_ptr) + + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + ptrRowIdx, + ptrColIdx, + ptrVal, + ptrRowPtr, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) + val, idx = torch.sort(coo_tensor.rowidx) + coo_tensor.rowidx = val + coo_tensor.colidx = coo_tensor.colidx[idx] + coo_tensor.values = coo_tensor.values[idx] + else: + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(0.0), + ct.c_int32(rows), + ct.c_int32(cols), + ) + else: + lib.cdouble_rowcol_quant( + ptrA, + ptrRowStats, + ptrColStats, + ptrOutCol, + ptrOutRow, + None, + None, + None, + None, + ct.c_float(threshold), + ct.c_int32(rows), + ct.c_int32(cols), + ) + post_call(prev_device) + + return out_row, out_col, row_stats, col_stats, coo_tensor + + @classmethod + def transform( + A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None + ): + prev_device = pre_call(A.device) + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + if out is None: + out, new_state = get_transform_buffer( + state[0], A.dtype, A.device, to_order, state[1], transpose + ) + else: + new_state = (state[0], to_order) # (shape, order) + + shape = state[0] + if len(shape) == 2: + dim1 = ct.c_int32(shape[0]) + dim2 = ct.c_int32(shape[1]) + else: + dim1 = ct.c_int32(shape[0] * shape[1]) + dim2 = ct.c_int32(shape[2]) + + is_on_gpu([A, out]) + if to_order == "col32": + if transpose: + lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == "col_turing": + if transpose: + lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == "col_ampere": + if transpose: + lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) + else: + lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == "row": + if from_order == "col_turing": + lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) + elif from_order == "col_ampere": + lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) + else: + raise NotImplementedError( + f"Transform function not implemented: From {from_order} to {to_order}" + ) + + post_call(prev_device) + + return out, new_state + + @classmethod + def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + shapeA = SA[0] + shapeB = SB[0] + dimsA = len(shapeA) + dimsB = len(shapeB) + assert dimsB == 2, "Only two dimensional matrices are supported for argument B" + if dimsA == 2: + m = shapeA[0] + elif dimsA == 3: + m = shapeA[0] * shapeA[1] + + rows = n = shapeB[0] + assert ( + prod(list(shapeA)) > 0 + ), f"Input tensor dimensions need to be > 0: {shapeA}" + + # if the tensor is empty, return a transformed empty tensor with the right dimensions + if shapeA[0] == 0 and dimsA == 2: + return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) + elif shapeA[1] == 0 and dimsA == 3: + return torch.empty( + tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16 + ) + + if dimsA == 2 and out is None: + out, Sout = get_transform_buffer( + (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" + ) + elif dimsA == 3 and out is None: + out, Sout = get_transform_buffer( + (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" + ) + + assert dimsB != 3, "len(B.shape)==3 not supported" + assert A.device.type == "cuda" + assert B.device.type == "cuda" + assert A.dtype == torch.int8 + assert B.dtype == torch.int8 + assert out.dtype == dtype + assert SA[1] == "col32" + assert SB[1] in ["col_turing", "col_ampere"] + assert Sout[1] == "col32" + assert ( + shapeA[-1] == shapeB[-1] + ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" + formatB = SB[1] + prev_device = A.device + torch.cuda.set_device(A.device) + + ptr = CUBLAS_Context.get_instance().get_context(A.device) + ptrA = get_ptr(A) + ptrB = get_ptr(B) + ptrC = get_ptr(out) + + k = shapeA[-1] + lda = ct.c_int32(m * 32) + if formatB == "col_turing": + # turing: tiles with rows filled up to multiple of 8 rows by 32 columns + # n = rows + ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) + else: + # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns + # n = rows + ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) + + ldc = ct.c_int32(m * 32) + m = ct.c_int32(m) + n = ct.c_int32(n) + k = ct.c_int32(k) + + has_error = 0 + ptrRowScale = get_ptr(None) + is_on_gpu([A, B, out]) + if formatB == "col_turing": + if dtype == torch.int32: + has_error = lib.cigemmlt_turing_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + else: + has_error = lib.cigemmlt_turing_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + elif formatB == "col_ampere": + if dtype == torch.int32: + has_error = lib.cigemmlt_ampere_32( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + else: + has_error = lib.cigemmlt_ampere_8( + ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc + ) + + if has_error == 1: + print( + f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}" + ) + raise Exception("cublasLt ran into an error!") + + torch.cuda.set_device(prev_device) + + return out, Sout + + @classmethod + def mm_dequant( + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None, + ): + assert A.dtype == torch.int32 + if bias is not None: + assert bias.dtype == torch.float16 + out_shape = quant_state[0] + if len(out_shape) == 3: + out_shape = (out_shape[0] * out_shape[1], out_shape[2]) + + if out is None: + out = torch.empty(out_shape, dtype=torch.float16, device=A.device) + if new_row_stats is None: + new_row_stats = torch.empty( + out_shape[0], dtype=torch.float32, device=A.device + ) + if new_col_stats is None: + new_col_stats = torch.empty( + out_shape[1], dtype=torch.float32, device=A.device + ) + assert ( + new_row_stats.shape[0] == row_stats.shape[0] + ), f"{new_row_stats.shape} vs {row_stats.shape}" + assert ( + new_col_stats.shape[0] == col_stats.shape[0] + ), f"{new_col_stats.shape} vs {col_stats.shape}" + + prev_device = pre_call(A.device) + ptrA = get_ptr(A) + ptrOut = get_ptr(out) + ptrRowStats = get_ptr(row_stats) + ptrColStats = get_ptr(col_stats) + ptrNewRowStats = get_ptr(new_row_stats) + ptrNewColStats = get_ptr(new_col_stats) + ptrBias = get_ptr(bias) + numRows = ct.c_int32(out_shape[0]) + numCols = ct.c_int32(out_shape[1]) + + is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) + lib.cdequant_mm_int32_fp16( + ptrA, + ptrRowStats, + ptrColStats, + ptrOut, + ptrNewRowStats, + ptrNewColStats, + ptrBias, + numRows, + numCols, + ) + post_call(prev_device) + + return out + + @classmethod + def extract_outliers(A, SA, idx): + shapeA = SA[0] + formatA = SA[1] + assert formatA in ["col_turing", "col_ampere"] + assert A.device.type == "cuda" + + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + + idx_size = ct.c_int32(idx.numel()) + rows = ct.c_int32(shapeA[0]) + cols = ct.c_int32(shapeA[1]) + ptrA = get_ptr(A) + ptrIdx = get_ptr(idx) + ptrOut = get_ptr(out) + + prev_device = pre_call(A.device) + if formatA == "col_turing": + lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + elif formatA == "col_ampere": + lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + post_call(prev_device) + + return out + + @classmethod + def quantize_4bit( + A: Tensor, + absmax: Tensor = None, + out: Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="fp4", + ) -> Tensor: + """ + Quantize tensor A in blocks of 4-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor (8-bit). + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype, int): + The quantization state to undo the quantization. + """ + if A.device.type != "cuda": + raise NotImplementedError( + f"Device type not supported for FP4 quantization: {A.device.type}" + ) + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError( + f"4-bit quantization data type {quant_type} is not implemented." + ) + + n = A.numel() + input_shape = A.shape + + if absmax is None: + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + + if out is None: + out = torch.zeros(((n + 1) // 2, 1), dtype=torch.uint8, device=A.device) + + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + + prev_device = pre_call(A.device) + is_on_gpu([A, out, absmax]) + + if A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + else: + lib.cquantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + else: + lib.cquantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + elif A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + else: + lib.cquantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + else: + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) + post_call(A.device) + + code = get_4bit_type(quant_type, device=A.device) + + if compress_statistics: + offset = absmax.mean() + absmax -= offset + qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) + del absmax + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) + else: + state = QuantState( + absmax=absmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + ) + + return out, state + + @classmethod + def dequantize_4bit( + A: Tensor, + quant_state: QuantState = None, + absmax: Tensor = None, + out: Tensor = None, + blocksize: int = 64, + quant_type="fp4", + ) -> Tensor: + """ + Dequantizes FP4 blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input 8-bit tensor (packed 4-bit values). + quant_state : QuantState + object with quantisation stats, incl. absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]" + ) + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError( + f"4-bit quantization data type {quant_type} is not implemented." + ) + + if quant_state is None: + assert absmax is not None and out is not None + + quant_state = QuantState( + absmax=absmax, + shape=out.shape, + dtype=out.dtype, + blocksize=blocksize, + quant_type=quant_type, + ) + + else: + absmax = quant_state.absmax + + if quant_state.nested: + absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) + absmax += quant_state.offset + if absmax.dtype != torch.float32: + absmax = absmax.float() + + if out is None: + out = torch.empty( + quant_state.shape, dtype=quant_state.dtype, device=A.device + ) + + n = out.numel() + + device = pre_call(A.device) + is_on_gpu([A, absmax, out]) + if out.dtype == torch.float32: + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) + else: + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) + elif out.dtype == torch.float16: + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) + else: + lib.cdequantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) + elif out.dtype == torch.bfloat16: + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) + else: + lib.cdequantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) + else: + raise ValueError( + f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" + ) + post_call(A.device) + + is_transposed = True if A.shape[0] == 1 else False + if is_transposed: + return out.t() + else: + return out + + +Backends.register_backend("cuda", CUDABackend) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 95295d6a1..178fa8614 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -19,7 +19,6 @@ from warnings import warn from .cextension import COMPILED_WITH_CUDA from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -from backends import Backends # CUDA specific lib if COMPILED_WITH_CUDA: @@ -887,52 +886,6 @@ def get_4bit_type(typename, device=None, blocksize=64): return data.to(device) - -def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') - -def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): - return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') - -def dequantize_fp4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') - -def dequantize_nf4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: - return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') - -def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: - if code is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - code = code.to(A.device) - - absmax = torch.abs(A).max() - if absmax.dtype != torch.float32: absmax = absmax.float() - inp = A / absmax - out = quantize_no_absmax(inp, code, out) - return out, (absmax, code) - - -def dequantize( - A: Tensor, - state: Tuple[Tensor, Tensor] = None, - absmax: Tensor = None, - code: Tensor = None, - out: Tensor = None, -) -> Tensor: - assert state is not None or absmax is not None - if code is None and state is None: - if "dynamic" not in name2qmap: - name2qmap["dynamic"] = create_dynamic_map().to(A.device) - code = name2qmap["dynamic"] - code = code.to(A.device) - - if state is None: - state = (absmax, code) - out = dequantize_no_absmax(A, state[1], out) - return out * state[0] - def get_colrow_absmax( A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 ): @@ -1035,6 +988,39 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Tensor = None) -> Tensor: post_call(prev_device) return out +def quantize(A: Tensor, code: Tensor = None, out: Tensor = None) -> Tensor: + if code is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + code = code.to(A.device) + + absmax = torch.abs(A).max() + if absmax.dtype != torch.float32: absmax = absmax.float() + inp = A / absmax + out = quantize_no_absmax(inp, code, out) + return out, (absmax, code) + + +def dequantize( + A: Tensor, + state: Tuple[Tensor, Tensor] = None, + absmax: Tensor = None, + code: Tensor = None, + out: Tensor = None, +) -> Tensor: + assert state is not None or absmax is not None + if code is None and state is None: + if "dynamic" not in name2qmap: + name2qmap["dynamic"] = create_dynamic_map().to(A.device) + code = name2qmap["dynamic"] + code = code.to(A.device) + + if state is None: + state = (absmax, code) + out = dequantize_no_absmax(A, state[1], out) + return out * state[0] + def optimizer_update_32bit( optimizer_name: str, @@ -2050,476 +2036,8 @@ def pipeline_test(A, batch_size): lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) return out -class CUDABackend: - @classmethod - def double_quant( - A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 - ): - device = A.device - assert A.dtype == torch.half - assert device.type == "cuda" - prev_device = pre_call(A.device) - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - rows = A.shape[0] - - if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( - A, threshold=threshold - ) - - if out_col is None: - out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) - if out_row is None: - out_row = torch.zeros(A.shape, device=device, dtype=torch.int8) - - coo_tensor = None - ptrA = get_ptr(A) - ptrColStats = get_ptr(col_stats) - ptrRowStats = get_ptr(row_stats) - ptrOutCol = get_ptr(out_col) - ptrOutRow = get_ptr(out_row) - - is_on_gpu([A, col_stats, row_stats, out_col, out_row]) - if threshold > 0.0: - nnz = nnz_row_ptr[-1].item() - if nnz > 0: - coo_tensor = coo_zeros( - A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device - ) - ptrRowIdx = get_ptr(coo_tensor.rowidx) - ptrColIdx = get_ptr(coo_tensor.colidx) - ptrVal = get_ptr(coo_tensor.values) - ptrRowPtr = get_ptr(nnz_row_ptr) - - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - ptrRowIdx, - ptrColIdx, - ptrVal, - ptrRowPtr, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - val, idx = torch.sort(coo_tensor.rowidx) - coo_tensor.rowidx = val - coo_tensor.colidx = coo_tensor.colidx[idx] - coo_tensor.values = coo_tensor.values[idx] - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(0.0), - ct.c_int32(rows), - ct.c_int32(cols), - ) - else: - lib.cdouble_rowcol_quant( - ptrA, - ptrRowStats, - ptrColStats, - ptrOutCol, - ptrOutRow, - None, - None, - None, - None, - ct.c_float(threshold), - ct.c_int32(rows), - ct.c_int32(cols), - ) - post_call(prev_device) - - return out_row, out_col, row_stats, col_stats, coo_tensor - @classmethod - def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - prev_device = pre_call(A.device) - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: new_state = (state[0], to_order) # (shape, order) - - shape = state[0] - if len(shape) == 2: - dim1 = ct.c_int32(shape[0]) - dim2 = ct.c_int32(shape[1]) - else: - dim1 = ct.c_int32(shape[0] * shape[1]) - dim2 = ct.c_int32(shape[2]) - - is_on_gpu([A, out]) - if to_order == 'col32': - if transpose: - lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "col_turing": - if transpose: - lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "col_ampere": - if transpose: - lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) - else: - lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) - elif to_order == "row": - if from_order == "col_turing": - lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) - elif from_order == "col_ampere": - lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) - else: - raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') - - post_call(prev_device) - - return out, new_state - @classmethod - def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): - shapeA = SA[0] - shapeB = SB[0] - dimsA = len(shapeA) - dimsB = len(shapeB) - assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' - if dimsA == 2: - m = shapeA[0] - elif dimsA == 3: - m = shapeA[0] * shapeA[1] - - rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' - - # if the tensor is empty, return a transformed empty tensor with the right dimensions - if shapeA[0] == 0 and dimsA == 2: - return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) - elif shapeA[1] == 0 and dimsA == 3: - return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) - - if dimsA == 2 and out is None: - out, Sout = get_transform_buffer( - (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" - ) - elif dimsA == 3 and out is None: - out, Sout = get_transform_buffer( - (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" - ) - - assert dimsB != 3, "len(B.shape)==3 not supported" - assert A.device.type == "cuda" - assert B.device.type == "cuda" - assert A.dtype == torch.int8 - assert B.dtype == torch.int8 - assert out.dtype == dtype - assert SA[1] == "col32" - assert SB[1] in ["col_turing", "col_ampere"] - assert Sout[1] == "col32" - assert ( - shapeA[-1] == shapeB[-1] - ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" - formatB = SB[1] - prev_device = A.device - torch.cuda.set_device(A.device) - - ptr = CUBLAS_Context.get_instance().get_context(A.device) - ptrA = get_ptr(A) - ptrB = get_ptr(B) - ptrC = get_ptr(out) - - k = shapeA[-1] - lda = ct.c_int32(m * 32) - if formatB == "col_turing": - # turing: tiles with rows filled up to multiple of 8 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32) - else: - # ampere: tiles with rows filled up to multiple of 32 rows by 32 columns - # n = rows - ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32) - - ldc = ct.c_int32(m * 32) - m = ct.c_int32(m) - n = ct.c_int32(n) - k = ct.c_int32(k) - - has_error = 0 - ptrRowScale = get_ptr(None) - is_on_gpu([A, B, out]) - if formatB == 'col_turing': - if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) - else: - has_error = lib.cigemmlt_turing_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) - elif formatB == "col_ampere": - if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) - else: - has_error = lib.cigemmlt_ampere_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) - - if has_error == 1: - print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') - raise Exception('cublasLt ran into an error!') - - torch.cuda.set_device(prev_device) - - return out, Sout - @classmethod - def mm_dequant( - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None - ): - assert A.dtype == torch.int32 - if bias is not None: assert bias.dtype == torch.float16 - out_shape = quant_state[0] - if len(out_shape) == 3: - out_shape = (out_shape[0] * out_shape[1], out_shape[2]) - - if out is None: - out = torch.empty(out_shape, dtype=torch.float16, device=A.device) - if new_row_stats is None: - new_row_stats = torch.empty( - out_shape[0], dtype=torch.float32, device=A.device - ) - if new_col_stats is None: - new_col_stats = torch.empty( - out_shape[1], dtype=torch.float32, device=A.device - ) - assert ( - new_row_stats.shape[0] == row_stats.shape[0] - ), f"{new_row_stats.shape} vs {row_stats.shape}" - assert ( - new_col_stats.shape[0] == col_stats.shape[0] - ), f"{new_col_stats.shape} vs {col_stats.shape}" - - prev_device = pre_call(A.device) - ptrA = get_ptr(A) - ptrOut = get_ptr(out) - ptrRowStats = get_ptr(row_stats) - ptrColStats = get_ptr(col_stats) - ptrNewRowStats = get_ptr(new_row_stats) - ptrNewColStats = get_ptr(new_col_stats) - ptrBias = get_ptr(bias) - numRows = ct.c_int32(out_shape[0]) - numCols = ct.c_int32(out_shape[1]) - - is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) - lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) - post_call(prev_device) - - return out - @classmethod - def extract_outliers(A, SA, idx): - shapeA = SA[0] - formatA = SA[1] - assert formatA in ["col_turing", "col_ampere"] - assert A.device.type == "cuda" - - out = torch.zeros( - (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device - ) - - idx_size = ct.c_int32(idx.numel()) - rows = ct.c_int32(shapeA[0]) - cols = ct.c_int32(shapeA[1]) - ptrA = get_ptr(A) - ptrIdx = get_ptr(idx) - ptrOut = get_ptr(out) - - prev_device = pre_call(A.device) - if formatA == 'col_turing': - lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - elif formatA == "col_ampere": - lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - post_call(prev_device) - - return out - @classmethod - def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4') -> Tensor: - """ - Quantize tensor A in blocks of 4-bit values. - - Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. - - Parameters - ---------- - A : torch.Tensor - The input tensor. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - The output tensor (8-bit). - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4} - - Returns - ------- - torch.Tensor: - The 8-bit tensor with packed 4-bit values. - tuple(torch.Tensor, torch.Size, torch.dtype, int): - The quantization state to undo the quantization. - """ - if A.device.type != 'cuda': - raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') - - n = A.numel() - input_shape = A.shape - - if absmax is None: - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - - - if out is None: - out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) - - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - - prev_device = pre_call(A.device) - is_on_gpu([A, out, absmax]) - - if A.dtype == torch.float32: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - else: - lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - elif A.dtype == torch.float16: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - else: - lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - elif A.dtype == torch.bfloat16: - if quant_type == 'fp4': - lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - else: - lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) - - code = get_4bit_type(quant_type, device=A.device) - - if compress_statistics: - offset = absmax.mean() - absmax -= offset - qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) - del absmax - state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2) - else: - state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, ) - - return out, state - @classmethod - def dequantize_4bit(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4') -> Tensor: - """ - Dequantizes FP4 blockwise quantized values. - - Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. - - Parameters - ---------- - A : torch.Tensor - The input 8-bit tensor (packed 4-bit values). - quant_state : QuantState - object with quantisation stats, incl. absmax values, original tensor shape and original dtype. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - Dequantized output tensor. - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4} - - - Returns - ------- - torch.Tensor: - Dequantized tensor. - """ - if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') - - if quant_state is None: - assert absmax is not None and out is not None - - quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type) - - else: - absmax = quant_state.absmax - - - if quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: absmax = absmax.float() - - if out is None: - out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) - - n = out.numel() - - device = pre_call(A.device) - is_on_gpu([A, absmax, out]) - if out.dtype == torch.float32: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) - else: - lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) - elif out.dtype == torch.float16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) - else: - lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) - elif out.dtype == torch.bfloat16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) - else: - lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - post_call(A.device) - - is_transposed = (True if A.shape[0] == 1 else False) - if is_transposed: return out.t() - else: return out - - -Backends.register_backend("cuda", CUDABackend) +from bitsandbytes.backends import Backends # 8 bits common functions def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): @@ -2558,4 +2076,16 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz def dequantize_4bit(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4'): assert A.device in Backends.device, f"Device backend for {A.device} is not supported" - return Backends.device[A.device].dequantize_4bit(A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type) \ No newline at end of file + return Backends.device[A.device].dequantize_4bit(A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type) + +def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') + +def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): + return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4') + +def dequantize_fp4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') + +def dequantize_nf4(A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64) -> Tensor: + return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') From 4050fe387e8330f1f5a10735c5651377450ee9fe Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 4 Dec 2023 00:49:06 -0800 Subject: [PATCH 10/39] format fix --- bitsandbytes/cextension.py | 3 +++ bitsandbytes/functional.py | 8 ++------ bitsandbytes/nn/modules.py | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 76bfa6647..72fbf18f0 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -4,8 +4,10 @@ from pathlib import Path from warnings import warn + from bitsandbytes.device_setup.cuda.main import CUDASetup + setup = CUDASetup.get_instance() if setup.initialized != True: setup.run_cuda_setup() @@ -35,6 +37,7 @@ COMPILED_WITH_CUDA = False print(str(ex)) + # print the setup details after checking for errors so we do not print twice #if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': #setup.print_log_stack() diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 178fa8614..3c74b30cc 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -15,14 +15,10 @@ from functools import reduce # Required in Python 3 from typing import Tuple, Any, Dict from torch import Tensor - -from warnings import warn -from .cextension import COMPILED_WITH_CUDA from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict -# CUDA specific lib -if COMPILED_WITH_CUDA: - from .cextension import lib +from .cextension import COMPILED_WITH_CUDA, lib + # math.prod not compatible with python < 3.8 def prod(iterable): diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 5de83e867..9c798fe39 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -185,7 +185,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - + if (device is not None and self.data.device.type == "cpu"): return self.quantize_to_device(device) else: From 30175d1967a7e52fa5b3aac85b46f9eee5143ac2 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 4 Dec 2023 00:56:18 -0800 Subject: [PATCH 11/39] format fix --- bitsandbytes/cextension.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 72fbf18f0..d7088e398 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -13,7 +13,6 @@ setup.run_cuda_setup() lib = setup.lib - try: if lib is None and torch.cuda.is_available(): CUDASetup.get_instance().generate_instructions() From e17549e222b6e6712a4a413e4f2ca2410691ed99 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 4 Dec 2023 01:09:34 -0800 Subject: [PATCH 12/39] use device.type --- bitsandbytes/functional.py | 28 ++++++++++++++-------------- bitsandbytes/nn/modules.py | 1 + 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 3c74b30cc..036026c77 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2037,16 +2037,16 @@ def pipeline_test(A, batch_size): # 8 bits common functions def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): - assert A.device in Backends.device, f"Device backend for {A.device} is not supported" - return Backends.device[A.device].double_quant(A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold) + assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" + return Backends.devices[A.device.type].double_quant(A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold) def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - assert A.device in Backends.device, f"Device backend for {A.device} is not supported" - return Backends.device[A.device].transform(A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld) + assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" + return Backends.devices[A.device.type].transform(A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld) def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): - assert A.device in Backends.device, f"Device backend for {A.device} is not supported" - return Backends.device[A.device].igemmlt(A, B, SA, SB, out=out, Sout=Sout, dtype=dtype) + assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" + return Backends.devices[A.device.type].igemmlt(A, B, SA, SB, out=out, Sout=Sout, dtype=dtype) def mm_dequant( A, @@ -2058,21 +2058,21 @@ def mm_dequant( new_col_stats=None, bias=None ): - assert A.device in Backends.device, f"Device backend for {A.device} is not supported" - return Backends.device[A.device].mm_dequant(A, quant_state, row_stats, col_stats, out=out, new_row_stats=new_row_stats, new_col_stats=new_col_stats, bias=bias) + assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" + return Backends.devices[A.device.type].mm_dequant(A, quant_state, row_stats, col_stats, out=out, new_row_stats=new_row_stats, new_col_stats=new_col_stats, bias=bias) def extract_outliers(A, SA, idx): - assert A.device in Backends.device, f"Device backend for {A.device} is not supported" - return Backends.device[A.device].extract_outliers(A, SA, idx) + assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" + return Backends.devices[A.device.type].extract_outliers(A, SA, idx) # 4 bits common functions def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4'): - assert A.device in Backends.device, f"Device backend for {A.device} is not supported" - return Backends.device[A.device].quantize_4bit(A, absmax = absmax, out = out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type) + assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" + return Backends.devices[A.device.type].quantize_4bit(A, absmax = absmax, out = out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type) def dequantize_4bit(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4'): - assert A.device in Backends.device, f"Device backend for {A.device} is not supported" - return Backends.device[A.device].dequantize_4bit(A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type) + assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" + return Backends.devices[A.device.type].dequantize_4bit(A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type) def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False): return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'fp4') diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 9c798fe39..2bde6b6d2 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -162,6 +162,7 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], self.blocksize = self.quant_state.blocksize self.compress_statistics = self.quant_state.nested self.quant_type = self.quant_state.quant_type + return self def quantize_to_device(self, device): w = self.data.contiguous().half().to(device) From a53bc318efe9109f9dcfd6eda579c4051d5ef813 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 4 Dec 2023 01:16:19 -0800 Subject: [PATCH 13/39] minor fix --- bitsandbytes/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 757aafb4f..d6e38b1e5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -230,7 +230,7 @@ def supports_igemmlt(device: torch.device) -> bool: nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series if any(model_name in device_name for model_name in nvidia16_models): return False # these devices are technically cuda 7.5-capable, but they lack tensor cores - if device == "cpu": + if device.type == "cpu": #TODO: will return True once CPU backend upstream the supports return False From 80c598c3ca95c9ffcd07410d2f87df07f6479ed4 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 4 Dec 2023 01:41:58 -0800 Subject: [PATCH 14/39] backend refinement --- bitsandbytes/backends/__init__.py | 41 +++++++++++++++++++ .../{backends.py => backends/cuda.py} | 39 ------------------ 2 files changed, 41 insertions(+), 39 deletions(-) create mode 100644 bitsandbytes/backends/__init__.py rename bitsandbytes/{backends.py => backends/cuda.py} (94%) diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py new file mode 100644 index 000000000..7a4306e92 --- /dev/null +++ b/bitsandbytes/backends/__init__.py @@ -0,0 +1,41 @@ +from .cuda import CUDABackend + + +class Backends: + """ + An dict class for device backends that registered with 8bits and 4bits functions. + + The values of this device backends are lowercase strings, e.g., ``"cuda"``. They can + be accessed as attributes with key-value, e.g., ``Backends.device["cuda"]``. + + """ + + devices = {} + + @classmethod + def register_backend(self, backend_name: str, backend_class): + assert backend_name.lower() in { + "cpu", + "cuda", + "xpu", + }, "register device backend choices in [cpu, cuda, xpu]" + + # check 8bits and 4bits interfaces + if ( + hasattr(backend_class, "double_quant") + and hasattr(backend_class, "transform") + and hasattr(backend_class, "igemmlt") + and hasattr(backend_class, "mm_dequant") + and hasattr(backend_class, "extract_outliers") + and hasattr(backend_class, "quantize_4bit") + and hasattr(backend_class, "dequantize_4bit") + ): + Backends.devices[backend_name.lower()] = backend_class + else: + assert ( + False + ), f"register device backend {backend_name.lower()} but its interfaces are not compelete" + + + +Backends.register_backend("cuda", CUDABackend) diff --git a/bitsandbytes/backends.py b/bitsandbytes/backends/cuda.py similarity index 94% rename from bitsandbytes/backends.py rename to bitsandbytes/backends/cuda.py index e2899e840..84b1b70ca 100644 --- a/bitsandbytes/backends.py +++ b/bitsandbytes/backends/cuda.py @@ -15,43 +15,6 @@ ) from bitsandbytes.functional import CUBLAS_Context, QuantState - -class Backends: - """ - An dict class for device backends that registered with 8bits and 4bits functions. - - The values of this device backends are lowercase strings, e.g., ``"cuda"``. They can - be accessed as attributes with key-value, e.g., ``Backends.device["cuda"]``. - - """ - - devices = {} - - @classmethod - def register_backend(self, backend_name: str, backend_class): - assert backend_name.lower() in { - "cpu", - "cuda", - "xpu", - }, "register device backend choices in [cpu, cuda, xpu]" - - # check 8bits and 4bits interfaces - if ( - hasattr(backend_class, "double_quant") - and hasattr(backend_class, "transform") - and hasattr(backend_class, "igemmlt") - and hasattr(backend_class, "mm_dequant") - and hasattr(backend_class, "extract_outliers") - and hasattr(backend_class, "quantize_4bit") - and hasattr(backend_class, "dequantize_4bit") - ): - Backends.devices[backend_name.lower()] = backend_class - else: - assert ( - False - ), f"register device backend {backend_name.lower()} but its interfaces are not compelete" - - class CUDABackend: @classmethod def double_quant( @@ -686,5 +649,3 @@ def dequantize_4bit( else: return out - -Backends.register_backend("cuda", CUDABackend) From 59facc84c29984adddef5609b6ba7deeb817c614 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 4 Dec 2023 20:23:56 -0800 Subject: [PATCH 15/39] minor fix --- bitsandbytes/autograd/_functions.py | 14 ++++++------- bitsandbytes/backends/cuda.py | 28 +++++++++++++++++++++----- bitsandbytes/device_setup/cuda/main.py | 2 +- bitsandbytes/functional.py | 2 +- 4 files changed, 32 insertions(+), 14 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index d6e38b1e5..43668dd82 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -291,14 +291,14 @@ def tile_indices(self): self._tile_indices = get_tile_inds(self.formatB, self.CxB.device) return self._tile_indices + class MatMul8bitLt(torch.autograd.Function): # forward is the same, but we added the fallback for pre-turing GPUs # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") @staticmethod def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): - device = A.device - using_igemmlt = supports_igemmlt(device) and not state.force_no_igemmlt + using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt # default of pytorch behavior if inputs are empty ctx.is_empty = False if prod(A.shape) == 0: @@ -307,9 +307,9 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): ctx.B = B ctx.bias = bias if A.shape[-1] == B.shape[0]: - return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=device) + return torch.empty(A.shape[:-1] + B.shape[1:], dtype=A.dtype, device=A.device) else: - return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=device) + return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) # 1. Quantize A # 2. Quantize B @@ -341,7 +341,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): else: if state.CxB is None and using_igemmlt: # B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions - # we also need to convert it to the turing/ampere format if using cuda + # we also need to convert it to the turing/ampere format state.CxB, state.SB = F.transform(state.CB, to_order=formatB) else: if not state.has_fp16_weights and state.CxB is None and using_igemmlt: @@ -403,7 +403,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): if using_igemmlt: C32A, SA = F.transform(CA, "col32") out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB) - if bias is None or bias.dtype in [torch.float16, torch.bfloat16]: + if bias is None or bias.dtype == torch.float16: # we apply the fused bias here output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) output = output.to(A.dtype) @@ -568,7 +568,7 @@ def matmul( def matmul_4bit(A: tensor, B: tensor, quant_state: F.QuantState, out: tensor = None, bias=None): assert quant_state is not None - if A.numel() == A.shape[-1] and A.requires_grad == False and A.device == "cuda": + if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type == "cuda": if A.shape[-1] % quant_state.blocksize != 0: warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}') return MatMul4Bit.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index 84b1b70ca..acf1aa5de 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -1,5 +1,6 @@ import torch from torch import Tensor +import ctypes as ct from bitsandbytes.functional import ( pre_call, post_call, @@ -14,11 +15,19 @@ dequantize_blockwise, ) from bitsandbytes.functional import CUBLAS_Context, QuantState +from bitsandbytes.cextension import lib + class CUDABackend: @classmethod def double_quant( - A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 + cls, + A, + col_stats=None, + row_stats=None, + out_col=None, + out_row=None, + threshold=0.0, ): device = A.device assert A.dtype == torch.half @@ -114,7 +123,14 @@ def double_quant( @classmethod def transform( - A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None + cls, + A, + to_order, + from_order="row", + out=None, + transpose=False, + state=None, + ld=None, ): prev_device = pre_call(A.device) if state is None: @@ -167,7 +183,7 @@ def transform( return out, new_state @classmethod - def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + def igemmlt(cls, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeA = SA[0] shapeB = SB[0] dimsA = len(shapeA) @@ -271,6 +287,7 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): @classmethod def mm_dequant( + cls, A, quant_state, row_stats, @@ -332,7 +349,7 @@ def mm_dequant( return out @classmethod - def extract_outliers(A, SA, idx): + def extract_outliers(cls, A, SA, idx): shapeA = SA[0] formatA = SA[1] assert formatA in ["col_turing", "col_ampere"] @@ -358,6 +375,7 @@ def extract_outliers(A, SA, idx): @classmethod def quantize_4bit( + cls, A: Tensor, absmax: Tensor = None, out: Tensor = None, @@ -509,6 +527,7 @@ def quantize_4bit( @classmethod def dequantize_4bit( + cls, A: Tensor, quant_state: QuantState = None, absmax: Tensor = None, @@ -648,4 +667,3 @@ def dequantize_4bit( return out.t() else: return out - diff --git a/bitsandbytes/device_setup/cuda/main.py b/bitsandbytes/device_setup/cuda/main.py index fd639beb7..cf1cf7796 100644 --- a/bitsandbytes/device_setup/cuda/main.py +++ b/bitsandbytes/device_setup/cuda/main.py @@ -125,7 +125,7 @@ def run_cuda_setup(self): self.binary_name = binary_name self.manual_override() - package_dir = Path(__file__).parent.parent + package_dir = Path(__file__).parent.parent.parent binary_path = package_dir / self.binary_name try: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 036026c77..4cacfc983 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2068,7 +2068,7 @@ def extract_outliers(A, SA, idx): # 4 bits common functions def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4'): assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" - return Backends.devices[A.device.type].quantize_4bit(A, absmax = absmax, out = out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type) + return Backends.devices[A.device.type].quantize_4bit(A, absmax=absmax, out=out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type) def dequantize_4bit(A: Tensor, quant_state: Tuple[Tensor, Tensor] = None, absmax: Tensor = None, out: Tensor = None, blocksize: int = 64, quant_type='fp4'): assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" From 066d0dc39663b5bebb467e1fd51ac395f8c38bc4 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 4 Dec 2023 22:56:04 -0800 Subject: [PATCH 16/39] final refinement --- bitsandbytes/backends/__init__.py | 5 ++--- bitsandbytes/nn/modules.py | 22 +++++++++++++--------- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index 7a4306e92..fd046d506 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -13,7 +13,7 @@ class Backends: devices = {} @classmethod - def register_backend(self, backend_name: str, backend_class): + def register_backend(cls, backend_name: str, backend_class): assert backend_name.lower() in { "cpu", "cuda", @@ -30,12 +30,11 @@ def register_backend(self, backend_name: str, backend_class): and hasattr(backend_class, "quantize_4bit") and hasattr(backend_class, "dequantize_4bit") ): - Backends.devices[backend_name.lower()] = backend_class + cls.devices[backend_name.lower()] = backend_class else: assert ( False ), f"register device backend {backend_name.lower()} but its interfaces are not compelete" - Backends.register_backend("cuda", CUDABackend) diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 2bde6b6d2..ddc40cfa6 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -164,8 +164,8 @@ def from_prequantized(cls, data: torch.Tensor, quantized_stats: Dict[str, Any], self.quant_type = self.quant_state.quant_type return self - def quantize_to_device(self, device): - w = self.data.contiguous().half().to(device) + def cuda(self, device): + w = self.data.contiguous().half().cuda(device) w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) self.data = w_4bit self.quant_state = quant_state @@ -187,8 +187,8 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: def to(self, *args, **kwargs): device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs) - if (device is not None and self.data.device.type == "cpu"): - return self.quantize_to_device(device) + if (device is not None and device.type == "cuda" and self.data.device.type == "cpu"): + return self.cuda(device) else: if self.quant_state is not None: self.quant_state.to(device) @@ -296,13 +296,13 @@ def __new__( data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, requires_grad) - def quantize_to_device(self, device): + def cuda(self, device): if self.has_fp16_weights: - return super().to(device) + return super().cuda(device) else: # we store the 8-bit rows-major weight # we convert this weight to the turning/ampere weight during the first inference pass - B = self.data.contiguous().half().to(device) + B = self.data.contiguous().half().cuda(device) CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) del CBt del SCBt @@ -334,8 +334,12 @@ def to(self, *args, **kwargs): *args, **kwargs ) - if (device is not None and self.data.device.type == "cpu"): - return self.quantize_to_device(device) + if ( + device is not None + and device.type == "cuda" + and self.data.device.type == "cpu" + ): + return self.cuda(device) else: new_param = Int8Params( super().to( From cebd83c10e4c4847e448ad0949af033bd8d1c4ef Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Tue, 6 Feb 2024 07:21:01 -0800 Subject: [PATCH 17/39] refine backend register with base-backend --- bitsandbytes/backends/__init__.py | 28 +++----- bitsandbytes/backends/basic_backend.py | 92 ++++++++++++++++++++++++++ bitsandbytes/backends/cuda.py | 8 ++- 3 files changed, 107 insertions(+), 21 deletions(-) create mode 100644 bitsandbytes/backends/basic_backend.py diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index fd046d506..496e7d671 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -1,5 +1,4 @@ -from .cuda import CUDABackend - +from bitsandbytes.cextension import COMPILED_WITH_CUDA class Backends: """ @@ -13,28 +12,17 @@ class Backends: devices = {} @classmethod - def register_backend(cls, backend_name: str, backend_class): + def register_backend(cls, backend_name: str, backend_instance): assert backend_name.lower() in { "cpu", "cuda", "xpu", }, "register device backend choices in [cpu, cuda, xpu]" - # check 8bits and 4bits interfaces - if ( - hasattr(backend_class, "double_quant") - and hasattr(backend_class, "transform") - and hasattr(backend_class, "igemmlt") - and hasattr(backend_class, "mm_dequant") - and hasattr(backend_class, "extract_outliers") - and hasattr(backend_class, "quantize_4bit") - and hasattr(backend_class, "dequantize_4bit") - ): - cls.devices[backend_name.lower()] = backend_class - else: - assert ( - False - ), f"register device backend {backend_name.lower()} but its interfaces are not compelete" - + cls.devices[backend_name.lower()] = backend_instance -Backends.register_backend("cuda", CUDABackend) +if COMPILED_WITH_CUDA: + from .cuda import CUDABackend + cuda_backend = CUDABackend(torch.device("cuda").type) + Backends.register_backend(cuda_backend.get_name(), cuda_backend) +# TODO: register more backends support \ No newline at end of file diff --git a/bitsandbytes/backends/basic_backend.py b/bitsandbytes/backends/basic_backend.py new file mode 100644 index 000000000..8565c5f73 --- /dev/null +++ b/bitsandbytes/backends/basic_backend.py @@ -0,0 +1,92 @@ +from abc import ABC, abstractmethod +import torch +from typing import Optional, Tuple +from bitsandbytes.functional import QuantState + + +class DeviceBackends(ABC): + """Base class for devices backends that will implement their own 8bits and 4bits functions.""" + + @abstractmethod + def get_name(self) -> str: + """Name of the device as the backend support.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def double_quant( + cls, + A, + col_stats=None, + row_stats=None, + out_col=None, + out_row=None, + threshold=0.0, + ): + raise NotImplementedError + + @classmethod + @abstractmethod + def transform( + cls, + A, + to_order, + from_order="row", + out=None, + transpose=False, + state=None, + ld=None, + ): + raise NotImplementedError + + @classmethod + @abstractmethod + def igemmlt(cls, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + raise NotImplementedError + + @classmethod + @abstractmethod + def mm_dequant( + cls, + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None, + ): + raise NotImplementedError + + @classmethod + @abstractmethod + def extract_outliers(cls, A, SA, idx): + raise NotImplementedError + + @classmethod + @abstractmethod + def quantize_4bit( + cls, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type="fp4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + + @classmethod + @abstractmethod + def dequantize_4bit( + cls, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type="fp4", + ) -> torch.Tensor: + raise NotImplementedError diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index f90c3d1e9..7680bf2a1 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -18,9 +18,15 @@ ) from bitsandbytes.functional import CUBLAS_Context, QuantState from bitsandbytes.cextension import lib +from .basic_backend import DeviceBackends +class CUDABackend(DeviceBackends): + def __init__(self, backend_name: str): + self.backend_name = backend_name + + def get_name(self) -> str: + return self.backend_name -class CUDABackend: @classmethod def double_quant( cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 From d20c01764d4980699667795c9711c7d505b9db1c Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Tue, 6 Feb 2024 23:26:19 +0800 Subject: [PATCH 18/39] minor clean format --- tests/test_cuda_setup_evaluator.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 914b7414a..e3620bf41 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -19,11 +19,3 @@ def test_manual_override(requires_cuda): import bitsandbytes as bnb loaded_lib = bnb.device_setup.cuda.main.CUDASetup.get_instance().binary_name #assert loaded_lib == 'libbitsandbytes_cuda122.so' - - - - - - - - From b41c1c4d68a6c4b2154c582ec01c2d2b5bd36f63 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Tue, 6 Feb 2024 21:36:50 -0800 Subject: [PATCH 19/39] format in CI --- bitsandbytes/__init__.py | 2 +- bitsandbytes/backends/__init__.py | 3 ++- bitsandbytes/backends/basic_backend.py | 4 +++- bitsandbytes/backends/cuda.py | 27 +++++++++++++++----------- bitsandbytes/cextension.py | 1 + bitsandbytes/functional.py | 2 +- 6 files changed, 24 insertions(+), 15 deletions(-) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 1045070cd..512fd2455 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import device_setup, utils, research +from . import device_setup, research, utils from .autograd._functions import ( MatmulLtState, bmm_cublas, diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index 496e7d671..bf8a76cba 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -1,5 +1,6 @@ from bitsandbytes.cextension import COMPILED_WITH_CUDA + class Backends: """ An dict class for device backends that registered with 8bits and 4bits functions. @@ -25,4 +26,4 @@ def register_backend(cls, backend_name: str, backend_instance): from .cuda import CUDABackend cuda_backend = CUDABackend(torch.device("cuda").type) Backends.register_backend(cuda_backend.get_name(), cuda_backend) -# TODO: register more backends support \ No newline at end of file +# TODO: register more backends support diff --git a/bitsandbytes/backends/basic_backend.py b/bitsandbytes/backends/basic_backend.py index 8565c5f73..b97723d81 100644 --- a/bitsandbytes/backends/basic_backend.py +++ b/bitsandbytes/backends/basic_backend.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod -import torch from typing import Optional, Tuple + +import torch + from bitsandbytes.functional import QuantState diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index 7680bf2a1..965138a69 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -1,25 +1,30 @@ -import torch -from torch import Tensor import ctypes as ct from typing import Optional, Tuple + +import torch +from torch import Tensor + +from bitsandbytes.cextension import lib from bitsandbytes.functional import ( - pre_call, - post_call, + CUBLAS_Context, + QuantState, + coo_zeros, + dequantize_blockwise, + dtype2bytes, + get_4bit_type, get_colrow_absmax, get_ptr, - is_on_gpu, - coo_zeros, get_transform_buffer, + is_on_gpu, + post_call, + pre_call, prod, - get_4bit_type, quantize_blockwise, - dequantize_blockwise, - dtype2bytes, ) -from bitsandbytes.functional import CUBLAS_Context, QuantState -from bitsandbytes.cextension import lib + from .basic_backend import DeviceBackends + class CUDABackend(DeviceBackends): def __init__(self, backend_name: str): self.backend_name = backend_name diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 0848784c0..dab34982e 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -2,6 +2,7 @@ from warnings import warn import torch + from bitsandbytes.device_setup.cuda.main import CUDASetup setup = CUDASetup.get_instance() diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c2fb491dd..f8a9723cb 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2081,6 +2081,7 @@ def pipeline_test(A, batch_size): from bitsandbytes.backends import Backends + # 8 bits common functions def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" @@ -2127,4 +2128,3 @@ def quantize_4bit( def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor: assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" return Backends.devices[A.device.type].dequantize_4bit(A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type) - From 1ab611e889a2ad069093e355fb6921486b59856c Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Tue, 6 Feb 2024 21:59:50 -0800 Subject: [PATCH 20/39] minor fix for format --- bitsandbytes/backends/__init__.py | 2 +- bitsandbytes/functional.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index bf8a76cba..793d98dc5 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -1,5 +1,5 @@ from bitsandbytes.cextension import COMPILED_WITH_CUDA - +import torch class Backends: """ diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index f8a9723cb..e6649ba34 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -16,6 +16,7 @@ from .cextension import COMPILED_WITH_CUDA, lib +from bitsandbytes.backends import Backends # math.prod not compatible with python < 3.8 def prod(iterable): @@ -2079,9 +2080,6 @@ def pipeline_test(A, batch_size): lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) return out -from bitsandbytes.backends import Backends - - # 8 bits common functions def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" From b933f9f1c686979d6dbf9ea97c753561162459e9 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 7 Feb 2024 23:00:06 +0800 Subject: [PATCH 21/39] refactor base backend registering Co-authored-by: Aarni Koskela --- bitsandbytes/backends/__init__.py | 29 +++++++---------------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index 793d98dc5..084cfa3e0 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -1,29 +1,14 @@ -from bitsandbytes.cextension import COMPILED_WITH_CUDA +import typing import torch -class Backends: - """ - An dict class for device backends that registered with 8bits and 4bits functions. - - The values of this device backends are lowercase strings, e.g., ``"cuda"``. They can - be accessed as attributes with key-value, e.g., ``Backends.device["cuda"]``. - - """ - - devices = {} +from bitsandbytes.cextension import COMPILED_WITH_CUDA +from bitsandbytes.backends.base import Backend - @classmethod - def register_backend(cls, backend_name: str, backend_instance): - assert backend_name.lower() in { - "cpu", - "cuda", - "xpu", - }, "register device backend choices in [cpu, cuda, xpu]" +backends: Dict[str, Backend] = {} - cls.devices[backend_name.lower()] = backend_instance +def register_backend(backend_name: str, backend_instance: Backend): + backends[backend_name.lower()] = backend_instance if COMPILED_WITH_CUDA: from .cuda import CUDABackend - cuda_backend = CUDABackend(torch.device("cuda").type) - Backends.register_backend(cuda_backend.get_name(), cuda_backend) -# TODO: register more backends support + register_backend("cuda", CUDABackend()) From 8b4baaa4ac53dc5051ba29cd5cd4093f7b149aad Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 7 Feb 2024 07:38:24 -0800 Subject: [PATCH 22/39] refine structures of backends --- bitsandbytes/backends/__init__.py | 2 +- bitsandbytes/backends/base.py | 133 ++++++++++++ bitsandbytes/backends/basic_backend.py | 94 --------- bitsandbytes/backends/cuda.py | 88 ++------ bitsandbytes/functional.py | 275 ++++++++++--------------- bitsandbytes/utils.py | 120 ++++++++++- 6 files changed, 372 insertions(+), 340 deletions(-) create mode 100644 bitsandbytes/backends/base.py delete mode 100644 bitsandbytes/backends/basic_backend.py diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index 084cfa3e0..0ae01a3d3 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -1,4 +1,4 @@ -import typing +from typing import Dict import torch from bitsandbytes.cextension import COMPILED_WITH_CUDA diff --git a/bitsandbytes/backends/base.py b/bitsandbytes/backends/base.py new file mode 100644 index 000000000..8232d17c1 --- /dev/null +++ b/bitsandbytes/backends/base.py @@ -0,0 +1,133 @@ +from abc import ABC, abstractmethod +from typing import Optional, Tuple + +import torch + +from bitsandbytes.utils import QuantState + + +class Backend(ABC): + """Base class for devices backends that will implement their own 8bits and 4bits functions.""" + + @abstractmethod + def double_quant( + self, + A, + col_stats=None, + row_stats=None, + out_col=None, + out_row=None, + threshold=0.0, + ): + raise NotImplementedError + + @abstractmethod + def transform( + self, + A, + to_order, + from_order="row", + out=None, + transpose=False, + state=None, + ld=None, + ): + raise NotImplementedError + + @abstractmethod + def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + raise NotImplementedError + + @abstractmethod + def mm_dequant( + self, + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None, + ): + raise NotImplementedError + + @abstractmethod + def extract_outliers(self, A, SA, idx): + raise NotImplementedError + + @abstractmethod + def quantize_4bit( + self, + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type="fp4", + quant_storage=torch.uint8, + ) -> Tuple[torch.Tensor, QuantState]: + """ + Quantize tensor A in blocks of 4-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + Returns + ------- + torch.Tensor: + Tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype, int): + The quantization state to undo the quantization. + """ + raise NotImplementedError + + @abstractmethod + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type="fp4", + ) -> torch.Tensor: + """ + Dequantizes FP4 blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input tensor (packed 4-bit values). + quant_state : QuantState + object with quantisation stats, incl. absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + raise NotImplementedError diff --git a/bitsandbytes/backends/basic_backend.py b/bitsandbytes/backends/basic_backend.py deleted file mode 100644 index b97723d81..000000000 --- a/bitsandbytes/backends/basic_backend.py +++ /dev/null @@ -1,94 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional, Tuple - -import torch - -from bitsandbytes.functional import QuantState - - -class DeviceBackends(ABC): - """Base class for devices backends that will implement their own 8bits and 4bits functions.""" - - @abstractmethod - def get_name(self) -> str: - """Name of the device as the backend support.""" - raise NotImplementedError - - @classmethod - @abstractmethod - def double_quant( - cls, - A, - col_stats=None, - row_stats=None, - out_col=None, - out_row=None, - threshold=0.0, - ): - raise NotImplementedError - - @classmethod - @abstractmethod - def transform( - cls, - A, - to_order, - from_order="row", - out=None, - transpose=False, - state=None, - ld=None, - ): - raise NotImplementedError - - @classmethod - @abstractmethod - def igemmlt(cls, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): - raise NotImplementedError - - @classmethod - @abstractmethod - def mm_dequant( - cls, - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None, - ): - raise NotImplementedError - - @classmethod - @abstractmethod - def extract_outliers(cls, A, SA, idx): - raise NotImplementedError - - @classmethod - @abstractmethod - def quantize_4bit( - cls, - A: torch.Tensor, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize=64, - compress_statistics=False, - quant_type="fp4", - quant_storage=torch.uint8, - ) -> Tuple[torch.Tensor, QuantState]: - raise NotImplementedError - - @classmethod - @abstractmethod - def dequantize_4bit( - cls, - A: torch.Tensor, - quant_state: Optional[QuantState] = None, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize: int = 64, - quant_type="fp4", - ) -> torch.Tensor: - raise NotImplementedError diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index 965138a69..248d1e4c1 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -2,12 +2,10 @@ from typing import Optional, Tuple import torch -from torch import Tensor from bitsandbytes.cextension import lib from bitsandbytes.functional import ( CUBLAS_Context, - QuantState, coo_zeros, dequantize_blockwise, dtype2bytes, @@ -22,19 +20,14 @@ quantize_blockwise, ) -from .basic_backend import DeviceBackends +from bitsandbytes.utils import QuantState +from .base import Backend -class CUDABackend(DeviceBackends): - def __init__(self, backend_name: str): - self.backend_name = backend_name - def get_name(self) -> str: - return self.backend_name - - @classmethod +class CUDABackend(Backend): def double_quant( - cls, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 + self, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 ): device = A.device assert A.dtype == torch.half @@ -128,8 +121,7 @@ def double_quant( return out_row, out_col, row_stats, col_stats, coo_tensor - @classmethod - def transform(cls, A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): + def transform(self, A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): prev_device = pre_call(A.device) if state is None: state = (A.shape, from_order) else: from_order = state[1] @@ -172,8 +164,7 @@ def transform(cls, A, to_order, from_order='row', out=None, transpose=False, sta return out, new_state - @classmethod - def igemmlt(cls, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeA = SA[0] shapeB = SB[0] dimsA = len(shapeA) @@ -272,9 +263,8 @@ def igemmlt(cls, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): return out, Sout - @classmethod def mm_dequant( - cls, + self, A, quant_state, row_stats, @@ -324,8 +314,7 @@ def mm_dequant( return out - @classmethod - def extract_outliers(cls, A, SA, idx): + def extract_outliers(self, A, SA, idx): shapeA = SA[0] formatA = SA[1] assert formatA in ["col_turing", "col_ampere"] @@ -351,42 +340,16 @@ def extract_outliers(cls, A, SA, idx): return out - @classmethod def quantize_4bit( - cls, - A: Tensor, + self, + A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_type='fp4', quant_storage=torch.uint8, - ) -> Tuple[Tensor, QuantState]: - """ - Quantize tensor A in blocks of 4-bit values. - - Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. - - Parameters - ---------- - A : torch.Tensor - The input tensor. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - The output tensor. - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4} - - Returns - ------- - torch.Tensor: - Tensor with packed 4-bit values. - tuple(torch.Tensor, torch.Size, torch.dtype, int): - The quantization state to undo the quantization. - """ + ) -> Tuple[torch.Tensor, QuantState]: if A.device.type != 'cuda': raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') if quant_type not in ['fp4', 'nf4']: @@ -442,34 +405,7 @@ def quantize_4bit( return out, state - @classmethod - def dequantize_4bit(cls, A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor: - """ - Dequantizes FP4 blockwise quantized values. - - Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. - - Parameters - ---------- - A : torch.Tensor - The input tensor (packed 4-bit values). - quant_state : QuantState - object with quantisation stats, incl. absmax values, original tensor shape and original dtype. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - Dequantized output tensor. - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4} - - - Returns - ------- - torch.Tensor: - Dequantized tensor. - """ + def dequantize_4bit(self, A: torch.Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> torch.Tensor: if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") if quant_type not in ['fp4', 'nf4']: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index e6649ba34..b75eac67e 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -16,7 +16,9 @@ from .cextension import COMPILED_WITH_CUDA, lib -from bitsandbytes.backends import Backends +from bitsandbytes.utils import QuantState + +from bitsandbytes.backends import backends # math.prod not compatible with python < 3.8 def prod(iterable): @@ -589,125 +591,6 @@ def estimate_quantiles(A: Tensor, out: Optional[torch.Tensor] = None, offset: fl return out -class QuantState: - """container for quantization state components to work with Params4bit and similar classes""" - valid_quant_types = ('fp4', 'nf4') - valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] - valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type', - 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] - - def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): - self.absmax = absmax - self.shape = shape - self.code = code - self.dtype = dtype - self.blocksize = blocksize - self.quant_type = quant_type - self.offset = offset - self.state2 = state2 - self.nested = state2 is not None - - def __get_item__(self, idx): - """ - ensures compatibility with older quant state scheme with nested lists. - assumes the following layout: - state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] - state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] - """ - if self.nested: - list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, [self.offset, self.state2], self.quant_type] - else: - list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] - return list_repr[idx] - - @classmethod - def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState': - """ - unpacks components of state_dict into QuantState - where necessary, convert into strings, torch.dtype, ints, etc. - - qs_dict: based on state_dict, with only relevant keys, striped of prefixes. - - item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. - """ - - # unpacking tensor with non-tensor components - qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] - if not len(qs_key) and 'quant_type' not in qs_dict: - raise ValueError("Expected packed or unpacked quant_state items, found neither") - elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: - raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.") - - # unpacking minor and non-tensor quant state items if necessary - if len(qs_key) == 1: - first_qs_key = qs_key[0] - qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) - - qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes - assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) - - if 'nested_absmax' in qs_dict: - offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) - state2 = cls( - absmax=qs_dict['nested_absmax'].to(device), - blocksize=qs_dict['nested_blocksize'], - code=qs_dict['nested_quant_map'].to(device), - dtype=getattr(torch, qs_dict['nested_dtype']), - ) - else: - offset, state2 = None, None - - quant_state = cls( - quant_type=qs_dict['quant_type'], - absmax=qs_dict['absmax'].to(device), - blocksize=qs_dict['blocksize'], - code=qs_dict['quant_map'].to(device), - dtype=getattr(torch, qs_dict['dtype']), - shape=torch.Size(qs_dict['shape']) if qs_dict['shape'] is not None else None, - offset=offset, - state2=state2, - ) - return quant_state - - def as_dict(self, packed=False): - """ - returns dict of tensors and strings to use in serialization via _save_to_state_dict() - param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving - """ - qs_dict = { - 'quant_type': self.quant_type, - 'absmax': self.absmax, - 'blocksize': self.blocksize, - 'quant_map': self.code, - 'dtype': str(self.dtype).strip('torch.'), - 'shape': tuple(self.shape), - } - if self.nested: - qs_dict.update({ - 'nested_absmax': self.state2.absmax, - 'nested_blocksize': self.state2.blocksize, - 'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors - 'nested_dtype': str(self.state2.dtype).strip('torch.'), - 'nested_offset': self.offset.item(), - }) - if not packed: - return qs_dict - - # packed format allows serialization of non-tensor components, critical for saving in safetensors format - qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} - non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} - qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) - return qs_packed_dict - - def to(self, device): - # make sure the quantization state is on the right device - self.absmax = self.absmax.to(device) - if self.nested: - self.offset = self.offset.to(device) - self.state2.absmax = self.state2.absmax.to(device) - self.state2.code = self.state2.code.to(device) - - def quantize_blockwise( A: Tensor, code: Optional[torch.Tensor] = None, @@ -918,12 +801,81 @@ def quantize_fp4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional def quantize_nf4(A: Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) + +def quantize_4bit( + A: Tensor, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=64, + compress_statistics=False, + quant_type='fp4', + quant_storage=torch.uint8, +) -> Tuple[Tensor, QuantState]: + """ + Quantize tensor A in blocks of 4-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + The output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + Returns + ------- + torch.Tensor: + Tensor with packed 4-bit values. + tuple(torch.Tensor, torch.Size, torch.dtype, int): + The quantization state to undo the quantization. + """ + assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + return backends[A.device.type].quantize_4bit(A, absmax=absmax, out=out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage) + def dequantize_fp4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'fp4') def dequantize_nf4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, 'nf4') +def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor: + """ + Dequantizes FP4 blockwise quantized values. + + Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. + + Parameters + ---------- + A : torch.Tensor + The input tensor (packed 4-bit values). + quant_state : QuantState + object with quantisation stats, incl. absmax values, original tensor shape and original dtype. + absmax : torch.Tensor + The absmax values. + out : torch.Tensor + Dequantized output tensor. + blocksize : int + The blocksize used in quantization. + quant_type : str + The 4-bit quantization data type {fp4, nf4} + + + Returns + ------- + torch.Tensor: + Dequantized tensor. + """ + assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + return backends[A.device.type].dequantize_4bit(A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type) + + def quantize( A: Tensor, code: Optional[torch.Tensor] = None, @@ -1690,6 +1642,25 @@ def batched_igemm( return out +def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): + assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + return backends[A.device.type].igemmlt(A, B, SA, SB, out=out, Sout=Sout, dtype=dtype) + + +def mm_dequant( + A, + quant_state, + row_stats, + col_stats, + out=None, + new_row_stats=None, + new_col_stats=None, + bias=None +): + assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + return backends[A.device.type].mm_dequant(A, quant_state, row_stats, col_stats, out=out, new_row_stats=new_row_stats, new_col_stats=new_col_stats, bias=bias) + + def get_colrow_absmax( A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0 ): @@ -1823,6 +1794,16 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) +def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): + assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + return backends[A.device.type].double_quant(A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold) + + +def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): + assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + return backends[A.device.type].transform(A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld) + + def spmm_coo(cooA, B, out=None): if out is None: out = torch.empty( @@ -2075,54 +2056,12 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): return x.to(dtype) -def pipeline_test(A, batch_size): - out = torch.zeros_like(A) - lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) - return out - -# 8 bits common functions -def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): - assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" - return Backends.devices[A.device.type].double_quant(A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold) - -def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" - return Backends.devices[A.device.type].transform(A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld) - -def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): - assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" - return Backends.devices[A.device.type].igemmlt(A, B, SA, SB, out=out, Sout=Sout, dtype=dtype) - -def mm_dequant( - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None -): - assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" - return Backends.devices[A.device.type].mm_dequant(A, quant_state, row_stats, col_stats, out=out, new_row_stats=new_row_stats, new_col_stats=new_col_stats, bias=bias) - def extract_outliers(A, SA, idx): - assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" - return Backends.devices[A.device.type].extract_outliers(A, SA, idx) + assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + return backends[A.device.type].extract_outliers(A, SA, idx) -# 4 bits common functions -def quantize_4bit( - A: Tensor, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize=64, - compress_statistics=False, - quant_type='fp4', - quant_storage=torch.uint8, -) -> Tuple[Tensor, QuantState]: - assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" - return Backends.devices[A.device.type].quantize_4bit(A, absmax=absmax, out=out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage) -def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> Tensor: - assert A.device.type in Backends.devices, f"Device backend for {A.device.type} is not supported" - return Backends.devices[A.device.type].dequantize_4bit(A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type) +def pipeline_test(A, batch_size): + out = torch.zeros_like(A) + lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) + return out \ No newline at end of file diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 0582f7fc0..8c42ddfed 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,7 +1,7 @@ import json import shlex import subprocess -from typing import Tuple +from typing import Tuple, Dict, Any import torch @@ -200,3 +200,121 @@ def unpack_tensor_to_dict(tensor_data): unpacked_dict = json.loads(json_str) return unpacked_dict + +class QuantState: + """container for quantization state components to work with Params4bit and similar classes""" + valid_quant_types = ('fp4', 'nf4') + valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] + valid_qs_keys = ['absmax', 'quant_map', 'nested_absmax', 'nested_quant_map', 'quant_state', 'quant_type', + 'blocksize', 'dtype', 'shape', 'nested_blocksize', 'nested_dtype', 'nested_offset'] + + def __init__(self, absmax, shape=None, code=None, blocksize=None, quant_type=None, dtype=None, offset=None, state2=None): + self.absmax = absmax + self.shape = shape + self.code = code + self.dtype = dtype + self.blocksize = blocksize + self.quant_type = quant_type + self.offset = offset + self.state2 = state2 + self.nested = state2 is not None + + def __get_item__(self, idx): + """ + ensures compatibility with older quant state scheme with nested lists. + assumes the following layout: + state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + """ + if self.nested: + list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, [self.offset, self.state2], self.quant_type] + else: + list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] + return list_repr[idx] + + @classmethod + def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> 'QuantState': + """ + unpacks components of state_dict into QuantState + where necessary, convert into strings, torch.dtype, ints, etc. + + qs_dict: based on state_dict, with only relevant keys, striped of prefixes. + + item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. + """ + + # unpacking tensor with non-tensor components + qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] + if not len(qs_key) and 'quant_type' not in qs_dict: + raise ValueError("Expected packed or unpacked quant_state items, found neither") + elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: + raise ValueError(f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.") + + # unpacking minor and non-tensor quant state items if necessary + if len(qs_key) == 1: + first_qs_key = qs_key[0] + qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) + + qs_dict = {k.split('.')[-1]: v for k, v in qs_dict.items()} # strip prefixes + assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) + + if 'nested_absmax' in qs_dict: + offset = torch.tensor(float(qs_dict['nested_offset'])).to(device) + state2 = cls( + absmax=qs_dict['nested_absmax'].to(device), + blocksize=qs_dict['nested_blocksize'], + code=qs_dict['nested_quant_map'].to(device), + dtype=getattr(torch, qs_dict['nested_dtype']), + ) + else: + offset, state2 = None, None + + quant_state = cls( + quant_type=qs_dict['quant_type'], + absmax=qs_dict['absmax'].to(device), + blocksize=qs_dict['blocksize'], + code=qs_dict['quant_map'].to(device), + dtype=getattr(torch, qs_dict['dtype']), + shape=torch.Size(qs_dict['shape']) if qs_dict['shape'] is not None else None, + offset=offset, + state2=state2, + ) + return quant_state + + def as_dict(self, packed=False): + """ + returns dict of tensors and strings to use in serialization via _save_to_state_dict() + param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving + """ + qs_dict = { + 'quant_type': self.quant_type, + 'absmax': self.absmax, + 'blocksize': self.blocksize, + 'quant_map': self.code, + 'dtype': str(self.dtype).strip('torch.'), + 'shape': tuple(self.shape), + } + if self.nested: + qs_dict.update({ + 'nested_absmax': self.state2.absmax, + 'nested_blocksize': self.state2.blocksize, + 'nested_quant_map': self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors + 'nested_dtype': str(self.state2.dtype).strip('torch.'), + 'nested_offset': self.offset.item(), + }) + if not packed: + return qs_dict + + # packed format allows serialization of non-tensor components, critical for saving in safetensors format + qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} + non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} + qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) + return qs_packed_dict + + def to(self, device): + # make sure the quantization state is on the right device + self.absmax = self.absmax.to(device) + if self.nested: + self.offset = self.offset.to(device) + self.state2.absmax = self.state2.absmax.to(device) + self.state2.code = self.state2.code.to(device) From 0905ad743f887ef396fedb0b364cfec04d8acd26 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Thu, 8 Feb 2024 07:32:13 -0800 Subject: [PATCH 23/39] fix import issue --- bitsandbytes/__init__.py | 4 +++- bitsandbytes/backends/__init__.py | 5 ----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 512fd2455..e7eb6af6f 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -17,7 +17,9 @@ if COMPILED_WITH_CUDA: from .optim import adam - + from .backends import register_backend, backends + from .backends.cuda import CUDABackend + register_backend("cuda", CUDABackend()) __pdoc__ = { "libbitsandbytes": False, "optim.optimizer.Optimizer8bit": False, diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index 0ae01a3d3..015b719cc 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -1,14 +1,9 @@ from typing import Dict import torch -from bitsandbytes.cextension import COMPILED_WITH_CUDA from bitsandbytes.backends.base import Backend backends: Dict[str, Backend] = {} def register_backend(backend_name: str, backend_instance: Backend): backends[backend_name.lower()] = backend_instance - -if COMPILED_WITH_CUDA: - from .cuda import CUDABackend - register_backend("cuda", CUDABackend()) From 145a8357c3063b94a15f02f12d062db6b478de1d Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Thu, 8 Feb 2024 23:33:38 +0800 Subject: [PATCH 24/39] minor clean --- bitsandbytes/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index e7eb6af6f..3f0db3536 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -17,7 +17,7 @@ if COMPILED_WITH_CUDA: from .optim import adam - from .backends import register_backend, backends + from .backends import register_backend from .backends.cuda import CUDABackend register_backend("cuda", CUDABackend()) __pdoc__ = { From d270832cb16c8d83d4a312bc569f8ae03b6cb2b3 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Mon, 12 Feb 2024 19:33:30 -0800 Subject: [PATCH 25/39] fix CI python format --- bitsandbytes/__init__.py | 2 +- bitsandbytes/backends/__init__.py | 1 + bitsandbytes/backends/cuda.py | 1 - bitsandbytes/functional.py | 10 ++++------ bitsandbytes/utils.py | 2 +- 5 files changed, 7 insertions(+), 9 deletions(-) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 3f0db3536..c42b4a274 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -16,9 +16,9 @@ from .nn import modules if COMPILED_WITH_CUDA: - from .optim import adam from .backends import register_backend from .backends.cuda import CUDABackend + from .optim import adam register_backend("cuda", CUDABackend()) __pdoc__ = { "libbitsandbytes": False, diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index 015b719cc..3a33d24ca 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -1,4 +1,5 @@ from typing import Dict + import torch from bitsandbytes.backends.base import Backend diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index 248d1e4c1..6ba02d009 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -19,7 +19,6 @@ prod, quantize_blockwise, ) - from bitsandbytes.utils import QuantState from .base import Backend diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index b75eac67e..9dbd5c1f0 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -6,19 +6,17 @@ from functools import reduce # Required in Python 3 import itertools import operator -from typing import Any, Dict, Optional, Tuple +from typing import Optional, Tuple import numpy as np import torch from torch import Tensor -from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict +from bitsandbytes.backends import backends +from bitsandbytes.utils import QuantState from .cextension import COMPILED_WITH_CUDA, lib -from bitsandbytes.utils import QuantState - -from bitsandbytes.backends import backends # math.prod not compatible with python < 3.8 def prod(iterable): @@ -2064,4 +2062,4 @@ def extract_outliers(A, SA, idx): def pipeline_test(A, batch_size): out = torch.zeros_like(A) lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) - return out \ No newline at end of file + return out diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 8c42ddfed..032bb31e5 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,7 +1,7 @@ import json import shlex import subprocess -from typing import Tuple, Dict, Any +from typing import Any, Dict, Tuple import torch From 68e785908adaaf3c1b0d06fbf0fba6ce7445df12 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Thu, 15 Feb 2024 21:04:40 +0000 Subject: [PATCH 26/39] fix py38 vers incompatibility from other PR --- tests/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index 46c6ef93d..f82a8631f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,13 +1,13 @@ from itertools import product import random -from typing import Any +from typing import Any, List import torch test_dims_rng = random.Random(42) -def get_test_dims(min: int, max: int, *, n: int) -> list[int]: +def get_test_dims(min: int, max: int, *, n: int) -> List[int]: return [test_dims_rng.randint(min, max) for _ in range(n)] From 012b565dea120fe40a353493881058f6ea0a48b5 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 16 Feb 2024 22:23:06 +0000 Subject: [PATCH 27/39] update pre-commit --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index edcbc9b6b..4fb5cf528 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.0 + rev: v0.2.1 hooks: - id: ruff args: @@ -18,6 +18,6 @@ repos: args: - --fix=lf - repo: https://github.com/crate-ci/typos - rev: v1.17.2 + rev: typos-v0.10.21 hooks: - id: typos From 8fa27f60b2f7c9fd5398a6eff8da0eafe9ed8f1e Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 16 Feb 2024 22:23:59 +0000 Subject: [PATCH 28/39] cuda.py: harmonize whitespace --- bitsandbytes/backends/cuda.py | 40 ++++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index 6ba02d009..4b9ae4b87 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -122,10 +122,15 @@ def double_quant( def transform(self, A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): prev_device = pre_call(A.device) - if state is None: state = (A.shape, from_order) - else: from_order = state[1] - if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) - else: new_state = (state[0], to_order) # (shape, order) + if state is None: + state = (A.shape, from_order) + else: + from_order = state[1] + + if out is None: + out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) + else: + new_state = (state[0], to_order) # (shape, order) shape = state[0] if len(shape) == 2: @@ -141,21 +146,25 @@ def transform(self, A, to_order, from_order='row', out=None, transpose=False, st lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) else: lib.ctransform_row2col32(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == "col_turing": if transpose: lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) else: lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == "col_ampere": if transpose: lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) else: lib.ctransform_row2ampere(get_ptr(A), get_ptr(out), dim1, dim2) + elif to_order == "row": if from_order == "col_turing": lib.ctransform_turing2row(get_ptr(A), get_ptr(out), dim1, dim2) elif from_order == "col_ampere": lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) + else: raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') @@ -168,6 +177,7 @@ def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): shapeB = SB[0] dimsA = len(shapeA) dimsB = len(shapeB) + assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' if dimsA == 2: m = shapeA[0] @@ -204,6 +214,7 @@ def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): assert ( shapeA[-1] == shapeB[-1] ), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" + formatB = SB[1] prev_device = A.device torch.cuda.set_device(A.device) @@ -232,6 +243,7 @@ def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = 0 ptrRowScale = get_ptr(None) is_on_gpu([A, B, out]) + if formatB == 'col_turing': if dtype == torch.int32: has_error = lib.cigemmlt_turing_32( @@ -241,6 +253,7 @@ def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): has_error = lib.cigemmlt_turing_8( ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc ) + elif formatB == "col_ampere": if dtype == torch.int32: has_error = lib.cigemmlt_ampere_32( @@ -331,10 +344,12 @@ def extract_outliers(self, A, SA, idx): ptrOut = get_ptr(out) prev_device = pre_call(A.device) + if formatA == 'col_turing': lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + post_call(prev_device) return out @@ -362,7 +377,6 @@ def quantize_4bit( blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - if out is None: mod = dtype2bytes[quant_storage] * 2 out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device) @@ -377,18 +391,22 @@ def quantize_4bit( lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + elif A.dtype == torch.float16: if quant_type == 'fp4': lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + elif A.dtype == torch.bfloat16: if quant_type == 'fp4': lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) code = get_4bit_type(quant_type, device=A.device) @@ -399,14 +417,16 @@ def quantize_4bit( qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2) + else: - state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, ) + state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type) return out, state def dequantize_4bit(self, A: torch.Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> torch.Tensor: if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + if quant_type not in ['fp4', 'nf4']: raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') @@ -414,11 +434,9 @@ def dequantize_4bit(self, A: torch.Tensor, quant_state: Optional[QuantState] = N assert absmax is not None and out is not None quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type) - else: absmax = quant_state.absmax - if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset @@ -431,25 +449,31 @@ def dequantize_4bit(self, A: torch.Tensor, quant_state: Optional[QuantState] = N device = pre_call(A.device) is_on_gpu([A, absmax, out]) + if out.dtype == torch.float32: if quant_state.quant_type == 'fp4': lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) else: lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + elif out.dtype == torch.float16: if quant_state.quant_type == 'fp4': lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) else: lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + elif out.dtype == torch.bfloat16: if quant_state.quant_type == 'fp4': lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) else: lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + post_call(A.device) is_transposed = (True if A.shape[0] == 1 else False) + if is_transposed: return out.t() else: return out From 2c04d4821a90f8a0bd2b1bfff0cd73e83006d7d8 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 16 Feb 2024 22:25:02 +0000 Subject: [PATCH 29/39] delete dead code --- bitsandbytes/cextension.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index dab34982e..db9c05779 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -32,8 +32,3 @@ "8-bit optimizers, 8-bit multiplication, and CUDA GPU quantization are unavailable.") COMPILED_WITH_CUDA = False print(str(ex)) - - -# print the setup details after checking for errors so we do not print twice -#if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': - #setup.print_log_stack() From c1846557a0388c553d13a0372a2be3e0d9720acb Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 16 Feb 2024 22:29:14 +0000 Subject: [PATCH 30/39] fix whitespace --- bitsandbytes/backends/cuda.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index 4b9ae4b87..4fa1946e9 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -152,7 +152,7 @@ def transform(self, A, to_order, from_order='row', out=None, transpose=False, st lib.ctransform_row2turingT(get_ptr(A), get_ptr(out), dim1, dim2) else: lib.ctransform_row2turing(get_ptr(A), get_ptr(out), dim1, dim2) - + elif to_order == "col_ampere": if transpose: lib.ctransform_row2ampereT(get_ptr(A), get_ptr(out), dim1, dim2) @@ -349,7 +349,7 @@ def extract_outliers(self, A, SA, idx): lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) - + post_call(prev_device) return out @@ -403,7 +403,7 @@ def quantize_4bit( lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) else: lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) - + else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") @@ -474,6 +474,6 @@ def dequantize_4bit(self, A: torch.Tensor, quant_state: Optional[QuantState] = N post_call(A.device) is_transposed = (True if A.shape[0] == 1 else False) - + if is_transposed: return out.t() else: return out From 03b53d7eb4558f80507155b8da148c177774d483 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 16 Feb 2024 22:29:24 +0000 Subject: [PATCH 31/39] fix typo --- csrc/kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/kernels.cu b/csrc/kernels.cu index df8488389..65aa14896 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3073,7 +3073,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * //// 4. do dequantization from register of B into second pair of registers //// 5. store (4) into fragment //// 6. matmul aggregate into fragment C -//// 7. aggreecate files of C into shared memory block C +//// 7. aggregate files of C into shared memory block C //// 8. sum (7) //// 9. write outputs to matmul output matrix //} From ba7a1620bef5231c8817218c0b24b580e3e80f25 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Fri, 16 Feb 2024 22:31:59 +0000 Subject: [PATCH 32/39] remove exstraneous import --- bitsandbytes/backends/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index 3a33d24ca..5fb2fc130 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -1,7 +1,5 @@ from typing import Dict -import torch - from bitsandbytes.backends.base import Backend backends: Dict[str, Backend] = {} From d162998ee5b7875e8b8bbd75780106757a467fc2 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Sat, 17 Feb 2024 00:28:11 +0000 Subject: [PATCH 33/39] factor out ensure_backend_is_available, exc instead of assert --- bitsandbytes/backends/__init__.py | 5 +++++ bitsandbytes/functional.py | 16 ++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index 5fb2fc130..d35021b1e 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -6,3 +6,8 @@ def register_backend(backend_name: str, backend_instance: Backend): backends[backend_name.lower()] = backend_instance + +def ensure_backend_is_available(device_type: str): + """Check if a backend is available for the given device type.""" + if device_type.lower() not in backends: + raise NotImplementedError(f"Device backend for {device_type} is currently not supported.") diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 9dbd5c1f0..e94265e53 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -12,7 +12,7 @@ import torch from torch import Tensor -from bitsandbytes.backends import backends +from bitsandbytes.backends import backends, ensure_backend_is_available from bitsandbytes.utils import QuantState from .cextension import COMPILED_WITH_CUDA, lib @@ -834,7 +834,7 @@ def quantize_4bit( tuple(torch.Tensor, torch.Size, torch.dtype, int): The quantization state to undo the quantization. """ - assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + ensure_backend_is_available(A.device.type) return backends[A.device.type].quantize_4bit(A, absmax=absmax, out=out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage) def dequantize_fp4(A: Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64) -> Tensor: @@ -870,7 +870,7 @@ def dequantize_4bit(A: Tensor, quant_state: Optional[QuantState] = None, absmax: torch.Tensor: Dequantized tensor. """ - assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + ensure_backend_is_available(A.device.type) return backends[A.device.type].dequantize_4bit(A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type) @@ -1641,7 +1641,7 @@ def batched_igemm( def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): - assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + ensure_backend_is_available(A.device.type) return backends[A.device.type].igemmlt(A, B, SA, SB, out=out, Sout=Sout, dtype=dtype) @@ -1655,7 +1655,7 @@ def mm_dequant( new_col_stats=None, bias=None ): - assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + ensure_backend_is_available(A.device.type) return backends[A.device.type].mm_dequant(A, quant_state, row_stats, col_stats, out=out, new_row_stats=new_row_stats, new_col_stats=new_col_stats, bias=bias) @@ -1793,12 +1793,12 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): - assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + ensure_backend_is_available(A.device.type) return backends[A.device.type].double_quant(A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold) def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): - assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + ensure_backend_is_available(A.device.type) return backends[A.device.type].transform(A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld) @@ -2055,7 +2055,7 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): def extract_outliers(A, SA, idx): - assert A.device.type in backends, f"Device backend for {A.device.type} is not supported" + ensure_backend_is_available(A.device.type) return backends[A.device.type].extract_outliers(A, SA, idx) From 2cd9718cdff2fc5da4a19ff7c912426b93f8f094 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 21 Feb 2024 16:56:38 +0800 Subject: [PATCH 34/39] Remove minor device filter to avoid confusion --- bitsandbytes/autograd/_functions.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index edf330f14..6cbb6efd9 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -230,10 +230,6 @@ def supports_igemmlt(device: torch.device) -> bool: nvidia16_models = ('GTX 1630', 'GTX 1650', 'GTX 1660') # https://en.wikipedia.org/wiki/GeForce_16_series if any(model_name in device_name for model_name in nvidia16_models): return False # these devices are technically cuda 7.5-capable, but they lack tensor cores - if device.type == "cpu": - #TODO: will return True once CPU backend upstream the supports - return False - return True @@ -568,7 +564,7 @@ def matmul( def matmul_4bit(A: torch.Tensor, B: torch.Tensor, quant_state: F.QuantState, out: Optional[torch.Tensor] = None, bias=None): assert quant_state is not None - if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type == "cuda": + if A.numel() == A.shape[-1] and A.requires_grad == False: if A.shape[-1] % quant_state.blocksize != 0: warn(f'Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}') return MatMul4Bit.apply(A, B, out, bias, quant_state) From adfb5e20d57aaaba5cda7c94d7f24f0f77f4a5f1 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 27 Mar 2024 20:21:48 -0700 Subject: [PATCH 35/39] clean up device setup --- bitsandbytes/device_setup/__init__.py | 0 bitsandbytes/device_setup/cuda/__init__.py | 0 bitsandbytes/device_setup/cuda/env_vars.py | 53 --- bitsandbytes/device_setup/cuda/main.py | 393 --------------------- 4 files changed, 446 deletions(-) delete mode 100644 bitsandbytes/device_setup/__init__.py delete mode 100644 bitsandbytes/device_setup/cuda/__init__.py delete mode 100644 bitsandbytes/device_setup/cuda/env_vars.py delete mode 100644 bitsandbytes/device_setup/cuda/main.py diff --git a/bitsandbytes/device_setup/__init__.py b/bitsandbytes/device_setup/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bitsandbytes/device_setup/cuda/__init__.py b/bitsandbytes/device_setup/cuda/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/bitsandbytes/device_setup/cuda/env_vars.py b/bitsandbytes/device_setup/cuda/env_vars.py deleted file mode 100644 index 4b2549653..000000000 --- a/bitsandbytes/device_setup/cuda/env_vars.py +++ /dev/null @@ -1,53 +0,0 @@ -import os -from typing import Dict - - -def to_be_ignored(env_var: str, value: str) -> bool: - ignorable = { - "PWD", # PWD: this is how the shell keeps track of the current working dir - "OLDPWD", - "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated - "SSH_TTY", - "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks - "HOME", # Linux shell default - "TMUX", # Terminal Multiplexer - "XDG_DATA_DIRS", # XDG: Desktop environment stuff - "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff - "XDG_RUNTIME_DIR", - "MAIL", # something related to emails - "SHELL", # binary for currently invoked shell - "DBUS_SESSION_BUS_ADDRESS", # hardware related - "PATH", # this is for finding binaries, not libraries - "LESSOPEN", # related to the `less` command - "LESSCLOSE", - "_", # current Python interpreter - } - return env_var in ignorable - - -def might_contain_a_path(candidate: str) -> bool: - return os.sep in candidate - - -def is_active_conda_env(env_var: str) -> bool: - return "CONDA_PREFIX" == env_var - - -def is_other_conda_env_var(env_var: str) -> bool: - return "CONDA" in env_var - - -def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: - return is_active_conda_env(env_var) or ( - might_contain_a_path(value) and not - is_other_conda_env_var(env_var) and not - to_be_ignored(env_var, value) - ) - - -def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: - return { - env_var: value - for env_var, value in os.environ.items() - if is_relevant_candidate_env_var(env_var, value) - } diff --git a/bitsandbytes/device_setup/cuda/main.py b/bitsandbytes/device_setup/cuda/main.py deleted file mode 100644 index 36224d2f9..000000000 --- a/bitsandbytes/device_setup/cuda/main.py +++ /dev/null @@ -1,393 +0,0 @@ -""" -extract factors the build is dependent on: -[X] compute capability - [ ] TODO: Q - What if we have multiple GPUs of different makes? -- CUDA version -- Software: - - CPU-only: only CPU quantization functions (no optimizer, no matrix multiply) - - CuBLAS-LT: full-build 8-bit optimizer - - no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`) - -evaluation: - - if paths faulty, return meaningful error - - else: - - determine CUDA version - - determine capabilities - - based on that set the default path -""" - -import ctypes as ct -import errno -import os -from pathlib import Path -import platform -from typing import Set, Union -from warnings import warn - -import torch - -from .env_vars import get_potentially_lib_path_containing_env_vars - -DYNAMIC_LIBRARY_SUFFIX = { "Darwin": ".dylib", "Windows": ".dll", "Linux": ".so"}.get(platform.system(), ".so") -if platform.system() == "Windows": # Windows - CUDA_RUNTIME_LIBS = ["nvcuda.dll"] -else: # Linux or other - # these are the most common libs names - # libcudart.so is missing by default for a conda install with PyTorch 2.0 and instead - # we have libcudart.so.11.0 which causes a lot of errors before - # not sure if libcudart.so.12.0 exists in pytorch installs, but it does not hurt - CUDA_RUNTIME_LIBS = ["libcudart.so", "libcudart.so.11.0", "libcudart.so.12.0", "libcudart.so.12.1", "libcudart.so.12.2"] - - -class CUDASetup: - _instance = None - - def __init__(self): - raise RuntimeError("Call get_instance() instead") - - def generate_instructions(self): - if getattr(self, 'error', False): return - print(self.error) - self.error = True - if not self.cuda_available: - self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA library was not detected or CUDA not installed.') - self.add_log_entry('CUDA SETUP: Solution 1): Your paths are probably not up-to-date. You can update them via: sudo ldconfig.') - self.add_log_entry('CUDA SETUP: Solution 2): If you do not have sudo rights, you can do the following:') - self.add_log_entry('CUDA SETUP: Solution 2a): Find the cuda library via: find / -name libcuda.so 2>/dev/null') - self.add_log_entry('CUDA SETUP: Solution 2b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_2a') - self.add_log_entry('CUDA SETUP: Solution 2c): For a permanent solution add the export from 2b into your .bashrc file, located at ~/.bashrc') - self.add_log_entry('CUDA SETUP: Solution 3): For a missing CUDA runtime library (libcudart.so), use `find / -name libcudart.so* and follow with step (2b)') - return - - if self.cudart_path is None: - self.add_log_entry('CUDA SETUP: Problem: The main issue seems to be that the main CUDA runtime library was not detected.') - self.add_log_entry('CUDA SETUP: Solution 1: To solve the issue the libcudart.so location needs to be added to the LD_LIBRARY_PATH variable') - self.add_log_entry('CUDA SETUP: Solution 1a): Find the cuda runtime library via: find / -name libcudart.so 2>/dev/null') - self.add_log_entry('CUDA SETUP: Solution 1b): Once the library is found add it to the LD_LIBRARY_PATH: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:FOUND_PATH_FROM_1a') - self.add_log_entry('CUDA SETUP: Solution 1c): For a permanent solution add the export from 1b into your .bashrc file, located at ~/.bashrc') - self.add_log_entry('CUDA SETUP: Solution 2: If no library was found in step 1a) you need to install CUDA.') - self.add_log_entry('CUDA SETUP: Solution 2a): Download CUDA install script: wget https://raw.githubusercontent.com/TimDettmers/bitsandbytes/main/cuda_install.sh') - self.add_log_entry('CUDA SETUP: Solution 2b): Install desired CUDA version to desired location. The syntax is bash cuda_install.sh CUDA_VERSION PATH_TO_INSTALL_INTO.') - self.add_log_entry('CUDA SETUP: Solution 2b): For example, "bash cuda_install.sh 113 ~/local/" will download CUDA 11.3 and install into the folder ~/local') - - return - - make_cmd = f'CUDA_VERSION={self.cuda_version_string}' - if len(self.cuda_version_string) < 3: - make_cmd += ' make cuda92' - elif self.cuda_version_string == '110': - make_cmd += ' make cuda110' - elif self.cuda_version_string[:2] == '11' and int(self.cuda_version_string[2]) > 0: - make_cmd += ' make cuda11x' - elif self.cuda_version_string[:2] == '12' and 1 >= int(self.cuda_version_string[2]) >= 0: - make_cmd += ' make cuda12x' - elif self.cuda_version_string == '100': - self.add_log_entry('CUDA SETUP: CUDA 10.0 not supported. Please use a different CUDA version.') - self.add_log_entry('CUDA SETUP: Before you try again running bitsandbytes, make sure old CUDA 10.0 versions are uninstalled and removed from $LD_LIBRARY_PATH variables.') - return - - - has_cublaslt = is_cublasLt_compatible(self.cc) - if not has_cublaslt: - make_cmd += '_nomatmul' - - self.add_log_entry('CUDA SETUP: Something unexpected happened. Please compile from source:') - self.add_log_entry('git clone https://github.com/TimDettmers/bitsandbytes.git') - self.add_log_entry('cd bitsandbytes') - self.add_log_entry(make_cmd) - self.add_log_entry('python setup.py install') - - def initialize(self): - if not getattr(self, 'initialized', False): - self.has_printed = False - self.lib = None - self.initialized = False - self.error = False - - def manual_override(self): - if not torch.cuda.is_available(): - return - override_value = os.environ.get('BNB_CUDA_VERSION') - if not override_value: - return - - binary_name_stem, _, binary_name_ext = self.binary_name.rpartition(".") - # `binary_name_stem` will now be e.g. `/foo/bar/libbitsandbytes_cuda118`; - # let's remove any trailing numbers: - binary_name_stem = binary_name_stem.rstrip("0123456789") - # `binary_name_stem` will now be e.g. `/foo/bar/libbitsandbytes_cuda`; - # let's tack the new version number and the original extension back on. - self.binary_name = f"{binary_name_stem}{override_value}.{binary_name_ext}" - - warn( - f'\n\n{"=" * 80}\n' - 'WARNING: Manual override via BNB_CUDA_VERSION env variable detected!\n' - 'BNB_CUDA_VERSION=XXX can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.\n' - 'If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=\n' - 'If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH\n' - 'For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH: Set[Path]: - return {Path(ld_path) for ld_path in paths_list_candidate.split(os.pathsep) if ld_path} - - -def remove_non_existent_dirs(candidate_paths: Set[Path]) -> Set[Path]: - existent_directories: Set[Path] = set() - for path in candidate_paths: - try: - if path.exists(): - existent_directories.add(path) - except PermissionError: - # Handle the PermissionError first as it is a subtype of OSError - # https://docs.python.org/3/library/exceptions.html#exception-hierarchy - pass - except OSError as exc: - if exc.errno != errno.ENAMETOOLONG: - raise exc - - non_existent_directories: Set[Path] = candidate_paths - existent_directories - if non_existent_directories: - CUDASetup.get_instance().add_log_entry( - f"The following directories listed in your path were found to be non-existent: {non_existent_directories}", - is_warning=False, - ) - - return existent_directories - - -def get_cuda_runtime_lib_paths(candidate_paths: Set[Path]) -> Set[Path]: - paths = set() - for libname in CUDA_RUNTIME_LIBS: - for path in candidate_paths: - try: - if (path / libname).is_file(): - paths.add(path / libname) - except PermissionError: - pass - return paths - - -def resolve_paths_list(paths_list_candidate: str) -> Set[Path]: - """ - Searches a given environmental var for the CUDA runtime library, - i.e. `libcudart.so`. - """ - return remove_non_existent_dirs(extract_candidate_paths(paths_list_candidate)) - - -def find_cuda_lib_in(paths_list_candidate: str) -> Set[Path]: - return get_cuda_runtime_lib_paths( - resolve_paths_list(paths_list_candidate) - ) - - -def warn_in_case_of_duplicates(results_paths: Set[Path]) -> None: - if len(results_paths) > 1: - warning_msg = ( - f"Found duplicate {CUDA_RUNTIME_LIBS} files: {results_paths}.. " - "We select the PyTorch default libcudart.so, which is {torch.version.cuda}," - "but this might mismatch with the CUDA version that is needed for bitsandbytes." - "To override this behavior set the BNB_CUDA_VERSION= environmental variable" - "For example, if you want to use the CUDA version 122" - "BNB_CUDA_VERSION=122 python ..." - "OR set the environmental variable in your .bashrc: export BNB_CUDA_VERSION=122" - "In the case of a manual override, make sure you set the LD_LIBRARY_PATH, e.g." - "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2") - CUDASetup.get_instance().add_log_entry(warning_msg, is_warning=True) - - -def determine_cuda_runtime_lib_path() -> Union[Path, None]: - """ - Searches for a cuda installations, in the following order of priority: - 1. active conda env - 2. LD_LIBRARY_PATH - 3. any other env vars, while ignoring those that - - are known to be unrelated (see `bnb.device_setup.cuda.env_vars.to_be_ignored`) - - don't contain the path separator `/` - - If multiple libraries are found in part 3, we optimistically try one, - while giving a warning message. - """ - candidate_env_vars = get_potentially_lib_path_containing_env_vars() - - cuda_runtime_libs = set() - if "CONDA_PREFIX" in candidate_env_vars: - conda_libs_path = Path(candidate_env_vars["CONDA_PREFIX"]) / "lib" - - conda_cuda_libs = find_cuda_lib_in(str(conda_libs_path)) - warn_in_case_of_duplicates(conda_cuda_libs) - - if conda_cuda_libs: - cuda_runtime_libs.update(conda_cuda_libs) - - CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["CONDA_PREFIX"]} did not contain ' - f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) - - if "LD_LIBRARY_PATH" in candidate_env_vars: - lib_ld_cuda_libs = find_cuda_lib_in(candidate_env_vars["LD_LIBRARY_PATH"]) - - if lib_ld_cuda_libs: - cuda_runtime_libs.update(lib_ld_cuda_libs) - warn_in_case_of_duplicates(lib_ld_cuda_libs) - - CUDASetup.get_instance().add_log_entry(f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain ' - f'{CUDA_RUNTIME_LIBS} as expected! Searching further paths...', is_warning=True) - - remaining_candidate_env_vars = { - env_var: value for env_var, value in candidate_env_vars.items() - if env_var not in {"CONDA_PREFIX", "LD_LIBRARY_PATH"} - } - - cuda_runtime_libs = set() - for env_var, value in remaining_candidate_env_vars.items(): - cuda_runtime_libs.update(find_cuda_lib_in(value)) - - if len(cuda_runtime_libs) == 0: - CUDASetup.get_instance().add_log_entry('CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...') - cuda_runtime_libs.update(find_cuda_lib_in('/usr/local/cuda/lib64')) - - warn_in_case_of_duplicates(cuda_runtime_libs) - - cuda_setup = CUDASetup.get_instance() - cuda_setup.add_log_entry(f'DEBUG: Possible options found for libcudart.so: {cuda_runtime_libs}') - - return next(iter(cuda_runtime_libs)) if cuda_runtime_libs else None - - -# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION -def get_cuda_version(): - major, minor = map(int, torch.version.cuda.split(".")) - - if major < 11: - CUDASetup.get_instance().add_log_entry('CUDA SETUP: CUDA version lower than 11 are currently not supported for LLM.int8(). You will be only to use 8-bit optimizers and quantization routines!!') - - return f'{major}{minor}' - -def get_compute_capabilities(): - ccs = [] - for i in range(torch.cuda.device_count()): - cc_major, cc_minor = torch.cuda.get_device_capability(torch.cuda.device(i)) - ccs.append(f"{cc_major}.{cc_minor}") - - ccs.sort(key=lambda v: tuple(map(int, str(v).split(".")))) - - return ccs - - -def evaluate_cuda_setup(): - cuda_setup = CUDASetup.get_instance() - if 'BITSANDBYTES_NOWELCOME' not in os.environ or str(os.environ['BITSANDBYTES_NOWELCOME']) == '0': - cuda_setup.add_log_entry('') - cuda_setup.add_log_entry('='*35 + 'BUG REPORT' + '='*35) - cuda_setup.add_log_entry(('Welcome to bitsandbytes. For bug reports, please run\n\npython -m bitsandbytes\n\n'), - ('and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues')) - cuda_setup.add_log_entry('='*80) - - if not torch.cuda.is_available(): - return f'libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}', None, None, None - - cudart_path = determine_cuda_runtime_lib_path() - cc = get_compute_capabilities()[-1] # we take the highest capability - cuda_version_string = get_cuda_version() - - cuda_setup.add_log_entry(f"CUDA SETUP: PyTorch settings found: CUDA_VERSION={cuda_version_string}, Highest Compute Capability: {cc}.") - cuda_setup.add_log_entry( - "CUDA SETUP: To manually override the PyTorch CUDA version please see:" - "https://github.com/TimDettmers/bitsandbytes/blob/main/how_to_use_nonpytorch_cuda.md" - ) - - - # 7.5 is the minimum CC vor cublaslt - has_cublaslt = is_cublasLt_compatible(cc) - - # TODO: - # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) - # (2) Multiple CUDA versions installed - - # we use ls -l instead of nvcc to determine the cuda version - # since most installations will have the libcudart.so installed, but not the compiler - - binary_name = f"libbitsandbytes_cuda{cuda_version_string}" - if not has_cublaslt: - # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt - binary_name += "_nocublaslt" - - binary_name = f"{binary_name}{DYNAMIC_LIBRARY_SUFFIX}" - - return binary_name, cudart_path, cc, cuda_version_string From 6f08879a2bf2d094f75013a3d0e791c662604240 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Thu, 28 Mar 2024 11:25:13 +0800 Subject: [PATCH 36/39] clean --- bitsandbytes/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index ec25eb0bc..0229e59e2 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,7 +1,7 @@ import json import shlex import subprocess -from typing import Any, Dict, Tuple +from typing import Tuple import torch From a9e454885a6e6999bfe6960bd40f7abf843da6bd Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 27 Mar 2024 20:28:30 -0700 Subject: [PATCH 37/39] fix utils --- bitsandbytes/functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 8a0c7dbae..a80e56011 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -6,14 +6,14 @@ from functools import reduce # Required in Python 3 import itertools import operator -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple import numpy as np import torch from torch import Tensor from bitsandbytes.backends import backends, ensure_backend_is_available - +from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict from .cextension import lib From 84f67d260ca7bc59113419dcdebc6ea729f29129 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 27 Mar 2024 20:35:28 -0700 Subject: [PATCH 38/39] link QuantState in F. --- bitsandbytes/functional.py | 182 +------------------------------------ bitsandbytes/utils.py | 177 +++++++++++++++++++++++++++++++++++- 2 files changed, 180 insertions(+), 179 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a80e56011..38459981b 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -6,14 +6,14 @@ from functools import reduce # Required in Python 3 import itertools import operator -from typing import Any, Dict, Optional, Tuple +from typing import Optional, Tuple import numpy as np import torch from torch import Tensor from bitsandbytes.backends import backends, ensure_backend_is_available -from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict +from bitsandbytes.utils import QuantState from .cextension import lib @@ -617,182 +617,8 @@ def estimate_quantiles( return out - -class QuantState: - """container for quantization state components to work with Params4bit and similar classes""" - - valid_quant_types = ("fp4", "nf4") - valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] - valid_qs_keys = [ - "absmax", - "quant_map", - "nested_absmax", - "nested_quant_map", - "quant_state", - "quant_type", - "blocksize", - "dtype", - "shape", - "nested_blocksize", - "nested_dtype", - "nested_offset", - ] - - def __init__( - self, - absmax, - shape=None, - code=None, - blocksize=None, - quant_type=None, - dtype=None, - offset=None, - state2=None, - ): - self.absmax = absmax - self.shape = shape - self.code = code - self.dtype = dtype - self.blocksize = blocksize - self.quant_type = quant_type - self.offset = offset - self.state2 = state2 - self.nested = state2 is not None - - def __get_item__(self, idx): - """ - ensures compatibility with older quant state scheme with nested lists. - assumes the following layout: - state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] - state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] - """ - if self.nested: - list_repr = [ - self.absmax, - self.shape, - self.dtype, - self.blocksize, - [self.offset, self.state2], - self.quant_type, - ] - else: - list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] - return list_repr[idx] - - @classmethod - def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": - """ - unpacks components of state_dict into QuantState - where necessary, convert into strings, torch.dtype, ints, etc. - - qs_dict: based on state_dict, with only relevant keys, striped of prefixes. - - item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. - """ - - # unpacking tensor with non-tensor components - qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] - if not len(qs_key) and "quant_type" not in qs_dict: - raise ValueError("Expected packed or unpacked quant_state items, found neither") - elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: - raise ValueError( - f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", - ) - - # unpacking minor and non-tensor quant state items if necessary - if len(qs_key) == 1: - first_qs_key = qs_key[0] - qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) - - qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes - assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) - - if "nested_absmax" in qs_dict: - offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) - state2 = cls( - absmax=qs_dict["nested_absmax"].to(device), - blocksize=qs_dict["nested_blocksize"], - code=qs_dict["nested_quant_map"].to(device), - dtype=getattr(torch, qs_dict["nested_dtype"]), - ) - else: - offset, state2 = None, None - - quant_state = cls( - quant_type=qs_dict["quant_type"], - absmax=qs_dict["absmax"].to(device), - blocksize=qs_dict["blocksize"], - code=qs_dict["quant_map"].to(device), - dtype=getattr(torch, qs_dict["dtype"]), - shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None, - offset=offset, - state2=state2, - ) - return quant_state - - def as_dict(self, packed=False): - """ - returns dict of tensors and strings to use in serialization via _save_to_state_dict() - param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving - """ - qs_dict = { - "quant_type": self.quant_type, - "absmax": self.absmax, - "blocksize": self.blocksize, - "quant_map": self.code, - "dtype": str(self.dtype).strip("torch."), - "shape": tuple(self.shape), - } - if self.nested: - qs_dict.update( - { - "nested_absmax": self.state2.absmax, - "nested_blocksize": self.state2.blocksize, - "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors - "nested_dtype": str(self.state2.dtype).strip("torch."), - "nested_offset": self.offset.item(), - }, - ) - if not packed: - return qs_dict - - # packed format allows serialization of non-tensor components, critical for saving in safetensors format - qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} - non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} - qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) - return qs_packed_dict - - def to(self, device): - # make sure the quantization state is on the right device - self.absmax = self.absmax.to(device) - if self.nested: - self.offset = self.offset.to(device) - self.state2.absmax = self.state2.absmax.to(device) - self.state2.code = self.state2.code.to(device) - - def __eq__(self, other): - if not isinstance(other, QuantState): - return False - - return ( - torch.allclose(self.absmax, other.absmax, atol=1e-6) - and self.shape == other.shape - and torch.allclose(self.code, other.code, atol=1e-6) - and self.dtype == other.dtype - and self.blocksize == other.blocksize - and self.quant_type == other.quant_type - and ( - self.offset == other.offset - if self.offset is not None and other.offset is not None - else self.offset is other.offset - ) - and ( - self.state2 == other.state2 - if self.state2 is not None and other.state2 is not None - else self.state2 is other.state2 - ) - ) - +# maintain the compatibility as F.QuantState +QuantState = QuantState def quantize_blockwise( A: Tensor, diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 0229e59e2..29a5cfea3 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -1,7 +1,7 @@ import json import shlex import subprocess -from typing import Tuple +from typing import Any, Dict, Tuple import torch @@ -198,3 +198,178 @@ def unpack_tensor_to_dict(tensor_data): unpacked_dict = json.loads(json_str) return unpacked_dict + +class QuantState: + """container for quantization state components to work with Params4bit and similar classes""" + + valid_quant_types = ("fp4", "nf4") + valid_qs_type_keys = [f"bitsandbytes__{x}" for x in valid_quant_types] + valid_qs_keys = [ + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "quant_state", + "quant_type", + "blocksize", + "dtype", + "shape", + "nested_blocksize", + "nested_dtype", + "nested_offset", + ] + + def __init__( + self, + absmax, + shape=None, + code=None, + blocksize=None, + quant_type=None, + dtype=None, + offset=None, + state2=None, + ): + self.absmax = absmax + self.shape = shape + self.code = code + self.dtype = dtype + self.blocksize = blocksize + self.quant_type = quant_type + self.offset = offset + self.state2 = state2 + self.nested = state2 is not None + + def __get_item__(self, idx): + """ + ensures compatibility with older quant state scheme with nested lists. + assumes the following layout: + state = [qabsmax, input_shape, A.dtype, blocksize, [offset, state2], quant_type] + state2 = [absmax, input_shape, A.dtype, blocksize, None, quant_type] + """ + if self.nested: + list_repr = [ + self.absmax, + self.shape, + self.dtype, + self.blocksize, + [self.offset, self.state2], + self.quant_type, + ] + else: + list_repr = [self.absmax, self.shape, self.dtype, self.blocksize, None, self.quant_type] + return list_repr[idx] + + @classmethod + def from_dict(cls, qs_dict: Dict[str, Any], device: torch.device) -> "QuantState": + """ + unpacks components of state_dict into QuantState + where necessary, convert into strings, torch.dtype, ints, etc. + + qs_dict: based on state_dict, with only relevant keys, striped of prefixes. + + item with key `quant_state.bitsandbytes__[nf4/fp4]` may contain minor and non-tensor quant state items. + """ + + # unpacking tensor with non-tensor components + qs_key = [k for k, v in qs_dict.items() if "quant_state" in k and isinstance(v, torch.Tensor)] + if not len(qs_key) and "quant_type" not in qs_dict: + raise ValueError("Expected packed or unpacked quant_state items, found neither") + elif len(qs_key) != 1 or qs_key[0].split(".")[-1] not in cls.valid_qs_type_keys: + raise ValueError( + f"There should be exactly one `quant_state` item with ending from {cls.valid_qs_type_keys}.\nDetected {qs_key}.", + ) + + # unpacking minor and non-tensor quant state items if necessary + if len(qs_key) == 1: + first_qs_key = qs_key[0] + qs_dict.update(unpack_tensor_to_dict(qs_dict.pop(first_qs_key))) + + qs_dict = {k.split(".")[-1]: v for k, v in qs_dict.items()} # strip prefixes + assert set(qs_dict.keys()).issubset(cls.valid_qs_keys) + + if "nested_absmax" in qs_dict: + offset = torch.tensor(float(qs_dict["nested_offset"])).to(device) + state2 = cls( + absmax=qs_dict["nested_absmax"].to(device), + blocksize=qs_dict["nested_blocksize"], + code=qs_dict["nested_quant_map"].to(device), + dtype=getattr(torch, qs_dict["nested_dtype"]), + ) + else: + offset, state2 = None, None + + quant_state = cls( + quant_type=qs_dict["quant_type"], + absmax=qs_dict["absmax"].to(device), + blocksize=qs_dict["blocksize"], + code=qs_dict["quant_map"].to(device), + dtype=getattr(torch, qs_dict["dtype"]), + shape=torch.Size(qs_dict["shape"]) if qs_dict["shape"] is not None else None, + offset=offset, + state2=state2, + ) + return quant_state + + def as_dict(self, packed=False): + """ + returns dict of tensors and strings to use in serialization via _save_to_state_dict() + param: packed -- returns dict[str, torch.Tensor] for state_dict fit for safetensors saving + """ + qs_dict = { + "quant_type": self.quant_type, + "absmax": self.absmax, + "blocksize": self.blocksize, + "quant_map": self.code, + "dtype": str(self.dtype).strip("torch."), + "shape": tuple(self.shape), + } + if self.nested: + qs_dict.update( + { + "nested_absmax": self.state2.absmax, + "nested_blocksize": self.state2.blocksize, + "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors + "nested_dtype": str(self.state2.dtype).strip("torch."), + "nested_offset": self.offset.item(), + }, + ) + if not packed: + return qs_dict + + # packed format allows serialization of non-tensor components, critical for saving in safetensors format + qs_packed_dict = {k: v for k, v in qs_dict.items() if isinstance(v, torch.Tensor)} + non_tensor_dict = {k: v for k, v in qs_dict.items() if not isinstance(v, torch.Tensor)} + qs_packed_dict["quant_state." + "bitsandbytes__" + self.quant_type] = pack_dict_to_tensor(non_tensor_dict) + return qs_packed_dict + + def to(self, device): + # make sure the quantization state is on the right device + self.absmax = self.absmax.to(device) + if self.nested: + self.offset = self.offset.to(device) + self.state2.absmax = self.state2.absmax.to(device) + self.state2.code = self.state2.code.to(device) + + def __eq__(self, other): + if not isinstance(other, QuantState): + return False + + return ( + torch.allclose(self.absmax, other.absmax, atol=1e-6) + and self.shape == other.shape + and torch.allclose(self.code, other.code, atol=1e-6) + and self.dtype == other.dtype + and self.blocksize == other.blocksize + and self.quant_type == other.quant_type + and ( + self.offset == other.offset + if self.offset is not None and other.offset is not None + else self.offset is other.offset + ) + and ( + self.state2 == other.state2 + if self.state2 is not None and other.state2 is not None + else self.state2 is other.state2 + ) + ) From 9ff6c638ef5bb1e07e2d00c5327e018d5d05e2f1 Mon Sep 17 00:00:00 2001 From: Titus von Koeller <9048635+Titus-von-Koeller@users.noreply.github.com> Date: Wed, 3 Apr 2024 09:58:31 +0000 Subject: [PATCH 39/39] pre-commit run --all-files --- bitsandbytes/__init__.py | 3 +- bitsandbytes/backends/__init__.py | 2 + bitsandbytes/backends/cuda.py | 245 ++++++++++++++++++------------ bitsandbytes/functional.py | 50 +++--- bitsandbytes/utils.py | 1 + install_cuda.py | 8 +- 6 files changed, 187 insertions(+), 122 deletions(-) diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index c3d6a19e7..019a4f6ab 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from . import research, utils -from .cextension import lib from .autograd._functions import ( MatmulLtState, bmm_cublas, @@ -13,12 +12,14 @@ matmul_cublas, mm_cublas, ) +from .cextension import lib from .nn import modules if lib and lib.compiled_with_cuda: from .backends import register_backend from .backends.cuda import CUDABackend from .optim import adam + register_backend("cuda", CUDABackend()) __pdoc__ = { "libbitsandbytes": False, diff --git a/bitsandbytes/backends/__init__.py b/bitsandbytes/backends/__init__.py index d35021b1e..30f08073a 100644 --- a/bitsandbytes/backends/__init__.py +++ b/bitsandbytes/backends/__init__.py @@ -4,9 +4,11 @@ backends: Dict[str, Backend] = {} + def register_backend(backend_name: str, backend_instance: Backend): backends[backend_name.lower()] = backend_instance + def ensure_backend_is_available(device_type: str): """Check if a backend is available for the given device type.""" if device_type.lower() not in backends: diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index 4fa1946e9..c76bcaebd 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -25,9 +25,7 @@ class CUDABackend(Backend): - def double_quant( - self, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0 - ): + def double_quant(self, A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): device = A.device assert A.dtype == torch.half assert device.type == "cuda" @@ -40,9 +38,7 @@ def double_quant( rows = A.shape[0] if row_stats is None or col_stats is None: - row_stats, col_stats, nnz_row_ptr = get_colrow_absmax( - A, threshold=threshold - ) + row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) if out_col is None: out_col = torch.zeros(A.shape, device=device, dtype=torch.int8) @@ -60,9 +56,7 @@ def double_quant( if threshold > 0.0: nnz = nnz_row_ptr[-1].item() if nnz > 0: - coo_tensor = coo_zeros( - A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device - ) + coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device) ptrRowIdx = get_ptr(coo_tensor.rowidx) ptrColIdx = get_ptr(coo_tensor.colidx) ptrVal = get_ptr(coo_tensor.values) @@ -120,7 +114,7 @@ def double_quant( return out_row, out_col, row_stats, col_stats, coo_tensor - def transform(self, A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): + def transform(self, A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): prev_device = pre_call(A.device) if state is None: state = (A.shape, from_order) @@ -130,7 +124,7 @@ def transform(self, A, to_order, from_order='row', out=None, transpose=False, st if out is None: out, new_state = get_transform_buffer(state[0], A.dtype, A.device, to_order, state[1], transpose) else: - new_state = (state[0], to_order) # (shape, order) + new_state = (state[0], to_order) # (shape, order) shape = state[0] if len(shape) == 2: @@ -141,7 +135,7 @@ def transform(self, A, to_order, from_order='row', out=None, transpose=False, st dim2 = ct.c_int32(shape[2]) is_on_gpu([A, out]) - if to_order == 'col32': + if to_order == "col32": if transpose: lib.ctransform_row2col32T(get_ptr(A), get_ptr(out), dim1, dim2) else: @@ -166,7 +160,7 @@ def transform(self, A, to_order, from_order='row', out=None, transpose=False, st lib.ctransform_ampere2row(get_ptr(A), get_ptr(out), dim1, dim2) else: - raise NotImplementedError(f'Transform function not implemented: From {from_order} to {to_order}') + raise NotImplementedError(f"Transform function not implemented: From {from_order} to {to_order}") post_call(prev_device) @@ -178,14 +172,14 @@ def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): dimsA = len(shapeA) dimsB = len(shapeB) - assert dimsB == 2, 'Only two dimensional matrices are supported for argument B' + assert dimsB == 2, "Only two dimensional matrices are supported for argument B" if dimsA == 2: m = shapeA[0] elif dimsA == 3: m = shapeA[0] * shapeA[1] rows = n = shapeB[0] - assert prod(list(shapeA)) > 0, f'Input tensor dimensions need to be > 0: {shapeA}' + assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" # if the tensor is empty, return a transformed empty tensor with the right dimensions if shapeA[0] == 0 and dimsA == 2: @@ -194,13 +188,9 @@ def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) if dimsA == 2 and out is None: - out, Sout = get_transform_buffer( - (shapeA[0], shapeB[0]), dtype, A.device, "col32", "row" - ) + out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row") elif dimsA == 3 and out is None: - out, Sout = get_transform_buffer( - (shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row" - ) + out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row") assert dimsB != 3, "len(B.shape)==3 not supported" assert A.device.type == "cuda" @@ -244,50 +234,37 @@ def igemmlt(self, A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): ptrRowScale = get_ptr(None) is_on_gpu([A, B, out]) - if formatB == 'col_turing': + if formatB == "col_turing": if dtype == torch.int32: - has_error = lib.cigemmlt_turing_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_turing_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) elif formatB == "col_ampere": if dtype == torch.int32: - has_error = lib.cigemmlt_ampere_32( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) else: - has_error = lib.cigemmlt_ampere_8( - ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc - ) + has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") if has_error: - print(f'A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}') - raise Exception('cublasLt ran into an error!') + print( + f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}" + ) + raise Exception("cublasLt ran into an error!") torch.cuda.set_device(prev_device) return out, Sout def mm_dequant( - self, - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None + self, A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None ): assert A.dtype == torch.int32 - if bias is not None: assert bias.dtype == torch.float16 + if bias is not None: + assert bias.dtype == torch.float16 out_shape = quant_state[0] if len(out_shape) == 3: out_shape = (out_shape[0] * out_shape[1], out_shape[2]) @@ -295,19 +272,11 @@ def mm_dequant( if out is None: out = torch.empty(out_shape, dtype=torch.float16, device=A.device) if new_row_stats is None: - new_row_stats = torch.empty( - out_shape[0], dtype=torch.float32, device=A.device - ) + new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device) if new_col_stats is None: - new_col_stats = torch.empty( - out_shape[1], dtype=torch.float32, device=A.device - ) - assert ( - new_row_stats.shape[0] == row_stats.shape[0] - ), f"{new_row_stats.shape} vs {row_stats.shape}" - assert ( - new_col_stats.shape[0] == col_stats.shape[0] - ), f"{new_col_stats.shape} vs {col_stats.shape}" + new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device) + assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}" + assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}" prev_device = pre_call(A.device) ptrA = get_ptr(A) @@ -321,7 +290,9 @@ def mm_dequant( numCols = ct.c_int32(out_shape[1]) is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) - lib.cdequant_mm_int32_fp16(ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols) + lib.cdequant_mm_int32_fp16( + ptrA, ptrRowStats, ptrColStats, ptrOut, ptrNewRowStats, ptrNewColStats, ptrBias, numRows, numCols + ) post_call(prev_device) return out @@ -332,9 +303,7 @@ def extract_outliers(self, A, SA, idx): assert formatA in ["col_turing", "col_ampere"] assert A.device.type == "cuda" - out = torch.zeros( - (shapeA[0], idx.numel()), dtype=torch.int8, device=A.device - ) + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) idx_size = ct.c_int32(idx.numel()) rows = ct.c_int32(shapeA[0]) @@ -345,7 +314,7 @@ def extract_outliers(self, A, SA, idx): prev_device = pre_call(A.device) - if formatA == 'col_turing': + if formatA == "col_turing": lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) elif formatA == "col_ampere": lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) @@ -361,13 +330,13 @@ def quantize_4bit( out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, - quant_type='fp4', + quant_type="fp4", quant_storage=torch.uint8, ) -> Tuple[torch.Tensor, QuantState]: - if A.device.type != 'cuda': - raise NotImplementedError(f'Device type not supported for FP4 quantization: {A.device.type}') - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + if A.device.type != "cuda": + raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") n = A.numel() input_shape = A.shape @@ -379,7 +348,7 @@ def quantize_4bit( if out is None: mod = dtype2bytes[quant_storage] * 2 - out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device) + out = torch.zeros(((n + 1) // mod, 1), dtype=quant_storage, device=A.device) assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] @@ -387,22 +356,34 @@ def quantize_4bit( is_on_gpu([A, out, absmax]) if A.dtype == torch.float32: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) else: - lib.cquantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_fp32_nf4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) elif A.dtype == torch.float16: - if quant_type == 'fp4': - lib.cquantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) else: - lib.cquantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_fp16_nf4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) elif A.dtype == torch.bfloat16: - if quant_type == 'fp4': - lib.cquantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) else: - lib.cquantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n)) + lib.cquantize_blockwise_bf16_nf4( + get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int32(blocksize), ct.c_int(n) + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") @@ -416,31 +397,55 @@ def quantize_4bit( absmax -= offset qabsmax, state2 = quantize_blockwise(absmax, blocksize=256) del absmax - state = QuantState(absmax=qabsmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type, offset=offset, state2=state2) + state = QuantState( + absmax=qabsmax, + shape=input_shape, + dtype=A.dtype, + blocksize=blocksize, + code=code, + quant_type=quant_type, + offset=offset, + state2=state2, + ) else: - state = QuantState(absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type) + state = QuantState( + absmax=absmax, shape=input_shape, dtype=A.dtype, blocksize=blocksize, code=code, quant_type=quant_type + ) return out, state - def dequantize_4bit(self, A: torch.Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type='fp4') -> torch.Tensor: + def dequantize_4bit( + self, + A: torch.Tensor, + quant_state: Optional[QuantState] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 64, + quant_type="fp4", + ) -> torch.Tensor: if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: - raise ValueError(f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]") + raise ValueError( + f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]" + ) - if quant_type not in ['fp4', 'nf4']: - raise NotImplementedError(f'4-bit quantization data type {quant_type} is not implemented.') + if quant_type not in ["fp4", "nf4"]: + raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented.") if quant_state is None: assert absmax is not None and out is not None - quant_state = QuantState(absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type) + quant_state = QuantState( + absmax=absmax, shape=out.shape, dtype=out.dtype, blocksize=blocksize, quant_type=quant_type + ) else: absmax = quant_state.absmax if quant_state.nested: absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2) absmax += quant_state.offset - if absmax.dtype != torch.float32: absmax = absmax.float() + if absmax.dtype != torch.float32: + absmax = absmax.float() if out is None: out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) @@ -451,29 +456,73 @@ def dequantize_4bit(self, A: torch.Tensor, quant_state: Optional[QuantState] = N is_on_gpu([A, absmax, out]) if out.dtype == torch.float32: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp32_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_fp32_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_fp32_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) elif out.dtype == torch.float16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_fp16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_fp16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_fp16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) elif out.dtype == torch.bfloat16: - if quant_state.quant_type == 'fp4': - lib.cdequantize_blockwise_bf16_fp4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_bf16_fp4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: - lib.cdequantize_blockwise_bf16_nf4(get_ptr(None), get_ptr(A), get_ptr(absmax), get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n)) + lib.cdequantize_blockwise_bf16_nf4( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) - is_transposed = (True if A.shape[0] == 1 else False) + is_transposed = True if A.shape[0] == 1 else False - if is_transposed: return out.t() - else: return out + if is_transposed: + return out.t() + else: + return out diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 38459981b..6bb02944d 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -14,6 +14,7 @@ from bitsandbytes.backends import backends, ensure_backend_is_available from bitsandbytes.utils import QuantState + from .cextension import lib @@ -617,9 +618,11 @@ def estimate_quantiles( return out + # maintain the compatibility as F.QuantState QuantState = QuantState + def quantize_blockwise( A: Tensor, code: Optional[torch.Tensor] = None, @@ -977,7 +980,15 @@ def quantize_4bit( The quantization state to undo the quantization. """ ensure_backend_is_available(A.device.type) - return backends[A.device.type].quantize_4bit(A, absmax=absmax, out=out, blocksize=blocksize, compress_statistics=compress_statistics, quant_type=quant_type, quant_storage=quant_storage) + return backends[A.device.type].quantize_4bit( + A, + absmax=absmax, + out=out, + blocksize=blocksize, + compress_statistics=compress_statistics, + quant_type=quant_type, + quant_storage=quant_storage, + ) def dequantize_fp4( @@ -1035,7 +1046,9 @@ def dequantize_4bit( Dequantized tensor. """ ensure_backend_is_available(A.device.type) - return backends[A.device.type].dequantize_4bit(A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type) + return backends[A.device.type].dequantize_4bit( + A, quant_state=quant_state, absmax=absmax, out=out, blocksize=blocksize, quant_type=quant_type + ) def quantize( @@ -1876,18 +1889,18 @@ def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): return backends[A.device.type].igemmlt(A, B, SA, SB, out=out, Sout=Sout, dtype=dtype) -def mm_dequant( - A, - quant_state, - row_stats, - col_stats, - out=None, - new_row_stats=None, - new_col_stats=None, - bias=None -): +def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None): ensure_backend_is_available(A.device.type) - return backends[A.device.type].mm_dequant(A, quant_state, row_stats, col_stats, out=out, new_row_stats=new_row_stats, new_col_stats=new_col_stats, bias=bias) + return backends[A.device.type].mm_dequant( + A, + quant_state, + row_stats, + col_stats, + out=out, + new_row_stats=new_row_stats, + new_col_stats=new_col_stats, + bias=bias, + ) def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): @@ -2009,12 +2022,16 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): ensure_backend_is_available(A.device.type) - return backends[A.device.type].double_quant(A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold) + return backends[A.device.type].double_quant( + A, col_stats=col_stats, row_stats=row_stats, out_col=out_col, out_row=out_row, threshold=threshold + ) -def transform(A, to_order, from_order='row', out=None, transpose=False, state=None, ld=None): +def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): ensure_backend_is_available(A.device.type) - return backends[A.device.type].transform(A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld) + return backends[A.device.type].transform( + A, to_order, from_order=from_order, out=out, transpose=transpose, state=state, ld=ld + ) def spmm_coo(cooA, B, out=None): @@ -2280,7 +2297,6 @@ def extract_outliers(A, SA, idx): return backends[A.device.type].extract_outliers(A, SA, idx) - def pipeline_test(A, batch_size): out = torch.zeros_like(A) lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py index 29a5cfea3..92744dead 100644 --- a/bitsandbytes/utils.py +++ b/bitsandbytes/utils.py @@ -199,6 +199,7 @@ def unpack_tensor_to_dict(tensor_data): return unpacked_dict + class QuantState: """container for quantization state components to work with Params4bit and similar classes""" diff --git a/install_cuda.py b/install_cuda.py index a5d09356d..cf7c8ee71 100644 --- a/install_cuda.py +++ b/install_cuda.py @@ -77,9 +77,7 @@ def main(): download_path = "/tmp" # default download path if len(sys.argv) < 2: - print( - "Usage: python install_cuda.py [user/system] [download_path]" - ) + print("Usage: python install_cuda.py [user/system] [download_path]") sys.exit(1) version = sys.argv[1] @@ -100,9 +98,7 @@ def main(): elif version in cuda_versions: install_cuda(version, base_path, download_path) else: - print( - f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}" - ) + print(f"Invalid CUDA version: {version}. Available versions are: {', '.join(cuda_versions.keys())}") sys.exit(1)