Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve FSDP support for low-bit optimizers #538

Merged
merged 7 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,15 @@ def _test_fsdp2(self, optim_cls):
base_optim.step()
self.assertEqual(fsdp_loss, base_loss)

base_param = base_optim.param_groups[0]["params"][0]
base_exp_avg = base_optim.state[base_param]["exp_avg"]

fsdp_param = fsdp_optim.param_groups[0]["params"][0]
fsdp_exp_avg = fsdp_optim.state[fsdp_param]["exp_avg"]
full_fsdp_exp_avg = fsdp_exp_avg.full_tensor()

self.assertEqual(base_exp_avg.dequantize(), full_fsdp_exp_avg.dequantize())


instantiate_parametrized_tests(TestQuantize)
instantiate_parametrized_tests(TestOptim)
Expand Down
10 changes: 5 additions & 5 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = torch.empty_like(p)
out._local_tensor = self._subclass_zeros(
out._local_tensor,
signed,
self.block_size,
out = DTensor.from_local(
local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
)
else:
out = self._subclass_zeros(p, signed, self.block_size)
Expand Down
10 changes: 5 additions & 5 deletions torchao/prototype/low_bit_optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = torch.empty_like(p)
out._local_tensor = self._subclass_zeros(
out._local_tensor,
signed,
self.block_size,
out = DTensor.from_local(
local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
)
else:
out = self._subclass_zeros(p, signed, self.block_size)
Expand Down
57 changes: 49 additions & 8 deletions torchao/prototype/low_bit_optim/subclass_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@


aten = torch.ops.aten

c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional

# https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml
# NOTE: power-1 is linear
Expand All @@ -31,17 +32,29 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape
)

def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape):
"""Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507

Args
codes: quantized and packed 4-bit data stored as uint8.
scale: scale data for block-wise quantization.
qmap: lookup table that maps between quantized value (code) and float value.
signed: whether the tensor is signed or unsigned.
shape: shape of original float tensor.

NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size).
Thus, the last dimension of the original float tensor is not necessarily divisible by block size.
Given `codes` and `scale`, `block_size` is calculated as `codes.numel() * 2 // scale.numel()`.
The extra `* 2` is because `codes` is 4-bit data packed in 8-bit storage.
"""
assert codes.dtype is torch.uint8
assert codes.ndim == 1 # flattened buffer
assert scale.ndim == 1
self.codes = codes
self.scale = scale
self.qmap = qmap
self.signed = signed
self._shape = shape

@property
def block_size(self):
return self.codes.numel() * 2 // self.scale.numel()
self.block_size = codes.numel() * 2 // scale.numel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious q: Is there some description of the codes/ scales tensor and their relation to each other?

I can see the pattern that codes has .5x (4bit) and 1x (8bit) the bsize * scale numels
But does this assert square blocks?
I think some description here would be helpful

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add some description. Basically for 8-bit and FP8, codes has the same shape as the "outer shape", while for 4-bit, since there is bit-packing, I find that it's easier to let codes be a flattened 1D buffer and keep track of the shape manually.
To get the scale, the float tensor is actually flattened first and reshape to (-1, block_size). This is done to relax the requirement that the last dimension must be divisible by block_size -> now we only need numel (total size) to be divisible by block_size. This is especially needed when block size is large (8-bit optim uses block_size=2048 as done in bnb). Since optim update is element-wise, we don't really need to care if the original tensor is 1D, 2D, or n-D (well, maybe there is some structure in n-D tensor that flattening it might not be so wise). I believe the original implementation in bnb does this as well.
-> scale is always a 1D tensor, with size=original_tensor.numel() // block_size

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg Added some docs. Lmk if it is still unclear.


def __tensor_flatten__(self):
return self.tensor_attrs, [self.signed, self._shape]
Expand Down Expand Up @@ -113,9 +126,37 @@ def _(func, *args, **kwargs):
return func(*args, **kwargs)


# this is needed for DTensor.from_local() and for flattening tensor
@OptimState4bit.implements(aten.view.default)
def _(func, *args, **kwargs):
x, shape = args
if len(shape) > 1 or shape[0] != -1:
raise ValueError(f"{x.__class__.__name__} only supports .view() with shape=[-1]")
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),))

if tuple(x.shape) == tuple(shape):
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, x._shape)

if len(shape) == 1 and shape[0] == -1:
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),))

raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]")


# this is needed for DTensor.full_tensor()
@OptimState4bit.implements([
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, *args, **kwargs):
x = args[0]
if not isinstance(x, OptimState4bit):
raise ValueError(f"expecting a OptimState4bit but found {type(x)}")

codes = func(x.codes, *args[1:], **kwargs)
scale = func(x.scale, *args[1:], **kwargs)

# adjust the first dim
shape = (x._shape[0] * codes.numel() // x.codes.numel(),) + x._shape[1:]

# assume tensors from all ranks have the same signedness
return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape)
51 changes: 44 additions & 7 deletions torchao/prototype/low_bit_optim/subclass_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@


aten = torch.ops.aten
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional

QMAP_SIGNED = create_dynamic_map(signed=True)
QMAP_UNSIGNED = create_dynamic_map(signed=False)


# dynamic tree quantization
# https://arxiv.org/pdf/1511.04561
# https://arxiv.org/abs/2110.02861
class OptimState8bit(Tensor):
implements = classmethod(_implements)
tensor_attrs = ["codes", "scale", "qmap"]
Expand All @@ -28,15 +27,25 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool):
)

def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool):
"""Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861

Args
codes: quantized 8-bit data stored as uint8. Has the same shape as the original float tensor.
scale: scale data for block-wise quantization.
qmap: lookup table that maps between quantized value (code) and float value.
signed: whether the tensor is signed or unsigned.

NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size).
Thus, the last dimension of the original float tensor is not necessarily divisible by block size.
Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`.
"""
assert codes.dtype is torch.uint8
assert scale.ndim == 1
self.codes = codes
self.scale = scale
self.qmap = qmap
self.signed = signed

@property
def block_size(self):
return self.codes.numel() // self.scale.numel()
self.block_size = codes.numel() // scale.numel()

def __tensor_flatten__(self):
return self.tensor_attrs, [self.signed]
Expand Down Expand Up @@ -97,3 +106,31 @@ def _(func, *args, **kwargs):
def _(func, *args, **kwargs):
args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local()
@OptimState8bit.implements(aten.view.default)
def _(func, *args, **kwargs):
x, shape = args
return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed)


# this is needed for DTensor.full_tensor()
@OptimState8bit.implements([
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, *args, **kwargs):
x = args[0]
if not isinstance(x, OptimState8bit):
raise ValueError(f"expecting a OptimState8bit but found {type(x)}")

# assume tensors from all ranks have the same signedness
return OptimState8bit(
func(x.codes, *args[1:], **kwargs),
func(x.scale, *args[1:], **kwargs),
x.qmap.clone(),
x.signed,
)
45 changes: 41 additions & 4 deletions torchao/prototype/low_bit_optim/subclass_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@


aten = torch.ops.aten
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional

DTYPE = torch.float8_e4m3fn


Expand Down Expand Up @@ -32,13 +35,21 @@ def __new__(cls, codes: Tensor, scale: Tensor):
)

def __init__(self, codes: Tensor, scale: Tensor):
"""Create quantized FP8 optimizer state.

Args
codes: quantized FP8 E4M3FN data. Has the same shape as the original float tensor.
scale: scale data for block-wise quantization.

NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size).
Thus, the last dimension of the original float tensor is not necessarily divisible by block size.
Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`.
"""
assert codes.dtype is DTYPE
assert scale.ndim == 1
self.codes = codes
self.scale = scale

@property
def block_size(self):
return self.codes.numel() // self.scale.numel()
self.block_size = codes.numel() // scale.numel()

def __tensor_flatten__(self):
return self.tensor_attrs, []
Expand Down Expand Up @@ -99,3 +110,29 @@ def _(func, *args, **kwargs):
def _(func, *args, **kwargs):
args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local()
@OptimStateFp8.implements(aten.view.default)
def _(func, *args, **kwargs):
x, shape = args
return OptimStateFp8(x.codes.view(shape), x.scale)


# this is needed for DTensor.full_tensor()
@OptimStateFp8.implements([
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, *args, **kwargs):
x = args[0]
if not isinstance(x, OptimStateFp8):
raise ValueError(f"expecting a OptimStateFp8 but found {type(x)}")

# assume tensors from all ranks have the same signedness
return OptimStateFp8(
func(x.codes, *args[1:], **kwargs),
func(x.scale, *args[1:], **kwargs),
)
Loading