Skip to content

Commit

Permalink
[Bugfix] Add fully sharded layer for QKVParallelLinearWithLora (#5665)
Browse files Browse the repository at this point in the history
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
  • Loading branch information
jeejeelee and Yard1 authored Jun 21, 2024
1 parent c35e4a3 commit 67005a0
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 26 deletions.
14 changes: 9 additions & 5 deletions tests/lora/test_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def test_baichuan_lora(baichuan_lora_files):


@pytest.mark.skip("Requires multiple GPUs")
def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
Expand All @@ -75,7 +76,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=1,
trust_remote_code=True)
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)

del llm_tp1
Expand All @@ -87,7 +89,8 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
trust_remote_code=True)
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)

del llm_tp2
Expand All @@ -101,10 +104,11 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True)
trust_remote_code=True,
fully_sharded_loras=fully_sharded)
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)

del llm_tp4
cleanup()

assert output_tp1 == output_tp4
assert output_tp1 == output_tp4
7 changes: 5 additions & 2 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
Expand Down Expand Up @@ -684,7 +685,9 @@ def create_column_parallel_packed_layer():
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = QKVParallelLinearWithLora(linear)
lora_linear = QKVParallelLinearWithLora(
linear
) if not fully_shard else QKVParallelLinearWithShardedLora(linear)

@dataclass
class FakeConfig:
Expand Down
58 changes: 55 additions & 3 deletions vllm/lora/fully_sharded_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA)
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level

Expand Down Expand Up @@ -90,11 +91,11 @@ def can_replace_layer(cls, source_layer: nn.Module,
def _mcp_apply(x, bias, layer):
"""
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
MergedQKVParallelLinearWithShardedLora share the same
LoRa weight application method.
The main difference is the step by shard_size for lora_b which can
vary for QKVParallelLinearWithShardedLora but is constant for
vary for MergedQKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
Expand Down Expand Up @@ -167,14 +168,65 @@ def can_replace_layer(cls, source_layer: nn.Module,
)


class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
"""
Differs from QKVParallelLinearWithLora by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""

def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a

def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)

bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
bgmv(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
# now have column partitioned output

output = output.view(*out_orig_shape)
return output

@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)


class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
"""
Differs from MergedQKVParallelLinearWithLora by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""

def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
Expand Down
36 changes: 21 additions & 15 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,24 @@ def __init__(self, base_layer: QKVParallelLinear) -> None:
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
self.base_layer.head_size)

def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset +
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset +
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
return lora_b

def set_lora(
self,
index: int,
Expand All @@ -650,21 +668,8 @@ def set_lora(
):
self.reset_lora(index)
if self.tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)

self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
Expand All @@ -674,6 +679,7 @@ def set_lora(
lora_b.T, non_blocking=True)

@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion vllm/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA)
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
Expand All @@ -35,6 +36,7 @@
RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLora,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora,
RowParallelLinearWithShardedLoRA,
Expand Down

0 comments on commit 67005a0

Please sign in to comment.