diff --git a/README.md b/README.md index ca6ec4b..065956a 100644 --- a/README.md +++ b/README.md @@ -14,13 +14,9 @@ Commandline Flags might have changed between the release version to HEAD. 1. Ssh to Cloud TPU VM (using v5e-8 TPU VM) a. Create a Cloud TPU VM if you haven’t 2. Download jetstream-pytorch github repo -3. Clone repo and install dependencies -4. Download and convert weights -5. Run checkpoint converter (quantizer) -6. Local run -7. Run the server -8. Run benchmarks -9. Typical Errors +3. Run the server +4. Run benchmarks +5. Typical Errors # Ssh to Cloud TPU VM (using v5e-8 TPU VM) @@ -49,108 +45,76 @@ cd jetstream-pytorch source install_everything.sh ``` -# Download and convert weights -## LLaMA -### Get official llama weights from meta-llama +# Run jetstream pytorch -Following instructions here: -* Llama-2: https://github.com/meta-llama/llama#download -* Llama-3: https://github.com/meta-llama/llama3/#download +## List out supported models -After you have downloaded the weights, it will also download a `tokenizer.model` file that is -the tokenizer that we will use. - -## Gemma -### Get Gemma Checkpoint from HuggingFace - -Please sign agreement on Huggingface website to access Gemma checkpoints. Download Gemma PyTorch checkpoint using huggingface-cli. Gemma Tokenizer is included in the checkpoint. - -```bash -# Install huggingface-cli and login if it's not set up. -pip install -U "huggingface_hub[cli]" -huggingface-cli login -huggingface-cli download google/gemma-7b-pytorch --local-dir $input_ckpt_dir ``` - -## Mixtral -### Get Mixtral Checkpoint from HuggingFace - -Please sign agreement on Huggingface website to access Mixtral checkpoints. Download Mixtral PyTorch checkpoint using huggingface-cli. Mixtral Tokenizer is included in the checkpoint. - -```bash -huggingface-cli download mistralai/Mixtral-8x7B-v0.1 --local-dir $input_ckpt_dir +jpt list ``` -## Run weight safetensor convert +This will print out list of support models and variants: -There are limited support (only Llama models as of now) for accessing checkpoints on GCS. Accessing GCS takes a long time and therefore storing checkpoints to local is recommended. - -```bash -export input_ckpt_dir=Original llama weights directory -export output_ckpt_dir=The output directory -export model_name="llama-3" # or "llama-2", "gemma", "mixtral" -export quantize_weights=True # Whether to quantize weights -export quantize_type="int8_per_channel" # "quantize_weights" needs to be turned on. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"}, "int8_per_channel" is the default option if not specified. -python -m convert_checkpoints --model_name=$model_name --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize_type=$quantize_type --quantize_weights=$quantize_weights +``` +meta-llama/Llama-2-7b-chat-hf +meta-llama/Llama-2-7b-hf +meta-llama/Llama-2-13b-chat-hf +meta-llama/Llama-2-13b-hf +meta-llama/Llama-2-70b-hf +meta-llama/Llama-2-70b-chat-hf +meta-llama/Meta-Llama-3-8B +meta-llama/Meta-Llama-3-8B-Instruct +meta-llama/Meta-Llama-3-70B +meta-llama/Meta-Llama-3-70B-Instruct +google/gemma-2b +google/gemma-2b-it +google/gemma-7b +google/gemma-7b-it +mistralai/Mixtral-8x7B-v0.1 +mistralai/Mixtral-8x7B-Instruct-v0.1 ``` +To run jetstream-pytorch server with one model: -# Local run - -Set tokenizer path -```bash -export tokenizer_path=tokenizer model file path ``` - -## Llama-2 7b -```bash -python run_interactive.py --size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml +jpt serve --model_id meta-llama/Meta-Llama-3-8B-Instruct ``` -## Llama-2 13b -```bash -python run_interactive.py --size=13b --model_name=$model_name --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml -``` +If it's the first time you run this model, it will download weights from +HuggingFace. -## Llama-3 8b -```bash -python run_interactive.py --size=8b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml -``` +HuggingFace's Llama3 weights are gated, so you need to either run +`huggingface-cli login` to set your token, OR, pass your hf_token explicitly. -## Llama-3 70b -```bash -python run_interactive.py --size=70b --model_name=$model_name --batch_size=8 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/llama.yaml +To pass hf token explicitly, add `--hf_token` flag ``` - -## Gemma 7b -```bash -python run_interactive.py --model_name=$model_name --size=7b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml +jpt serve --model_id meta-llama/Meta-Llama-3-8B-Instruct --hf_token=... ``` -## Mixtral 8x7b -```bash -python run_interactive.py --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config=default_shardings/$model_name.yaml -``` +To login using huggingface hub, run: +``` +pip install -U "huggingface_hub[cli]" +huggingface-cli login +``` +Then follow its prompt. -# Run the server -Here is an example to run the server with llama2 7B config. +After the weights are downloaded, +Next time when you run this `--hf_token` will no longer be required. -```bash -python run_server.py --model_name=$model_name --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize_weights --quantize_type=$quantize_type --quantize_kv_cache=$quantize_weights --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml" -``` +To run this model in `int8` quantization, add `--quantize_weights=1`. +Quantization will be done on the flight as the weight loads. -Now you can fire gRPC to it. +Weights downloaded from HuggingFace will be stored by default in `checkpoints` folder. +in the place where `jpt` is executed. -Optional flags: -* `--shard_on_batch=1` This makes the model to shard on - the batch dimension. I.e. this runs in data parallel mode instead of model - parallel. This will ignore the sharding config. This is recommended for Gemma 2B - model, because Gemma 2B is small enough to fit on a single TPU chip. +You can change where the weights are stored with `--working_dir` flag. -* `--sharding_config=` This makes use of alternative sharding config instead of - the ones in default_shardings directory. +If you wish to use your own checkpoint, then, place them inside +of the `checkpoints///hf_original` dir (or the corresponding subdir in `--working_dir`). For example, +Llama3 checkpoints will be at `checkpoints/meta-llama/Llama-2-7b-hf/hf_original/*.safetensors`. You can replace these files with modified +weights in HuggingFace format. # Run the server with ray diff --git a/jetstream_pt/cli.py b/jetstream_pt/cli.py index 76dcace..ce49d55 100644 --- a/jetstream_pt/cli.py +++ b/jetstream_pt/cli.py @@ -33,10 +33,10 @@ def shard_weights(env, weights, weight_shardings): sharded = {} for key, val in weights.items(): sharding = env.sharding_by_axis(weight_shardings.get(key, -1)) + print("SHARDING", key, sharding) with jax.default_device(jax.devices("cpu")[0]): arr = torch_xla2.tensor.t2j(val) - print("SHARDING", key, sharding) arr = jax.device_put(arr, sharding) sharded[key] = torchjax.to_torch(arr) return sharded @@ -207,22 +207,28 @@ def interactive(): print(tokenizer.decode(sampled_tokens_list)) -def main(argv): - """Entry point""" - if len(argv) < 2: - print("Invalid arguments. please specify 'list' or 'serve'") +def main(): + """Main function.""" - if argv[1] == "list": - list_model() - elif argv[1] == "serve": - serve() - elif argv[1] == "interactive": - interactive() - else: - print( - "Invalid arguments. please specify 'list', 'serve', or 'interactive'." - ) + def main_real(argv): + """Entry point""" + if len(argv) < 2: + print("Invalid arguments. please specify 'list' or 'serve'") + + if argv[1] == "list": + list_model() + elif argv[1] == "serve": + serve() + elif argv[1] == "interactive": + interactive() + else: + print( + "Invalid arguments. please specify 'list', 'serve', or 'interactive'." + ) + + app.run(main_real) + return 0 if __name__ == "__main__": - app.run(main) + main() diff --git a/jetstream_pt/fetch_models.py b/jetstream_pt/fetch_models.py index 6786b51..c3e2312 100644 --- a/jetstream_pt/fetch_models.py +++ b/jetstream_pt/fetch_models.py @@ -38,15 +38,18 @@ class ModelInfo: model_class: torch.nn.Module # information needed to allocate cache num_layers: int - num_heads: int + # number of kv heads + num_kv_heads: int + head_dim: int n_reps: int # repeatition for GQA _llama2_7 = ModelInfo(llama_model.Transformer, 32, 32, 128, 1) _llama2_13 = ModelInfo(llama_model.Transformer, 40, 40, 128, 1) -_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 4) +_llama2_70 = ModelInfo(llama_model.Transformer, 80, 8, 128, 8) _llama3_8 = ModelInfo(llama_model.Transformer, 32, 8, 128, 4) +_llama3_70 = _llama2_70 _mixtral_87 = ModelInfo(mixtral_model.Transformer, 32, 8, 128, 4) @@ -59,8 +62,12 @@ class ModelInfo: "meta-llama/Llama-2-7b-hf": _llama2_7, "meta-llama/Llama-2-13b-chat-hf": _llama2_13, "meta-llama/Llama-2-13b-hf": _llama2_13, + "meta-llama/Llama-2-70b-hf": _llama2_70, + "meta-llama/Llama-2-70b-chat-hf": _llama2_70, "meta-llama/Meta-Llama-3-8B": _llama3_8, "meta-llama/Meta-Llama-3-8B-Instruct": _llama3_8, + "meta-llama/Meta-Llama-3-70B": _llama3_70, + "meta-llama/Meta-Llama-3-70B-Instruct": _llama3_70, "google/gemma-2b": _gemma_2b, "google/gemma-2b-it": _gemma_2b, "google/gemma-7b": _gemma_7b, @@ -132,7 +139,7 @@ def construct_env_data_from_model_id( ) env_data.cache_shape = ( batch_size, - model_info.num_heads, + model_info.num_kv_heads, max_cache_length, model_info.head_dim, ) diff --git a/jetstream_pt/third_party/llama/model_exportable.py b/jetstream_pt/third_party/llama/model_exportable.py index 791ff7a..7cebeb5 100644 --- a/jetstream_pt/third_party/llama/model_exportable.py +++ b/jetstream_pt/third_party/llama/model_exportable.py @@ -347,8 +347,12 @@ def from_hf_model_id(cls, model_id, env): "meta-llama/Llama-2-7b-hf": "llama-2-7b", "meta-llama/Llama-2-13b-chat-hf": "llama-2-13b", "meta-llama/Llama-2-13b-hf": "llama-2-13b", + "meta-llama/Llama-2-70b-hf": "llama-2-70b", + "meta-llama/Llama-2-70b-chat-hf": "llama-2-70b", "meta-llama/Meta-Llama-3-8B": "llama-3-8b", "meta-llama/Meta-Llama-3-8B-Instruct": "llama-3-8b", + "meta-llama/Meta-Llama-3-70B": "llama-3-70b", + "meta-llama/Meta-Llama-3-70B-Instruct": "llama-3-70b", }.get(model_id) assert name args = model_args.get_model_args( @@ -380,4 +384,6 @@ def transform(val, n_heads): updated[key] = transform( value, self.params.n_kv_heads or self.params.n_heads ) - return super().convert_hf_weights(updated) + res = super().convert_hf_weights(updated) + res["freqs_cis"] = self.freqs_cis + return res diff --git a/pyproject.toml b/pyproject.toml index c9f3568..2000ca3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,8 +18,12 @@ dependencies = [ "google-jetstream @ {root:uri}/deps/JetStream", ] + requires-python = ">=3.10" license = {file = "LICENSE"} +[project.scripts] +jpt = "jetstream_pt.cli:main" + [tool.hatch.metadata] -allow-direct-references = true \ No newline at end of file +allow-direct-references = true