Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update README for new CLI #178

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 49 additions & 85 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@ Commandline Flags might have changed between the release version to HEAD.
1. Ssh to Cloud TPU VM (using v5e-8 TPU VM)
a. Create a Cloud TPU VM if you haven’t
2. Download jetstream-pytorch github repo
3. Clone repo and install dependencies
4. Download and convert weights
5. Run checkpoint converter (quantizer)
6. Local run
7. Run the server
8. Run benchmarks
9. Typical Errors
3. Run the server
4. Run benchmarks
5. Typical Errors

# Ssh to Cloud TPU VM (using v5e-8 TPU VM)

Expand Down Expand Up @@ -49,108 +45,76 @@ cd jetstream-pytorch
source install_everything.sh
```

# Download and convert weights

## LLaMA
### Get official llama weights from meta-llama
# Run jetstream pytorch

Following instructions here:
* Llama-2: https://github.com/meta-llama/llama#download
* Llama-3: https://github.com/meta-llama/llama3/#download
## List out supported models

After you have downloaded the weights, it will also download a `tokenizer.model` file that is
the tokenizer that we will use.

## Gemma
### Get Gemma Checkpoint from HuggingFace

Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint.

```bash
# Install huggingface-cli and login if it's not set up.
pip install -U "huggingface_hub[cli]"
huggingface-cli login
huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir
```

## Mixtral
### Get Mixtral Checkpoint from HuggingFace

Please sign agreement on Huggingface website to access Mixtral checkpoints. Download Mixtral PyTorch checkpoint using huggingface-cli. Mixtral Tokenizer is included in the checkpoint.

```bash
huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir $input_ckpt_dir
jpt list
```

## Run weight safetensor convert
This will print out list of support models and variants:

There are limited support (only Llama models as of now) for accessing checkpoints on GCS. Accessing GCS takes a long time and therefore storing checkpoints to local is recommended.

```bash
export input_ckpt_dir=Original llama weights directory
export output_ckpt_dir=The output directory
export model_name="llama-3" # or "llama-2", "gemma", "mixtral"
export quantize_weights=True # Whether to quantize weights
export quantize_type="int8_per_channel" # "quantize_weights" needs to be turned on. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, "int8_per_channel" is the default option if not specified.
python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_type=$quantize_type --quantize_weights=$quantize_weights
```
meta-llama/Llama-2-7b-chat-hf
meta-llama/Llama-2-7b-hf
meta-llama/Llama-2-13b-chat-hf
meta-llama/Llama-2-13b-hf
meta-llama/Llama-2-70b-hf
meta-llama/Llama-2-70b-chat-hf
meta-llama/Meta-Llama-3-8B
meta-llama/Meta-Llama-3-8B-Instruct
meta-llama/Meta-Llama-3-70B
meta-llama/Meta-Llama-3-70B-Instruct
google/gemma-2b
google/gemma-2b-it
google/gemma-7b
google/gemma-7b-it
mistralai/Mixtral-8x7B-v0.1
mistralai/Mixtral-8x7B-Instruct-v0.1
```

To run jetstream-pytorch server with one model:

# Local run

Set tokenizer path
```bash
export tokenizer_path=tokenizer model file path
```

## Llama-2 7b
```bash
python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
jpt serve --model_id meta-llama/Meta-Llama-3-8B-Instruct
```

## Llama-2 13b
```bash
python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
```
If it's the first time you run this model, it will download weights from
HuggingFace.

## Llama-3 8b
```bash
python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
```
HuggingFace's Llama3 weights are gated, so you need to either run
`huggingface-cli login` to set your token, OR, pass your hf_token explicitly.

## Llama-3 70b
```bash
python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml
To pass hf token explicitly, add `--hf_token` flag
```

## Gemma 7b
```bash
python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
jpt serve --model_id meta-llama/Meta-Llama-3-8B-Instruct --hf_token=...
```

## Mixtral 8x7b
```bash
python run_interactive.py --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml
```
To login using huggingface hub, run:

```
pip install -U "huggingface_hub[cli]"
huggingface-cli login
```
Then follow its prompt.

# Run the server
Here is an example to run the server with llama2 7B config.
After the weights are downloaded,
Next time when you run this `--hf_token` will no longer be required.

```bash
python run_server.py --model_name=$model_name --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml"
```
To run this model in `int8` quantization, add `--quantize_weights=1`.
Quantization will be done on the flight as the weight loads.

Now you can fire gRPC to it.
Weights downloaded from HuggingFace will be stored by default in `checkpoints` folder.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we the options to store weight separately? Even we have problem storing the weights in gcp vm directly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For gs bucket it need to be brought locally or use mount using Fuse.

The working dir can be edited. Added paragraph to describe that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll be great if you can add how to change the working dir. Cuz for us, we also need to direct to the external ssd. I will approve the PR to unblock you for now.

in the place where `jpt` is executed.

Optional flags:
* `--shard_on_batch=1` This makes the model to shard on
the batch dimension. I.e. this runs in data parallel mode instead of model
parallel. This will ignore the sharding config. This is recommended for Gemma 2B
model, because Gemma 2B is small enough to fit on a single TPU chip.
You can change where the weights are stored with `--working_dir` flag.

