Skip to content

Commit

Permalink
Add torchtune convertor and README changes (pytorch#444)
Browse files Browse the repository at this point in the history
* Add torchtune convertor

* Update
  • Loading branch information
kartikayk authored and malfet committed Jul 17, 2024
1 parent a99dc12 commit c44ca50
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 0 deletions.
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,37 @@ Read the [iOS documentation](docs/iOS.md) for more details on iOS.

Read the [Android documentation](docs/Android.md) for more details on Android.

## Fine-tuned models from torchtune

torchchat supports running inference with models fine-tuned using [torchtune](https://github.com/pytorch/torchtune). To do so, we first need to convert the checkpoints into a format supported by torchchat.

Below is a simple workflow to run inference on a fine-tuned Llama3 model. For more details on how to fine-tune Llama3, see the instructions [here](https://github.com/pytorch/torchtune?tab=readme-ov-file#llama3)

```bash
# install torchtune
pip install torchtune

# download the llama3 model
tune download meta-llama/Meta-Llama-3-8B \
--output-dir ./Meta-Llama-3-8B \
--hf-token <ACCESS TOKEN>

# Run LoRA fine-tuning on a single device. This assumes the config points to <checkpoint_dir> above
tune run lora_finetune_single_device --config llama3/8B_lora_single_device

# convert the fine-tuned checkpoint to a format compatible with torchchat
python3 build/convert_torchtune_checkpoint.py \
--checkpoint-dir ./Meta-Llama-3-8B \
--checkpoint-files meta_model_0.pt \
--model-name llama3_8B \
--checkpoint-format meta

# run inference on a single GPU
python3 torchchat.py generate \
--checkpoint-path ./Meta-Llama-3-8B/model.pth \
--device cuda
```

## Acknowledgements
Thank you to the [community](docs/ACKNOWLEDGEMENTS.md) for all the awesome libraries and tools
you've built around local LLM inference.
Expand Down
179 changes: 179 additions & 0 deletions build/convert_torchtune_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os
import re
import sys
import logging
from pathlib import Path
from typing import Dict, List, Optional

import torch

# support running without installing as a package
wd = Path(__file__).parent.parent
sys.path.append(str(wd.resolve()))
sys.path.append(str((wd / "build").resolve()))

logger = logging.getLogger(__name__)

MODEL_CONFIGS = {
"llama2_7B": {"num_heads": 32, "num_kv_heads": 32, "dim": 4096},
"llama3_8B": {"num_heads": 32, "num_kv_heads": 8, "dim": 4096},
}

WEIGHT_MAP = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}


def from_hf(
merged_result: Dict[str, torch.Tensor],
num_heads: int = 32,
num_kv_heads: int = 32,
dim: int = 4096
) -> Dict[str, torch.Tensor]:
"""
Utility function which converts the given state_dict from the HF format
to one that is compatible with torchchat. The HF-format model involve
permuting the query and key tensors and this requires additional arguments
such as num_heads, num_kv_heads and dim.
"""

def permute(w, n_heads):
head_dim = dim // n_heads
return (
w.view(n_heads, 2, head_dim // 2, dim)
.transpose(1, 2)
.reshape(head_dim * n_heads, dim)
)

# Replace the keys with the version compatible with torchchat
final_result = {}
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r"(\d+)", "{}", key)
layer_num = re.search(r"\d+", key).group(0)
new_key = WEIGHT_MAP[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = WEIGHT_MAP[key]

final_result[new_key] = value

# torchchat expects a fused q,k and v matrix
for key in tuple(final_result.keys()):
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, num_heads)
k = permute(k, num_kv_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
return final_result


@torch.inference_mode()
def convert_torchtune_checkpoint(
*,
checkpoint_dir: Path,
checkpoint_files: List[str],
checkpoint_format: str,
model_name: str,
) -> None:

# Sanity check all for all of the params
if not checkpoint_dir.is_dir():
raise RuntimeError(f"{checkpoint_dir} is not a directory")

if len(checkpoint_files) == 0:
raise RuntimeError("No checkpoint files provided")

for file in checkpoint_files:
if not (Path.joinpath(checkpoint_dir, file)).is_file():
raise RuntimeError(f"{checkpoint_dir / file} is not a file")

# If the model is already in meta format, simply rename it
if checkpoint_format == 'meta':
if len(checkpoint_files) > 1:
raise RuntimeError("Multiple meta format checkpoint files not supported")

checkpoint_path = Path.joinpath(checkpoint_dir, checkpoint_files[0])
loaded_result = torch.load(
checkpoint_path, map_location="cpu", mmap=True, weights_only=True
)
del loaded_result

os.rename(checkpoint_path, Path.joinpath(checkpoint_dir, "model.pth"))

# If the model is in HF format, merge all of the checkpoints and then convert
elif checkpoint_format == 'hf':
merged_result = {}
for file in checkpoint_files:
state_dict = torch.load(
Path.joinpath(checkpoint_dir, file), map_location="cpu", mmap=True, weights_only=True
)
merged_result.update(state_dict)

model_config = MODEL_CONFIGS[model_name]
final_result = from_hf(merged_result, **model_config)

print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}. This may take a while.")
torch.save(final_result, Path.joinpath(checkpoint_dir, "model.pth"))
print("Done.")



if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Convert torchtune checkpoint.")
parser.add_argument(
"--checkpoint-dir",
type=Path,
required=True,
)
parser.add_argument(
"--checkpoint-files",
nargs='+',
required=True,
)
parser.add_argument(
"--checkpoint-format",
type=str,
required=True,
choices=['meta', 'hf'],
)
parser.add_argument(
"--model-name",
type=str,
choices=['llama2_7B', 'llama3_8B'],
)

args = parser.parse_args()
convert_torchtune_checkpoint(
checkpoint_dir=args.checkpoint_dir,
checkpoint_files=args.checkpoint_files,
checkpoint_format=args.checkpoint_format,
model_name=args.model_name,
)

0 comments on commit c44ca50

Please sign in to comment.