Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtral enablement. #120

Merged
merged 17 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,21 @@ huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir

Need to manually modify the `config.json` in the checkpoint folder to make it a valid JSON file. (Replace `'` with `"`, remove the excessive `,` after the last item in the JSON object)

## Mixtral
### Get Mixtral Checkpoint from HuggingFace

Please sign agreement on Huggingface website to access Mixtral checkpoints. Download Mixtral PyTorch checkpoint using huggingface-cli. Mixtral Tokenizer is included in the checkpoint.

```bash
huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir $input_ckpt_dir
```

## Run weight safetensor convert

```bash
export input_ckpt_dir=Original llama weights directory
export output_ckpt_dir=The output directory
export model_name="llama-3" # or "llama-2", "gemma"
export model_name="llama-3" # or "llama-2", "gemma", "mistral"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change this to mixtral

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I was confused about the name initially and that's why there are mixes of mistral and mixtral. I also changes everything to Mixtral. Done.

export quantize_weights=True # Whether to quantize weights
export quantize_type="int8_per_channel" # "quantize_weights" needs to be turned on. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, "int8_per_channel" is the default option if not specified.
python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_type=$quantize_type
Expand Down Expand Up @@ -108,6 +117,11 @@ python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --m
python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
```

## Mixtral 8x7b
```bash
python run_interactive.py --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
```


# Run the server
Here is an example to run the server with llama2 7B config.
Expand Down
89 changes: 89 additions & 0 deletions convert_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import hashlib
import json
import os
import re
import time

import torch
Expand All @@ -37,6 +38,8 @@
from jetstream_pt.config import FLAGS
from jetstream_pt.third_party.gemma import model as gemma_model
from jetstream_pt.third_party.llama import model_exportable as llama_model
from jetstream_pt.third_party.mistral import model as mistral_model, config as mistral_config

from safetensors import safe_open
from safetensors.torch import save_file

