Skip to content

Commit

Permalink
feat: Add post_process_vLLM_adapters_new_tokens function to main
Browse files Browse the repository at this point in the history
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
  • Loading branch information
willmj committed Sep 18, 2024
1 parent 4c9bb95 commit af191d1
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import transformers

# Local
from build.utils import get_highest_checkpoint
from tuning.config import configs, peft_config
from tuning.config.acceleration_configs import (
AccelerationFrameworkConfig,
Expand Down Expand Up @@ -68,7 +69,9 @@
is_pretokenized_dataset,
validate_data_args,
)

from tuning.utils.merge_model_utils import(
post_process_vLLM_adapters_new_tokens,
)

def train(
model_args: configs.ModelArguments,
Expand Down Expand Up @@ -633,6 +636,22 @@ def main():
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)


# post process lora
if isinstance(tune_config, peft_config.LoraConfig):
try:
checkpoint_dir = job_config.get("save_model_dir")
if not checkpoint_dir:
checkpoint_dir = os.path.join(
training_args.output_dir, get_highest_checkpoint(training_args.output_dir)
)
print(training_args)
print(f"Post processing LoRA adapters in {checkpoint_dir}")
post_process_vLLM_adapters_new_tokens(path_to_checkpoint=checkpoint_dir)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(f"Exception encountered while lora post-processing model: {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

if __name__ == "__main__":
main()

0 comments on commit af191d1

Please sign in to comment.