Skip to content

Commit

Permalink
fix: utilities to post process checkpoint for LoRA (#338)
Browse files Browse the repository at this point in the history
* utilities to post process checkpoint for LoRA

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* improve code comments

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* Add unit test and fix some lint errors

Signed-off-by: Angel Luu <angel.luu@us.ibm.com>

* lint: fix more fmt errors

Signed-off-by: Angel Luu <angel.luu@us.ibm.com>

* feat: Add post_process_vLLM_adapters_new_tokens function to main

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* fmt

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* fix: Add post processing flag so post processing is only done for vLLM

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* fix: get num_added_tokens from resize function (#344)

* get num_added_tokens

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* remove extra code

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

---------

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* Ran fmt and also removed unneccessary files from test artifact

Signed-off-by: Angel Luu <angel.luu@us.ibm.com>

* fix: unit tests

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix: Adding tokens in special_tokens_dict

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: Add additional arg to tests to reflect new flag post_process_vllm

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* fmt

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* feat: Refactor post-processing of adapters (#345)

* refactor saving tokens metadata

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* remove extra check

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* post processing script

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* post processing script

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix: unit test args

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* undo post_process_vLLm flag

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

---------

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* add test for LoRA tuning from main

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix formatting

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* correcting post processing script

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix:post-process in place

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* update documentation for post-processing

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix:formatting

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix:linting

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* more warnings /exceptions in script

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* check for no tokens added

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix:linting

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* additional unit test

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* add more tests

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix:tokenizer test

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix:linting and docstrings

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix:return type of trainer

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* test: enable tests and fix copytree

Signed-off-by: Anh Uong <anh.uong@ibm.com>

* use copy function from build

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* fix:linting and formatting

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* make build a module

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* add back old copy function

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

---------

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
Signed-off-by: Angel Luu <angel.luu@us.ibm.com>
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Signed-off-by: Anh Uong <anh.uong@ibm.com>
Co-authored-by: Angel Luu <angel.luu@us.ibm.com>
Co-authored-by: Will Johnson <mwjohnson728@gmail.com>
Co-authored-by: Abhishek <maurya.abhishek@ibm.com>
Co-authored-by: Anh Uong <anh.uong@ibm.com>
  • Loading branch information
5 people authored Sep 25, 2024
1 parent c0c4355 commit 7714dfc
Show file tree
Hide file tree
Showing 16 changed files with 97,521 additions and 27 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,35 @@ Example 3:

</details>

#### Post-processing needed for inference on VLLM

In order to run inference of LoRA adapters on vLLM, any new token embeddings added while tuning needs to be moved out of 'adapters.safetensors' to a new file 'new_embeddings.safetensors'. The 'adapters.safetensors' should only have LoRA weights and should not have modified embedding vectors. This is a requirement to support vLLM's paradigm that one base model can serve multiple adapters. New token embedding vectors are appended to the embedding matrix read from the base model by vLLM.

To do this postprocessing, the tuning script sft_trainer.py will generate a file 'added_tokens_info.json' with model artifacts. After tuning, you can run script 'post_process_adapters_vLLM.py' :

```bash
# model_path: Path to saved model artifacts which has file 'added_tokens_info.json'
# output_model_path: Optional. If you want to store modified \
# artifacts in a different directory rather than modify in-place.
python scripts/post_process_adapters_vLLM.py \
--model_path "/testing/tuning/output/post-process-LoRA-saved" \
--output_model_path "/testing/tuning/output/post-process-LoRA-modified"
```

<details>
<summary> Alternatively, if using SDK :</summary>

```bash
# function in tuning/utils/merge_model_utils.py
post_process_vLLM_adapters_new_tokens(
path_to_checkpoint="/testing/tuning/output/post-process-LoRA-saved",
modified_checkpoint_path=None,
num_added_tokens=1,
)
# where num_added_tokens is returned by sft_trainer.train()
```
</details>

_________________________


Expand Down
6 changes: 5 additions & 1 deletion build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@
import shutil


def copy_checkpoint(source, destination):
def copy_checkpoint(source, destination, exclude_files: list[str] = None):
if not os.path.exists(destination):
os.makedirs(destination)
shutil.copystat(source, destination)
# Have a list of directory objects, now iterate over them.
if exclude_files is None:
exclude_files = []
for item in os.listdir(source):
if item in exclude_files:
continue
source_file = os.path.join(source, item)
destination_file = os.path.join(destination, item)
if os.path.isdir(source_file):
Expand Down
94 changes: 94 additions & 0 deletions scripts/post_process_adapters_vLLM.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
""" Script to post-process tuned LoRA adapters for inference on vLLM.
vLLM requires that any token embeddings added while tuning be moved to a new file \
called new_embeddings.safetensors. \
See the description in utility function \
/tuning/utils/merge_model_utils/post_process_vLLM_adapters_new_tokens for more details.
This script takes a path to tuned model artifacts containing adapters \
(or checkpoints with adapters) and the file 'added_tokens_info.json' produced while tuning. \
It will perform the post-processing as needed for inferencing on vLLM.
"""
# Standard
import argparse
import json
import logging
import os
import sys

# Local
from tuning.utils.merge_model_utils import (
copy_files_to_directory,
post_process_vLLM_adapters_new_tokens,
)


### Main & arg parsing
def main():
parser = argparse.ArgumentParser(
description="Post processes LoRA adapters due to addition of new tokens, as needed by vLLM"
)
parser.add_argument(
"--model_path",
help="Path to tuned model containing either one or multiple checkpoints. \
Path should have file added_tokens_info.json produced by tuning. \
Hint: This will be either output_dir or save_model_dir arguments while tuning. \
If multiple checkpoints are present, each checkpoint folder name \
should begin with 'checkpoint-'",
required=True,
)
parser.add_argument(
"--output_model_path",
help="Output directory where post-processed artifacts will be stored. \
If not provided, artifacts will be modified in place",
default=None,
)
args = parser.parse_args()

if args.output_model_path is None:
output_model_path = args.model_path
else:
output_model_path = args.output_model_path
if os.path.exists(os.path.join(args.model_path, "added_tokens_info.json")):
with open(
os.path.join(args.model_path, "added_tokens_info.json"), encoding="utf-8"
) as json_data:
added_tokens_info = json.load(json_data)
num_added_tokens = added_tokens_info["num_new_tokens"]
else:
raise ValueError(
"file added_tokens_info.json not in model_path. \
Cannot post-processes"
)
if num_added_tokens == 0:
logging.info("No new tokens added, hence post-processing not needed")
sys.exit(0)

found_adapters = 0
if os.path.exists(os.path.join(args.model_path, "adapter_model.safetensors")):
found_adapters = 1
post_process_vLLM_adapters_new_tokens(
args.model_path, output_model_path, num_added_tokens
)
# if multiple checkpoints in directory, process each checkpoint
found_checkpoints = 0
for _, dirs, _ in os.walk(args.model_path, topdown=False):
for name in dirs:
if "checkpoint-" in name.lower():
post_process_vLLM_adapters_new_tokens(
os.path.join(args.model_path, name),
os.path.join(output_model_path, name),
num_added_tokens,
)
found_checkpoints = 1
if found_checkpoints and output_model_path != args.model_path:
copy_files_to_directory(
args.model_path,
output_model_path,
exclude_files=["adapter_model.safetensors"],
)
if not found_adapters and not found_checkpoints:
logging.warning("No adapters were found to process in model path provided")


if __name__ == "__main__":
main()
29 changes: 29 additions & 0 deletions tests/artifacts/tuned_llama_with_added_tokens/adapter_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"alpha_pattern": {},
"auto_mapping": null,
"base_model_name_or_path": "Maykeye/TinyLLama-v0",
"bias": "none",
"fan_in_fan_out": false,
"inference_mode": true,
"init_lora_weights": true,
"layer_replication": null,
"layers_pattern": null,
"layers_to_transform": null,
"loftq_config": {},
"lora_alpha": 32,
"lora_dropout": 0.05,
"megatron_config": null,
"megatron_core": "megatron.core",
"modules_to_save": null,
"peft_type": "LORA",
"r": 8,
"rank_pattern": {},
"revision": null,
"target_modules": [
"v_proj",
"q_proj"
],
"task_type": "CAUSAL_LM",
"use_dora": false,
"use_rslora": false
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"<pad>": 32000
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"bos_token": {
"content": "<s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
}
Loading

0 comments on commit 7714dfc

Please sign in to comment.