Expand Down Expand Up @@ -123,6 +126,10 @@ def _quantize_state_dict(
block_size = orig_block_size
n_bit = orig_n_bit
state_dict.update(updated_weights)
for k, v in state_dict.items():
if "layers" in k and "layers.0" not in k:
continue
print(f"After quantization the converted key: {k} and value: {v.shape} {v.dtype}")
return state_dict


Expand Down Expand Up @@ -462,6 +469,80 @@ def _get_gemma_state_dict(input_ckpt_dir):
return state_dict, model_config


def _get_mistral_state_dict(input_ckpt_dir):
ckpt_files = list(input_ckpt_dir.glob("*.pt"))
assert len(ckpt_files) == 8, "only expect 8 ckpt file for Mistral model."

start = time.perf_counter()
state_dict = {}
for file in sorted(ckpt_files):
ckpt = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
state_dict.update(ckpt)
end = time.perf_counter()
print(f"Loading checkpoints takes {end - start} seconds")

for k, v in state_dict.items():
if "layers" in k and "layers.0" not in k:
continue
print(f"The loaded key: {k} and value: {v.shape} {v.dtype}")

config = json.loads((input_ckpt_dir / "config.json").read_text())
print(f"Loaded config: {config}")
#config = mistral_config.ModelArgs.from_name("Mixtral-8x7B-v0.1")
weight_map = {
"tok_embeddings.weight": "tok_embeddings.weight",
"layers.{}.attention.wq.weight": "layers.{}.attention.wq.weight",
"layers.{}.attention.wk.weight": "layers.{}.attention.wk.weight",
"layers.{}.attention.wv.weight": "layers.{}.attention.wv.weight",
"layers.{}.attention.wo.weight": "layers.{}.attention.wo.weight",
"layers.{}.block_sparse_moe.w1": "layers.{}.block_sparse_moe.cond_ffn.w1",
"layers.{}.block_sparse_moe.w2": "layers.{}.block_sparse_moe.cond_ffn.w2",
"layers.{}.block_sparse_moe.w3": "layers.{}.block_sparse_moe.cond_ffn.w3",
"layers.{}.block_sparse_moe.gate.weight": "layers.{}.block_sparse_moe.gate.weight",
"layers.{}.attention_norm.weight": "layers.{}.attention_norm.weight",
"layers.{}.ffn_norm.weight": "layers.{}.ffn_norm.weight",
"norm.weight": "norm.weight",
"output.weight": "output.weight",
}
for key in list(state_dict.keys()):
if state_dict[key].dtype.is_complex and _OUTPUT_SAFETENSORS.value:
assert (
key == "freqs_cis"
), "Only expect key 'freqs_cis' in the state_dict has complex dtype."
# Remove "freqs_cis" since it has complex dtype, and safetensor doesn't support it.
# The "freqs_cis" will be reconstructed when it's loaded by inference engine.
state_dict.pop(key)
continue
prefix_to_remove = "model."
new_key = key
if key.startswith(prefix_to_remove):
new_key = new_key.removeprefix(prefix_to_remove)

if "layers" in key:
abstract_key = re.sub(r'.(\d+).', '.{}.', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
new_key = new_key.format(layer_num)
if new_key is None:
continue

if new_key == key:
continue

if "w1" in key or "w3" in key:
state_dict[new_key] = state_dict.pop(key).reshape(config["num_local_experts"], config["intermediate_size"], config["hidden_size"]).contiguous()
elif "w2" in key:
state_dict[new_key] = state_dict.pop(key).reshape(config["num_local_experts"], config["intermediate_size"], config["hidden_size"]).permute(0, 2, 1).contiguous()
elif "gate" in key:
state_dict[new_key] = state_dict.pop(key).contiguous()
else:
state_dict[new_key] = state_dict.pop(key)
for k, v in state_dict.items():
if "layers" in k and "layers.0" not in k:
continue
print(f"The converted key: {k} and value: {v.shape} {v.dtype}")
return state_dict, config

def main(argv) -> None:
"""merge weights"""

Expand All @@ -473,6 +554,14 @@ def main(argv) -> None:
quantize_embedding_weight_map = (
gemma_model.GemmaModel.get_quantized_embedding_weight_to_scaler_map()
)
elif FLAGS.model_name == "mistral":
state_dict, params = _get_mistral_state_dict(_INPUT_CHECKPOINT_DIR.value)
quantize_linear_weight_map = (
mistral_model.Transformer.get_quantized_linear_weight_to_scaler_map()
)
quantize_embedding_weight_map = (
mistral_model.Transformer.get_quantized_embedding_weight_to_scaler_map()
)
else:
state_dict, params = _get_llama_state_dict(_INPUT_CHECKPOINT_DIR.value)
quantize_linear_weight_map = (
Expand Down
32 changes: 32 additions & 0 deletions default_shardings/mistral.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

# Sharding config for mixtral
# Sharding should either be an int between 0 and rank - 1
# signifying the axis to shard or -1 / null signifying replicated


freqs_cis : -1 # torch.complex64 (2048, 64)
tok_embeddings.weight : 1 # torch.float32 (vocab_size, 4096)
tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096)
layers.*.attention.wo.weight_scaler : -1 # torch.bfloat16 (4096,)
layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096)
layers.*.attention.wq.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention.wk.weight : 0 # torch.int8 (4096, 4096)
layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096)
layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.attention.wqkv.weight : 0 # torch.int8 (4096, 4096)
layers.*.attention.wqkv.weight_scaler : 0 # torch.bfloat16 (4096,)
layers.*.block_sparse_moe.gate.weight: -1
layers.*.block_sparse_moe.gate.weight_scaler: -1
layers.*.block_sparse_moe.cond_ffn.w1: 1
layers.*.block_sparse_moe.cond_ffn.w1_scaler: 1
layers.*.block_sparse_moe.cond_ffn.w2: 2
layers.*.block_sparse_moe.cond_ffn.w2_scaler: -1
layers.*.block_sparse_moe.cond_ffn.w3: 1
layers.*.block_sparse_moe.cond_ffn.w3_scaler: 1
layers.*.ffn_norm.weight : -1 # torch.float32 (4096,)
layers.*.attention_norm.weight : -1 # torch.float32 (4096,)
norm.weight : -1 # torch.float32 (4096,)
output.weight : 0 # torch.float32 (vocab_size, 4096)
output.weight_scaler : 0 # torch.float32 (4096,)
8 changes: 7 additions & 1 deletion jetstream_pt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,13 @@ def create_engine_from_config_flags():
sharding_file_name = FLAGS.sharding_config
if not sharding_file_name:
sharding_file_name = (
"llama" if FLAGS.model_name.startswith("llama") else "gemma"
"llama"
if FLAGS.model_name.startswith("llama")
else "gemma"
if FLAGS.model_name.startswith("gemma")
else "mistral"
if FLAGS.model_name.startswith("mistral")
else None
)
if (
quant_config.enable_weight_quantization
Expand Down
32 changes: 27 additions & 5 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData, QuantizationConfig
from jetstream_pt.third_party.llama import model_exportable as llama_model, model_args
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
from jetstream_pt.third_party.mistral import config as mistral_config, model as mistral_model


Mesh = jax.sharding.Mesh
Expand Down Expand Up @@ -359,7 +360,6 @@ def _insert_wrap(

start_insert = decode_state.current_position - prefix.seq_len
tokens = decode_state.tokens.at[slot].set(prefix.token)

start_insert = start_insert % self.env.cache_sequence_length
# pos < 0
update_indexes = (
Expand Down Expand Up @@ -641,12 +641,17 @@ def _load_from_safetensors(self, path):
def _load_from_state_dict(self, path):
state_dict = torch.load(path, map_location=torch.device("cpu"))
weights = {}
print(f"Loaded keys are : {state_dict.keys()}")
for key, model_weights in self.pt_model.state_dict().items():
if key == "freqs_cis":
continue
assert key in state_dict, f"key: {key} not found"
weights[key] = torchjax.from_torch(state_dict[key])
weights[key] = torch_xla2.tensor.t2j(state_dict[key])
assert tuple(model_weights.shape) == tuple(
weights[key].shape
), f"key: {key} error: {model_weights.shape} != {weights[key].shape}"

weights["freqs_cis"] = torch_xla2.tensor.t2j(self.pt_model.freqs_cis)
return weights

# pylint: disable-next=all
Expand Down Expand Up @@ -760,7 +765,7 @@ def create_pytorch_engine(
) -> PyTorchEngine:
"""Returns: The pytorch engine."""

supported_models = ["llama-2", "llama-3", "gemma"]
supported_models = ["llama-2", "llama-3", "gemma", "mistral"]
if model_name not in supported_models:
raise NotImplementedError(
f"Model name should be one of{','.join(supported_models)}"
Expand All @@ -772,7 +777,6 @@ def create_pytorch_engine(
jax.config.update("jax_traceback_filtering", "off")
torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
torch.set_default_dtype(torch_dtype)

checkpoint_format = ""
checkpoint_path = ""

Expand All @@ -797,8 +801,14 @@ def create_pytorch_engine(

pt_model = None

sharding_file_name = ""
if not sharding_config:
sharding_file_name = "llama" if model_name.startswith("llama") else "gemma"
if model_name.startswith("llama"):
sharding_file_name = "llama"
elif model_name.startswith("gemma"):
sharding_file_name = "gemma"
elif model_name.startswith("mistral"):
sharding_file_name = "mistral"
sharding_config = os.path.join(
"default_shardings", sharding_file_name + ".yaml"
)
Expand Down Expand Up @@ -851,6 +861,18 @@ def create_pytorch_engine(
env = JetEngineEnvironment(env_data)
print(f"Enviroment variables: {vars(env)}")
pt_model = gemma_model.GemmaModel(args, env)
elif model_name == "mistral":
args = mistral_config.ModelArgs.from_name("Mixtral-8x7B-v0.1")
args.device = "meta"
env_data.cache_shape = (
batch_size,
args.n_local_heads,
max_cache_length,
args.dim // args.n_head,
)
env_data.num_layers = args.n_layer
env = JetEngineEnvironment(env_data)
pt_model = mistral_model.Transformer(args, env)
else:
raise RuntimeError(f"Model with name {model_name} not found")

Expand Down
2 changes: 1 addition & 1 deletion jetstream_pt/third_party/llama/model_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,10 @@ def forward(
):
"""
tokens: the input token for decoding
input_pos: the decoding position relative to the start, which is the length of the decoding results
caches: kv caches
mask: causal mask to filter the attention results
start: the starting position for each slot
input_pos: the decoding position relative to the start, which is the length of the decoding results
ragged_batch_index: precomputed batch index for ragged attention
ragged_block_index: precomputed block index for ragged attention
"""
Expand Down
Empty file.
78 changes: 78 additions & 0 deletions jetstream_pt/third_party/mistral/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# pylint: disable-all
# # Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Mixtral model config
import dataclasses
from dataclasses import dataclass


def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)


@dataclass
class ModelArgs:
block_size: int = 2048
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
dim: int = 4096
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
num_experts: int = 8
num_activated_experts: int = 2
device: str = "meta"

def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head

@classmethod
def from_name(cls, name: str):
if name in transformer_configs:
return cls(**transformer_configs[name])
# fuzzy search
config = [
config
for config in transformer_configs
if config in str(name).upper() or config in str(name)
]
assert len(config) == 1, name
return cls(**transformer_configs[config[0]])


transformer_configs = {
"Mixtral-8x7B-v0.1": dict(
block_size=32768,
n_layer=32,
n_head=32,
n_local_heads=8,
dim=4096,
intermediate_size=14336,
rope_base=1000000.0,
num_experts=8,
num_activated_experts=2,
),
}
Loading
Loading