Skip to content

Commit

Permalink
feat: Add post processing logic to accelerate launch (#351)
Browse files Browse the repository at this point in the history
* feat: Add post processing logic to accelerate launch

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* docs: Add post processing arg and explanation

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* test: Add test for post processing in accelerate launch

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* fix: Remove comma from example

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

* fix: small changes from review

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>

---------

Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
  • Loading branch information
willmj authored Sep 25, 2024
1 parent 7714dfc commit 8676d01
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 2 deletions.
6 changes: 4 additions & 2 deletions build/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ For example, the below config is used for running with two GPUs and FSDP for fin
"per_device_train_batch_size": 4,
"learning_rate": 1e-5,
"response_template": "\n### Label:",
"dataset_text_field": "output"
"dataset_text_field": "output",
"lora_post_process_for_vllm": true
}
```

Users should always set `num_processes` to be explicit about the number of processes to run tuning on. When `num_processes` is greater than 1, the [FSDP config](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/fixtures/accelerate_fsdp_defaults.yaml) is used by default. Thus in the above example, you don't need to pass in the FSDP flags since they match the ones used in the default FSDP config. You can also set your own default values by specifying your own config file using key `config_file`. Any of these values in configs can be overwritten by passing in flags via `accelerate_launch_args` in the JSON config.
`num_processes` defaults to the amount of GPUs allocated for tuning, unless the user sets `SET_NUM_PROCESSES_TO_NUM_GPUS` to `False`. When `num_processes` is greater than 1, the [FSDP config](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/fixtures/accelerate_fsdp_defaults.yaml) is used by default. Thus in the above example, you don't need to pass in the FSDP flags since they match the ones used in the default FSDP config. You can also set your own default values by specifying your own config file using key `config_file`. Any of these values in configs can be overwritten by passing in flags via `accelerate_launch_args` in the JSON config.

Note that `num_processes` which is the total number of processes to be launched in parallel, should match the number of GPUs to run on. The number of GPUs used can also be set by setting environment variable `CUDA_VISIBLE_DEVICES`. If ``num_processes=1`, the script will assume single-GPU.

If tuning for inference on vLLM, set `lora_post_process_for_vllm` to `true`. Post process LoRA adapters to allow inferencing on vLLM. vLLM needs new token embedding weights added during tuning to be moved to a new file new_embeddings.safetensors.

## Building the Image

Expand Down
53 changes: 53 additions & 0 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import subprocess
import sys
import traceback
import json
from pathlib import Path

# Third Party
Expand All @@ -32,6 +33,9 @@
from build.utils import (
process_accelerate_launch_args,
)
from tuning.utils.merge_model_utils import (
post_process_vLLM_adapters_new_tokens,
)
from tuning.utils.config_utils import get_json_config
from tuning.utils.error_logging import (
write_termination_log,
Expand Down Expand Up @@ -115,6 +119,55 @@ def main():
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

peft_method = job_config.get("peft_method")

if job_config.get("lora_post_process_for_vllm") and peft_method == "lora":
save_model_dir = job_config.get("save_model_dir")
if save_model_dir:
if os.path.exists(os.path.join(save_model_dir, "added_tokens_info.json")):
with open(
os.path.join(save_model_dir, "added_tokens_info.json"),
encoding="utf-8",
) as json_data:
added_tokens_info = json.load(json_data)
num_added_tokens = added_tokens_info["num_new_tokens"]
else:
logging.warning(
"Failed to post-process: file added_tokens_info.json not in path %s",
save_model_dir,
)

if os.path.exists(
os.path.join(save_model_dir, "adapter_model.safetensors")
):
post_process_vLLM_adapters_new_tokens(
save_model_dir, save_model_dir, num_added_tokens
)

if (
os.path.exists(os.path.join(output_dir, "added_tokens_info.json"))
and job_config.get("save_strategy") != "no"
):
with open(
os.path.join(output_dir, "added_tokens_info.json"), encoding="utf-8"
) as json_data:
added_tokens_info = json.load(json_data)
num_added_tokens = added_tokens_info["num_new_tokens"]
# if multiple checkpoints in directory, process each checkpoint
for _, dirs, _ in os.walk(output_dir, topdown=False):
for name in dirs:
if "checkpoint-" in name.lower():
post_process_vLLM_adapters_new_tokens(
os.path.join(output_dir, name),
os.path.join(output_dir, name),
num_added_tokens,
)
else:
logging.warning(
"Failed to post-process: file added_tokens_info.json not in path %s",
save_model_dir,
)

# The .complete file will signal to users that we are finished copying
# files over
if os.path.exists(output_dir):
Expand Down
32 changes: 32 additions & 0 deletions tests/build/test_launch_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,38 @@ def test_lora_save_model_dir_same_dir_as_output_dir_save_strategy_no():
assert len(checkpoints) == 0


def test_lora_with_lora_post_process_for_vllm_set_to_true():
with tempfile.TemporaryDirectory() as tempdir:
setup_env(tempdir)
TRAIN_KWARGS = {
**BASE_LORA_KWARGS,
**{
"output_dir": tempdir,
"save_model_dir": tempdir,
"lora_post_process_for_vllm": True,
},
}
serialized_args = serialize_args(TRAIN_KWARGS)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args

assert main() == 0
# check that model and logs exists in output_dir
_validate_termination_files_when_tuning_succeeds(tempdir)
_validate_training_output(tempdir, "lora")

for _, dirs, _ in os.walk(tempdir, topdown=False):
for name in dirs:
if "checkpoint-" in name.lower():
new_embeddings_file_path = os.path.join(
tempdir, name, "new_embeddings.safetensors"
)
assert os.path.exists(new_embeddings_file_path)

# check for new_embeddings.safetensors
new_embeddings_file_path = os.path.join(tempdir, "new_embeddings.safetensors")
assert os.path.exists(new_embeddings_file_path)


def test_bad_script_path():
"""Check for appropriate error for an invalid training script location"""
with tempfile.TemporaryDirectory() as tempdir:
Expand Down

0 comments on commit 8676d01

Please sign in to comment.