Skip to content

Commit

Permalink
feat: trainer controller revamped
Browse files Browse the repository at this point in the history
Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>
Co-authored-by: Ashok Pon Kumar <ashokponkumar@gmail.com>
Co-authored-by: Dushyant Behl <dushyantbehl@hotmail.com>
  • Loading branch information
3 people committed Mar 17, 2024
1 parent 25ef858 commit 7edf675
Show file tree
Hide file tree
Showing 24 changed files with 479 additions and 626 deletions.
6 changes: 0 additions & 6 deletions examples/trainer-controller-configs/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,3 @@ To use one of these files with the trainer, execute the `sft_trainer.py` with th
```
--trainer_controller_config_file "examples/trainer-controller-configs/<file-name>"
```

# Note on trainer controller configuration examples
- `trainercontroller_config_step.yaml`: Defines a trainer controller, which computes loss at every step, and if the loss consistently increases for three steps, then the training is stopped.
- `trainercontroller_config_epoch.yaml`: Defines an epoch-level trainer controller, which computes loss at every epoch. The rule applied here is to compare the current epoch loss with the previous epoch loss, and if the current epoch loss turns out to be more, then the training is stopped.
- `trainercontroller_config_epoch_threshold.yaml`: Defines a trainer controller similar to previous case, but also adds a threshold constraint.
- `trainercontroller_config_evaluate.yaml`: Defines a trainer controller which behaves similar to the `EarlyStoppingCallback` from hugging face which can be found [here](https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/trainer_callback.py#L543).
10 changes: 10 additions & 0 deletions examples/trainer-controller-configs/loss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller-metrics:
loss:
Loss:
controllers:
- name: loss-controller
triggers:
- on_log
rule: loss > 2.5
operations:
- hfcontrols.should_training_stop

This file was deleted.

This file was deleted.

This file was deleted.

111 changes: 33 additions & 78 deletions tests/trainercontroller/test_tuning_trainercontroller.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,39 @@
# Third Party
import pytest
import math
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Local
import tuning.trainercontroller as tc
import tuning.config.configs as config
from transformers import TrainerControl, TrainerState, IntervalStrategy

def test_step_loss():
test_data = [{'loss': 2.0, 'eval_loss': 2.0, 'epoch': 0.1}, \
{'loss': 2.1, 'eval_loss': 2.1, 'epoch': 0.25}, \
{'loss': 2.3, 'eval_loss': 2.3, 'epoch': 0.5}]
outcomes = [False, False, True]
training_args = config.TrainingArguments(output_dir='')
trainer_controller_args = config.TrainerControllerArguments()
training_args.logging_strategy = IntervalStrategy.STEPS
training_args.logging_steps = 1
trainer_controller_args.trainer_controller_config_file = 'examples/trainer-controller-configs/trainercontroller_config_step.yaml'
tc_callback = tc.TrainerControllerCallback(trainer_controller_args, training_args)
control = TrainerControl()
control.should_training_stop = False
state = TrainerState()
state.log_history = []
for i in range(len(test_data)):
state.log_history.append(test_data[i])
control = tc_callback.on_step_end(training_args, state, control)
assert control.should_training_stop == outcomes[i]

def test_epoch_loss():
test_data = [{'loss': 2.0, 'eval_loss': 2.0, 'epoch': 0.1}, \
{'loss': 2.1, 'eval_loss': 2.1, 'epoch': 0.25}, \
{'loss': 2.3, 'eval_loss': 2.3, 'epoch': 0.5}, \
{'loss': 2.35, 'eval_loss': 2.35, 'epoch': 0.75}, \
{'loss': 2.4, 'eval_loss': 2.35, 'epoch': 1.0}, \
{'loss': 2.45, 'eval_loss': 2.4, 'epoch': 1.25}, \
{'loss': 2.5, 'eval_loss': 2.45, 'epoch': 1.5}, \
{'loss': 2.55, 'eval_loss': 2.5, 'epoch': 1.75}, \
{'loss': 2.6, 'eval_loss': 2.55, 'epoch': 2.0}]
outcomes = [False, False, False, False, False, False, False, False, True]
training_args = config.TrainingArguments(output_dir='')
trainer_controller_args = config.TrainerControllerArguments()
training_args.logging_strategy = IntervalStrategy.STEPS
training_args.logging_steps = 1
trainer_controller_args.trainer_controller_config_file = 'examples/trainer-controller-configs/trainercontroller_config_epoch.yaml'
tc_callback = tc.TrainerControllerCallback(trainer_controller_args, training_args)
control = TrainerControl()
control.should_training_stop = False
state = TrainerState()
state.log_history = []
for i in range(len(test_data)):
state.log_history.append(test_data[i])
if (math.ceil(test_data[i]['epoch']) - test_data[i]['epoch']) > 0:
continue
control = tc_callback.on_epoch_end(training_args, state, control)
assert control.should_training_stop == outcomes[i]

def test_epoch_threshold_loss():
test_data = [{'loss': 2.1, 'eval_loss': 2.0, 'epoch': 0.1}, \
{'loss': 2.1, 'eval_loss': 2.1, 'epoch': 0.25}, \
{'loss': 2.05, 'eval_loss': 2.3, 'epoch': 0.5}, \
{'loss': 2.05, 'eval_loss': 2.35, 'epoch': 0.75}, \
{'loss': 2.02, 'eval_loss': 2.35, 'epoch': 1.0}, \
{'loss': 2.03, 'eval_loss': 2.4, 'epoch': 1.25}, \
{'loss': 2.01, 'eval_loss': 2.45, 'epoch': 1.5}, \
{'loss': 2.0, 'eval_loss': 2.5, 'epoch': 1.75}, \
{'loss': 2.09, 'eval_loss': 2.55, 'epoch': 2.0}]
outcomes = [False, False, False, False, False, False, False, False, True]
training_args = config.TrainingArguments(output_dir='')
trainer_controller_args = config.TrainerControllerArguments()
training_args.logging_strategy = IntervalStrategy.STEPS
training_args.logging_steps = 1
trainer_controller_args.trainer_controller_config_file = 'examples/trainer-controller-configs/trainercontroller_config_epoch_threshold.yaml'
tc_callback = tc.TrainerControllerCallback(trainer_controller_args, training_args)
control = TrainerControl()
control.should_training_stop = False
state = TrainerState()
state.log_history = []
for i in range(len(test_data)):
state.log_history.append(test_data[i])
if (math.ceil(test_data[i]['epoch']) - test_data[i]['epoch']) > 0:
continue
control = tc_callback.on_epoch_end(training_args, state, control)
assert control.should_training_stop == outcomes[i]
def test_step_loss_on_threshold():
test_data = [{'loss': 2.0, 'epoch': 0.1}, \
{'loss': 2.1, 'epoch': 0.25}, \
{'loss': 1.3, 'epoch': 0.5}, \
{'loss': 0.9, 'epoch': 0.6}]
training_args = config.TrainingArguments(
output_dir='',
logging_strategy=IntervalStrategy.STEPS,
logging_steps=1,
)
tc_callback = tc.TrainerControllerCallback('examples/trainer-controller-configs/loss.yaml')
control = TrainerControl(should_training_stop = False)
state = TrainerState(log_history = [])
tc_callback.on_init_end(args=training_args)
state.log_history=test_data
tc_callback.on_step_end(args=training_args, state=state, control=control)
assert control.should_training_stop == True
4 changes: 2 additions & 2 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ class TrainingArguments(transformers.TrainingArguments):
@dataclass
class TrainerControllerArguments():
trainer_controller_config_file: str = field(
default="trainercontroller_config.yaml",
default=None,
metadata={
"help": (
"Trainer controller configuration file in YAML format."
"Trainer controller configuration file (e.g trainercontroller_config.yaml) in YAML format."
)
},
)
12 changes: 4 additions & 8 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,10 @@ def train(
file_logger_callback = FileLoggingCallback(logger)
callbacks = [aim_callback, file_logger_callback]

if os.path.exists(trainer_controller_args.trainer_controller_config_file):
with open(trainer_controller_args.trainer_controller_config_file, "r") as f:
trainer_controller_config = yaml.safe_load(f)
tc_callback = TrainerControllerCallback(trainer_controller_config)
callbacks.append(tc_callback)
else:
raise FileNotFoundError("Trainer controller configuration [%s] does NOT exist" % trainer_controller_args.trainer_controller_config_file)

if trainer_controller_args.trainer_controller_config_file is not None:
tc_callback = TrainerControllerCallback(trainer_controller_args.trainer_controller_config_file)
callbacks.append(tc_callback)

if train_args.packing:
logger.info("Packing is set to True")
data_collator = None
Expand Down
17 changes: 17 additions & 0 deletions tuning/trainercontroller/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,18 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

from .callback import TrainerControllerCallback
Loading

0 comments on commit 7edf675

Please sign in to comment.