Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,6 @@ def quantize_and_upload(
model_to_quantize = _untie_weights_and_save_locally(model_to_quantize)

# quantization

if "AWQ" in quant:
# awq will use torchao API directly
assert quant == "AWQ-INT4", "Only support AWQ-INT4 for now"
Expand Down
252 changes: 120 additions & 132 deletions test/prototype/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def forward(self, x):
"cpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="opaque")],
"xpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="plain_int32")],
}
configs = [(d, c) for d in devices for c in device_to_base_configs[d]]


class TestAWQ(TestCase):
Expand All @@ -95,109 +96,100 @@ def test_awq_config(self):
with self.assertRaisesRegex(ValueError, "is not one of"):
AWQConfig(base_config, step="not_supported")

@parametrize("device", devices)
def test_awq_functionality(self, device):
@parametrize("device,base_config", configs)
def test_awq_functionality(self, device, base_config):
dataset_size = 10
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
sequence_length = 5

assert device in device_to_base_configs, "Unsupported device: {}".format(device)
base_configs = device_to_base_configs[device]

for base_config in base_configs:
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
m_baseline = copy.deepcopy(m)

dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
)
# for test, we use calibration_data = dataset so that awq is
# guranteed to be better than baseline
# in reality, calibration_data will be a small subset or a different
# dataset
calibration_data = dataset
# concatenatd inputs
input_cat = torch.cat(calibration_data, dim=-2)
ref_out = m(input_cat)

# baseline quantization
quantize_(m_baseline, base_config)

# awq quantization
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

for example in calibration_data:
m(example)

quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)

# evaluating on calibration data set to remove any uncertainty
awq_out = m(input_cat)
baseline_out = m_baseline(input_cat)

loss_awq = (ref_out - awq_out).pow(2).mean().item()
loss_base = (ref_out - baseline_out).pow(2).mean().item()
assert loss_awq <= loss_base

@parametrize("device", devices)
def test_awq_loading(self, device):
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
m_baseline = copy.deepcopy(m)

dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
)
# for test, we use calibration_data = dataset so that awq is
# guranteed to be better than baseline
# in reality, calibration_data will be a small subset or a different
# dataset
calibration_data = dataset
input_cat = torch.cat(calibration_data, dim=-2)
ref_out = m(input_cat)

# baseline quantization
quantize_(m_baseline, base_config)

# awq quantization
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

for example in calibration_data:
m(example)

quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)

# evaluating on calibration data set to remove any uncertainty
awq_out = m(input_cat)
baseline_out = m_baseline(input_cat)

loss_awq = (ref_out - awq_out).pow(2).mean().item()
loss_base = (ref_out - baseline_out).pow(2).mean().item()
assert loss_awq <= loss_base

@parametrize("device,base_config", configs)
def test_awq_loading(self, device, base_config):
dataset_size = 10
l1, l2, l3 = 512, 256, 128
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
sequence_length = 5

assert device in device_to_base_configs, "Unsupported device: {}".format(device)
base_configs = device_to_base_configs[device]

for base_config in base_configs:
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
)
# for test purpose, we don't need to get a subset
calibration_data = dataset
# concatenatd inputs
input_cat = torch.cat(calibration_data, dim=-2)
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
)
# for test purpose, we don't need to get a subset
calibration_data = dataset
# concatenat inputs
input_cat = torch.cat(calibration_data, dim=-2)

# calibrate
# calibrate

quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

for example in calibration_data:
m(example)
for example in calibration_data:
m(example)

# quantize
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)
# quantize
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)

with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)
with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)

loaded_model = ToyLinearModel(
l1, l2, l3, device=device, dtype=original_dtype
).eval()
loaded_model.load_state_dict(state_dict, assign=True)
loaded_model = ToyLinearModel(
l1, l2, l3, device=device, dtype=original_dtype
).eval()
loaded_model.load_state_dict(state_dict, assign=True)

m = torch.compile(m, fullgraph=True)
loaded_model = torch.compile(loaded_model, fullgraph=True)
m = torch.compile(m, fullgraph=True)
loaded_model = torch.compile(loaded_model, fullgraph=True)

