From 2de7c3302344e2d14312541cd2f0de487408d32f Mon Sep 17 00:00:00 2001 From: Chen Wu <72850361+CNTRYROA@users.noreply.github.com> Date: Sat, 23 Nov 2024 13:13:59 +0800 Subject: [PATCH] [Model] Fix Baichuan BNB online quantization (#10572) Signed-off-by: Chen Wu Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com> --- vllm/model_executor/models/baichuan.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index a923ed36a9db2..39cb5a8b2cbbe 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -350,6 +350,21 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_modules = {} embedding_padding_modules = [] + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".W_pack.", + ".o_proj.", + ".down_proj.", + ".up_proj.", + ".gate_proj.", + ".up_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def __init__( self, *,