Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

[LLM] Support woq model save and load #1211

Merged
merged 4 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
# ============AutoModel parameters==============
parser.add_argument("--load_in_4bit", type=bool, default=False)
parser.add_argument("--load_in_8bit", type=bool, default=False)
parser.add_argument("--_commit_hash", default="main", type=str)
parser.add_argument("--_commit_hash", default=None, type=str)
parser.add_argument("--trust_remote_code", type=bool, default=False)
parser.add_argument("--use_llm_runtime", action="store_true")
# =======================================
Expand Down Expand Up @@ -335,22 +335,20 @@
else:
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
config=config,
trust_remote_code=args.trust_remote_code,
_commit_hash=args._commit_hash,
use_llm_runtime=args.use_llm_runtime,
)

# save model
if args.output_dir:
tokenizer.save_pretrained(args.output_dir)
if args.sq:
config.save_pretrained(args.output_dir)
user_model.save(args.output_dir)
elif args.mixed_precision:
user_model.config.save_pretrained(args.output_dir)
torch.save(
user_model.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin")
)
elif args.mixed_precision or args.woq:
user_model.save_pretrained(args.output_dir)



# int8 model loading
if args.int8 or args.int8_bf16_mixed:
Expand Down
10 changes: 2 additions & 8 deletions intel_extension_for_transformers/llm/quantization/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,6 @@ def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)

if getattr(self.weight, "quant_state", None) is None:
print(
"FP4 quantization state not initialized. Please call .quantize_weights()."
)

shape = list(x.size())
m = reduce(mul, shape[0:-1])
out = torch.zeros(m, self.out_features, dtype=x.dtype)
Expand All @@ -143,7 +137,8 @@ def forward(self, x: torch.Tensor):
return out

def set_weights_bias(self, weight_data, bias=None):
shape = weight_data.shape
if weight_data.is_meta:
weight_data = torch.ones(weight_data.shape, dtype=torch.float)
weight = torch.ops.bestlaop.woq_quantize(
weight_data,
True,
Expand All @@ -153,7 +148,6 @@ def set_weights_bias(self, weight_data, bias=None):
self.scale_dtype if self.scale_dtype is not None else "fp32",
False,
)
weight.resize_(shape)
self.weight = ParamsQBits(
data=weight,
requires_grad=False,
Expand Down
5 changes: 5 additions & 0 deletions intel_extension_for_transformers/llm/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ def _replace_linear(
module.weight.data,
None if module.bias is None else module.bias.data,
)
else:
model._modules[name].set_weights_bias(
module.weight.data,
None if module.bias is None else module.bias.data,
)
else:
if not hasattr(module, "qweight"):
n_pack = 8 // DTYPE_BITS_MAPPING[quantization_config.weight_dtype]
Expand Down
2 changes: 2 additions & 0 deletions intel_extension_for_transformers/transformers/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool =
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
"""
# set tokenizer to None due to it doesn't support write to json
self.tokenizer = None
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string(use_diff=use_diff))

Expand Down
Loading