Skip to content

Commit

Permalink
install cli into a command
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Aug 30, 2024
1 parent 321f5aa commit 56da057
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 106 deletions.
129 changes: 43 additions & 86 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@ Commandline Flags might have changed between the release version to HEAD.
1. Ssh to Cloud TPU VM (using v5e-8 TPU VM)
a. Create a Cloud TPU VM if you haven’t
2. Download jetstream-pytorch github repo
3. Clone repo and install dependencies
4. Download and convert weights
5. Run checkpoint converter (quantizer)
6. Local run
7. Run the server
8. Run benchmarks
9. Typical Errors
3. Run the server
4. Run benchmarks
5. Typical Errors

# Ssh to Cloud TPU VM (using v5e-8 TPU VM)

Expand Down Expand Up @@ -49,108 +45,69 @@ cd jetstream-pytorch
source install_everything.sh
```

# Download and convert weights

## LLaMA
### Get official llama weights from meta-llama
# Run jetstream pytorch

Following instructions here:
* Llama-2: https://github.com/meta-llama/llama#download
* Llama-3: https://github.com/meta-llama/llama3/#download
## List out supported models

After you have downloaded the weights, it will also download a `tokenizer.model` file that is
the tokenizer that we will use.

## Gemma
### Get Gemma Checkpoint from HuggingFace

Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint.

```bash
# Install huggingface-cli and login if it's not set up.
pip install -U "huggingface_hub[cli]"
huggingface-cli login
huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir
```

## Mixtral
### Get Mixtral Checkpoint from HuggingFace

Please sign agreement on Huggingface website to access Mixtral checkpoints. Download Mixtral PyTorch checkpoint using huggingface-cli. Mixtral Tokenizer is included in the checkpoint.

```bash
huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir $input_ckpt_dir
jpt list
```

## Run weight safetensor convert
This will print out list of support models and variants:

There are limited support (only Llama models as of now) for accessing checkpoints on GCS. Accessing GCS takes a long time and therefore storing checkpoints to local is recommended.

```bash
export input_ckpt_dir=Original llama weights directory
export output_ckpt_dir=The output directory
export model_name="llama-3" # or "llama-2", "gemma", "mixtral"
export quantize_weights=True # Whether to quantize weights
export quantize_type="int8_per_channel" # "quantize_weights" needs to be turned on. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, "int8_per_channel" is the default option if not specified.
python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_type=$quantize_type --quantize_weights=$quantize_weights
```


# Local run

Set tokenizer path
```bash
export tokenizer_path=tokenizer model file path
meta-llama/Llama-2-7b-chat-hf
meta-llama/Llama-2-7b-hf
meta-llama/Llama-2-13b-chat-hf
meta-llama/Llama-2-13b-hf
meta-llama/Llama-2-70b-hf
meta-llama/Llama-2-70b-chat-hf
meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3-8B-Instruct
meta-llama/Meta-Llama-3-70B
meta-llama/Meta-Llama-3-70B-Instruct
google/gemma-2b
google/gemma-2b-it
google/gemma-7b
google/gemma-7b-it
mistralai/Mixtral-8x7B-v0.1
mistralai/Mixtral-8x7B-Instruct-v0.1
```

## Llama-2 7b
```bash
python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
```
To run jetstream-pytorch server with one model:

## Llama-2 13b
```bash
python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
```

## Llama-3 8b
```bash
python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
jpt serve --model_id --model_id meta-llama/Meta-Llama-3-8B-Instruct
```

## Llama-3 70b
```bash
python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
```
If it the first time you run this model, it will download weights from
HuggingFace.

## Gemma 7b
```bash
python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
```
HuggingFace's Llama3 weights are gated, so you need to either run
`huggingface-cli login` to set your token, OR, pass your hf_token explicitly.

## Mixtral 8x7b
```bash
python run_interactive.py --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
To pass hf token, add `--hf_token` flag
```
jpt serve --model_id --model_id meta-llama/Meta-Llama-3-8B-Instruct --hf_token=...
```

To login using huggingface hub, run:

# Run the server
Here is an example to run the server with llama2 7B config.