awq_out = m(input_cat)
awq_save_load_out = loaded_model(input_cat)
awq_out = m(input_cat)
awq_save_load_out = loaded_model(input_cat)

assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)

@parametrize("device", devices)
def test_awq_loading_vllm(self, device):
@parametrize("device,base_config", configs)
def test_awq_loading_vllm(self, device, base_config):
"""Simulate weight loading in vllm:
* prepare model weight to the same format (awq weight)
* use weight.copy_(state_dict["weight"]) to copy over the quantized weights from checkpoint
Expand All @@ -209,55 +201,51 @@ def test_awq_loading_vllm(self, device):
original_dtype = torch.bfloat16 # tinygemm kernel only uses bfloat16 inputs
sequence_length = 5

assert device in device_to_base_configs, "Unsupported device: {}".format(device)
base_configs = device_to_base_configs[device]

for base_config in base_configs:
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
)
# for test purpose, we don't need to get a subset
calibration_data = dataset
# concatenatd inputs
input_cat = torch.cat(calibration_data, dim=-2)

# calibrate
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

for example in calibration_data:
m(example)

# quantize
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)

with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)

loaded_model = ToyLinearModel(
l1, l2, l3, device=device, dtype=original_dtype
).eval()
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
quantize_(loaded_model, quant_config)

loaded_model.linear1.weight.copy_(state_dict["linear1.weight"])
loaded_model.linear2.weight.copy_(state_dict["linear2.weight"])
loaded_model.linear3.weight.copy_(state_dict["linear3.weight"])

m = torch.compile(m, fullgraph=True)
loaded_model = torch.compile(loaded_model, fullgraph=True)

awq_out = m(input_cat)
awq_save_load_out = loaded_model(input_cat)

assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)
m = ToyLinearModel(l1, l2, l3, device=device, dtype=original_dtype).eval()
dataset = m.example_inputs(
dataset_size,
sequence_length=sequence_length,
)
# for test purpose, we don't need to get a subset
calibration_data = dataset
# concatenatd inputs
input_cat = torch.cat(calibration_data, dim=-2)

# calibrate
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
quantize_(m, quant_config)

for example in calibration_data:
m(example)

# quantize
quant_config = AWQConfig(base_config, step=AWQStep.CONVERT)
quantize_(m, quant_config)

with tempfile.NamedTemporaryFile() as f:
torch.save(m.state_dict(), f)
f.seek(0)
state_dict = torch.load(f)

loaded_model = ToyLinearModel(
l1, l2, l3, device=device, dtype=original_dtype
).eval()
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE_FOR_LOADING)
quantize_(loaded_model, quant_config)

loaded_model.linear1.weight.copy_(state_dict["linear1.weight"])
loaded_model.linear2.weight.copy_(state_dict["linear2.weight"])
loaded_model.linear3.weight.copy_(state_dict["linear3.weight"])

m = torch.compile(m, fullgraph=True)
loaded_model = torch.compile(loaded_model, fullgraph=True)

awq_out = m(input_cat)
awq_save_load_out = loaded_model(input_cat)

assert awq_out is not None
assert awq_save_load_out is not None
assert torch.allclose(awq_out, awq_save_load_out, atol=1e-2)


instantiate_parametrized_tests(TestAWQ)
Expand Down
2 changes: 1 addition & 1 deletion torchao/_models/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class TransformerEvalWrapper(eval_wrapper):
"""

def __init__(
self, model, tokenizer, max_seq_length, input_prep_func=None, device="cuda"
self, model, tokenizer, max_seq_length=1024, input_prep_func=None, device="cuda"
):
try:
super().__init__(device=device)
Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/awq/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def quantize_and_eval(
)

# Optional arguments with default values
parser.add_argument("--repo", type=str, help="Repository ID of the model.")
parser.add_argument("--model", type=str, help="Repository ID of the model.")
parser.add_argument(
"--quant",
type=str,
Expand Down Expand Up @@ -402,7 +402,7 @@ def quantize_and_eval(
# Convert precision argument to torch dtype
precision_dtype = getattr(torch, args.precision, torch.bfloat16)
result = quantize_and_eval(
args.repo,
args.model,
args.quant,
args.tasks,
args.max_seq_length,
Expand Down