Skip to content

Commit

Permalink
feat: Trainer controller framework (foundation-model-stack#45)
Browse files Browse the repository at this point in the history
* feat: Extended gitignore to include backup files and folders

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>
Co-authored-by: Ashok Pon Kumar <ashokponkumar@gmail.com>
Co-authored-by: Dushyant Behl <dushyantbehl@hotmail.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>
Co-authored-by: Ashok Pon Kumar <ashokponkumar@gmail.com>
Co-authored-by: Dushyant Behl <dushyantbehl@hotmail.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Policy driven training control

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Addressed review comments related to exceptions and abstract class inheritance

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Design changes to trainer controller including validations, schema etc

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: trainer controller revamped

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>
Co-authored-by: Ashok Pon Kumar <ashokponkumar@gmail.com>
Co-authored-by: Dushyant Behl <dushyantbehl@hotmail.com>

* feat: trainer controller revamped

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>
Co-authored-by: Ashok Pon Kumar <ashokponkumar@gmail.com>
Co-authored-by: Dushyant Behl <dushyantbehl@hotmail.com>

* feat: Documentation and some test case bug fixes

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: Formatting issues to make build succeed

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: Add rule validation to make eval safe again

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: Removed default package typing from requirements.txt

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* feat: Added test cases, data, some exception handling

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: Addressed the action filter bug and added a test case for it

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: bugs in operation validate

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* adr: Architecture document for trainer-controller

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: Trainer controller examples directory renamed

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: Trainer controller examples directory renamed

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: Prefix regex corrected

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* adr: Details on key collisions added

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* adr: Details on key collisions added

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: rebase issues related to aim callback addressed

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: rebase issues related to aim callback addressed

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: brackets missing comma

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: Addressed lint comments

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: Added lint disable directives

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: Reformatted files from black

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

* fix: Resolved cyclic package dependencies

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>

---------

Signed-off-by: Padmanabha V Seshadri <seshapad@in.ibm.com>
Co-authored-by: Ashok Pon Kumar <ashokponkumar@gmail.com>
Co-authored-by: Dushyant Behl <dushyantbehl@hotmail.com>
  • Loading branch information
3 people authored and jbusche committed Apr 9, 2024
1 parent 207fe3a commit aa93d7d
Show file tree
Hide file tree
Showing 35 changed files with 1,612 additions and 2 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ test
.vscode/
.idea/

# AIM files
.aim

# Env files
.env

Expand All @@ -28,3 +31,8 @@ venv/

# Aim
.aim

# Backup files and folders
*.bkp
*.bkp.*
*bkp*
129 changes: 129 additions & 0 deletions architecture_records/001-trainer-controller-framework.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Trainer Controller Framework

**Deciders(s)**: Alexander Brooks (alex.brooks@ibm.com), Sukriti Sharma (sukriti.sharma4@ibm.com), Raghu Ganti (rganti@us.ibm.com), Padmanabha Venkatagiri Seshadri (seshapad@in.ibm.com), Dushyant Behl (dushyantbehl@in.ibm.com)
**Date (YYYY-MM-DD)**: 2024-03-05
**Obsoletes ADRs**: NA
**Modified By ADRs**: NA
**Relevant Issues**: [537](https://github.ibm.com/ai-foundation/watson-fm-stack-tracker/issues/537), [323](https://github.ibm.com/ai-foundation/watson-fm-stack-tracker/issues/323)

- [Summary and Objective](#summary-and-objective)
- [Motivation](#motivation)
- [User Benefit](#user-benefit)
- [Decision](#decision)
- [Alternatives Considered](#alternatives-considered)
- [Consequences](#consequences)
- [Detailed Design](#detailed-design)

## Summary and Objective

To create a framework for controlling the trainer loop using user-defined rules and metrics.

### Motivation

- The issue [537](https://github.ibm.com/ai-foundation/watson-fm-stack-tracker/issues/537), had raised the need for stopping an ongoing training if some stopping criteria is satisfied (E.g loss validation reaching a certain target, loss increasing with epoch, loss values for last 100 steps increasing etc).
- There is a [EarlyStoppingCallback](https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/trainer_callback.py#L543) in HF, but the granularity of stopping is only on `evaluate` events, and handles only compares instantaneous metric value to a threshold.
- Therefore, there is a need for a mechanism to capture the user-defined custom stopping criteria which could involve multiple metrics.
- In addition to user-defined stopping criteria, there could other types of control operations with respect to training (for instance, should the trainer perform saving, logging or evaluation operations or not, should we scale resources dynamically so that training could run faster and so on). Therefore, there is a need for general need to capture all these use-cases in a single framework. This PR attempts to provide such a framework.

### User Benefit

Users could control the training loop by defining custom rules. This will benefit the user in terms of conserving resources by killing training jobs with run-away loss, help in collecting debugging data (log-on-demand), checkpoint-on-demand, and similar scenarios where intervention is required with respect to training loop.

## Decision

### Alternatives Considered

We considered some of the following alternatives:
- Defining the metrics as functions instead of classes. This was dropped because it was less expressive and did not preserve state of the computation (as in the case of windowing mechanisms), which could be updated more efficiently using the evolving logs.
- Making modification to the trainer loop directly instead of using callbacks. This was dropped because it will require custom huggingface stack with the said modifications. Further, it will make the framework tightly coupled.

## Consequences

Following are the advantages and limitations of our design approach:

### Advantages:
- We have used the trainer callback approach which can used in `plug-and-play` form with the trainer. In addition, we have designed this framework to be a independent packaged
- The rules and metrics are flexible and can be defined by user easily with limited coding effort (only in the case of custom metrics), and no coding effort if the user is using existing metrics.

### Impact on performance:
Could add to the overhead of the trainer loop as the callback is invoked at various events of the trainer loop and what is computed within the callback could affect the performance of the loop iteration.


## Detailed Design

### Usage and customization
We have implemented a trainer callback (see [here](https://huggingface.co/docs/transformers/v4.37.2/en/main_classes/callback)) which accepts a `training control definition` file (in YAML format) which facilitates the definition of:
1. Rules to control training loop
2. Trigger points that evaluate the above rules
3. Control operation and action that needs to be performed if rule is evaluated to true.

The trainer controller configuration is structured as shown below. There are list of metric definitions under `controller-metrics`, a list of operations and their actions under `operations` and a list of controllers, each of which define the rules, triggers and control operations.
```
controller-metrics:
<controller-name>:
<controller-handler-class>:
<arg1>: <value>
...
operations:
<operation-name>:
<operation-handler-class>:
<arg1>: <value>
...
controllers:
- name: <controller-name>
triggers:
- <event-1>
...
rule: <rule-string>
operations:
- <operation-action-1>
...
```
The `controller-metrics` and `operations` are optional. We provide a set of built-in `controller-metrics` and `operations` which could be referred to without actually defining them as. For example, the below configuration defines a `controller-metric` called `loss` which refers to a built-in `Loss` controller-metric class with custom arguments (in this case, no arguments), but does not define any `operations`. It only refers to a built-in operation.
```
controller-metrics:
loss:
Loss:
controllers:
- name: loss-controller
triggers:
- on_log
rule: loss < 1.0
operations:
- hfcontrols.should_training_stop
```
For defining custom handler classes, we have an interface defined as an abstract class as shown below, with two abstract methods, namely: `validate()` to define the validation conditions, and `compute()` to compute the metric. The `compute()` returns an `Any` type. While it could be any value, developers should keep in mind that it should be only key-value pairs that are used in the rule(s) defined in the configuration.

Further, the `init` method of the class should accept variable arguments in the form of key-value pairs. `Important point to note is that keys used in the arguments of the above config should not conflict with any keys used by Hugging face trainer callback. Please try to use unique keys are arguments name`.
```
class MetricHandler(metaclass=abc.ABCMeta):
@abc.abstractmethod
def validate(self) -> bool:
pass
@abc.abstractmethod
def compute(self, event_name: str, **kwargs) -> Any:
pass
```
These classes can be user-defined. To add a new metric class, simply implement the above structure and register it with the trainer controller framework using the `register_metric_handlers()` method. To use the metric handler class, add the class name, arguments to the above configuration file.

Similarly, there is an operator abstract class `Operation` which could be inherited and custom operations could be defined as illustrated below:
```
class CustomOperation(Operation):
def should_perform_action_xyz(args):
pass
```
Every action defined in the custom operation should be represented as a function with `"should_"` prefixed in the function name. The controller will automatically pickup these functions and invoke them if they are referred to in the configuration. Custom operations could be registered using `register_operation_handlers()` method.

`rule` is python expression which could express a condition to evaluate on a metric variable. For example, in the above configuration, `loss` is the variable, and the rule is applying a threshold on it.

`operations` lists the operation-actions to be performed when the rule evaluates to True. The convention followed to refer to an operation is `<operation-name>.<action-name>`. In this example, the `<operation-class-name>` is referring to built-in operation `hfcontrols` and one of its corresponding action `action-name` i.e `should_training_stop`.

### High-level architecture
Following is a high-level design diagram. Following are the touch-points to the framework through which user of this framework could interact with it:
- **Registration**: The registration mechanism de-couples the metrics and operators from the trainer framework. A user could implement a custom metric or operator and register it through the registration methods mentioned previously. This makes the framework highly extensible.

- **Configuration**: The trainer controller configuration supplies the definition for triggers, rule, operations and metrics to orchestrate the enactment of a particular control policy. These details are split up and passed off to the respective modules by the trainer controller as shown in the figure.

- **Events**: Events supply the state and arguments required for the metric handlers to perform metric computation at the events they are registered for. The framework callback lists out all event handlers with prefix `"on_"` and loads then as event handlers. Every metric declares one or more events from this list of valid handlers. These computed metric variables are stored in a global state of the trainer controller and independently picked up the operations which could potentially be triggered on an entirely different set of events. This decouples the control loop for metrics and operations. I.e. the metric could be computed on event A, while operation could be triggered on event B. The controller rules which use the metric variables from the trainer controller state are evaluated and based on the outcomes specified actions are performed.
![High-Level Design Diagram: Trainer Controller Framework](arch.png)
Binary file added architecture_records/arch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions examples/trainercontroller_configs/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# How-To
To use one of these files with the trainer, execute the `sft_trainer.py` with the following option:
```
--trainer_controller_config_file "examples/trainercontroller_configs/<file-name>"
```
10 changes: 10 additions & 0 deletions examples/trainercontroller_configs/loss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller-metrics:
loss:
Loss:
controllers:
- name: loss-controller
triggers:
- on_log
rule: loss < 1.0
operations:
- hfcontrols.should_training_stop
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[pytest]
# Register tests from `build` dir, removing `build` from norecursedirs default list,
# see https://doc.pytest.org/en/latest/reference/reference.html#confval-norecursedirs
norecursedirs = *.egg .* _darcs CVS dist node_modules venv {arch}
norecursedirs = *.egg .* _darcs CVS dist node_modules venv {arch}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ trl
peft>=0.8.0
datasets>=2.15.0
fire

51 changes: 51 additions & 0 deletions tests/data/trainercontroller/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright The IBM Tuning Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpful datasets for configuring individual unit tests.
"""
# Standard
import os

### Constants used for data
_DATA_DIR = os.path.join(os.path.dirname(__file__))
TRAINER_CONFIG_TEST_LOSS_ON_THRESHOLD_YAML = os.path.join(
_DATA_DIR, "loss_on_threshold.yaml"
)
TRAINER_CONFIG_TEST_MALICIOUS_OS_RULE_YAML = os.path.join(
_DATA_DIR, "loss_with_malicious_os_rule.yaml"
)
TRAINER_CONFIG_TEST_MALICIOUS_INPUT_RULE_YAML = os.path.join(
_DATA_DIR, "loss_with_malicious_input_rule.yaml"
)
TRAINER_CONFIG_TEST_INVALID_TRIGGER_YAML = os.path.join(
_DATA_DIR, "loss_invalid_trigger.yaml"
)
TRAINER_CONFIG_TEST_INVALID_OPERATION_YAML = os.path.join(
_DATA_DIR, "loss_invalid_operation.yaml"
)
TRAINER_CONFIG_TEST_INVALID_OPERATION_ACTION_YAML = os.path.join(
_DATA_DIR, "loss_invalid_operation_action.yaml"
)
TRAINER_CONFIG_TEST_INVALID_METRIC_YAML = os.path.join(
_DATA_DIR, "loss_invalid_metric.yaml"
)
TRAINER_CONFIG_TEST_CUSTOM_METRIC_YAML = os.path.join(
_DATA_DIR, "loss_custom_metric.yaml"
)
TRAINER_CONFIG_TEST_CUSTOM_OPERATION_YAML = os.path.join(
_DATA_DIR, "loss_custom_operation.yaml"
)
TRAINER_CONFIG_TEST_CUSTOM_OPERATION_INVALID_ACTION_YAML = os.path.join(
_DATA_DIR, "loss_custom_operation_invalid_action.yaml"
)
10 changes: 10 additions & 0 deletions tests/data/trainercontroller/loss_custom_metric.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller-metrics:
testflag:
CustomMetric:
controllers:
- name: loss-controller-custom-metric
triggers:
- on_log
rule: testflag == True
operations:
- hfcontrols.should_training_stop
13 changes: 13 additions & 0 deletions tests/data/trainercontroller/loss_custom_operation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
controller-metrics:
loss:
Loss:
operations:
customoperation:
CustomOperation:
controllers:
- name: loss-controller-custom-operation
triggers:
- on_log
rule: loss < 1.0
operations:
- customoperation.should_perform_action_xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
controller-metrics:
loss:
Loss:
operations:
customoperation:
CustomOperationInvalidAction:
controllers:
- name: loss-controller-custom-operation-invalid-action
triggers:
- on_log
rule: loss < 1.0
operations:
- customoperation.should_
10 changes: 10 additions & 0 deletions tests/data/trainercontroller/loss_invalid_metric.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller-metrics:
loss:
MissingMetricClass:
controllers:
- name: loss-controller-invalid-metric
triggers:
- on_log
rule: loss < 1.0
operations:
- hfcontrols.should_training_stop
10 changes: 10 additions & 0 deletions tests/data/trainercontroller/loss_invalid_operation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller-metrics:
loss:
Loss:
controllers:
- name: loss-controller-invalid-operation
triggers:
- on_log
rule: loss < 1.0
operations:
- missingop.should_training_stop
10 changes: 10 additions & 0 deletions tests/data/trainercontroller/loss_invalid_operation_action.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller-metrics:
loss:
Loss:
controllers:
- name: loss-controller-invalid-operation-action
triggers:
- on_log
rule: loss < 1.0
operations:
- hfcontrols.missingaction
10 changes: 10 additions & 0 deletions tests/data/trainercontroller/loss_invalid_trigger.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller-metrics:
loss:
Loss:
controllers:
- name: loss-controller-invalid-trigger
triggers:
- log_it_all_incorrect_trigger_name
rule: loss < 1.0
operations:
- hfcontrols.should_training_stop
10 changes: 10 additions & 0 deletions tests/data/trainercontroller/loss_on_threshold.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller-metrics:
loss:
Loss:
controllers:
- name: loss-controller
triggers:
- on_log
rule: loss < 1.0
operations:
- hfcontrols.should_training_stop
10 changes: 10 additions & 0 deletions tests/data/trainercontroller/loss_with_malicious_input_rule.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller-metrics:
loss:
Loss:
controllers:
- name: loss-controller-wrong-input-rule
triggers:
- on_log
rule: input('Please enter your password:')
operations:
- hfcontrols.should_training_stop
10 changes: 10 additions & 0 deletions tests/data/trainercontroller/loss_with_malicious_os_rule.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
controller-metrics:
loss:
Loss:
controllers:
- name: loss-controller-wrong-os-rule
triggers:
- on_log
rule: __import__('os').system('clear')
operations:
- hfcontrols.should_training_stop
Loading

0 comments on commit aa93d7d

Please sign in to comment.