Skip to content

Commit

Permalink
[low-bit optim] Upcast everything to FP32 for internal calculations (#…
Browse files Browse the repository at this point in the history
…1068)

* fix dtype

* Update regression_test.yml

---------

Co-authored-by: Mark Saroufim <marksaroufim@gmail.com>
  • Loading branch information
gau-nernst and msaroufim authored Oct 14, 2024
1 parent e7b33bc commit afc0a02
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:

uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
timeout: 60
timeout: 120
runner: ${{ matrix.runs-on }}
gpu-arch-type: ${{ matrix.gpu-arch-type }}
gpu-arch-version: ${{ matrix.gpu-arch-version }}
Expand Down
29 changes: 15 additions & 14 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ def step(self, closure=None):

# this will work with any optim state tensor subclass that implements aten.lerp.Scalar and aten.copy_.default
# and param tensor subclass that implements aten.add_.Tensor, and aten.addcdiv_.default
# NOTE: right now all of our optimizer state subclasses will dequant to FP32, thus adam computation
# will be done in FP32 (not purposely). we should explicitly cast all inputs to FP32 to ensure FP32
# computation. will need to benchmark to ensure no slowdown.
def single_param_adam(
p: Tensor,
grad: Tensor,
Expand All @@ -126,32 +123,36 @@ def single_param_adam(
eps: float,
is_adamw: bool,
):
# compute in FP32 for accurate calculations
p_f32 = p.float()
grad_f32 = grad.float()

if not is_adamw:
grad = grad.add(p, alpha=weight_decay)
grad_f32 = grad_f32.add(p_f32, alpha=weight_decay)

bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step

# keep high precision copy for param update
new_exp_avg = exp_avg.lerp(grad, 1 - beta1)
new_exp_avg_sq = exp_avg_sq.lerp(grad.square(), 1 - beta2)
exp_avg_f32 = exp_avg.float().lerp(grad_f32, 1 - beta1)
exp_avg_sq_f32 = exp_avg_sq.float().lerp(grad_f32.square(), 1 - beta2)

exp_avg.copy_(new_exp_avg)
exp_avg_sq.copy_(new_exp_avg_sq)
exp_avg.copy_(exp_avg_f32)
exp_avg_sq.copy_(exp_avg_sq_f32)

if max_exp_avg_sq is not None:
new_max_exp_avg_sq = torch.maximum(max_exp_avg_sq, new_exp_avg_sq)
max_exp_avg_sq.copy_(new_max_exp_avg_sq)
denom = (new_max_exp_avg_sq.sqrt() / bias_correction2.sqrt()).add_(eps)
max_exp_avg_sq_f32 = torch.maximum(max_exp_avg_sq.float(), exp_avg_sq_f32)
max_exp_avg_sq.copy_(max_exp_avg_sq_f32)
denom = (max_exp_avg_sq_f32.sqrt() / bias_correction2.sqrt()).add_(eps)
else:
denom = (new_exp_avg_sq.sqrt() / bias_correction2.sqrt()).add_(eps)
denom = (exp_avg_sq_f32.sqrt() / bias_correction2.sqrt()).add_(eps)

step_size = lr / bias_correction1
if is_adamw:
# merge weight decay and param update in a single .add_() to make this work with quantized param
p.add_(-lr * weight_decay * p - step_size * new_exp_avg / denom)
p.add_(-lr * weight_decay * p_f32 - step_size * exp_avg_f32 / denom)
else:
p.addcdiv_(new_exp_avg, denom, value=-step_size)
p.addcdiv_(exp_avg_f32, denom, value=-step_size)


class Adam8bit(_AdamBase):
Expand Down

0 comments on commit afc0a02

Please sign in to comment.