Skip to content

Commit

Permalink
update lm head; try test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Sep 3, 2024
1 parent 3043b80 commit e666b22
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
7 changes: 6 additions & 1 deletion tests/weight_loading/test_weight_loading.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

import torch

MAX_MODEL_LEN = 1024
MODEL_NAME = os.environ.get("MODEL_NAME",
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
Expand All @@ -8,9 +10,12 @@


def test_weight_loading(vllm_runner):
"""
Test parameter weight loading with tp>1.
"""
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype="auto",
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
quantization=QUANTIZATION,
max_model_len=MAX_MODEL_LEN,
tensor_parallel_size=2) as model:
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs

DEFAULT_VOCAB_PADDING_SIZE = 64
Expand Down Expand Up @@ -370,10 +371,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If param packed on the same dim we are sharding on, then
# need to adjust offsets of loaded weight by pack_factor.
if packed_dim is not None and packed_dim == output_dim:
packed_factor = param.packed_factor if isinstance(
param, BasevLLMParameter) else param.pack_factor
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
param.pack_factor)
start_idx = start_idx // param.pack_factor
shard_size = shard_size // param.pack_factor
param.packed_factor)
start_idx = start_idx // packed_factor
shard_size = shard_size // packed_factor
else:
assert loaded_weight.shape[output_dim] == self.org_vocab_size

Expand Down

0 comments on commit e666b22

Please sign in to comment.