diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 50fab4cced37..c235cc766209 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -385,7 +385,7 @@ def update_mp_params(self, child): return for param in [ "n_heads", "inner_dim", "num_heads", "num_kv", "num_attention_heads", "num_attn_heads", - "all_head_size", "embed_dim", "hidden_size", "num_key_value_heads" + "all_head_size", "embed_dim", "hidden_size", "num_key_value_heads", "num_kv_heads" ]: if hasattr(child, param): param_val = getattr(child, param)