Skip to content

Commit

Permalink
Add support for Llama3-70b (#101)
Browse files Browse the repository at this point in the history
* Add support for Llama3-70b

* Fix unit tests

* assert model_name is one of llama-2 or llama-3 for weight sharding

* Fix lint

* Revert separate shardings for llama-2 and llama-3

* Fix lint
  • Loading branch information
bhavya01 authored Jun 10, 2024
1 parent e07aee6 commit 4535bdf
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 8 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --
python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
```

## Llama-3 70b
```bash
python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
```

## Gemma 7b
```bash
python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
Expand Down
11 changes: 6 additions & 5 deletions convert_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,21 @@ def _merge_llama_weights(
f"{len(tensors)} shards (shape = {tensors[0].shape}) for {key})"
)
state_dict_for_key = {}
weight_sharding_type = (
llama_model.Transformer.get_weight_sharding_type().items()
)

weight_sharding_type = llama_model.Transformer.get_weight_sharding_type(
model_name=FLAGS.model_name
).items()
for pattern, kind in weight_sharding_type:
if not key.endswith(pattern):
continue
with torch.no_grad():
if kind in ("ParallelEmbedding", "RowParallelLinear"):
state_dict_for_key[key] = torch.cat(tensors, 1)
elif kind == "ColumnParallelLinear":
elif kind in ("ColumnParallelLinear", "VocabParallelEmbedding"):
state_dict_for_key[key] = torch.cat(tensors, 0)
else:
if not all(
torch.allclose(tensors[0], tensor, atol=1e-6)
torch.allclose(tensors[0], tensor, atol=1e-2)
for tensor in tensors[1:]
):
raise ValueError(
Expand Down
13 changes: 13 additions & 0 deletions jetstream_pt/third_party/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ def get_arg(
"norm_eps": 1e-05,
"rope_theta": 500000.0,
}
elif model_name == "llama-3-70b":
data = {
"dim": 8192,
"ffn_dim_multiplier": 1.3,
"multiple_of": 4096,
"n_heads": 64,
"n_kv_heads": 8,
"n_layers": 80,
"norm_eps": 1e-05,
"vocab_size": 128256,
"rope_theta": 500000.0,
}

return ModelArgs(
max_seq_len=seqlen,
max_batch_size=batch_size,
Expand Down
15 changes: 12 additions & 3 deletions jetstream_pt/third_party/llama/model_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,17 @@ def get_quantized_embedding_weight_to_scaler_map():
}

@staticmethod
def get_weight_sharding_type():
def get_weight_sharding_type(model_name: str = ""):
# ParallelEmbedding is col partitioned across the shards.
# VocalParallelEmbedding is row partitioned across the shards.
# ColumnParallelLinear is row partitioned across shards due to transpose.
# RowParallelLinear is col partitioned across shards due to transpose.
# None is no partitioning and tensor should be identical across shards
return {
"tok_embeddings.weight": "ParallelEmbedding",
expected_model_names = ("llama-2", "llama-3")
assert (
model_name in expected_model_names
), f"Expected model_name to one of {expected_model_names}"
sharding_dict = {
"rope.freqs": None,
"attention.wq.weight": "ColumnParallelLinear",
"attention.wk.weight": "ColumnParallelLinear",
Expand All @@ -279,3 +283,8 @@ def get_weight_sharding_type():
"norm.weight": None,
"output.weight": "ColumnParallelLinear",
}
if model_name == "llama-2":
sharding_dict["tok_embeddings.weight"] = "ParallelEmbedding"
elif model_name == "llama-3":
sharding_dict["tok_embeddings.weight"] = "VocabParallelEmbedding"
return sharding_dict

0 comments on commit 4535bdf

Please sign in to comment.