Skip to content

Commit

Permalink
add FSDP QLoRA test and revert failing PR (#403)
Browse files Browse the repository at this point in the history
* add FSDP QLoRA test and revert failing PR

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* check pytorch version and cuda for ci

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* revert linter

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
weifengpy authored Jun 21, 2024
1 parent dd35079 commit 2eb08be
Show file tree
Hide file tree
Showing 2 changed files with 446 additions and 250 deletions.
174 changes: 173 additions & 1 deletion test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import copy
import logging
import unittest
from packaging import version
import math

import pytest
import torch
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
CheckpointWrapper,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
Expand Down Expand Up @@ -431,6 +439,170 @@ def test_to_cpu(self):
inner_tensor = getattr(nf4_tensor, attr)
self.assertEqual(inner_tensor.device.type, "cpu")

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@parametrize("input_size", [512 * 512, (512 * 512,), (512, 512)])
def test_tensor_deepcopy(self, input_size: Union[Tuple[int], int]):
nf4_orig = to_nf4(torch.randn(input_size, device="cuda"))
nf4_clone = copy.deepcopy(nf4_orig)
self.assertEqual(
nf4_clone.get_original_weight(), nf4_orig.get_original_weight()
)


class LoRALinear(nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
weight: torch.Tensor,
rank: int,
alpha: float,
dropout: float = 0.0,
):
super().__init__()
self.in_dim = in_dim
self.rank = rank
self.alpha = alpha
self.out_dim = out_dim
self.register_parameter("weight", nn.Parameter(to_nf4(weight)))
self.dropout = nn.Dropout(p=dropout)
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)
self.initialize_parameters()

def initialize_parameters(self):
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lora_b.weight, a=math.sqrt(5))

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = linear_nf4(input=x, weight=self.weight)
lora_out = self.lora_a(self.dropout(x))
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)
return out + lora_out


class TestQLoRA(FSDPTest):
@property
def world_size(self) -> int:
return 2

@pytest.mark.skipif(
version.parse(torch.__version__).base_version < "2.4.0",
reason="torch >= 2.4 required",
)
@skip_if_lt_x_gpu(2)
def test_qlora_fsdp2(self):
from torch.distributed._composable.fsdp import CPUOffloadPolicy, OffloadPolicy

self.run_subtests(
{
"enable_activation_checkpointing": [False, True],
"offload_policy": [
OffloadPolicy(),
CPUOffloadPolicy(pin_memory=True),
CPUOffloadPolicy(pin_memory=False),
],
},
self._test_qlora_fsdp2,
)

def _test_qlora_fsdp2(
self,
enable_activation_checkpointing: bool,
offload_policy: "OffloadPolicy",
):
from torch.distributed._composable.fsdp import fully_shard
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
TransformerBlock,
)

batch_size = 3
lora_r = 8
lora_alpha = 16
vocab_size = 1024
seq_len = 64
model_args = ModelArgs(
n_layers=3,
n_heads=4,
dim=1024,
vocab_size=vocab_size,
max_seq_len=seq_len,
dropout_p=0,
)
torch.manual_seed(42)
with torch.device("cuda"):
base_model = Transformer(model_args)
for layer in base_model.layers:
# attention with lora adapters
for attr in ["wq", "wk", "wv", "wo"]:
orig_linear = getattr(layer.attention, attr)
setattr(
layer.attention,
attr,
LoRALinear(
orig_linear.weight.shape[1],
orig_linear.weight.shape[0],
orig_linear.weight,
lora_r,
lora_alpha,
),
)
for attr in ["w1", "w2"]:
orig_linear = getattr(layer.feed_forward, attr)
setattr(
layer.feed_forward,
attr,
LoRALinear(
orig_linear.weight.shape[1],
orig_linear.weight.shape[0],
orig_linear.weight,
lora_r,
lora_alpha,
),
)
for name, param in base_model.named_parameters():
param.requires_grad_(
name.endswith("lora_a.weight") or name.endswith("lora_b.weight")
)
if enable_activation_checkpointing:
apply_activation_checkpointing(
base_model, auto_wrap_policy=ModuleWrapPolicy({TransformerBlock})
)
base_optim = torch.optim.AdamW(base_model.parameters(), lr=1e-2)

fsdp_kwargs = {"offload_policy": offload_policy}
fsdp_model = copy.deepcopy(base_model)
for m in fsdp_model.modules():
if enable_activation_checkpointing:
if isinstance(m, CheckpointWrapper):
fully_shard(m, **fsdp_kwargs)
else:
if isinstance(m, TransformerBlock):
fully_shard(m, **fsdp_kwargs)
fully_shard(fsdp_model, **fsdp_kwargs)
fsdp_optim = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-2)

torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(5):
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = fsdp_model(inp).sum()
fsdp_loss.backward()
fsdp_optim.step()

base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
base_loss = base_model(inp).sum()
base_loss.backward()
for param in base_model.parameters():
if param.grad is not None:
torch.distributed.all_reduce(
param.grad, op=torch.distributed.ReduceOp.AVG
)
base_optim.step()
self.assertEqual(fsdp_loss, base_loss)


instantiate_parametrized_tests(TestNF4Linear)
instantiate_parametrized_tests(TestFSDPOps)
Expand Down
Loading

0 comments on commit 2eb08be

Please sign in to comment.