Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
iankur committed Sep 14, 2024
1 parent 77dfe74 commit 7dc6262
Showing 1 changed file with 119 additions and 22 deletions.
141 changes: 119 additions & 22 deletions vqllm/utils/checkpointer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import os
import re
from pathlib import Path
Expand All @@ -7,7 +8,6 @@
import torchtune
from torchtune import utils
from torchtune.models import convert_weights
from torchtune.models.gemma import gemma_hf_to_tune, gemma_tune_to_hf
from torchtune.utils._checkpointing._checkpointer_utils import (
ModelType,
safe_torch_load,
Expand Down Expand Up @@ -72,6 +72,103 @@ def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
return new_key


def gemma_hf_to_tune(
state_dict: Dict[str, torch.Tensor],
num_heads: int = 8,
num_kv_heads: int = 1,
dim: int = 2048,
head_dim: int = 256,
) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from HF's format to TorchTune's format, which contains the weights
of a Gemma model.
State dicts from multiple checkpoint files should be consolidated into a single state dict
before calling this function.
The logic is identical to :func:`~torchtune.models.convert_weights.hf_to_tune`, but doesn't load
output projection weights.
Args:
state_dict (Dict[str, torch.Tensor]): State dict in HF's format.
num_heads (int): Number of heads in the model. Defaults to 8.
num_kv_heads (int): Number of heads in the key/value projection layers. Defaults to 1.
dim (int): Dimension of the model. Defaults to 2048.
head_dim (int): Dimension of the attention head. This value is explicit in Gemma confs. Defaults to 256.
Returns:
Dict[str, torch.Tensor]: State dict in TorchTune's format.
"""
converted_state_dict = {}

def _permute(t, n_heads):
return (
t.view(n_heads, 2, head_dim // 2, dim)
.transpose(1, 2)
.reshape((head_dim * n_heads), dim)
)

for key, value in state_dict.items():
if (
"rotary_emb.inv_freq" not in key and "lm_head.weight" not in key
): # Skip loading the position embeddings and output projection weights
new_key = get_mapped_key(key, _FROM_HF)
if "q_proj" in key:
value = _permute(value, num_heads)
elif "k_proj" in key:
value = _permute(value, num_kv_heads)
converted_state_dict[new_key] = value
return converted_state_dict


def gemma_tune_to_hf(
state_dict: Dict[str, torch.Tensor],
num_heads: int = 8,
num_kv_heads: int = 1,
dim: int = 2048,
head_dim: int = 256,
) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from TorchTune's format to Hugging Face's format for Gemma.
This function takes a state dictionary in TorchTune's format, which contains the weights of a Gemma model,
and converts it into a format that can be loaded into a Hugging Face model.
The logic is identical to :func:`~torchtune.models.convert_weights.tune_to_hf`, but saves the tied
output projection weights.
Args:
state_dict (Dict[str, torch.Tensor]): State dict in TorchTune's format.
num_heads (int, optional): Number of heads in the model. Defaults to 8.
num_kv_heads (int, optional): Number of heads in the key/value projection layers. Defaults to 1.
dim (int, optional): Dimension of the model. Defaults to 2048.
head_dim (int): Dimension of the attention head. This value is explicit in Gemma confs. Defaults to 256.
Returns:
Dict[str, torch.Tensor]: State dict in Hugging Face's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _FROM_HF.items()}

def _permute(t, n_heads):
return (
t.view(n_heads, head_dim // 2, 2, dim)
.transpose(1, 2)
.reshape((head_dim * n_heads), dim)
)

for key, value in state_dict.items():
new_key = get_mapped_key(key, inverted_mapping_dict)
if "q_proj" in key:
value = _permute(value, num_heads)
elif "k_proj" in key:
value = _permute(value, num_kv_heads)
elif "tok_embeddings" in key:
# HF also uses tied weights, see
# https://github.com/huggingface/transformers/blob/14ff5dd962c1bd0a4e3adaac347ba396d8df5add/src/transformers/models/gemma/convert_gemma_weights_to_hf.py#L104
converted_state_dict["lm_head.weight"] = value
converted_state_dict[new_key] = value
return converted_state_dict


def load_hf_checkpoint(self) -> Dict[str, Any]:
self._weight_map = {}

Expand Down Expand Up @@ -101,27 +198,27 @@ def load_hf_checkpoint(self) -> Dict[str, Any]:
del state_dict
gc.collect()

if self._model_type == ModelType.GEMMA:
converted_state_dict[utils.MODEL_KEY] = gemma_hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config["head_dim"],
)
else:
converted_state_dict[utils.MODEL_KEY] = convert_weights.hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config.get("head_dim", None),
)

if self._resume_from_checkpoint:
recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False)
converted_state_dict.update(recipe_state)
return converted_state_dict
if self._model_type == ModelType.GEMMA:
converted_state_dict[utils.MODEL_KEY] = gemma_hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config["head_dim"],
)
else:
converted_state_dict[utils.MODEL_KEY] = convert_weights.hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config.get("head_dim", None),
)

if self._resume_from_checkpoint:
recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False)
converted_state_dict.update(recipe_state)
return converted_state_dict


def save_hf_checkpoint(
Expand Down

0 comments on commit 7dc6262

Please sign in to comment.