Skip to content

Commit

Permalink
load gemma checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
iankur committed Sep 13, 2024
1 parent bd08400 commit 77dfe74
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 8 deletions.
2 changes: 1 addition & 1 deletion recipes/run_vq_model_ablation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,5 @@ tune run recipes/eleuther_eval.py \
$VQ_KEY=True $VQ_VALUE=True $FAST_QUANTIZER=True \
$RESIDUAL_CODEBOOKS=$K $CODES=$C $CODE_DIM=$dhat \
$REORDER_CHANNEL=True checkpointer.checkpoint_dir=$CKPT_DIR \
checkpointer.checkpoint_files=[hf_model_0001_0.pt,hf_model_0002_0.pt,hf_model_0003_0.pt] \
checkpointer.checkpoint_files=[hf_model_0001_0.pt,hf_model_0002_0.pt,hf_model_0003_0.pt,hf_model_0004_0.pt,hf_model_0005_0.pt] \
$WANDB_PROJECT $WANDB_GROUP $WANDB_NAME="eval_vq_gemma_7b"
4 changes: 4 additions & 0 deletions vqllm/models/_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,7 @@ def forward(
if return_vq_loss:
return output, commitment_loss
return output

def update_quantizer_weight(self):
for layer in self.layers:
layer.update_quantizer_weight()
81 changes: 74 additions & 7 deletions vqllm/utils/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
import torchtune
from torchtune import utils
from torchtune.models import convert_weights
from torchtune.models.gemma import gemma_hf_to_tune, gemma_tune_to_hf
from torchtune.utils._checkpointing._checkpointer_utils import (
ModelType,
safe_torch_load,
)
from torchtune.utils.logging import get_logger

logger = get_logger("DEBUG")
Expand Down Expand Up @@ -67,19 +72,80 @@ def get_mapped_key(key: str, mapping_dict: Dict[str, str]) -> str:
return new_key


def load_hf_checkpoint(self) -> Dict[str, Any]:
self._weight_map = {}

# merged state_dict contains keys and weights from all the checkpoint files
merged_state_dict: Dict[str, torch.Tensor] = {}

# converted_state_dict is the final state_dict passed to the recipe after the
# keys are converted into the torchtune format. This optionally also contains
# the recipe state and adapter weights
converted_state_dict: Dict[str, Dict[str, torch.Tensor]] = {}

for cpt_idx, cpt_path in enumerate(self._checkpoint_paths):
state_dict = safe_torch_load(cpt_path)
for key, value in state_dict.items():
# Ensure that the state dict is a flat dict of keys and tensors. Breaking this assumption
# will break recipe code
if not isinstance(value, torch.Tensor):
raise ValueError(
f"Expected all values in the state dict to be torch.Tensor. "
f"Found {type(value)} instead."
)
# idx is written in the 4 digit format (eg: 0001, 0002, etc.)
self._weight_map[key] = f"{cpt_idx+1:04}"
merged_state_dict.update(state_dict)

# delete the state_dict to free up memory; TODO check if this del is needed
del state_dict
gc.collect()

if self._model_type == ModelType.GEMMA:
converted_state_dict[utils.MODEL_KEY] = gemma_hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config["head_dim"],
)
else:
converted_state_dict[utils.MODEL_KEY] = convert_weights.hf_to_tune(
merged_state_dict,
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config.get("head_dim", None),
)

if self._resume_from_checkpoint:
recipe_state = safe_torch_load(self._recipe_checkpoint, mmap=False)
converted_state_dict.update(recipe_state)
return converted_state_dict


def save_hf_checkpoint(
self,
state_dict: Dict[str, Any],
epoch: int,
intermediate_checkpoint: bool = False,
) -> None:
state_dict[utils.MODEL_KEY] = convert_weights.tune_to_hf(
state_dict[utils.MODEL_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config["head_dim"],
)
if self._model_type == ModelType.GEMMA:
state_dict[utils.MODEL_KEY] = gemma_tune_to_hf(
state_dict[utils.MODEL_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config["head_dim"],
)
else:
state_dict[utils.MODEL_KEY] = convert_weights.tune_to_hf(
state_dict[utils.MODEL_KEY],
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
head_dim=self._config["head_dim"],
)
self._output_dir.mkdir(exist_ok=True)

# split the state_dict into separate dicts, one for each output checkpoint file
Expand Down Expand Up @@ -109,5 +175,6 @@ def save_hf_checkpoint(
torchtune.models.convert_weights._FROM_META = _FROM_META
torchtune.models.convert_weights.get_mapped_key = get_mapped_key
FullModelHFCheckpointer = torchtune.utils._checkpointing.FullModelHFCheckpointer
FullModelHFCheckpointer.load_checkpoint = load_hf_checkpoint
FullModelHFCheckpointer.save_checkpoint = save_hf_checkpoint
FullModelMetaCheckpointer = torchtune.utils._checkpointing.FullModelMetaCheckpointer

0 comments on commit 77dfe74

Please sign in to comment.