```bash
python run_server.py --model_name=$model_name --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml"
```
pip install -U "huggingface_hub[cli]"
huggingface-cli login
```
Then follow its prompt.

Now you can fire gRPC to it.
After the weights are downloaded,
Next time when you run this `--hf_token` will no longer be required.

Optional flags:
* `--shard_on_batch=1` This makes the model to shard on
the batch dimension. I.e. this runs in data parallel mode instead of model
parallel. This will ignore the sharding config. This is recommended for Gemma 2B
model, because Gemma 2B is small enough to fit on a single TPU chip.
To run this model in `int8` quantization, add `--quantize_weights=1`.
Quantization will be done on the flight as the weight loads.

* `--sharding_config=<path>` This makes use of alternative sharding config instead of
the ones in default_shardings directory.
Weights downloaded from HuggingFace will be stored by default in `checkpoints` folder.
in the place where `jpt` is executed.


# Run the server with ray
Expand Down
37 changes: 20 additions & 17 deletions jetstream_pt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def shard_weights(env, weights, weight_shardings):
sharded = {}
for key, val in weights.items():
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
print("SHARDING", key, sharding)
with jax.default_device(jax.devices("cpu")[0]):
arr = torch_xla2.tensor.t2j(val)

print("SHARDING", key, sharding)
arr = jax.device_put(arr, sharding)
sharded[key] = torchjax.to_torch(arr)
return sharded
Expand Down Expand Up @@ -207,22 +207,25 @@ def interactive():
print(tokenizer.decode(sampled_tokens_list))


def main(argv):
"""Entry point"""
if len(argv) < 2:
print("Invalid arguments. please specify 'list' or 'serve'")

if argv[1] == "list":
list_model()
elif argv[1] == "serve":
serve()
elif argv[1] == "interactive":
interactive()
else:
print(
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."
)
def main():
def main_real(argv):
"""Entry point"""
if len(argv) < 2:
print("Invalid arguments. please specify 'list' or 'serve'")

if argv[1] == "list":
list_model()
elif argv[1] == "serve":
serve()
elif argv[1] == "interactive":
interactive()
else:
print(
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."
)
app.run(main_real)
return 0


if __name__ == "__main__":
app.run(main)
main()
9 changes: 8 additions & 1 deletion jetstream_pt/fetch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,18 @@ class ModelInfo:
model_class: torch.nn.Module
# information needed to allocate cache
num_layers: int
# number of kv heads
num_heads: int

head_dim: int
n_reps: int # repeatition for GQA


_llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128, 1)
_llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128, 1)
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 4)
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 8)
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128, 4)
_llama3_70 = _llama2_70

_mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4)

Expand All @@ -59,8 +62,12 @@ class ModelInfo:
"meta-llama/Llama-2-7b-hf": _llama2_7,
"meta-llama/Llama-2-13b-chat-hf": _llama2_13,
"meta-llama/Llama-2-13b-hf": _llama2_13,
"meta-llama/Llama-2-70b-hf": _llama2_70,
"meta-llama/Llama-2-70b-chat-hf": _llama2_70,
"meta-llama/Meta-Llama-3-8B": _llama3_8,
"meta-llama/Meta-Llama-3-8B-Instruct": _llama3_8,
"meta-llama/Meta-Llama-3-70B": _llama3_70,
"meta-llama/Meta-Llama-3-70B-Instruct": _llama3_70,
"google/gemma-2b": _gemma_2b,
"google/gemma-2b-it": _gemma_2b,
"google/gemma-7b": _gemma_7b,
Expand Down
8 changes: 7 additions & 1 deletion jetstream_pt/third_party/llama/model_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,12 @@ def from_hf_model_id(cls, model_id, env):
"meta-llama/Llama-2-7b-hf": "llama-2-7b",
"meta-llama/Llama-2-13b-chat-hf": "llama-2-13b",
"meta-llama/Llama-2-13b-hf": "llama-2-13b",
"meta-llama/Llama-2-70b-hf": "llama-2-70b",
"meta-llama/Llama-2-70b-chat-hf": "llama-2-70b",
"meta-llama/Meta-Llama-3-8B": "llama-3-8b",
"meta-llama/Meta-Llama-3-8B-Instruct": "llama-3-8b",
"meta-llama/Meta-Llama-3-70B": "llama-3-70b",
"meta-llama/Meta-Llama-3-70B-Instruct": "llama-3-70b",
}.get(model_id)
assert name
args = model_args.get_model_args(
Expand Down Expand Up @@ -380,4 +384,6 @@ def transform(val, n_heads):
updated[key] = transform(
value, self.params.n_kv_heads or self.params.n_heads
)
return super().convert_hf_weights(updated)
res = super().convert_hf_weights(updated)
res['freqs_cis'] = self.freqs_cis
return res
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ dependencies = [
"google-jetstream @ {root:uri}/deps/JetStream",
]


requires-python = ">=3.10"
license = {file = "LICENSE"}

[project.scripts]
jpt = "jetstream_pt.cli:main"

[tool.hatch.metadata]
allow-direct-references = true
allow-direct-references = true

0 comments on commit 56da057

Please sign in to comment.