Skip to content

Commit

Permalink
feat: Add support for smoothly resuming training from a saved checkpo…
Browse files Browse the repository at this point in the history
…int (#300)

* Add feature of resume training

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* Remove lastcheckpoints conditions

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* PR Changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* feat:resume tuning based on value from user's flag

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* test:added unit tests for resume tuning feature

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* test: PR changes of resume from checkpoint feature

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: Modified test fn descripts, added readme

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

---------

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Co-authored-by: Anh Uong <anh.uong@ibm.com>
  • Loading branch information
Abhishek-TAMU and anhuong authored Sep 16, 2024
1 parent 5dd5494 commit cd6ba00
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ indent-string=' '
max-line-length=100

# Maximum number of lines in a module.
max-module-lines=1100
max-module-lines=1200

# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,11 @@ You can set `output_dir` to a local directory and set `save_model_dir` to COS to

In order to achieve the fastest train time, set `save_strategy="no"`, as saving no checkpoints except for the final model will remove intermediate write operations all together.

#### Resuming tuning from checkpoints
If the output directory already contains checkpoints, tuning will automatically resume from the latest checkpoint in the directory specified by the `output_dir` flag. To start tuning from scratch and ignore existing checkpoints, set the `resume_from_checkpoint` flag to False.

You can also use the resume_from_checkpoint flag to resume tuning from a specific checkpoint by providing the full path to the desired checkpoint as a string. This flag is passed as an argument to the [trainer.train()](https://github.com/huggingface/transformers/blob/db70426854fe7850f2c5834d633aff637f14772e/src/transformers/trainer.py#L1901) function of the SFTTrainer.

## Tuning Techniques:

### LoRA Tuning Example
Expand Down
208 changes: 208 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,214 @@
PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05)


def test_resume_training_from_checkpoint():
"""
Test tuning resumes from the latest checkpoint, creating new checkpoints and the
checkpoints created before resuming tuning is not affected.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert init_trainer_state is not None

# Resume training with higher epoch and same output dir
train_args.num_train_epochs += 5
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert final_trainer_state is not None

assert final_trainer_state["epoch"] == init_trainer_state["epoch"] + 5
assert final_trainer_state["global_step"] > init_trainer_state["global_step"]

# Check if loss of 1st epoch after first tuning is same after
# resuming tuning and not overwritten
assert len(init_trainer_state["log_history"]) > 0

init_log_history = init_trainer_state["log_history"][0]
assert init_log_history["epoch"] == 1

final_log_history = final_trainer_state["log_history"][0]
assert final_log_history["epoch"] == 1

assert init_log_history["loss"] == final_log_history["loss"]


def test_resume_training_from_checkpoint_with_flag_true():
"""
Test tuning resumes from the latest checkpoint when flag is true,
creating new checkpoints and the checkpoints created before resuming
tuning is not affected.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.resume_from_checkpoint = "True"

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert init_trainer_state is not None

# Get Training logs
init_training_logs = _get_training_logs_by_epoch(tempdir)

# Resume training with higher epoch and same output dir
train_args.num_train_epochs += 5
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
final_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert final_trainer_state is not None

assert final_trainer_state["epoch"] == init_trainer_state["epoch"] + 5
assert final_trainer_state["global_step"] > init_trainer_state["global_step"]

final_training_logs = _get_training_logs_by_epoch(tempdir)

assert (
init_training_logs[0]["data"]["timestamp"]
== final_training_logs[0]["data"]["timestamp"]
)


def test_resume_training_from_checkpoint_with_flag_false():
"""
Test when setting resume_from_checkpoint=False that tuning will start from scratch.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.resume_from_checkpoint = "False"

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get trainer state of latest checkpoint
init_trainer_state, _ = _get_latest_checkpoint_trainer_state(tempdir)
assert init_trainer_state is not None

# Get Training log entry for epoch 1
init_training_logs = _get_training_logs_by_epoch(tempdir, epoch=1)
assert len(init_training_logs) == 1

# Training again with higher epoch and same output dir
train_args.num_train_epochs += 5
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
_validate_training(tempdir)

# Get Training log entry for epoch 1
final_training_logs = _get_training_logs_by_epoch(tempdir, epoch=1)
assert len(final_training_logs) == 2


def test_resume_training_from_checkpoint_with_flag_checkpoint_path_lora():
"""
Test resume checkpoint from a specified checkpoint path for LoRA tuning.
"""
with tempfile.TemporaryDirectory() as tempdir:
train_args = copy.deepcopy(TRAIN_ARGS)
lora_config = copy.deepcopy(PEFT_LORA_ARGS)
train_args.output_dir = tempdir

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, lora_config)
_validate_training(tempdir)

# Get trainer state and checkpoint_path of second last checkpoint
init_trainer_state, checkpoint_path = _get_latest_checkpoint_trainer_state(
tempdir, checkpoint_index=-2
)
assert init_trainer_state is not None

# Resume training with higher epoch and same output dir
train_args.num_train_epochs += 5
train_args.resume_from_checkpoint = checkpoint_path
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, lora_config)
_validate_training(tempdir)

# Get total_flos from trainer state of checkpoint_path and check if its same
final_trainer_state = None
trainer_state_file = os.path.join(checkpoint_path, "trainer_state.json")
with open(trainer_state_file, "r", encoding="utf-8") as f:
final_trainer_state = json.load(f)

assert final_trainer_state["total_flos"] == init_trainer_state["total_flos"]


def _get_latest_checkpoint_trainer_state(dir_path: str, checkpoint_index: int = -1):
"""
Get the trainer state from the latest or specified checkpoint directory.
The trainer state is returned along with the path to the checkpoint.
Args:
dir_path (str): The directory path where checkpoint folders are located.
checkpoint_index (int, optional): The index of the checkpoint to retrieve,
based on the checkpoint number. The default
is -1, which returns the latest checkpoint.
Returns:
trainer_state: The trainer state loaded from `trainer_state.json` in the
checkpoint directory.
last_checkpoint: The path to the checkpoint directory.
"""
trainer_state = None
last_checkpoint = None
checkpoints = [
os.path.join(dir_path, d)
for d in os.listdir(dir_path)
if d.startswith("checkpoint")
]
if checkpoints:
last_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))[
checkpoint_index
]
trainer_state_file = os.path.join(last_checkpoint, "trainer_state.json")
with open(trainer_state_file, "r", encoding="utf-8") as f:
trainer_state = json.load(f)
return trainer_state, last_checkpoint


def _get_training_logs_by_epoch(dir_path: str, epoch: int = None):
"""
Load and optionally filter training_logs.jsonl file.
If an epoch number is specified, the function filters the logs
and returns only the entries corresponding to the specified epoch.
Args:
dir_path (str): The directory path where the `training_logs.jsonl` file is located.
epoch (int, optional): The epoch number to filter logs by. If not specified,
all logs are returned.
Returns:
list: A list containing the training logs. If `epoch` is specified,
only logs from the specified epoch are returned; otherwise, all logs are returned.
"""
data_list = []
with open(f"{dir_path}/training_logs.jsonl", "r", encoding="utf-8") as file:
for line in file:
json_data = json.loads(line)
data_list.append(json_data)

if epoch:
mod_data_list = []
for value in data_list:
if value["data"]["epoch"] == epoch:
mod_data_list.append(value)
return mod_data_list
return data_list


def test_run_train_requires_output_dir():
"""Check fails when output dir not provided."""
updated_output_dir_train_args = copy.deepcopy(TRAIN_ARGS)
Expand Down
22 changes: 20 additions & 2 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
LlamaTokenizerFast,
TrainerCallback,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import is_accelerate_available
from trl import SFTConfig, SFTTrainer
import transformers
Expand Down Expand Up @@ -215,7 +216,7 @@ def train(
),
)

# add special tokens only when a custom tokenizer is not passed
# Add special tokens only when a custom tokenizer is not passed
if not model_args.tokenizer_name_or_path:
# TODO: understand if we need to hardcode these here or just use defaults in model
if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
Expand Down Expand Up @@ -366,7 +367,24 @@ def train(
for x in framework.get_callbacks_and_ready_for_train(model, accelerator):
trainer.add_callback(x)

trainer.train()
resume_from_checkpoint = None
# Check if resume flag is not passed (None), or if flag is true and
# output_dir has checkpoints then get last checkpoint from output_dir
if (
training_args.resume_from_checkpoint is None
or training_args.resume_from_checkpoint.lower() == "true"
):
resume_from_checkpoint = get_last_checkpoint(training_args.output_dir)
else:
# `training_args.resume_from_checkpoint` gives string values
# Check if flag is false OR flag has checkpoint value for resuming tuning
resume_from_checkpoint = (
training_args.resume_from_checkpoint
if training_args.resume_from_checkpoint.lower() != "false"
else False
)

trainer.train(resume_from_checkpoint)

return trainer

Expand Down

0 comments on commit cd6ba00

Please sign in to comment.