Skip to content

Commit

Permalink
8-bit Adam (#463)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst authored Jul 3, 2024
1 parent d1e15b4 commit 739952b
Show file tree
Hide file tree
Showing 8 changed files with 806 additions and 0 deletions.
211 changes: 211 additions & 0 deletions benchmarks/benchmark_adam_8bit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# pip install timm wandb tqdm datasets
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core
#
# python benchmarks_adam_8bit.py \
# --model "timm/vit_base_patch16_224.augreg_in21k" \
# --amp bf16 \
# --optim Adam
#
# To use bnb 8-bit optimizer, set --optim Adam8bitBnb. To use 8-bit optimizer implemented in torchao, set --optim Adam8bitAo
# To profile and export chrome trace, set --profile
# To enable cosine learning rate scheduler, set --cosine_lr_scheduler

import argparse
import math
from contextlib import nullcontext
from pathlib import Path

import bitsandbytes as bnb
import datasets
import timm
import torch
import torch.nn.functional as F
from torch.profiler import ProfilerActivity, profile
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from tqdm import tqdm

from torchao.prototype.optim_8bit import Adam8bit


class CosineSchedule:
def __init__(self, lr: float, total_steps: int, warmup: float = 0.05) -> None:
self.lr = lr
self.final_lr = 0
self.total_steps = total_steps
self.warmup_steps = round(total_steps * warmup)

def get_lr(self, step: int) -> float:
if step < self.warmup_steps:
return self.lr * step / self.warmup_steps
if step < self.total_steps:
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi))
return self.final_lr


class WandbLogger:
def __init__(self, args):
if args.project is not None and not args.profile:
import wandb

Path("wandb_logs").mkdir(exist_ok=True)
self.run = wandb.init(project=args.project, name=args.run_name, config=args, dir="wandb_logs")

else:
self.run = None

def log(self, *args, **kwargs):
if self.run is not None:
self.run.log(*args, **kwargs)


def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)

parser.add_argument("--amp", default="none")
parser.add_argument("--channels_last", action="store_true")
parser.add_argument("--compile", action="store_true")

parser.add_argument("--n_epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--n_workers", type=int, default=4)

parser.add_argument("--optim", default="Adam")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=0)
parser.add_argument("--cosine_lr_scheduler", action="store_true")

parser.add_argument("--project")
parser.add_argument("--run_name", default="debug")
parser.add_argument("--profile", action="store_true")
return parser


def get_dloader(args, training: bool):
transforms = [v2.ToImage()]

if training:
transforms.extend([v2.RandomResizedCrop(224), v2.RandomHorizontalFlip()])
else:
transforms.extend([v2.Resize(256), v2.CenterCrop(224)])

transforms.append(v2.ToDtype(torch.float32, scale=True))
transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
transforms = v2.Compose(transforms)

# use dataset from HF so download is fast
ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation")
ds = ds.select_columns(["image", "label"])
ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"]))

return DataLoader(
ds,
batch_size=args.batch_size,
shuffle=training,
num_workers=args.n_workers,
pin_memory=training,
drop_last=training,
)


def get_amp_ctx(amp):
dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp]
return torch.autocast("cuda", dtype=dtype, enabled=amp != "none")


@torch.no_grad()
def evaluate_model(model, args):
model.eval()
val_dloader = get_dloader(args, False)

all_labels = []
all_preds = []

for batch in tqdm(val_dloader, dynamic_ncols=True, desc=f"Evaluating"):
all_labels.append(batch["label"].clone())
if args.channels_last:
batch["image"] = batch["image"].to(memory_format=torch.channels_last)

with get_amp_ctx(args.amp):
all_preds.append(model(batch["image"].cuda()).argmax(1).cpu())

all_labels = torch.cat(all_labels, dim=0)
all_preds = torch.cat(all_preds, dim=0)

acc = (all_labels == all_preds).float().mean()
return acc


if __name__ == "__main__":
args = get_parser().parse_args()

if args.profile:
args.n_epochs = 1

for k, v in vars(args).items():
print(f"{k}: {v}")

# wandb is only enabled when args.project is set and args.profile is False
logger = WandbLogger(args)
dloader = get_dloader(args, True)
print(f"Train dataset: {len(dloader.dataset):,} images")

model = timm.create_model(args.model, pretrained=True, num_classes=45).cuda()
if args.channels_last:
model.to(memory_format=torch.channels_last)
if args.compile:
model.compile(fullgraph=True)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

OPTIM_MAP = dict(
Adam=torch.optim.Adam,
Adam8bitBnb=bnb.optim.Adam8bit,
Adam8bitAo=Adam8bit,
)
optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay)
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs)

grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16")

