Skip to content

Commit

Permalink
[int4-quant] Execute weights shuffling on CPU until MPS memory issue …
Browse files Browse the repository at this point in the history
…is resolved (pytorch#552)
  • Loading branch information
manuelcandales authored Jul 29, 2024
1 parent 4563492 commit 4abe4b8
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,13 @@ def groupwise_affine_quantize_tensor_from_qparams(

int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
if TORCH_VERSION_AFTER_2_5:
int_data_device_type = int_data.device.type
# Move to cpu, until issue with MPS memory management of temporary tensors is resolved
if int_data_device_type == 'mps':
int_data = int_data.cpu()
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
if int_data_device_type == 'mps':
int_data = int_data.to(device='mps')
return int_data

def groupwise_affine_dequantize_tensor_from_qparams(
Expand Down

0 comments on commit 4abe4b8

Please sign in to comment.