-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d1e15b4
commit 739952b
Showing
8 changed files
with
806 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
# pip install timm wandb tqdm datasets | ||
# To fine-tune a pre-trained ViT-Base on resisc45 dataset with BF16 AMP, using default Adam optimizer from PyTorch core | ||
# | ||
# python benchmarks_adam_8bit.py \ | ||
# --model "timm/vit_base_patch16_224.augreg_in21k" \ | ||
# --amp bf16 \ | ||
# --optim Adam | ||
# | ||
# To use bnb 8-bit optimizer, set --optim Adam8bitBnb. To use 8-bit optimizer implemented in torchao, set --optim Adam8bitAo | ||
# To profile and export chrome trace, set --profile | ||
# To enable cosine learning rate scheduler, set --cosine_lr_scheduler | ||
|
||
import argparse | ||
import math | ||
from contextlib import nullcontext | ||
from pathlib import Path | ||
|
||
import bitsandbytes as bnb | ||
import datasets | ||
import timm | ||
import torch | ||
import torch.nn.functional as F | ||
from torch.profiler import ProfilerActivity, profile | ||
from torch.utils.data import DataLoader | ||
from torchvision.transforms import v2 | ||
from tqdm import tqdm | ||
|
||
from torchao.prototype.optim_8bit import Adam8bit | ||
|
||
|
||
class CosineSchedule: | ||
def __init__(self, lr: float, total_steps: int, warmup: float = 0.05) -> None: | ||
self.lr = lr | ||
self.final_lr = 0 | ||
self.total_steps = total_steps | ||
self.warmup_steps = round(total_steps * warmup) | ||
|
||
def get_lr(self, step: int) -> float: | ||
if step < self.warmup_steps: | ||
return self.lr * step / self.warmup_steps | ||
if step < self.total_steps: | ||
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) | ||
return self.final_lr + 0.5 * (self.lr - self.final_lr) * (1 + math.cos(progress * math.pi)) | ||
return self.final_lr | ||
|
||
|
||
class WandbLogger: | ||
def __init__(self, args): | ||
if args.project is not None and not args.profile: | ||
import wandb | ||
|
||
Path("wandb_logs").mkdir(exist_ok=True) | ||
self.run = wandb.init(project=args.project, name=args.run_name, config=args, dir="wandb_logs") | ||
|
||
else: | ||
self.run = None | ||
|
||
def log(self, *args, **kwargs): | ||
if self.run is not None: | ||
self.run.log(*args, **kwargs) | ||
|
||
|
||
def get_parser(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model", required=True) | ||
|
||
parser.add_argument("--amp", default="none") | ||
parser.add_argument("--channels_last", action="store_true") | ||
parser.add_argument("--compile", action="store_true") | ||
|
||
parser.add_argument("--n_epochs", type=int, default=10) | ||
parser.add_argument("--batch_size", type=int, default=64) | ||
parser.add_argument("--n_workers", type=int, default=4) | ||
|
||
parser.add_argument("--optim", default="Adam") | ||
parser.add_argument("--lr", type=float, default=1e-4) | ||
parser.add_argument("--weight_decay", type=float, default=0) | ||
parser.add_argument("--cosine_lr_scheduler", action="store_true") | ||
|
||
parser.add_argument("--project") | ||
parser.add_argument("--run_name", default="debug") | ||
parser.add_argument("--profile", action="store_true") | ||
return parser | ||
|
||
|
||
def get_dloader(args, training: bool): | ||
transforms = [v2.ToImage()] | ||
|
||
if training: | ||
transforms.extend([v2.RandomResizedCrop(224), v2.RandomHorizontalFlip()]) | ||
else: | ||
transforms.extend([v2.Resize(256), v2.CenterCrop(224)]) | ||
|
||
transforms.append(v2.ToDtype(torch.float32, scale=True)) | ||
transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) | ||
transforms = v2.Compose(transforms) | ||
|
||
# use dataset from HF so download is fast | ||
ds = datasets.load_dataset("timm/resisc45", split="train" if training else "validation") | ||
ds = ds.select_columns(["image", "label"]) | ||
ds.set_transform(lambda x: dict(image=transforms(x["image"]), label=x["label"])) | ||
|
||
return DataLoader( | ||
ds, | ||
batch_size=args.batch_size, | ||
shuffle=training, | ||
num_workers=args.n_workers, | ||
pin_memory=training, | ||
drop_last=training, | ||
) | ||
|
||
|
||
def get_amp_ctx(amp): | ||
dtype = dict(bf16=torch.bfloat16, fp16=torch.float16, none=None)[amp] | ||
return torch.autocast("cuda", dtype=dtype, enabled=amp != "none") | ||
|
||
|
||
@torch.no_grad() | ||
def evaluate_model(model, args): | ||
model.eval() | ||
val_dloader = get_dloader(args, False) | ||
|
||
all_labels = [] | ||
all_preds = [] | ||
|
||
for batch in tqdm(val_dloader, dynamic_ncols=True, desc=f"Evaluating"): | ||
all_labels.append(batch["label"].clone()) | ||
if args.channels_last: | ||
batch["image"] = batch["image"].to(memory_format=torch.channels_last) | ||
|
||
with get_amp_ctx(args.amp): | ||
all_preds.append(model(batch["image"].cuda()).argmax(1).cpu()) | ||
|
||
all_labels = torch.cat(all_labels, dim=0) | ||
all_preds = torch.cat(all_preds, dim=0) | ||
|
||
acc = (all_labels == all_preds).float().mean() | ||
return acc | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_parser().parse_args() | ||
|
||
if args.profile: | ||
args.n_epochs = 1 | ||
|
||
for k, v in vars(args).items(): | ||
print(f"{k}: {v}") | ||
|
||
# wandb is only enabled when args.project is set and args.profile is False | ||
logger = WandbLogger(args) | ||
dloader = get_dloader(args, True) | ||
print(f"Train dataset: {len(dloader.dataset):,} images") | ||
|
||
model = timm.create_model(args.model, pretrained=True, num_classes=45).cuda() | ||
if args.channels_last: | ||
model.to(memory_format=torch.channels_last) | ||
if args.compile: | ||
model.compile(fullgraph=True) | ||
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") | ||
|
||
OPTIM_MAP = dict( | ||
Adam=torch.optim.Adam, | ||
Adam8bitBnb=bnb.optim.Adam8bit, | ||
Adam8bitAo=Adam8bit, | ||
) | ||
optim = OPTIM_MAP[args.optim](model.parameters(), args.lr, weight_decay=args.weight_decay) | ||
lr_schedule = CosineSchedule(args.lr, len(dloader) * args.n_epochs) | ||
|
||
grad_scaler = torch.amp.GradScaler("cuda", enabled=args.amp == "fp16") | ||
|
||
step = 0 | ||
for epoch_idx in range(args.n_epochs): | ||
model.train() | ||
prof = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) if args.profile else nullcontext() | ||
|
||
with prof: | ||
for batch in tqdm(dloader, dynamic_ncols=True, desc=f"Epoch {epoch_idx + 1}/{args.n_epochs}"): | ||
if args.channels_last: | ||
batch["image"] = batch["image"].to(memory_format=torch.channels_last) | ||
|
||
with get_amp_ctx(args.amp): | ||
loss = F.cross_entropy(model(batch["image"].cuda()), batch["label"].cuda()) | ||
grad_scaler.scale(loss).backward() | ||
|
||
if args.cosine_lr_scheduler: | ||
lr = lr_schedule.get_lr(step) | ||
for param_group in optim.param_groups: | ||
param_group["lr"] = lr | ||
|
||
if step % 100 == 0: | ||
logger.log(dict(loss=loss.item(), lr=optim.param_groups[0]["lr"]), step=step) | ||
|
||
grad_scaler.step(optim) | ||
grad_scaler.update() | ||
optim.zero_grad() | ||
|
||
step += 1 | ||
|
||
if args.profile and step == 20: | ||
break | ||
|
||
if args.profile: | ||
prof.export_chrome_trace("trace.json") | ||
|
||
else: | ||
val_acc = evaluate_model(model, args) | ||
print(f"Epoch {epoch_idx + 1}/{args.n_epochs}: val_acc={val_acc.item() * 100:.2f}") | ||
logger.log(dict(val_acc=val_acc), step=step) | ||
|
||
print(f"Max memory allocated: {torch.cuda.max_memory_allocated() / (1 << 30):.2f} GB") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import copy | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
instantiate_parametrized_tests, | ||
parametrize, | ||
run_tests, | ||
) | ||
from torchao.prototype import optim_8bit | ||
from torchao.prototype.optim_8bit.subclass import quantize_8bit_with_qmap, QMAP_SIGNED | ||
from torchao.utils import TORCH_VERSION_AFTER_2_3 | ||
|
||
try: | ||
import bitsandbytes as bnb | ||
except ImportError: | ||
bnb = None | ||
|
||
|
||
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) | ||
|
||
|
||
class TestDTQ8bit(TestCase): | ||
@parametrize("device", _DEVICES) | ||
def test_quantize_8bit_with_qmap_correctness(self, device): | ||
x = torch.randn(32, 1024, device=device) | ||
qmap = torch.tensor(QMAP_SIGNED, device=device) | ||
|
||
actual_codes, actual_scale = quantize_8bit_with_qmap(x, qmap, 256, implementation=1) | ||
expected_codes, expected_scale = quantize_8bit_with_qmap(x, qmap, 256, implementation=0) | ||
|
||
torch.testing.assert_close(actual_codes, expected_codes) | ||
torch.testing.assert_close(actual_scale, expected_scale) | ||
|
||
@parametrize("device", _DEVICES) | ||
def test_quantize_8bit_with_qmap_compile(self, device): | ||
x = torch.randn(32, 1024, device=device) | ||
qmap = torch.tensor(QMAP_SIGNED, device=device) | ||
|
||
actual_codes, actual_scale = torch.compile(quantize_8bit_with_qmap, fullgraph=True)(x, qmap, 256) | ||
expected_codes, expected_scale = quantize_8bit_with_qmap(x, qmap, 256) | ||
|
||
torch.testing.assert_close(actual_codes, expected_codes) | ||
torch.testing.assert_close(actual_scale, expected_scale) | ||
|
||
|
||
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle") | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA") | ||
@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3") | ||
class TestOptim8bit(TestCase): | ||
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"]) | ||
def test_adam_8bit_correctness(self, optim_name): | ||
device = "cuda" | ||
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) | ||
model2 = copy.deepcopy(model1) | ||
|
||
optim1 = getattr(bnb.optim, optim_name)(model1.parameters()) | ||
optim2 = getattr(optim_8bit, optim_name)(model2.parameters()) | ||
|
||
for _ in range(2): | ||
x = torch.randn(4, 32, device=device) | ||
|
||
loss1 = model1(x).sum() | ||
loss1.backward() | ||
optim1.step() | ||
optim1.zero_grad() | ||
|
||
loss2 = model2(x).sum() | ||
loss2.backward() | ||
optim2.step() | ||
optim2.zero_grad() | ||
|
||
for p1, p2 in zip(model1.parameters(), model2.parameters()): | ||
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5) | ||
|
||
|
||
instantiate_parametrized_tests(TestDTQ8bit) | ||
instantiate_parametrized_tests(TestOptim8bit) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# 8-bit optimizers | ||
|
||
This folder implements 8-bit optimizers using dynamic tree quantization as outlined in https://arxiv.org/abs/2110.02861. The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel. | ||
|
||
## Usage | ||
|
||
This is a drop-in replacement for `torch.optim.Adam` | ||
|
||
```python | ||
from torchao.prototype.optim_8bit import Adam8bit | ||
|
||
model = ... | ||
optim = Adam8bit(model.parameters()) | ||
``` | ||
|
||
You can also change quantization block size (default 2048) by passing `block_size=value` to the optimizer. | ||
|
||
**Other optimizers**: AdamW is also available as `AdamW8bit`. | ||
|
||
NOTE: this requires PyTorch >= 2.3 | ||
|
||
## Benchmarks | ||
|
||
Benchmark script for fine-tuning a [timm](https://github.com/huggingface/pytorch-image-models) model on [resisc45](https://huggingface.co/datasets/timm/resisc45) dataset is available at [benchmarks/benchmark_adam_8bit.py](../../../benchmarks/benchmark_adam_8bit.py). | ||
|
||
Results for fine-tuning ViT-B with BF16 AMP, on 4070Ti SUPER: | ||
|
||
Adam impl | max memory (GB) | training time | accuracy | ||
----------|-----------------|---------------|---------- | ||
PyTorch | 5.26 | 9m 11s | 93.62% | ||
bnb 8-bit | 4.78 | 9m 10s | 93.06% | ||
ao 8-bit | 4.78 | 9m 15s | 94.14% | ||
|
||
**Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed | ||
|
||
## Credits | ||
|
||
Credits to Tim Dettmers for creating the wonderful bitsandbytes library. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .adam import Adam8bit | ||
from .adamw import AdamW8bit |
Oops, something went wrong.