Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FP8 Adam #482

Merged
merged 14 commits into from
Jul 7, 2024
7 changes: 4 additions & 3 deletions benchmarks/benchmark_low_bit_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@
from torchvision.transforms import v2
from tqdm import tqdm

from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
from torchao.prototype import low_bit_optim

# lpmm doesn't have Adam, only AdamW
OPTIM_MAP = dict(
Adam=torch.optim.Adam,
Adam8bitBnb=bnb.optim.Adam8bit,
Adam8bitAo=Adam8bit,
Adam8bitAo=low_bit_optim.Adam8bit,
AdamFp8Ao=low_bit_optim.AdamFp8,
Adam4bitLpmm=partial(lpmm.optim.AdamW, weight_decay=0, fused=True),
Adam4bitAo=Adam4bit,
Adam4bitAo=low_bit_optim.Adam4bit,
Adam4bitRank1Lpmm=partial(lpmm.optim.AdamW, weight_decay=0, qconfig=argparse.Namespace(scale_type="rank1")),
)

Expand Down
16 changes: 16 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,22 @@ def test_optim_4bit_correctness(self, optim_name):
for p1, p2 in zip(model1.parameters(), model2.parameters()):
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)

@pytest.mark.xfail(not TORCH_VERSION_AFTER_2_3, reason="torch.compile() fails for PyTorch < 2.3")
@parametrize("optim_name", ["AdamFp8", "AdamWFp8"])
@parametrize("device", _DEVICES)
def test_optim_fp8_smoke(self, optim_name, device):
if device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
pytest.skip("FP8 requires compute capability >= 8.9")

model = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
optim = getattr(low_bit_optim, optim_name)(model.parameters())

x = torch.randn(4, 32, device=device)
loss = model(x).sum()
loss.backward()
optim.step()
optim.zero_grad()


instantiate_parametrized_tests(TestQuantize)
instantiate_parametrized_tests(TestOptim)
Expand Down
8 changes: 5 additions & 3 deletions torchao/prototype/low_bit_optim/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This folder implements:

- 8-bit optimizers as outlined in https://arxiv.org/abs/2110.02861
- 4-bit optimizers as outlined in https://arxiv.org/abs/2309.01507
- FP8 optimizers using the native `torch.float8_e4m3fn` dtype (experimental)

The implementation is fully done in Python (with tensor subclass) and relies on `torch.compile()` to generate efficient fused kernel.

