Skip to content

Commit

Permalink
[Model] LoRA support added for command-r (#5178)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergey-tinkoff authored Jun 18, 2024
1 parent 19091ef commit 07feecd
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 6 deletions.
6 changes: 6 additions & 0 deletions csrc/punica/bgmv/bgmv_config.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 43264) \
f(in_T, out_T, W_T, narrow, 49152) \
f(in_T, out_T, W_T, narrow, 60544) \
f(in_T, out_T, W_T, narrow, 60672) \
f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64512) \
Expand All @@ -78,6 +80,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128512) \


// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py

Expand Down Expand Up @@ -144,6 +148,8 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 60544, narrow) \
f(in_T, out_T, W_T, 60672, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
Expand Down
2 changes: 2 additions & 0 deletions tests/lora/test_punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def _lora_ref_impl(
36864,
43264,
49152,
60544,
60672,
64000,
64256,
102400,
Expand Down
48 changes: 42 additions & 6 deletions vllm/model_executor/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from transformers import CohereConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
Expand Down Expand Up @@ -265,10 +265,14 @@ def __init__(
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
Expand Down Expand Up @@ -302,18 +306,44 @@ def forward(

class CohereForCausalLM(nn.Module):

packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
]
embedding_modules = {"embed_tokens": "input_embeddings"}
embedding_padding_modules = []

def __init__(
self,
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config.vocab_size,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
scale=config.logit_scale)
self.model = CohereModel(config, cache_config, quant_config)
self.model = CohereModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.sampler = Sampler()

@torch.no_grad()
Expand All @@ -330,8 +360,14 @@ def forward(

def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.embed_tokens.weight,
hidden_states, sampling_metadata)
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
if is_not_lora:
embedding_weights = self.model.embed_tokens.weight
else:
embedding_weights = self.model.embed_tokens.base_layer.weight

logits = self.logits_processor(embedding_weights, hidden_states,
sampling_metadata)
return logits

def sample(
Expand Down

0 comments on commit 07feecd

Please sign in to comment.