Skip to content

Commit

Permalink
fixed randint range
Browse files Browse the repository at this point in the history
  • Loading branch information
vayuda committed Jun 12, 2024
1 parent 0849496 commit bfd6c5f
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions benchmarks/benchmark_bitpacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def benchmark(function, args, num_runs):

def test_vs_existing():
def new_(scale):
fake_tensor = torch.randint(2**8-1, (1, scale,scale), dtype=torch.uint8).cuda()
fake_tensor = torch.randint(2**8, (1, scale,scale), dtype=torch.uint8).cuda()
packed = pack(fake_tensor, 4, dim=1)
unpacked = unpack(packed, 4, dim=1)
def old_(scale):
fake_tensor = torch.randint(2**8-1, (1, scale,scale), dtype=torch.uint8).cuda()
fake_tensor = torch.randint(2**8, (1, scale,scale), dtype=torch.uint8).cuda()
packed = pack_uint4(fake_tensor)
unpacked = unpack_uint4(packed)

Expand Down Expand Up @@ -55,9 +55,9 @@ class W4A16_symmetric_weight_only(torch.nn.Module):
def __init__(self, scale):
super().__init__()
assert scale % 4 == 0
self.l1 = torch.randint(2**8-1,(scale, scale), dtype=torch.uint8).cuda()
self.l1 = torch.randint(2**8,(scale, scale), dtype=torch.uint8).cuda()
self.s1 = torch.tensor((scale),dtype=torch.float16).cuda()
self.l2 = torch.randint(2**8-1,(scale//2, scale//4), dtype=torch.uint8).cuda()
self.l2 = torch.randint(2**8,(scale//2, scale//4), dtype=torch.uint8).cuda()
self.s2 = torch.tensor((scale//4),dtype=torch.float16).cuda()


Expand All @@ -79,7 +79,7 @@ def forward(self, x):
b = torch.compile(b, fullgraph=True)

test_input = torch.randn(scale*2, dtype=torch.float16).cuda()
forward_args = [test_input]
forward_args = [test_input]
b.forward(test_input)
print("scale: ", scale)
print("fp16 time: ", benchmark(a.forward, forward_args, 100))
Expand Down

0 comments on commit bfd6c5f

Please sign in to comment.