Expand All @@ -18,12 +19,12 @@ model = ...
optim = Adam8bit(model.parameters())
```

To use 4-bit Adam, replace the above with `Adam4bit`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit optimizers, and 128 for 4-bit optimizers.
To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. You can also change quantization block size by passing `block_size=value` to the optimizer. By default, block size is 2048 for 8-bit and FP8 optimizers, and 128 for 4-bit optimizers.

**Other optimizers**: AdamW is also available as `AdamW8bit` and `AdamW4bit`. Other optimizers can be added based on demand.
**Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand.

NOTE:
- The low-bit optimizers require PyTorch >= 2.3
- The low-bit optimizers require PyTorch >= 2.3. FP8 optimizers require CUDA compute capability >= 8.9.
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
- **Known issue**: When learning rate is updated every step (e.g. using cosine learning rate scheduler), training speed is slower. This is because we have to convert learning rate to a CUDA tensor (which incurs expensive memory transfer cost), since torch.compile() will treat a Python float as a constant and trigger recompile whenever the value is changed.

Expand All @@ -38,6 +39,7 @@ Adam impl | max memory (GB) | time taken for 2nd epoch | accuracy
PyTorch | 12.94 | 8m 18s | 91.14
bnb 8-bit | 8.31 | 6m 50s | 90.67
ao 8-bit | 8.32 | 9m 04s | 90.71
ao FP8 E4M3 | 8.32 | 6m 38s | 91.08
lpmm 4-bit | 7.72 | 5m 59s | 89.97
ao 4-bit | 7.72 | 7m 00s | 89.94
lpmm 4-bit (*) | 7.73 | 11m 10s | 89.71
Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/low_bit_optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .adam import Adam8bit, Adam4bit
from .adamw import AdamW8bit, AdamW4bit
from .adam import Adam8bit, Adam4bit, AdamFp8
from .adamw import AdamW8bit, AdamW4bit, AdamWFp8
20 changes: 20 additions & 0 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .subclass_8bit import maybe_new_8bit_zero_buffer
from .subclass_4bit import maybe_new_4bit_zero_buffer
from .subclass_fp8 import maybe_new_fp8_zero_buffer


class _Adam(Optimizer):
Expand Down Expand Up @@ -155,3 +156,22 @@ def __init__(
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size)

_new_buffer = staticmethod(maybe_new_4bit_zero_buffer)


class AdamFp8(_Adam):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
amsgrad=False,
*,
block_size=2048
) -> None:
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size)

@staticmethod
def _new_buffer(p: Tensor, signed: bool, block_size: int):
return maybe_new_fp8_zero_buffer(p, block_size)
20 changes: 20 additions & 0 deletions torchao/prototype/low_bit_optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from .subclass_8bit import maybe_new_8bit_zero_buffer
from .subclass_4bit import maybe_new_4bit_zero_buffer
from .subclass_fp8 import maybe_new_fp8_zero_buffer


class _AdamW(Optimizer):
Expand Down Expand Up @@ -154,3 +155,22 @@ def __init__(
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size)

_new_buffer = staticmethod(maybe_new_4bit_zero_buffer)


class AdamWFp8(_AdamW):
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
amsgrad=False,
*,
block_size=2048
) -> None:
super().__init__(params, lr, betas, eps, weight_decay, amsgrad, block_size=block_size)

@staticmethod
def _new_buffer(p: Tensor, signed: bool, block_size: int):
return maybe_new_fp8_zero_buffer(p, block_size)
106 changes: 106 additions & 0 deletions torchao/prototype/low_bit_optim/subclass_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import torch
from torch import Tensor
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE


aten = torch.ops.aten
DTYPE = torch.float8_e4m3fn


def quantize_fp8(input: Tensor, block_size: int):
input = input.view(-1, block_size)
scale = input.abs().amax(-1).clip(1e-12) / torch.finfo(DTYPE).max
input = input / scale.view(-1, 1)
codes = input.to(DTYPE).view(-1)
return codes, scale


class OptimStateFp8(Tensor):
implements = classmethod(_implements)
tensor_attrs = ["codes", "scale"]

@staticmethod
def __new__(cls, codes: Tensor, scale: Tensor):
return Tensor._make_wrapper_subclass(
cls,
codes.shape,
device=codes.device,
requires_grad=False,
)

def __init__(self, codes: Tensor, scale: Tensor):
assert codes.dtype is DTYPE
self.codes = codes
self.scale = scale

@property
def block_size(self):
return self.codes.numel() // self.scale.numel()

def __tensor_flatten__(self):
return self.tensor_attrs, []

@classmethod
def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
return cls(*[tensor_data_dict[name] for name in cls.tensor_attrs], *tensor_attributes)

def dequantize(self, output_dtype=None):
float_data = self.codes.float()
float_data = float_data.view(-1, self.block_size) * self.scale.view(-1, 1)

dtype = output_dtype or torch.get_default_dtype()
return float_data.view(self.codes.shape).to(dtype)

@classmethod
def zeros(cls, shape, block_size: int = 2048, device=None):
codes = torch.zeros(shape, dtype=DTYPE, device=device)
scale = torch.zeros(codes.numel() // block_size, device=device)
return cls(codes, scale)

def __repr__(self):
return (
f"{self.__class__.__name__}(block_size={self.block_size}, "
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)

raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported")


@OptimStateFp8.implements(aten.copy_.default)
def _(func, *args, **kwargs):
dst = args[0]
src = args[1]

if isinstance(dst, OptimStateFp8) and isinstance(src, OptimStateFp8):
assert dst.block_size == src.block_size
dst.codes.copy_(src.codes)
dst.scale.copy_(src.scale)

elif isinstance(dst, OptimStateFp8):
codes, scale = quantize_fp8(src, dst.block_size)
dst.codes.copy_(codes)
dst.scale.copy_(scale)

else:
dst.copy_(src.dequantize())

return dst


@OptimStateFp8.implements(aten.lerp.Scalar)
def _(func, *args, **kwargs):
args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args]
return func(*args, **kwargs)


def maybe_new_fp8_zero_buffer(p: Tensor, block_size: int = 2048):
if p.numel() >= 4096 and p.numel() % block_size == 0:
out = OptimStateFp8.zeros(p.shape, block_size, device=p.device)
else:
out = torch.zeros_like(p)
return out
Loading