Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[LLM] Support woq model save and load (#1211)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss authored Feb 1, 2024
1 parent ff501d0 commit 1c8078f
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
# ============AutoModel parameters==============
parser.add_argument("--load_in_4bit", type=bool, default=False)
parser.add_argument("--load_in_8bit", type=bool, default=False)
parser.add_argument("--_commit_hash", default="main", type=str)
parser.add_argument("--_commit_hash", default=None, type=str)
parser.add_argument("--trust_remote_code", type=bool, default=False)
parser.add_argument("--use_llm_runtime", action="store_true")
# =======================================
Expand Down Expand Up @@ -335,22 +335,20 @@
else:
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
config=config,
trust_remote_code=args.trust_remote_code,
_commit_hash=args._commit_hash,
use_llm_runtime=args.use_llm_runtime,
)

# save model
if args.output_dir:
tokenizer.save_pretrained(args.output_dir)
if args.sq:
config.save_pretrained(args.output_dir)
user_model.save(args.output_dir)
elif args.mixed_precision:
user_model.config.save_pretrained(args.output_dir)
torch.save(
user_model.state_dict(), os.path.join(args.output_dir, "pytorch_model.bin")
)
elif args.mixed_precision or args.woq:
user_model.save_pretrained(args.output_dir)



# int8 model loading
if args.int8 or args.int8_bf16_mixed:
Expand Down
10 changes: 2 additions & 8 deletions intel_extension_for_transformers/llm/quantization/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,6 @@ def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)

if getattr(self.weight, "quant_state", None) is None:
print(
"FP4 quantization state not initialized. Please call .quantize_weights()."
)

shape = list(x.size())
m = reduce(mul, shape[0:-1])
out = torch.zeros(m, self.out_features, dtype=x.dtype)
Expand All @@ -143,7 +137,8 @@ def forward(self, x: torch.Tensor):
return out

def set_weights_bias(self, weight_data, bias=None):
shape = weight_data.shape
if weight_data.is_meta:
weight_data = torch.ones(weight_data.shape, dtype=torch.float)
weight = torch.ops.bestlaop.woq_quantize(
weight_data,
True,
Expand All @@ -153,7 +148,6 @@ def set_weights_bias(self, weight_data, bias=None):
self.scale_dtype if self.scale_dtype is not None else "fp32",
False,
)
weight.resize_(shape)
self.weight = ParamsQBits(
data=weight,
requires_grad=False,
Expand Down
5 changes: 5 additions & 0 deletions intel_extension_for_transformers/llm/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ def _replace_linear(
module.weight.data,
None if module.bias is None else module.bias.data,
)
else:
model._modules[name].set_weights_bias(
module.weight.data,
None if module.bias is None else module.bias.data,
)
else:
if not hasattr(module, "qweight"):
n_pack = 8 // DTYPE_BITS_MAPPING[quantization_config.weight_dtype]
Expand Down
2 changes: 2 additions & 0 deletions intel_extension_for_transformers/transformers/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool =
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
"""
# set tokenizer to None due to it doesn't support write to json
self.tokenizer = None
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string(use_diff=use_diff))

Expand Down

0 comments on commit 1c8078f

Please sign in to comment.