step = 0
for epoch_idx in range(args.n_epochs):
model.train()
prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext()

with prof:
for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"):
if args.channels_last:
batch["image"] = batch["image"].to(memory_format=torch.channels_last)

with get_amp_ctx(args.amp):
loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda())
grad_scaler.scale(loss).backward()

if args.cosine_lr_scheduler:
lr = lr_schedule.get_lr(step)
for param_group in optim.param_groups:
param_group["lr"] = lr

if step % 100 == 0:
logger.log(dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]), step=step)

grad_scaler.step(optim)
grad_scaler.update()
optim.zero_grad()

step += 1

if args.profile and step == 20:
break

if args.profile:
prof.export_chrome_trace("trace.json")

else:
val_acc = evaluate_model(model, args)
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}")
logger.log(dict(val_acc=val_acc), step=step)

print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / (1 << 30):.2f} GB")
84 changes: 84 additions & 0 deletions test/prototype/test_optim_8bit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import copy

import pytest
import torch
from torch import nn
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.prototype import optim_8bit
from torchao.prototype.optim_8bit.subclass import quantize_8bit_with_qmap, QMAP_SIGNED
from torchao.utils import TORCH_VERSION_AFTER_2_3

try:
import bitsandbytes as bnb
except ImportError:
bnb = None


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


class TestDTQ8bit(TestCase):
@parametrize("device", _DEVICES)
def test_quantize_8bit_with_qmap_correctness(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(QMAP_SIGNED, device=device)

actual_codes, actual_scale = quantize_8bit_with_qmap(x, qmap, 256, implementation=1)
expected_codes, expected_scale = quantize_8bit_with_qmap(x, qmap, 256, implementation=0)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)

@parametrize("device", _DEVICES)
def test_quantize_8bit_with_qmap_compile(self, device):
x = torch.randn(32, 1024, device=device)
qmap = torch.tensor(QMAP_SIGNED, device=device)

actual_codes, actual_scale = torch.compile(quantize_8bit_with_qmap, fullgraph=True)(x, qmap, 256)
expected_codes, expected_scale = quantize_8bit_with_qmap(x, qmap, 256)

torch.testing.assert_close(actual_codes, expected_codes)
torch.testing.assert_close(actual_scale, expected_scale)


@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
class TestOptim8bit(TestCase):
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
def test_adam_8bit_correctness(self, optim_name):
device = "cuda"
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
model2 = copy.deepcopy(model1)

optim1 = getattr(bnb.optim, optim_name)(model1.parameters())
optim2 = getattr(optim_8bit, optim_name)(model2.parameters())

for _ in range(2):
x = torch.randn(4, 32, device=device)

loss1 = model1(x).sum()
loss1.backward()
optim1.step()
optim1.zero_grad()

loss2 = model2(x).sum()
loss2.backward()
optim2.step()
optim2.zero_grad()

for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)


instantiate_parametrized_tests(TestDTQ8bit)
instantiate_parametrized_tests(TestOptim8bit)


if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions torchao/prototype/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm
- `galore/docs` - implementation notes and discussion of issues faced in kernel design.
- [`quant_llm`](quant_llm) - FP16 x FPx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112)
- [`optim_8bit`](optim_8bit) - re-implementation of 8-bit optimizers from [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).

#### Roadmap

Expand Down
38 changes: 38 additions & 0 deletions torchao/prototype/optim_8bit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# 8-bit optimizers

This folder implements 8-bit optimizers using dynamic tree quantization as outlined in https://arxiv.org/abs/2110.02861. The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel.

## Usage

This is a drop-in replacement for `torch.optim.Adam`

```python
from torchao.prototype.optim_8bit import Adam8bit

model = ...
optim = Adam8bit(model.parameters())
```

You can also change quantization block size (default 2048) by passing `block_size=value` to the optimizer.

**Other optimizers**: AdamW is also available as `AdamW8bit`.

NOTE: this requires PyTorch >= 2.3

## Benchmarks

Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_adam_8bit.py](../../../benchmarks/benchmark_adam_8bit.py).

Results for fine-tuning ViT-B with BF16 AMP, on 4070Ti SUPER:

Adam impl | max memory (GB) | training time | accuracy
----------|-----------------|---------------|----------
PyTorch | 5.26 | 9m 11s | 93.62%
bnb 8-bit | 4.78 | 9m 10s | 93.06%
ao 8-bit | 4.78 | 9m 15s | 94.14%

**Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed

## Credits

Credits to Tim Dettmers for creating the wonderful bitsandbytes library.
2 changes: 2 additions & 0 deletions torchao/prototype/optim_8bit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .adam import Adam8bit
from .adamw import AdamW8bit
Loading

0 comments on commit 739952b

Please sign in to comment.