Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix internlm2 tensor parallel inference #8055

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions vllm/model_executor/models/internlm2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
Expand All @@ -7,7 +8,10 @@

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
Expand Down Expand Up @@ -70,20 +74,21 @@ def __init__(
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
assert self.total_num_heads % self.tp_size == 0
self.num_heads = self.total_num_heads // self.tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
if self.total_num_kv_heads >= self.tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
assert self.total_num_kv_heads % self.tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
assert self.tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
Expand Down Expand Up @@ -122,11 +127,27 @@ def __init__(
quant_config=quant_config)

def split_qkv(self, qkv: torch.Tensor):
qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128)
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2)
q = q.reshape(-1, self.q_size)
k = k.reshape(-1, self.kv_size)
v = v.reshape(-1, self.kv_size)
seq_len = qkv.shape[0]
if self.tp_size > 1:
qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
qkv = tensor_model_parallel_all_gather(qkv)
qkv = torch.split(qkv, qkv_map, dim=-1)
qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
qkv = torch.cat(qkv, dim=-1)

qkv = qkv.view(seq_len, self.total_num_kv_heads,
self.key_value_groups + 2, self.head_dim)
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
q = q.reshape(seq_len, self.q_size * self.tp_size)
k = k.reshape(seq_len, self.kv_size * self.tp_size)
v = v.reshape(seq_len, self.kv_size * self.tp_size)

if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]
return q, k, v

def forward(
Expand Down
Loading