* `--sharding_config=<path>` This makes use of alternative sharding config instead of
the ones in default_shardings directory.
If you wish to use your own checkpoint, then, place them inside
of the `checkpoints/<org>/<model>/hf_original` dir (or the corresponding subdir in `--working_dir`). For example,
Llama3 checkpoints will be at `checkpoints/meta-llama/Llama-2-7b-hf/hf_original/*.safetensors`. You can replace these files with modified
weights in HuggingFace format.


# Run the server with ray
Expand Down
38 changes: 22 additions & 16 deletions jetstream_pt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def shard_weights(env, weights, weight_shardings):
sharded = {}
for key, val in weights.items():
sharding = env.sharding_by_axis(weight_shardings.get(key, -1))
print("SHARDING", key, sharding)
with jax.default_device(jax.devices("cpu")[0]):
arr = torch_xla2.tensor.t2j(val)

print("SHARDING", key, sharding)
arr = jax.device_put(arr, sharding)
sharded[key] = torchjax.to_torch(arr)
return sharded
Expand Down Expand Up @@ -207,22 +207,28 @@ def interactive():
print(tokenizer.decode(sampled_tokens_list))


def main(argv):
"""Entry point"""
if len(argv) < 2:
print("Invalid arguments. please specify 'list' or 'serve'")
def main():
"""Main function."""

if argv[1] == "list":
list_model()
elif argv[1] == "serve":
serve()
elif argv[1] == "interactive":
interactive()
else:
print(
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."
)
def main_real(argv):
"""Entry point"""
if len(argv) < 2:
print("Invalid arguments. please specify 'list' or 'serve'")

if argv[1] == "list":
list_model()
elif argv[1] == "serve":
serve()
elif argv[1] == "interactive":
interactive()
else:
print(
"Invalid arguments. please specify 'list', 'serve', or 'interactive'."
)

app.run(main_real)
return 0


if __name__ == "__main__":
app.run(main)
main()
13 changes: 10 additions & 3 deletions jetstream_pt/fetch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,18 @@ class ModelInfo:
model_class: torch.nn.Module
# information needed to allocate cache
num_layers: int
num_heads: int
# number of kv heads
num_kv_heads: int

head_dim: int
n_reps: int # repeatition for GQA


_llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128, 1)
_llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128, 1)
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 4)
_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 8)
_llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128, 4)
_llama3_70 = _llama2_70

_mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4)

Expand All @@ -59,8 +62,12 @@ class ModelInfo:
"meta-llama/Llama-2-7b-hf": _llama2_7,
"meta-llama/Llama-2-13b-chat-hf": _llama2_13,
"meta-llama/Llama-2-13b-hf": _llama2_13,
"meta-llama/Llama-2-70b-hf": _llama2_70,
"meta-llama/Llama-2-70b-chat-hf": _llama2_70,
"meta-llama/Meta-Llama-3-8B": _llama3_8,
"meta-llama/Meta-Llama-3-8B-Instruct": _llama3_8,
"meta-llama/Meta-Llama-3-70B": _llama3_70,
"meta-llama/Meta-Llama-3-70B-Instruct": _llama3_70,
"google/gemma-2b": _gemma_2b,
"google/gemma-2b-it": _gemma_2b,
"google/gemma-7b": _gemma_7b,
Expand Down Expand Up @@ -132,7 +139,7 @@ def construct_env_data_from_model_id(
)
env_data.cache_shape = (
batch_size,
model_info.num_heads,
model_info.num_kv_heads,
max_cache_length,
model_info.head_dim,
)
Expand Down
8 changes: 7 additions & 1 deletion jetstream_pt/third_party/llama/model_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,12 @@ def from_hf_model_id(cls, model_id, env):
"meta-llama/Llama-2-7b-hf": "llama-2-7b",
"meta-llama/Llama-2-13b-chat-hf": "llama-2-13b",
"meta-llama/Llama-2-13b-hf": "llama-2-13b",
"meta-llama/Llama-2-70b-hf": "llama-2-70b",
"meta-llama/Llama-2-70b-chat-hf": "llama-2-70b",
"meta-llama/Meta-Llama-3-8B": "llama-3-8b",
"meta-llama/Meta-Llama-3-8B-Instruct": "llama-3-8b",
"meta-llama/Meta-Llama-3-70B": "llama-3-70b",
"meta-llama/Meta-Llama-3-70B-Instruct": "llama-3-70b",
}.get(model_id)
assert name
args = model_args.get_model_args(
Expand Down Expand Up @@ -380,4 +384,6 @@ def transform(val, n_heads):
updated[key] = transform(
value, self.params.n_kv_heads or self.params.n_heads
)
return super().convert_hf_weights(updated)
res = super().convert_hf_weights(updated)
res["freqs_cis"] = self.freqs_cis
return res
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ dependencies = [
"google-jetstream @ {root:uri}/deps/JetStream",
]


requires-python = ">=3.10"
license = {file = "LICENSE"}

[project.scripts]
jpt = "jetstream_pt.cli:main"

[tool.hatch.metadata]
allow-direct-references = true
allow-direct-references = true
Loading