Skip to content

Latest commit

 

History

History
224 lines (185 loc) · 8.45 KB

README.md

File metadata and controls

224 lines (185 loc) · 8.45 KB

An automatic differentiation library for generalized meta-learning and multilevel optimization
Docs | Tutorials | Examples | Paper | CASL Project

Version Testing License

pip install betty-ml

Introduction

Betty is a PyTorch library for generalized meta-learning (GML) and multilevel optimization (MLO) that allows a simple and modular programming interface for a number of large-scale applications including meta-learning, hyperparameter optimization, neural architecture search, data reweighting, and many more.

With Betty, users simply need to do two things to implement any GML/MLO programs:

  1. Define each level's optimization problem using the Problem class.
  2. Define the hierarchical problem structure using the Engine class.

Quick Start

Problem

Basics

Each level problem can be defined with seven components: (1) module, (2) optimizer, (3) data loader, (4) loss function, (5) problem configuration, (6) name, and (7) other optional components (e.g. learning rate scheduler). The loss function (4) can be defined via the training_step method, while all other components can be provided through the class constructor. For example, an image classification problem can be defined as follows:

from betty.problems import ImplicitProblem
from betty.configs import Config

# set up module, optimizer, data loader (i.e. (1)-(3))
cls_module, cls_optimizer, cls_data_loader = setup_classification()

class Classifier(ImplicitProblem):
    # set up loss function
    def training_step(self, batch):
        inputs, labels = batch
        outputs = self.module(inputs)
        loss = F.cross_entropy(outputs, labels)

        return loss

# set up problem configuration
cls_config = Config(type='darts', unroll_steps=1, log_step=100)

# Classifier problem class instantiation
cls_prob = Classifier(name='classifier',
                      module=cls_module,
                      optimizer=cls_optimizer,
                      train_data_loader=cls_data_loader,
                      config=cls_config)

Interactions between problems

In GML/MLO, each problem will often need to access modules from other problems to define its loss function. This can be achieved by using the name attribute as follows:

class HPO(ImplicitProblem):
    def training_step(self, batch):
        # set up hyperparameter optimization loss
        ...

# HPO problem class instantiation
hpo_prob = HPO(name='hpo', module=...)

class Classifier(ImplicitProblem):
    def training_step(self, batch):
        inputs, labels = batch
        outputs = self.module(inputs)
        loss = F.cross_entropy(outputs, labels)
        
        """
        accessing weight decay hyperparameter from another
        problem HPO can be achieved by its name 'hpo'
        """
        weight_decay = self.hpo()
        reg_loss = weight_decay * sum(
            [p.norm().pow(2) for p in self.module.parameters()]
        )
        
        return loss + reg_loss

cls_prob = Classifier(name='classifier', module=...)

Engine

Basics

The Engine class handles the hierarchical dependencies between problems. In GML/MLO, there are two types of dependencies: upper-to-lower (u2l) and lower-to-upper (l2u). Both types of dependencies can be defined with a Python dictionary, where the key is the starting node and the value is the list of destination nodes.

from betty import Engine
from betty.configs import EngineConfig

# set up all involved problems
problems = [cls_prob, hpo_prob]

# set up upper-to-lower and lower-to-upper dependencies
u2l = {hpo_prob: [cls_prob]}
l2u = {cls_prob: [hpo_prob]}
dependencies = {'u2l': u2l, 'l2u': l2u}

# set up Engine configuration
engine_config = EngineConfig(train_iters=10000, valid_step=100)

# instantiate Engine class
engine = Engine(problems=problems,
                dependencies=dependencies,
                config=engine_config)

# execute multilevel optimization
engine.run()

Since Engine manages the whole GML/MLO program, you can also perform a global validation stage within it. All problems that comprise the GML/MLO program can again be accessed with their names.

class HPOEngine(Engine):
    # set up global validation
    @torch.no_grad()
    def validation(self):
        loss = 0
        for inputs, labels in test_loader:
            outputs = self.classifer(inputs)
            loss += F.cross_entropy(outputs, targets)
            
        # Returned dict will be automatically logged after each validation
        return {'loss': loss}
...
engine = HPOEngine(problems=problems,
                   dependencies=dependencies,
                   config=engine_config)
engine.run()

Once we define all optimization problems and the hierarchical dependencies between them with, respectively, the Problem class and the Engine class, all complicated internal mechanisms of GML/MLO such as gradient calculation and optimization execution order will be handled by Betty. For more details and advanced features, users can check out our Documentation and Tutorials.

Happy multilevel optimization programming!

Applications

We provide reference implementations of several GML/MLO applications, including:

While each of the above examples traditionally has a distinct implementation style, note that our implementations share the same code structure thanks to Betty. More examples are on the way!

Features

Gradient Approximation Methods

Training

  • Gradient accumulation
  • FP16 training
  • Distributed data-parallel training
  • Gradient clipping

Logging

Contributing

We welcome contributions from the community! Please see our contributing guidelines for details on how to contribute to Betty.

Citation

If you use Betty in your research, please cite our paper with the following Bibtex entry.

@article{choe2022betty,
  title={Betty: An Automatic Differentiation Library for Multilevel Optimization},
  author={Choe, Sang Keun and Neiswanger, Willie and Xie, Pengtao and Xing, Eric},
  journal={arXiv preprint arXiv:2207.02849},
  year={2022}
}

License

Betty is licensed under the Apache 2.0 License.