Skip to content

Commit

Permalink
Fix Failing CI - Update bitsandbytes import (#1343)
Browse files Browse the repository at this point in the history
Update bitsandbytes import
  • Loading branch information
jainapurva authored Nov 25, 2024
1 parent 9bb1b23 commit 8b1b168
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions test/quantization/test_galore_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
except ImportError:
pytest.skip("triton is not installed", allow_module_level=True)

import bitsandbytes.functional as F
from bitsandbytes.functional import create_dynamic_map, quantize_blockwise, dequantize_blockwise
import torch

from torchao.prototype.galore.kernels import (
Expand Down Expand Up @@ -36,9 +36,9 @@
def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01

qmap = F.create_dynamic_map(signed).to(g.device)
qmap = create_dynamic_map(signed).to(g.device)

ref_bnb, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize)
ref_bnb, qstate = quantize_blockwise(g, code=qmap, blocksize=blocksize)
bnb_norm = (g.reshape(-1, blocksize) / qstate.absmax[:, None]).reshape(g.shape)

tt_q, tt_norm, tt_absmax = triton_quantize_blockwise(
Expand Down Expand Up @@ -82,10 +82,10 @@ def test_galore_quantize_blockwise(dim1, dim2, dtype, signed, blocksize):
def test_galore_dequant_blockwise(dim1, dim2, dtype, signed, blocksize):
g = torch.randn(dim1, dim2, device="cuda", dtype=dtype) * 0.01

qmap = F.create_dynamic_map(signed).to(g.device)
qmap = create_dynamic_map(signed).to(g.device)

q, qstate = F.quantize_blockwise(g, code=qmap, blocksize=blocksize)
q, qstate = quantize_blockwise(g, code=qmap, blocksize=blocksize)

dq_ref = F.dequantize_blockwise(q, qstate)
dq_ref = dequantize_blockwise(q, qstate)
dq = triton_dequant_blockwise(q, qmap, qstate.absmax, group_size=blocksize)
assert torch.allclose(dq, dq_ref)

0 comments on commit 8b1b168

Please sign in to comment.