diff --git a/litgpt/adapter.py b/litgpt/adapter.py index 628217b61c..bef77ece1b 100644 --- a/litgpt/adapter.py +++ b/litgpt/adapter.py @@ -151,7 +151,7 @@ def scaled_dot_product_attention( ak, av = self.adapter_kv_cache else: prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd) - aqkv = self.attn(prefix) + aqkv = self.qkv(prefix) q_per_kv = self.config.n_head // self.config.n_query_groups aqkv = aqkv.view(1, aT, self.config.n_query_groups, q_per_kv + 2, self.config.head_size) aqkv = aqkv.permute(0, 2, 3, 1, 4) diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index 7c94a8d630..9b975260f0 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -21,6 +21,7 @@ from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention from litgpt.adapter import Config as BaseConfig from litgpt.model import KVCache +from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble from litgpt.utils import map_old_state_dict_weights @@ -163,7 +164,7 @@ def __init__(self, config: Config, block_idx: int) -> None: nn.Module.__init__(self) shape = (config.n_head + 2 * config.n_query_groups) * config.head_size # key, query, value projections for all heads, but in a batch - self.attn = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias) + self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias) # output projection # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) @@ -186,10 +187,10 @@ def __init__(self, config: Config, block_idx: int) -> None: self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: - """For compatibility with base checkpoints.""" + """For compatibility with base and/or legacy checkpoints.""" mapping = { - "attn.weight": "attn.linear.weight", - "attn.bias": "attn.linear.bias", + "qkv.weight": "qkv.linear.weight", + "qkv.bias": "qkv.linear.bias", "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } @@ -197,6 +198,13 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa # For compatibility with older checkpoints if (key := prefix + "gating_factor") in state_dict and state_dict[key].size(1) == self.config.n_head: state_dict[key] = state_dict[key].permute(0, 2, 1, 3) + + for attr in ("weight", "bias"): + legacy_key = f"{prefix}attn.linear.{attr}" + current_key = f"{prefix}qkv.linear.{attr}" + if legacy_key in state_dict: + state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) diff --git a/litgpt/generate/tp.py b/litgpt/generate/tp.py index c76d4f27c9..7b45ffd014 100644 --- a/litgpt/generate/tp.py +++ b/litgpt/generate/tp.py @@ -3,31 +3,30 @@ import logging import sys import time +import warnings from functools import partial from pathlib import Path from pprint import pprint from typing import Literal, Optional, Union -import warnings import lightning as L -from lightning_utilities.core.imports import RequirementCache import torch import torch._dynamo.config import torch._inductor.config from lightning.fabric.plugins import BitsandbytesPrecision from lightning.fabric.utilities import rank_zero_only +from lightning_utilities.core.imports import RequirementCache import litgpt.generate.base as generate_base -from litgpt.model import GPT from litgpt.config import Config -from litgpt.tokenizer import Tokenizer -from litgpt.model import CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE +from litgpt.model import GPT, CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style +from litgpt.tokenizer import Tokenizer from litgpt.utils import ( check_nvlink_connectivity, check_valid_checkpoint_dir, extend_checkpoint_dir, - get_default_supported_precision + get_default_supported_precision, ) @@ -71,7 +70,7 @@ def tensor_parallel_mlp(fabric: L.Fabric, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMA def tensor_parallel_attn(fabric: L.Fabric, attn: CausalSelfAttention) -> None: - tensor_parallel_linear(fabric, attn.attn, "colwise") + tensor_parallel_linear(fabric, attn.qkv, "colwise") tensor_parallel_linear(fabric, attn.proj, "rowwise") attn.register_forward_hook(partial(all_reduce_output, fabric.world_size)) diff --git a/litgpt/lora.py b/litgpt/lora.py index db48175eac..beca761c48 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -58,6 +58,7 @@ from litgpt.model import Block as BaseBlock from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention from litgpt.model import KVCache +from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble from litgpt.utils import map_old_state_dict_weights @@ -267,18 +268,14 @@ def lora_ind(self) -> torch.Tensor: # Indices are needed to properly pad weight updates with zeros. if not hasattr(self, "_lora_ind"): enable_q, enable_k, enable_v = self.enable_lora - qkv_group_size = self.n_head // self.n_query_groups + 2 - candidate_indices = range(self.linear.out_features) + kv_embd_size = self.linear.in_features // (self.n_head // self.n_query_groups) lora_ind = [] if enable_q: - q_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size < qkv_group_size - 2] - lora_ind.extend(q_ind) + lora_ind.extend(range(0, self.linear.in_features)) if enable_k: - k_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size == qkv_group_size - 2] - lora_ind.extend(k_ind) + lora_ind.extend(range(self.linear.in_features, self.linear.in_features + kv_embd_size)) if enable_v: - v_ind = [x for x in candidate_indices if (x // self.head_size) % qkv_group_size == qkv_group_size - 1] - lora_ind.extend(v_ind) + lora_ind.extend(range(self.linear.in_features + kv_embd_size, self.linear.out_features)) self.register_buffer( "_lora_ind", torch.tensor(lora_ind, device=self.linear.weight.device), persistent=False ) @@ -298,27 +295,6 @@ def zero_pad(self, x: torch.Tensor) -> torch.Tensor: ________________________________________ | query | key | value | ---------------------------------------- - For Llama2's GQA support, Q, K, and V weights are interleaved, so that weights for grouped - queries are adjacent to their associated key and value weights. - For example, suppose we have n_head = 12 with 3 query groups. - Then along the embedding dimension the interleaved weights would look like - - [Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V], - - where each Q, K, and V has size head_size. - - In this case, the previously-described weight update applies separately to each - individual block, so the update will take the form - - [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...], - [.............................................................................], - [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...]] - ↑ ↑ ↑ ↑ ↑ ↑ - ________________________________________________________________________________ - | q block 1 | k block 1 | v block 1 | q block 2 | k block 2 | v block 2 | ... - -------------------------------------------------------------------------------- - Note that in the above diagram, the size of each q block will equal q_per_kv - times the size of each k and v block. Args: x: tensor with weights update that will be padded with zeros if necessary @@ -391,7 +367,9 @@ def get_lora_AB(self) -> torch.Tensor: lora = self.conv1d( self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128) self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1) - ).squeeze(0) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) + ).squeeze( + 0 + ) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) return self.zero_pad(lora.T * self.scaling).T # (256, 128) after zero_pad (384, 128) def merge(self) -> None: @@ -430,7 +408,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: after_B = self.conv1d( after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64) self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1) - ).transpose(-2, -1) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) + ).transpose( + -2, -1 + ) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) lora = self.zero_pad(after_B) * self.scaling # (64, 64, 256) after zero_pad (64, 64, 384) return pretrained + lora @@ -602,7 +582,7 @@ def __init__(self, config: Config, block_idx: int) -> None: nn.Module.__init__(self) shape = (config.n_head + 2 * config.n_query_groups) * config.head_size # key, query, value projections for all heads, but in a batch - self.attn = LoRAQKVLinear( + self.qkv = LoRAQKVLinear( in_features=config.n_embd, out_features=shape, r=config.lora_r, @@ -628,21 +608,28 @@ def __init__(self, config: Config, block_idx: int) -> None: # disabled by default self.kv_cache: Optional[KVCache] = None self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_stride == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: - """For compatibility with base checkpoints.""" + """For compatibility with base and/or legacy checkpoints.""" mapping = { - "attn.weight": "attn.linear.weight", - "attn.bias": "attn.linear.bias", + "qkv.weight": "qkv.linear.weight", + "qkv.bias": "qkv.linear.bias", "proj.weight": "proj.linear.weight", "proj.bias": "proj.linear.bias", } state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + + for attr in ("weight", "bias"): + legacy_key = f"{prefix}attn.linear.{attr}" + current_key = f"{prefix}qkv.linear.{attr}" + if legacy_key in state_dict: + state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) @@ -758,4 +745,4 @@ def merge_lora_weights(model: GPT) -> None: """Merge LoRA weights into the full-rank weights to speed up inference.""" for module in model.modules(): if isinstance(module, LoRALinear): - module.merge() \ No newline at end of file + module.merge() diff --git a/litgpt/model.py b/litgpt/model.py index 643ba59a71..cbdf2a4bdd 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -7,13 +7,14 @@ """ import math -from typing import Any, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn from typing_extensions import Self from litgpt.config import Config +from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble class GPT(nn.Module): @@ -44,8 +45,10 @@ def max_seq_length(self, value: int) -> None: This allows setting a smaller number to avoid allocating unused memory """ if value > self.config.block_size: - raise ValueError(f"Cannot attend to {value}, block size is only {self.config.block_size}." - " This is likely because the input text exceeds the supported context length of this model.") + raise ValueError( + f"Cannot attend to {value}, block size is only {self.config.block_size}." + " This is likely because the input text exceeds the supported context length of this model." + ) self._max_seq_length = value if not hasattr(self, "cos"): # first call @@ -148,7 +151,9 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso } else: # Some but not all parameters are specified; raise an error - missing_params = [param for param, present in zip(adjusted_params_required, params_present) if not present] + missing_params = [ + param for param, present in zip(adjusted_params_required, params_present) if not present + ] raise ValueError( f"The following adjusted RoPE parameters are missing in rope_adjustments: {', '.join(missing_params)}. " "All adjusted RoPE parameters must be specified together." @@ -180,7 +185,11 @@ def set_kv_cache( # initialize the kv cache for all blocks for block in self.transformer.h: block.attn.kv_cache = block.attn.build_kv_cache( - batch_size, max_seq_length, rope_cache_length, device, dtype, + batch_size, + max_seq_length, + rope_cache_length, + device, + dtype, ) if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: @@ -262,17 +271,20 @@ def forward( class CausalSelfAttention(nn.Module): def __init__(self, config: Config, block_idx: int) -> None: super().__init__() - shape = (config.n_head + 2 * config.n_query_groups) * config.head_size - # key, query, value projections for all heads, but in a batch - self.attn = nn.Linear(config.n_embd, shape, bias=config.bias or config.attn_bias) + # key, query and value projections for all heads, but in a batch + self.qkv = nn.Linear( + config.n_embd, + (config.n_head + 2 * config.n_query_groups) * config.head_size, # support for grouped/multi queries + bias=config.bias or config.attn_bias, + ) # output projection # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) # disabled by default self.kv_cache: Optional[KVCache] = None self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_stride == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config @@ -285,42 +297,60 @@ def forward( mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - - qkv = self.attn(x) - - # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) - q_per_kv = self.config.n_head // self.config.n_query_groups - total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value - qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - - # split batched computation into three: - # q: (B, n_query_groups, q_per_kv, T, hs) - # k, v: (B, n_query_groups, 1, T, hs) - q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - - # maybe repeat k and v if for the non multi-head attention cases - # training: flash attention requires it - # inference: multi-query would require a full kv cache so avoid it to limit its memory usage - if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): - k = k.expand(*q.shape) - v = v.expand(*q.shape) - - q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - + # Notation: + # - B | batch size + # - T | time-step (sequence length) + # - C | model's embeddings size (n_embd) + # - C* | attentions's embeddings size + # - nh_(q,k,v) | number of heads for query, key and value + # - hs | head size + + B, T, C = x.size() + + # Perform a single multiplication operation using a combined QKV matrix to calculate `query`, `key`, and `value` + # instead of individually multiplying the input `x` with the respective weight matrices. + qkv = self.qkv(x) # (B, T, 3xC*) + + # Define query, key and value sizes. + # If grouped/multi query is enabled, these sizes are not equal (see the diagram in `lit_gpt/config.py::Config`). + query_size = self.config.n_head * self.config.head_size + key_size = value_size = self.config.n_query_groups * self.config.head_size + # Split qkv into query, key and value matrices. + q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*) + + # To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the + # embedding size (C) into num_heads (nh) and head_size (hs). + q = q.view(B, T, self.config.n_head, self.config.head_size) # (B, T, nh_q, hs) + k = k.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_k, hs) + v = v.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_v, hs) + + # The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are + # multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector + # of size `hs`. + q = q.transpose(1, 2) # (B, nh_q, T, hs) + k = k.transpose(1, 2) # (B, nh_k, T, hs) + v = v.transpose(1, 2) # (B, nh_v, T, hs) + + # Unlike standard positional embeddings rotary embeddings must be applied at every layer. q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) + q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) # (B, nh_q, T, hs) + k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) # (B, nh_k, T, hs) + # Apply kv-cache during inference. if input_pos is not None: if not isinstance(self.kv_cache, KVCache): raise TypeError("You need to call `gpt.set_kv_cache()`") k, v = self.kv_cache(input_pos, k, v) + # Grouped queries: balance the number of heads across all three matrices. + # NOTE: flash attention requires it in training mode. + # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting. + if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): + q_per_kv = self.config.n_head // self.config.n_query_groups + k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) + v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) + if self.apply_sliding_window_attention: """ Global Window Sliding window Sliding window @@ -339,12 +369,16 @@ def forward( sliding_window_bias.masked_fill_(sliding_window_bias.bool(), float("-inf")) mask += sliding_window_bias + # Efficient attention using Flash Attention CUDA kernels. + # NOTE: efficient implementation is disabled if `mask` is not None or softcapping is enabled. + # ↓ (B, nh, T, hs) @ (B, nh, T, hs).mT --> (B, nh, T, T) @ (B, nh, T, hs) --> (B, nh, T, hs) y = self.scaled_dot_product_attention(q, k, v, mask) - y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side + # Re-assemble all head outputs side by side. + y = y.reshape(B, T, self.config.head_size * self.config.n_head) - # output projection - return self.proj(y) + # Output projection. + return self.proj(y) # (B, T, C) def scaled_dot_product_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None @@ -375,8 +409,7 @@ def build_kv_cache( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> "KVCache": - heads = 1 if self.config.n_query_groups == 1 else self.config.n_head - v_shape = (batch_size, heads, max_seq_length, self.config.head_size) + v_shape = (batch_size, self.config.n_query_groups, max_seq_length, self.config.head_size) if rope_cache_length is None: if self.config.rotary_percentage != 1.0: raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value") @@ -384,12 +417,23 @@ def build_kv_cache( else: k_shape = ( batch_size, - heads, + self.config.n_query_groups, max_seq_length, rope_cache_length + self.config.head_size - self.config.rope_n_elem, ) return KVCache(k_shape, v_shape, device=device, dtype=dtype) + def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + """For compatibility with legacy checkpoints.""" + + for attr in ("weight", "bias"): + legacy_key = f"{prefix}attn.{attr}" + current_key = f"{prefix}qkv.{attr}" + if legacy_key in state_dict: + state_dict[current_key] = qkv_reassemble(state_dict.pop(legacy_key), self.config) + + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + class GptNeoxMLP(nn.Module): def __init__(self, config: Config) -> None: diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 6125840ed9..fbcfa871a6 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -2,37 +2,38 @@ import gc import json +import os +import re from collections import defaultdict from functools import partial -import os from pathlib import Path from pprint import pprint from typing import Dict, List, Optional, Tuple, Union -from tqdm import tqdm import torch from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor +from tqdm import tqdm from litgpt.config import Config from litgpt.utils import extend_checkpoint_dir, incremental_save, lazy_load, save_config def copy_weights_gpt_neox( + config: Config, state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, - debug_mode: Optional[bool] = False - + debug_mode: Optional[bool] = False, ) -> None: weight_map = { "gpt_neox.embed_in.weight": "transformer.wte.weight", "gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", - "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.attn.bias", - "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", + "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.qkv.bias", + "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.qkv.weight", "gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias", "gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight", "gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None, @@ -52,16 +53,16 @@ def copy_weights_gpt_neox( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights)) - for name, param in hf_weights.items(): - if "gpt_neox.layers" in name: - from_name, number = layer_template(name, 2) - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(number) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype, verbose=debug_mode) + for from_name, param in hf_weights.items(): + name_template, layer_idx = layer_template(from_name) + to_name = weight_map[name_template] + if to_name is None: + continue + to_name = to_name.format(layer_idx) + param = load_param(param, from_name, dtype, verbose=debug_mode) + if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")): + # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...] + param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -71,18 +72,18 @@ def copy_weights_gpt_neox( def copy_weights_falcon( - model_name: str, + config: Config, state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, - debug_mode: Optional[bool] = False + debug_mode: Optional[bool] = False, ) -> None: weight_map = { "transformer.word_embeddings.weight": "transformer.wte.weight", - "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", + "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.qkv.weight", "transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight", "transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", "transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", @@ -91,14 +92,14 @@ def copy_weights_falcon( "lm_head.weight": "lm_head.weight", } # the original model definition is different for each size - if "7b" in model_name: + if "7b" in config.name: weight_map.update( { "transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", } ) - elif "40b" in model_name or "180B" in model_name: + elif "40b" in config.name or "180B" in config.name: weight_map.update( { "transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias", @@ -113,16 +114,17 @@ def copy_weights_falcon( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights)) - for name, param in hf_weights.items(): - if "transformer.h" in name: - from_name, number = layer_template(name, 2) - to_name = weight_map[from_name].format(number) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype, verbose=debug_mode) + for from_name, param in hf_weights.items(): + name_template, layer_idx = layer_template(from_name) + to_name = weight_map[name_template].format(layer_idx) + param = load_param(param, from_name, dtype, verbose=debug_mode) + if from_name.endswith((".query_key_value.weight", ".query_key_value.bias")): + # Reassemble [q, k, v, q, k, v, ...] --> [q, q, ..., k, k, ..., v, v, ...] + param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param + if progress_per_file is not None: pbar.update(progress_per_file) @@ -136,19 +138,19 @@ def copy_weights_hf_llama( dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, - debug_mode: Optional[bool] = False + debug_mode: Optional[bool] = False, ) -> None: weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", - "model.layers.{}.input_layernorm.weight": "transformer.h.{l}.norm_1.weight", - "model.layers.{}.input_layernorm.bias": "transformer.h.{l}.norm_1.bias", + "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + "model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", "model.layers.{}.self_attn.q_proj.weight": None, "model.layers.{}.self_attn.k_proj.weight": None, "model.layers.{}.self_attn.v_proj.weight": None, - "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{l}.attn.proj.weight", + "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.self_attn.rotary_emb.inv_freq": None, - "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{l}.norm_2.weight", - "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{l}.norm_2.bias", + "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", + "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", "model.norm.weight": "transformer.ln_f.weight", "model.norm.bias": "transformer.ln_f.bias", "lm_head.weight": "lm_head.weight", @@ -156,18 +158,18 @@ def copy_weights_hf_llama( if config.mlp_class_name == "LLaMAMoE": weight_map.update( { - "model.layers.{}.block_sparse_moe.gate.weight": "transformer.h.{l}.mlp.gate.weight", - "model.layers.{}.block_sparse_moe.experts.{}.w1.weight": "transformer.h.{l}.mlp.experts.{e}.fc_1.weight", - "model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{l}.mlp.experts.{e}.fc_2.weight", - "model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{l}.mlp.experts.{e}.proj.weight", + "model.layers.{}.block_sparse_moe.gate.weight": "transformer.h.{}.mlp.gate.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w1.weight": "transformer.h.{}.mlp.experts.{}.fc_1.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{}.mlp.experts.{}.fc_2.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{}.mlp.experts.{}.proj.weight", } ) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"): weight_map.update( { - "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{l}.mlp.fc_1.weight", - "model.layers.{}.mlp.up_proj.weight": "transformer.h.{l}.mlp.fc_2.weight", - "model.layers.{}.mlp.down_proj.weight": "transformer.h.{l}.mlp.proj.weight", + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", } ) else: @@ -176,26 +178,17 @@ def copy_weights_hf_llama( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) - for name, param in hf_weights.items(): - if "model.layers" in name: - from_name, l = layer_template(name, 2) - e = None - if "block_sparse_moe.experts" in name: - from_name, e = layer_template(from_name, 5) - qkv = qkv_weights.setdefault(l, [None, None, None]) - if "q_proj" in name: - qkv[0] = param - elif "k_proj" in name: - qkv[1] = param - elif "v_proj" in name: - qkv[2] = param - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(l=l, e=e) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype, verbose=debug_mode) + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -206,28 +199,24 @@ def copy_weights_hf_llama( if "lm_head.weight" not in state_dict: state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] - # convert separate q, k, v matrices into an interleaved qkv - for i, (q, k, v) in list(qkv_weights.items()): - if q is None or k is None or v is None: - # split across different .bin files - continue - q = load_param(q, f"layer {i} q", dtype, verbose=debug_mode) - k = load_param(k, f"layer {i} k", dtype, verbose=debug_mode) - v = load_param(v, f"layer {i} v", dtype, verbose=debug_mode) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) - state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv - del qkv_weights[i] - if progress_per_file is not None: - pbar.update(progress_per_file) + for i in list(qkv_weights): + for weight_type in list(qkv_weights[i]): + qkv = qkv_weights[i][weight_type] + if len(qkv) != 3: + # qkv is splitted across different .bin files + continue + q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) + k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) + v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv + del qkv_weights[i][weight_type] + + if progress_per_file is not None: + pbar.update(progress_per_file) def copy_weights_gemma_2( - config: Config, qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], state_dict: Dict[str, torch.Tensor], hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], @@ -235,7 +224,7 @@ def copy_weights_gemma_2( dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, - debug_mode: Optional[bool] = False + debug_mode: Optional[bool] = False, ) -> None: weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", @@ -257,20 +246,17 @@ def copy_weights_gemma_2( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) - for name, param in hf_weights.items(): - if "model.layers" in name: - from_name, l_idx = layer_template(name, 2) - qkv = qkv_weights.setdefault(l_idx, defaultdict(dict)) - if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): - weight_name, weight_type = from_name.split(".")[-2:] - qkv[weight_type][weight_name] = param - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(l_idx) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype) + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -281,24 +267,19 @@ def copy_weights_gemma_2( if "lm_head.weight" not in state_dict: state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] - # convert separate q, k, v matrices into an interleaved qkv for i in list(qkv_weights): for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: - # split across different .bin files + # qkv is splitted across different .bin files continue - q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype) - k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype) - v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) - state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) + k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) + v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] + if progress_per_file is not None: pbar.update(progress_per_file) @@ -312,7 +293,7 @@ def copy_weights_phi( dtype: Optional[torch.dtype] = None, pbar: Optional[tqdm] = None, progress_per_file: Optional[float] = None, - debug_mode: Optional[bool] = False + debug_mode: Optional[bool] = False, ) -> None: if any(layer_name.startswith(("layers.", "transformer.")) for layer_name in hf_weights): raise ValueError( @@ -344,7 +325,7 @@ def copy_weights_phi( if config.name.startswith("Phi-3"): weight_map.update( { - "model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.attn.weight", + "model.layers.{}.self_attn.qkv_proj.weight": "transformer.h.{}.attn.qkv.weight", "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", @@ -355,35 +336,27 @@ def copy_weights_phi( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) - for name, param in hf_weights.items(): - if name.startswith("model.layers."): - from_name, l = layer_template(name, 2) - qkv = qkv_weights.setdefault(l, defaultdict(dict)) - if "qkv_proj" in from_name: - weight = load_param(param, f"layer {l} qkv", dtype) - weight = qkv_reassemble(weight, config) - to_name = weight_map[from_name].format(l) - state_dict[to_name] = weight - continue - if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): - weight_name, weight_type = from_name.split(".")[-2:] - qkv[weight_type][weight_name] = param - elif from_name.endswith("gate_up_proj.weight"): - weight = load_param(param, f"layer {l} gate_up_proj", dtype) - fc_1, fc_2 = weight.chunk(2, dim=0) - state_dict[f"transformer.h.{l}.mlp.fc_1.weight"] = fc_1 - state_dict[f"transformer.h.{l}.mlp.fc_2.weight"] = fc_2 - continue - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(l) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype, verbose=debug_mode) + for from_name, param in hf_weights.items(): + name_template, layer_idx = layer_template(from_name) + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(layer_idx, defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + elif from_name.endswith("gate_up_proj.weight"): + weight = load_param(param, f"layer {layer_idx} gate_up_proj", dtype, verbose=debug_mode) + fc_1, fc_2 = weight.chunk(2, dim=0) + state_dict[f"transformer.h.{layer_idx}.mlp.fc_1.weight"] = fc_1 + state_dict[f"transformer.h.{layer_idx}.mlp.fc_2.weight"] = fc_2 + continue + to_name = weight_map[name_template] + if to_name is None: + continue + to_name = to_name.format(layer_idx) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param + if progress_per_file is not None: pbar.update(progress_per_file) @@ -391,19 +364,15 @@ def copy_weights_phi( for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: - # split across different .bin files + # qkv is splitted across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) - state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] + if progress_per_file is not None: pbar.update(progress_per_file) @@ -439,20 +408,17 @@ def copy_weights_qwen_2_5( if progress_per_file is not None: progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) - for name, param in hf_weights.items(): - if "model.layers" in name: - from_name, l = layer_template(name, 2) - qkv = qkv_weights.setdefault(l, defaultdict(dict)) - if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): - weight_name, weight_type = from_name.split(".")[-2:] - qkv[weight_type][weight_name] = param - to_name = weight_map[from_name] - if to_name is None: - continue - to_name = to_name.format(l) - else: - to_name = weight_map[name] - param = load_param(param, name, dtype, verbose=debug_mode) + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(*ids) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -467,49 +433,52 @@ def copy_weights_qwen_2_5( for weight_type in list(qkv_weights[i]): qkv = qkv_weights[i][weight_type] if len(qkv) != 3: - # split across different .bin files + # qkv is splitted across different .bin files continue q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) - q_per_kv = config.n_head // config.n_query_groups - qs = torch.split(q, config.head_size * q_per_kv) - ks = torch.split(k, config.head_size) - vs = torch.split(v, config.head_size) - cycled = [t for group in zip(qs, ks, vs) for t in group] - qkv = torch.cat(cycled) - state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv del qkv_weights[i][weight_type] + if progress_per_file is not None: pbar.update(progress_per_file) -def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: + +def qkv_reassemble( + param: Union[torch.Tensor, NotYetLoadedTensor], config: Config +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Reassemble from a normal to an interleaved placement in a QKV matrix. - [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...] + [Q, K, V, Q, K, V, ...] --> [Q, Q, ..., K, K, ..., V, V, ...] """ - q, k, v = param.split( - ( - config.n_head * config.head_size, - config.n_query_groups * config.head_size, - config.n_query_groups * config.head_size, - ) - ) - qs = q.split(config.n_head // config.n_query_groups * config.head_size) - ks = k.split(config.head_size) - vs = v.split(config.head_size) - interleaved = [t for group in zip(qs, ks, vs) for t in group] - return torch.cat(interleaved) - - -def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: - split = layer_name.split(".") - number = int(split[idx]) - split[idx] = "{}" - from_name = ".".join(split) - return from_name, number - - -def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype], verbose=False) -> torch.Tensor: + q_per_kv = config.n_head // config.n_query_groups + qs = [] + ks = [] + vs = [] + for chunk in torch.chunk(param, config.n_query_groups): + split = torch.split(chunk, [config.head_size * q_per_kv, config.head_size, config.head_size]) + qs.append(split[0]) + ks.append(split[1]) + vs.append(split[2]) + q = torch.cat(qs) + k = torch.cat(ks) + v = torch.cat(vs) + return torch.cat((q, k, v)) + + + +def layer_template(layer_name: str, num_matches: int = 1) -> Tuple[str, int]: + pattern = r"\.(\d+)\." + if not (search_res := re.findall(pattern, layer_name)): + return layer_name, -1 + layer_name_template = re.sub(pattern, ".{}.", layer_name, count=num_matches) + return layer_name_template, *(int(x) for x in search_res[:num_matches]) + + +def load_param( + param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: Optional[torch.dtype], verbose: bool =False +) -> torch.Tensor: if hasattr(param, "_load_tensor"): # support tensors loaded via `lazy_load()` if verbose: @@ -522,13 +491,14 @@ def load_param(param: Union[torch.Tensor, NotYetLoadedTensor], name: str, dtype: return param + @torch.inference_mode() def convert_hf_checkpoint( checkpoint_dir: Path, *, model_name: Optional[str] = None, dtype: Optional[str] = None, - debug_mode: Optional[bool] = False + debug_mode: Optional[bool] = False, ) -> None: """ Convert a Hugging Face Transformers checkpoint into a LitGPT compatible checkpoint. @@ -554,10 +524,10 @@ def convert_hf_checkpoint( save_config(config, checkpoint_dir) if "falcon" in model_name: - copy_fn = partial(copy_weights_falcon, model_name) + copy_fn = partial(copy_weights_falcon, config) elif model_name.lower().startswith("gemma-2"): qkv_weights = {} - copy_fn = partial(copy_weights_gemma_2, config, qkv_weights) + copy_fn = partial(copy_weights_gemma_2, qkv_weights) elif model_name.lower().startswith("phi"): # holder to reconstitute the split q, k, v qkv_weights = {} @@ -571,7 +541,7 @@ def convert_hf_checkpoint( qkv_weights = {} copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) else: - copy_fn = copy_weights_gpt_neox + copy_fn = partial(copy_weights_gpt_neox, config) # initialize a new empty state dict to hold our new weights sd = {} @@ -604,14 +574,26 @@ def convert_hf_checkpoint( total_size = max(1, sum(os.path.getsize(bin_file) for bin_file in bin_files)) total_progress = 100 - with tqdm(total=total_progress, desc="Initializing", bar_format="{desc}{percentage:3.0f}%|{bar}| {elapsed}<{remaining}, {rate_fmt}") as pbar: + with tqdm( + total=total_progress, + desc="Initializing", + bar_format="{desc}{percentage:3.0f}%|{bar}| {elapsed}<{remaining}, {rate_fmt}", + ) as pbar: for bin_file in sorted(bin_files): pbar.set_description(f"Loading weights: {bin_file.name}") current_file_size = os.path.getsize(bin_file) progress_per_file = (current_file_size / total_size) * total_progress hf_weights = lazy_load(bin_file) - copy_fn(sd, hf_weights, saver=saver, dtype=dtype, pbar=pbar, progress_per_file=progress_per_file, debug_mode=debug_mode) + copy_fn( + sd, + hf_weights, + saver=saver, + dtype=dtype, + pbar=pbar, + progress_per_file=progress_per_file, + debug_mode=debug_mode, + ) gc.collect() if pbar.n < total_progress: diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index a994b3022a..f276e3ae31 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -5,7 +5,7 @@ from functools import partial from pathlib import Path from pprint import pprint -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Union import torch from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor @@ -16,14 +16,14 @@ def copy_weights_falcon( - model_name: str, + config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, ) -> None: weight_map = { "transformer.wte.weight": "transformer.word_embeddings.weight", - "transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight", + "transformer.h.{}.attn.qkv.weight": "transformer.h.{}.self_attention.query_key_value.weight", "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", @@ -32,14 +32,14 @@ def copy_weights_falcon( "lm_head.weight": "lm_head.weight", } # the original model definition is different for each size - if "7b" in model_name: + if "7b" in config.name: weight_map.update( { "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", } ) - elif "40b" in model_name or "180B" in model_name: + elif "40b" in config.name or "180B" in config.name: weight_map.update( { "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", @@ -51,19 +51,20 @@ def copy_weights_falcon( else: raise NotImplementedError - for name, param in lit_weights.items(): - if "transformer.h" in name: - from_name, number = layer_template(name, 2) - to_name = weight_map[from_name].format(number) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + for from_name, param in lit_weights.items(): + name_template, layer_idx = layer_template(from_name) + to_name = weight_map[name_template].format(layer_idx) + param = load_param(param, from_name, None) + if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): + # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] + param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param def copy_weights_gpt_neox( + config: Config, state_dict: Dict[str, torch.Tensor], lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], saver: Optional[incremental_save] = None, @@ -72,8 +73,8 @@ def copy_weights_gpt_neox( "transformer.wte.weight": "gpt_neox.embed_in.weight", "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", - "transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", - "transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", + "transformer.h.{}.attn.qkv.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", + "transformer.h.{}.attn.qkv.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", @@ -87,13 +88,13 @@ def copy_weights_gpt_neox( "lm_head.weight": "embed_out.weight", } - for name, param in lit_weights.items(): - if "transformer.h" in name: - from_name, number = layer_template(name, 2) - to_name = weight_map[from_name].format(number) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + for from_name, param in lit_weights.items(): + name_template, layer_idx = layer_template(from_name) + to_name = weight_map[name_template].format(layer_idx) + param = load_param(param, from_name, None) + if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): + # Reassemble [q, q, ..., k, k, ..., v, v, ...] --> [q, k, v, q, k, v, ...] + param = qkv_reassemble(param, config) if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -108,11 +109,11 @@ def copy_weights_llama( ) -> None: weight_map = { "transformer.wte.weight": "model.embed_tokens.weight", - "transformer.h.{}.norm_1.weight": "model.layers.{l}.input_layernorm.weight", - "transformer.h.{}.norm_1.bias": "model.layers.{l}.input_layernorm.bias", - "transformer.h.{}.attn.proj.weight": "model.layers.{l}.self_attn.o_proj.weight", - "transformer.h.{}.norm_2.weight": "model.layers.{l}.post_attention_layernorm.weight", - "transformer.h.{}.norm_2.bias": "model.layers.{l}.post_attention_layernorm.bias", + "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", + "transformer.h.{}.norm_1.bias": "model.layers.{}.input_layernorm.bias", + "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", + "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", + "transformer.h.{}.norm_2.bias": "model.layers.{}.post_attention_layernorm.bias", "transformer.ln_f.weight": "model.norm.weight", "transformer.ln_f.bias": "model.norm.bias", "lm_head.weight": "lm_head.weight", @@ -120,48 +121,46 @@ def copy_weights_llama( if config.mlp_class_name == "LLaMAMoE": weight_map.update( { - "transformer.h.{}.mlp.gate.weight": "model.layers.{l}.block_sparse_moe.gate.weight", - "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w1.weight", - "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w3.weight", - "transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight", + "transformer.h.{}.mlp.gate.weight": "model.layers.{}.block_sparse_moe.gate.weight", + "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{}.block_sparse_moe.experts.{}.w1.weight", + "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{}.block_sparse_moe.experts.{}.w3.weight", + "transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{}.block_sparse_moe.experts.{}.w2.weight", } ) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"): weight_map.update( { - "transformer.h.{}.mlp.fc_1.weight": "model.layers.{l}.mlp.gate_proj.weight", - "transformer.h.{}.mlp.fc_2.weight": "model.layers.{l}.mlp.up_proj.weight", - "transformer.h.{}.mlp.proj.weight": "model.layers.{l}.mlp.down_proj.weight", + "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", + "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", + "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", } ) else: raise NotImplementedError - for name, param in lit_weights.items(): - if name == "lm_head.weight" and untie_weights: + for from_name, param in lit_weights.items(): + if from_name == "lm_head.weight" and untie_weights: continue - if name.endswith(".attn.attn.weight"): - from_name, l = layer_template(name, 2) - q = "model.layers.{}.self_attn.q_proj.weight".format(l) - k = "model.layers.{}.self_attn.k_proj.weight".format(l) - v = "model.layers.{}.self_attn.v_proj.weight".format(l) - qkv = load_param(param, name, None) - qp, kp, vp = qkv_split(qkv, config) - for to_name, param in zip((q, k, v), (qp, kp, vp)): - if saver is not None: - param = saver.store_early(param) - state_dict[to_name] = param + name_template, *ids = layer_template(from_name, num_matches=2) + param = load_param(param, from_name, None) + if from_name.endswith(".attn.qkv.weight"): + to_names = ( + "model.layers.{}.self_attn.q_proj.weight".format(*ids), + "model.layers.{}.self_attn.k_proj.weight".format(*ids), + "model.layers.{}.self_attn.v_proj.weight".format(*ids), + ) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) else: - if "transformer.h" in name: - from_name, l = layer_template(name, 2) - e = None - if "mlp.experts" in name: - from_name, e = layer_template(from_name, 5) - to_name = weight_map[from_name] - to_name = to_name.format(l=l, e=e) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + to_names = (weight_map[name_template].format(*ids),) + params = (param,) + + for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -188,31 +187,29 @@ def copy_weights_gemma_2( "lm_head.weight": "lm_head.weight", } - for name, param in lit_weights.items(): - if name == "lm_head.weight" and untie_weights: + for from_name, param in lit_weights.items(): + if from_name == "lm_head.weight" and untie_weights: continue - if name.endswith(".attn.attn.weight"): - from_name, layer_idx = layer_template(name, 2) - q = "model.layers.{}.self_attn.q_proj.weight".format(layer_idx) - k = "model.layers.{}.self_attn.k_proj.weight".format(layer_idx) - v = "model.layers.{}.self_attn.v_proj.weight".format(layer_idx) - qkv = load_param(param, name, None) - qp, kp, vp = qkv_split(qkv, config) - for to_name, param in zip((q, k, v), (qp, kp, vp)): - if saver is not None: - param = saver.store_early(param) - state_dict[to_name] = param + name_template, *ids = layer_template(from_name, num_matches=2) + param = load_param(param, from_name, None) + if from_name.endswith(".attn.qkv.weight"): + to_names = ( + "model.layers.{}.self_attn.q_proj.weight".format(*ids), + "model.layers.{}.self_attn.k_proj.weight".format(*ids), + "model.layers.{}.self_attn.v_proj.weight".format(*ids), + ) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) else: - if "transformer.h" in name: - from_name, layer_idx = layer_template(name, 2) - e = None - if "mlp.experts" in name: - from_name, e = layer_template(from_name, 5) - to_name = weight_map[from_name] - to_name = to_name.format(layer_idx) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + to_names = (weight_map[name_template].format(*ids),) + params = (param,) + + for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param @@ -239,11 +236,10 @@ def copy_weights_phi( "lm_head.weight": "lm_head.weight", "lm_head.bias": "lm_head.bias", } - if config.name.startswith("Phi-3"): weight_map.update( { - "transformer.h.{}.attn.attn.weight": "model.layers.{}.self_attn.qkv_proj.weight", + "transformer.h.{}.attn.qkv.weight": "model.layers.{}.self_attn.qkv_proj.weight", "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", @@ -252,51 +248,48 @@ def copy_weights_phi( ) gate_up_proj_weights = defaultdict(dict) - for name, param in lit_weights.items(): - if name.endswith((".attn.attn.weight", ".attn.attn.bias")): - from_name, l_idx = layer_template(name, 2) - qkv = load_param(param, name, None) - qp, kp, vp = qkv_split(qkv, config) + for from_name, param in lit_weights.items(): + name_template, layer_idx = layer_template(from_name) + param = load_param(param, from_name, None) + if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): if config.name.startswith("Phi-3"): - qkv_reassembled = torch.concat([qp, kp, vp], dim=0) - to_name = weight_map[from_name].format(l_idx) - if saver is not None: - qkv_reassembled = saver.store_early(qkv_reassembled) - state_dict[to_name] = qkv_reassembled + to_names = (weight_map[name_template].format(layer_idx),) + params = (param,) else: - weight_type = name.split(".")[-1] # weight or bias - q = f"model.layers.{l_idx}.self_attn.q_proj.{weight_type}" - k = f"model.layers.{l_idx}.self_attn.k_proj.{weight_type}" - v = f"model.layers.{l_idx}.self_attn.v_proj.{weight_type}" - for to_name, param in zip((q, k, v), (qp, kp, vp)): - if saver is not None: - param = saver.store_early(param) - state_dict[to_name] = param - elif name.endswith((".fc_1.weight", ".fc_2.weight")): - from_name, l_idx = layer_template(name, 2) - weight = load_param(param, name, None) - weight_name = name.split(".")[-2] - gate_up_proj_weights[l_idx][weight_name] = weight + weight_type = from_name.split(".")[-1] # weight or bias + to_names = ( + f"model.layers.{{}}.self_attn.q_proj.{weight_type}".format(layer_idx), + f"model.layers.{{}}.self_attn.k_proj.{weight_type}".format(layer_idx), + f"model.layers.{{}}.self_attn.v_proj.{weight_type}".format(layer_idx), + ) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) + elif from_name.endswith((".fc_1.weight", ".fc_2.weight")): + weight = load_param(param, from_name, None) + weight_name = from_name.split(".")[-2] + gate_up_proj_weights[layer_idx][weight_name] = weight else: - if "transformer.h" in name: - from_name, l_idx = layer_template(name, 2) - to_name = weight_map[from_name] - to_name = to_name.format(l_idx) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + to_names = (weight_map[name_template].format(layer_idx),) + params = (param,) + + for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param if config.name.startswith("Phi-3"): - for i in list(gate_up_proj_weights): - fc_1_weight = gate_up_proj_weights[i]["fc_1"] - fc_2_weight = gate_up_proj_weights[i]["fc_2"] + for layer_idx in list(gate_up_proj_weights): + fc_1_weight = gate_up_proj_weights[layer_idx]["fc_1"] + fc_2_weight = gate_up_proj_weights[layer_idx]["fc_2"] weight = torch.concat([fc_1_weight, fc_2_weight], dim=0) - layer_name = f"model.layers.{i}.mlp.gate_up_proj.weight" + layer_name = f"model.layers.{layer_idx}.mlp.gate_up_proj.weight" state_dict[layer_name] = weight - del gate_up_proj_weights[i] + del gate_up_proj_weights[layer_idx] def copy_weights_qwen_2_5( config: Config, @@ -317,50 +310,51 @@ def copy_weights_qwen_2_5( "lm_head.weight": "lm_head.weight", } - for name, param in lit_weights.items(): - if name == "lm_head.weight" and untie_weights: + for from_name, param in lit_weights.items(): + if from_name == "lm_head.weight" and untie_weights: continue - if name.endswith((".attn.attn.weight", ".attn.attn.bias")): - from_name, l_idx = layer_template(name, 2) - qkv = load_param(param, name, None) - qp, kp, vp = qkv_split(qkv, config) - - weight_type = name.split(".")[-1] # weight or bias - q = f"model.layers.{l_idx}.self_attn.q_proj.{weight_type}" - k = f"model.layers.{l_idx}.self_attn.k_proj.{weight_type}" - v = f"model.layers.{l_idx}.self_attn.v_proj.{weight_type}" - for to_name, param in zip((q, k, v), (qp, kp, vp)): - if saver is not None: - param = saver.store_early(param) - state_dict[to_name] = param + name_template, *ids = layer_template(from_name, num_matches=2) + param = load_param(param, from_name, None) + if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")): + weight_type = from_name.split(".")[-1] # weight or bias + to_names = ( + "model.layers.{}.self_attn.q_proj.{}".format(*ids, weight_type), + "model.layers.{}.self_attn.k_proj.{}".format(*ids, weight_type), + "model.layers.{}.self_attn.v_proj.{}".format(*ids, weight_type), + ) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) else: - if "transformer.h" in name: - from_name, l_idx = layer_template(name, 2) - to_name = weight_map[from_name] - to_name = to_name.format(l_idx) - else: - to_name = weight_map[name] - param = load_param(param, name, None) + to_names = (weight_map[name_template].format(*ids),) + params = (param,) + + for to_name, param in zip(to_names, params): if saver is not None: param = saver.store_early(param) state_dict[to_name] = param -def qkv_split( - param: Union[torch.Tensor, NotYetLoadedTensor], config: Config -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - q_per_kv = config.n_head // config.n_query_groups - qs = [] - ks = [] - vs = [] - for chunk in torch.chunk(param, config.n_query_groups): - split = torch.split(chunk, [config.head_size * q_per_kv, config.head_size, config.head_size]) - qs.append(split[0]) - ks.append(split[1]) - vs.append(split[2]) - q = torch.cat(qs) - k = torch.cat(ks) - v = torch.cat(vs) - return q, k, v + +def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor: + """Reassemble from a normal to an interleaved placement in a QKV matrix. + [Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...] + """ + q, k, v = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) + qs = q.split(config.n_head // config.n_query_groups * config.head_size) + ks = k.split(config.head_size) + vs = v.split(config.head_size) + interleaved = [t for group in zip(qs, ks, vs) for t in group] + return torch.cat(interleaved) def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: @@ -382,7 +376,7 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None: output_path = output_dir / "model.pth" if "falcon" in config.name: - copy_fn = partial(copy_weights_falcon, config.name) + copy_fn = partial(copy_weights_falcon, config) elif config.name.startswith("Gemma-2"): copy_fn = partial(copy_weights_gemma_2, config) elif config.name.lower().startswith("phi"): @@ -393,7 +387,7 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None: untie_weights = "Gemma" in config.name copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights) else: - copy_fn = copy_weights_gpt_neox + copy_fn = partial(copy_weights_gpt_neox, config) # initialize a new empty state dict to hold our new weights sd = {} diff --git a/tests/test_adapter.py b/tests/test_adapter.py index da422f6288..9deb7be1f7 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -1,6 +1,7 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout +from copy import deepcopy from dataclasses import asdict from io import StringIO from unittest import mock @@ -19,10 +20,11 @@ import litgpt.adapter as gpt_adapter import litgpt.finetune.adapter as module import litgpt.model as gpt -from litgpt.adapter import GPT, Config, adapter_filter +from litgpt.adapter import GPT, CausalSelfAttention, Config, adapter_filter from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_hf_llama +from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from tests.conftest import RunIf @@ -192,7 +194,7 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca "transformer.h.0.norm_1.weight", "transformer.h.0.norm_1.bias", "transformer.h.0.attn.gating_factor", - "transformer.h.0.attn.attn.bias", + "transformer.h.0.attn.qkv.bias", "transformer.h.0.attn.proj.bias", "transformer.h.0.attn.adapter_wte.weight", "transformer.h.0.norm_2.weight", @@ -202,7 +204,7 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca "transformer.h.1.norm_1.weight", "transformer.h.1.norm_1.bias", "transformer.h.1.attn.gating_factor", - "transformer.h.1.attn.attn.bias", + "transformer.h.1.attn.qkv.bias", "transformer.h.1.attn.proj.bias", "transformer.h.1.attn.adapter_wte.weight", "transformer.h.1.norm_2.weight", @@ -214,11 +216,11 @@ def test_adapter_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca }, "torch.uint8": { "lm_head.weight", - "transformer.h.0.attn.attn.weight", + "transformer.h.0.attn.qkv.weight", "transformer.h.0.attn.proj.weight", "transformer.h.0.mlp.fc.weight", "transformer.h.0.mlp.proj.weight", - "transformer.h.1.attn.attn.weight", + "transformer.h.1.attn.qkv.weight", "transformer.h.1.attn.proj.weight", "transformer.h.1.mlp.fc.weight", "transformer.h.1.mlp.proj.weight", @@ -345,7 +347,7 @@ def test_against_original_gemma_2(model_name, device, dtype): # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} - copy_weights_gemma_2(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) @@ -355,3 +357,25 @@ def test_against_original_gemma_2(model_name, device, dtype): ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + + +def test_load_legacy_state_dict(): + """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" + config = Config( + n_embd=32, + n_head=4, + head_size=8, + n_query_groups=4, + bias=True, + ) + + attention_1 = CausalSelfAttention(config=config, block_idx=0) + + # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` + # and make them interleaved + state_dict = deepcopy(attention_1.state_dict()) + state_dict["attn.weight"] = make_qkv_interleaved(state_dict.pop("qkv.weight"), config) + state_dict["attn.bias"] = make_qkv_interleaved(state_dict.pop("qkv.bias"), config) + + attention_2 = CausalSelfAttention(config=config, block_idx=0) + attention_2.load_state_dict(state_dict) diff --git a/tests/test_adapter_v2.py b/tests/test_adapter_v2.py index aec205155d..ca00a5d641 100644 --- a/tests/test_adapter_v2.py +++ b/tests/test_adapter_v2.py @@ -1,6 +1,7 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout +from copy import deepcopy from io import StringIO from unittest import mock from unittest.mock import Mock @@ -19,11 +20,12 @@ import litgpt.config as config_module import litgpt.finetune.adapter_v2 as module from litgpt.adapter_v2 import GPT as AdapterV2GPT -from litgpt.adapter_v2 import Config, adapter_filter +from litgpt.adapter_v2 import CausalSelfAttention, Config, adapter_filter from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca from litgpt.model import GPT as BaseGPT from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_hf_llama +from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from tests.conftest import RunIf @@ -33,10 +35,10 @@ def test_config_identical(): base_model = BaseGPT.from_name(name) adapter_model = AdapterV2GPT.from_name(name) - assert not hasattr(base_model.transformer.h[2].attn.attn, "adapter_bias") - assert not hasattr(base_model.transformer.h[2].attn.attn, "adapter_scale") - assert hasattr(adapter_model.transformer.h[2].attn.attn, "adapter_bias") - assert hasattr(adapter_model.transformer.h[2].attn.attn, "adapter_scale") + assert not hasattr(base_model.transformer.h[2].attn.qkv, "adapter_bias") + assert not hasattr(base_model.transformer.h[2].attn.qkv, "adapter_scale") + assert hasattr(adapter_model.transformer.h[2].attn.qkv, "adapter_bias") + assert hasattr(adapter_model.transformer.h[2].attn.qkv, "adapter_scale") def test_adapter_v2_filter(tmp_path): @@ -56,8 +58,8 @@ def test_adapter_v2_filter(tmp_path): } for layer in range(3): for param in ( - "attn.attn.adapter_bias", - "attn.attn.adapter_scale", + "attn.qkv.adapter_bias", + "attn.qkv.adapter_scale", "attn.proj.adapter_bias", "attn.proj.adapter_scale", "mlp.fc.adapter_bias", @@ -297,7 +299,7 @@ def test_against_original_gemma_2(model_name): # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} - copy_weights_gemma_2(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = AdapterV2GPT(ours_config).to(device) keys = ours_model.load_state_dict(state_dict, strict=False) assert not keys.unexpected_keys @@ -364,27 +366,27 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp "torch.uint8": { "transformer.h.0.mlp.fc.linear.weight", "transformer.h.1.mlp.proj.linear.weight", - "transformer.h.1.attn.attn.linear.weight", + "transformer.h.1.attn.qkv.linear.weight", "transformer.h.0.attn.proj.linear.weight", "lm_head.linear.weight", "transformer.h.1.attn.proj.linear.weight", "transformer.h.0.mlp.proj.linear.weight", - "transformer.h.0.attn.attn.linear.weight", + "transformer.h.0.attn.qkv.linear.weight", "transformer.h.1.mlp.fc.linear.weight", }, "torch.float16": { - "transformer.h.1.attn.attn.adapter_bias", + "transformer.h.1.attn.qkv.adapter_bias", "transformer.h.1.mlp.proj.adapter_bias", - "transformer.h.0.attn.attn.adapter_bias", + "transformer.h.0.attn.qkv.adapter_bias", "transformer.h.0.norm_1.bias", - "transformer.h.0.attn.attn.linear.bias", + "transformer.h.0.attn.qkv.linear.bias", "transformer.h.1.attn.adapter_wte.weight", "transformer.ln_f.weight", "transformer.h.0.mlp.fc.linear.bias", "transformer.h.0.mlp.proj.linear.bias", "transformer.h.1.mlp.fc.linear.bias", "transformer.h.0.attn.proj.adapter_scale", - "transformer.h.0.attn.attn.adapter_scale", + "transformer.h.0.attn.qkv.adapter_scale", "transformer.h.1.norm_2.bias", "transformer.h.1.attn.proj.adapter_scale", "transformer.h.0.norm_2.bias", @@ -406,9 +408,9 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp "lm_head.adapter_bias", "transformer.h.1.norm_2.weight", "transformer.h.0.attn.adapter_wte.weight", - "transformer.h.1.attn.attn.adapter_scale", + "transformer.h.1.attn.qkv.adapter_scale", "transformer.h.1.mlp.fc.adapter_scale", - "transformer.h.1.attn.attn.linear.bias", + "transformer.h.1.attn.qkv.linear.bias", "transformer.wte.weight", "transformer.wte.norm.weight", "transformer.wte.norm.bias", @@ -435,20 +437,20 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp "transformer.ln_f.bias", "lm_head.adapter_scale", "transformer.h.1.norm_2.weight", - "transformer.h.0.attn.attn.adapter_scale", + "transformer.h.0.attn.qkv.adapter_scale", "transformer.h.0.mlp.proj.adapter_bias", "transformer.h.0.attn.gating_factor", "transformer.h.1.norm_1.bias", "transformer.h.1.mlp.fc.adapter_bias", "transformer.h.1.mlp.proj.adapter_scale", "transformer.h.0.mlp.fc.adapter_scale", - "transformer.h.1.attn.attn.adapter_bias", + "transformer.h.1.attn.qkv.adapter_bias", "transformer.h.0.norm_2.weight", "transformer.h.1.norm_2.bias", "transformer.h.0.norm_1.weight", "transformer.h.0.attn.proj.adapter_scale", "transformer.h.1.mlp.proj.adapter_bias", - "transformer.h.0.attn.attn.adapter_bias", + "transformer.h.0.attn.qkv.adapter_bias", "transformer.h.0.attn.adapter_wte.weight", "transformer.ln_f.weight", "transformer.h.1.attn.gating_factor", @@ -458,10 +460,31 @@ def test_adapter_v2_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alp "transformer.h.0.norm_1.bias", "transformer.h.0.norm_2.bias", "transformer.h.1.norm_1.weight", - "transformer.h.1.attn.attn.adapter_scale", + "transformer.h.1.attn.qkv.adapter_scale", } } logs = stdout.getvalue() assert "of trainable parameters: 552" in logs assert "of non-trainable parameters: 1,808" in logs + +def test_load_legacy_state_dict(): + """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" + config = Config( + n_embd=32, + n_head=4, + head_size=8, + n_query_groups=4, + bias=True, + ) + + attention_1 = CausalSelfAttention(config=config, block_idx=0) + + # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` + # and make them interleaved + state_dict = deepcopy(attention_1.state_dict()) + state_dict["attn.linear.weight"] = make_qkv_interleaved(state_dict.pop("qkv.linear.weight"), config) + state_dict["attn.linear.bias"] = make_qkv_interleaved(state_dict.pop("qkv.linear.bias"), config) + + attention_2 = CausalSelfAttention(config=config, block_idx=0) + attention_2.load_state_dict(state_dict) diff --git a/tests/test_convert_hf_checkpoint.py b/tests/test_convert_hf_checkpoint.py index 08749e521d..38b41f711d 100644 --- a/tests/test_convert_hf_checkpoint.py +++ b/tests/test_convert_hf_checkpoint.py @@ -6,7 +6,7 @@ import torch from litgpt import Config -from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint, copy_weights_hf_llama +from litgpt.scripts.convert_hf_checkpoint import convert_hf_checkpoint, copy_weights_hf_llama, qkv_reassemble def test_llama2_70b_conversion(): @@ -17,10 +17,10 @@ def test_llama2_70b_conversion(): "model.layers.0.mlp.gate_proj.weight": (28672, 8192), "model.layers.0.mlp.up_proj.weight": (28672, 8192), "model.layers.0.post_attention_layernorm.weight": (8192,), - "model.layers.0.self_attn.k_proj.weight": (1024, 8192), - "model.layers.0.self_attn.o_proj.weight": (8192, 8192), "model.layers.0.self_attn.q_proj.weight": (8192, 8192), + "model.layers.0.self_attn.k_proj.weight": (1024, 8192), "model.layers.0.self_attn.v_proj.weight": (1024, 8192), + "model.layers.0.self_attn.o_proj.weight": (8192, 8192), "model.layers.1.input_layernorm.weight": (8192,), "model.layers.1.mlp.down_proj.weight": (8192, 28672), "model.layers.1.mlp.gate_proj.weight": (28672, 8192), @@ -56,14 +56,14 @@ def test_llama2_70b_conversion(): weight_map = {k: torch.empty(s) for k, s in shapes.items()} copy_weights_hf_llama(config, qkv_weights, holder, weight_map) - # we are only testing 5 layers - assert len(qkv_weights) == 5 + # NOTE: there are 5 layers, but only in the first layer we have `q`, `k` and `v` + assert len(qkv_weights) == 1 # there are no loaded qkv weights assert all(v is None for qkv in qkv_weights.values() for v in qkv) # the shapes are correct holder = {k: tuple(t.shape) for k, t in holder.items()} assert holder == { - "transformer.h.0.attn.attn.weight": (10240, 8192), + "transformer.h.0.attn.qkv.weight": (10240, 8192), "transformer.h.0.attn.proj.weight": (8192, 8192), "transformer.h.0.mlp.fc_1.weight": (28672, 8192), "transformer.h.0.mlp.fc_2.weight": (28672, 8192), @@ -101,14 +101,18 @@ def test_llama2_70b_conversion(): } -def test_convert_hf_checkpoint(tmp_path): +@pytest.mark.parametrize("model_name", ("pythia-14m", "falcon-7b", "Llama-2-7b-hf", "phi-2")) +def test_convert_hf_checkpoint(tmp_path, model_name): with pytest.raises(ValueError, match="to contain .bin"): - convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name="pythia-14m") + convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name) bin_file = tmp_path / "foo.bin" bin_file.touch() with mock.patch("litgpt.scripts.convert_hf_checkpoint.lazy_load") as load: - convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name="pythia-14m") + # bypass if-statement for weight tying + if model_name == "Llama-2-7b-hf": + load.return_value = {"model.embed_tokens.weight": torch.rand((10, 10))} + convert_hf_checkpoint(checkpoint_dir=tmp_path, model_name=model_name) load.assert_called_with(bin_file) assert {p.name for p in tmp_path.glob("*")} == {"foo.bin", "model_config.yaml", "lit_model.pth"} @@ -119,43 +123,40 @@ def test_convert_hf_checkpoint(tmp_path): def test_qkv_reassemble(): - from litgpt import Config - from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble - # MHA config = Config(n_embd=4, n_head=4) - qkv = torch.tensor( + qkv_interleaved = torch.tensor( [ [0, 1, 2, 3], # query - [4, 5, 6, 7], # query - [8, 9, 10, 11], # query - [12, 13, 14, 15], # query [16, 17, 18, 19], # key - [20, 21, 22, 23], # key - [24, 25, 26, 27], # key - [28, 29, 30, 31], # key [32, 33, 34, 35], # value + [4, 5, 6, 7], # query + [20, 21, 22, 23], # key [36, 37, 38, 39], # value + [8, 9, 10, 11], # query + [24, 25, 26, 27], # key [40, 41, 42, 43], # value + [12, 13, 14, 15], # query + [28, 29, 30, 31], # key [44, 45, 46, 47], # value ] ) - qkv_interleaved = qkv_reassemble(qkv, config) + qkv = qkv_reassemble(qkv_interleaved, config) torch.testing.assert_close( - qkv_interleaved, + qkv, torch.tensor( [ [0, 1, 2, 3], # query - [16, 17, 18, 19], # key - [32, 33, 34, 35], # value [4, 5, 6, 7], # query - [20, 21, 22, 23], # key - [36, 37, 38, 39], # value [8, 9, 10, 11], # query - [24, 25, 26, 27], # key - [40, 41, 42, 43], # value [12, 13, 14, 15], # query + [16, 17, 18, 19], # key + [20, 21, 22, 23], # key + [24, 25, 26, 27], # key [28, 29, 30, 31], # key + [32, 33, 34, 35], # value + [36, 37, 38, 39], # value + [40, 41, 42, 43], # value [44, 45, 46, 47], # value ] ), @@ -163,30 +164,30 @@ def test_qkv_reassemble(): # GQA config = Config(n_embd=4, n_head=4, n_query_groups=2) - qkv = torch.tensor( + qkv_interleaved = torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query + [16, 17, 18, 19], # key + [24, 25, 26, 27], # value [8, 9, 10, 11], # query [12, 13, 14, 15], # query - [16, 17, 18, 19], # key [20, 21, 22, 23], # key - [24, 25, 26, 27], # value [28, 29, 30, 31], # value ] ) - qkv_interleaved = qkv_reassemble(qkv, config) + qkv = qkv_reassemble(qkv_interleaved, config) torch.testing.assert_close( - qkv_interleaved, + qkv, torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query - [16, 17, 18, 19], # key - [24, 25, 26, 27], # value [8, 9, 10, 11], # query [12, 13, 14, 15], # query + [16, 17, 18, 19], # key [20, 21, 22, 23], # key + [24, 25, 26, 27], # value [28, 29, 30, 31], # value ] ), @@ -194,7 +195,7 @@ def test_qkv_reassemble(): # MQA config = Config(n_embd=4, n_head=4, n_query_groups=1) - qkv = torch.tensor( + qkv_interleaved = torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query @@ -204,9 +205,9 @@ def test_qkv_reassemble(): [20, 21, 22, 23], # value ] ) - qkv_interleaved = qkv_reassemble(qkv, config) + qkv = qkv_reassemble(qkv_interleaved, config) torch.testing.assert_close( - qkv_interleaved, + qkv, torch.tensor( [ [0, 1, 2, 3], # query diff --git a/tests/test_convert_lit_checkpoint.py b/tests/test_convert_lit_checkpoint.py index 5809f0063d..9e0cd93c35 100644 --- a/tests/test_convert_lit_checkpoint.py +++ b/tests/test_convert_lit_checkpoint.py @@ -15,6 +15,10 @@ from transformers.models.llama import LlamaConfig, LlamaForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.olmo import OlmoConfig, OlmoForCausalLM +from transformers.models.phi.configuration_phi import PhiConfig +from transformers.models.phi.modeling_phi import PhiForCausalLM +from transformers.models.phi3.configuration_phi3 import Phi3Config +from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from litgpt import GPT, Config @@ -27,13 +31,14 @@ copy_weights_llama, copy_weights_phi, copy_weights_qwen_2_5, - qkv_split, + qkv_reassemble, ) from tests.conftest import RunIf -def test_convert_lit_checkpoint(tmp_path): - ours_config = Config.from_name("Llama-2-7b-hf", block_size=8, n_layer=2, n_embd=32, n_head=2, padding_multiple=128) +@pytest.mark.parametrize("model_name", ("pythia-14m", "falcon-7b", "Llama-2-7b-hf", "phi-2")) +def test_convert_lit_checkpoint(tmp_path, model_name): + ours_config = Config.from_name(model_name, block_size=8, n_layer=2, n_embd=32, n_head=2, padding_multiple=128) ours_model = GPT(ours_config) checkpoint_path = tmp_path / "lit_model.pth" config_path = tmp_path / "model_config.yaml" @@ -70,7 +75,7 @@ def test_against_falcon_40b(): ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} - copy_weights_falcon("40b", theirs_state_dict, ours_state_dict) + copy_weights_falcon(ours_config, theirs_state_dict, ours_state_dict) theirs_model = FalconForCausalLM(theirs_config) # assign must be set to True for torch.testing.assert_close to pass @@ -105,7 +110,7 @@ def test_against_original_gpt_neox(): ours_model = GPT(ours_config) ours_state_dict = ours_model.state_dict() theirs_state_dict = {} - copy_weights_gpt_neox(theirs_state_dict, ours_state_dict) + copy_weights_gpt_neox(ours_config, theirs_state_dict, ours_state_dict) theirs_model = GPTNeoXForCausalLM(theirs_config) # strict=False because we don't save the rotary embeddings inv frequency keys = theirs_model.load_state_dict(theirs_state_dict, strict=False) @@ -196,6 +201,7 @@ def test_against_mixtral(model_name): theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) + @torch.inference_mode() @pytest.mark.parametrize("model_name", ("OLMo-1B-hf", "OLMo-7B-hf")) def test_against_olmo(model_name): @@ -239,6 +245,7 @@ def test_against_olmo(model_name): theirs_y = theirs_model(x)["logits"] torch.testing.assert_close(ours_y, theirs_y) + @torch.inference_mode() def test_against_original_open_llama_3b(): ours_config = Config.from_name("open_llama_3b", n_layer=2, n_head=8, n_embd=32, intermediate_size=86) @@ -270,9 +277,6 @@ def test_against_original_open_llama_3b(): @torch.inference_mode() @pytest.mark.parametrize("model_name", ("phi-1_5", "phi-2")) def test_against_hf_phi(model_name): - from transformers.models.phi.configuration_phi import PhiConfig - from transformers.models.phi.modeling_phi import PhiForCausalLM - ours_config = Config.from_name( model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5 ) @@ -308,9 +312,6 @@ def test_against_hf_phi(model_name): @torch.inference_mode() @pytest.mark.parametrize("model_name", ("Phi-3-mini-4k-instruct",)) def test_against_hf_phi_3(model_name): - from transformers.models.phi3.configuration_phi3 import Phi3Config - from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM - ours_config = Config.from_name(model_name, padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256) T = 5 theirs_config = Phi3Config( @@ -425,7 +426,10 @@ def test_against_original_gemma(model_name, device, dtype): theirs_state_dict = {} copy_weights_llama(ours_config, theirs_state_dict, ours_state_dict, untie_weights=True) theirs_model = GemmaForCausalLM(theirs_config).to(device) - theirs_model.load_state_dict(theirs_state_dict, strict=False,) + theirs_model.load_state_dict( + theirs_state_dict, + strict=False, + ) # test end to end x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) @@ -587,41 +591,41 @@ def test_against_original_qwen_2_5(model_name, device, dtype): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) -def test_qkv_split(): +def test_qkv_reassemble(): # MHA config = Config(n_embd=4, n_head=4) - qkv_interleaved = torch.tensor( + qkv = torch.tensor( [ [0, 1, 2, 3], # query - [16, 17, 18, 19], # key - [32, 33, 34, 35], # value [4, 5, 6, 7], # query - [20, 21, 22, 23], # key - [36, 37, 38, 39], # value [8, 9, 10, 11], # query - [24, 25, 26, 27], # key - [40, 41, 42, 43], # value [12, 13, 14, 15], # query + [16, 17, 18, 19], # key + [20, 21, 22, 23], # key + [24, 25, 26, 27], # key [28, 29, 30, 31], # key + [32, 33, 34, 35], # value + [36, 37, 38, 39], # value + [40, 41, 42, 43], # value [44, 45, 46, 47], # value ] ) - qkv = torch.cat(qkv_split(qkv_interleaved, config)) + qkv_interleaved = qkv_reassemble(qkv, config) torch.testing.assert_close( - qkv, + qkv_interleaved, torch.tensor( [ [0, 1, 2, 3], # query - [4, 5, 6, 7], # query - [8, 9, 10, 11], # query - [12, 13, 14, 15], # query [16, 17, 18, 19], # key - [20, 21, 22, 23], # key - [24, 25, 26, 27], # key - [28, 29, 30, 31], # key [32, 33, 34, 35], # value + [4, 5, 6, 7], # query + [20, 21, 22, 23], # key [36, 37, 38, 39], # value + [8, 9, 10, 11], # query + [24, 25, 26, 27], # key [40, 41, 42, 43], # value + [12, 13, 14, 15], # query + [28, 29, 30, 31], # key [44, 45, 46, 47], # value ] ), @@ -629,30 +633,30 @@ def test_qkv_split(): # GQA config = Config(n_embd=4, n_head=4, n_query_groups=2) - qkv_interleaved = torch.tensor( + qkv = torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query - [16, 17, 18, 19], # key - [24, 25, 26, 27], # value [8, 9, 10, 11], # query [12, 13, 14, 15], # query + [16, 17, 18, 19], # key [20, 21, 22, 23], # key + [24, 25, 26, 27], # value [28, 29, 30, 31], # value ] ) - qkv = torch.cat(qkv_split(qkv_interleaved, config)) + qkv_interleaved = qkv_reassemble(qkv, config) torch.testing.assert_close( - qkv, + qkv_interleaved, torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query + [16, 17, 18, 19], # key + [24, 25, 26, 27], # value [8, 9, 10, 11], # query [12, 13, 14, 15], # query - [16, 17, 18, 19], # key [20, 21, 22, 23], # key - [24, 25, 26, 27], # value [28, 29, 30, 31], # value ] ), @@ -660,7 +664,7 @@ def test_qkv_split(): # MQA config = Config(n_embd=4, n_head=4, n_query_groups=1) - qkv_interleaved = torch.tensor( + qkv = torch.tensor( [ [0, 1, 2, 3], # query [4, 5, 6, 7], # query @@ -670,9 +674,9 @@ def test_qkv_split(): [20, 21, 22, 23], # value ] ) - qkv = torch.cat(qkv_split(qkv_interleaved, config)) + qkv_interleaved = qkv_reassemble(qkv, config) torch.testing.assert_close( - qkv, + qkv_interleaved, torch.tensor( [ [0, 1, 2, 3], # query diff --git a/tests/test_generate_sequentially.py b/tests/test_generate_sequentially.py index 51bc9d2fe1..2d7603eb60 100644 --- a/tests/test_generate_sequentially.py +++ b/tests/test_generate_sequentially.py @@ -12,13 +12,13 @@ import pytest import torch import yaml -from tests.conftest import RunIf from lightning import Fabric from litgpt import Config from litgpt.generate.sequentially import layer_to_device, replace_device, sequential from litgpt.model import GPT, Block from litgpt.scripts.download import download_from_hub +from tests.conftest import RunIf @pytest.mark.parametrize( @@ -117,8 +117,8 @@ def _test_model_1device(accelerator): "cos": device_str, "sin": device_str, "lm_head.weight": device_str, - "transformer.h.0.attn.attn.bias": device_str, - "transformer.h.0.attn.attn.weight": device_str, + "transformer.h.0.attn.qkv.bias": device_str, + "transformer.h.0.attn.qkv.weight": device_str, "transformer.h.0.attn.proj.bias": device_str, "transformer.h.0.attn.proj.weight": device_str, "transformer.h.0.mlp.fc.bias": device_str, @@ -131,8 +131,8 @@ def _test_model_1device(accelerator): "transformer.h.0.norm_2.weight": device_str, "transformer.h.0.attn.kv_cache.k": device_str, "transformer.h.0.attn.kv_cache.v": device_str, - "transformer.h.1.attn.attn.bias": device_str, - "transformer.h.1.attn.attn.weight": device_str, + "transformer.h.1.attn.qkv.bias": device_str, + "transformer.h.1.attn.qkv.weight": device_str, "transformer.h.1.attn.proj.bias": device_str, "transformer.h.1.attn.proj.weight": device_str, "transformer.h.1.mlp.fc.bias": device_str, @@ -187,8 +187,8 @@ def test_model_forward_hooks(): "transformer.wte.weight": "cuda:0", "transformer.h.0.norm_1.weight": "cuda:0", "transformer.h.0.norm_1.bias": "cuda:0", - "transformer.h.0.attn.attn.weight": "cuda:0", - "transformer.h.0.attn.attn.bias": "cuda:0", + "transformer.h.0.attn.qkv.weight": "cuda:0", + "transformer.h.0.attn.qkv.bias": "cuda:0", "transformer.h.0.attn.proj.weight": "cuda:0", "transformer.h.0.attn.proj.bias": "cuda:0", "transformer.h.0.norm_2.weight": "cuda:0", @@ -199,8 +199,8 @@ def test_model_forward_hooks(): "transformer.h.0.mlp.proj.bias": "cuda:0", "transformer.h.1.norm_1.weight": "cuda:0", "transformer.h.1.norm_1.bias": "cuda:0", - "transformer.h.1.attn.attn.weight": "cuda:0", - "transformer.h.1.attn.attn.bias": "cuda:0", + "transformer.h.1.attn.qkv.weight": "cuda:0", + "transformer.h.1.attn.qkv.bias": "cuda:0", "transformer.h.1.attn.proj.weight": "cuda:0", "transformer.h.1.attn.proj.bias": "cuda:0", "transformer.h.1.norm_2.weight": "cuda:0", @@ -211,8 +211,8 @@ def test_model_forward_hooks(): "transformer.h.1.mlp.proj.bias": "cuda:0", "transformer.h.2.norm_1.weight": "cuda:0", "transformer.h.2.norm_1.bias": "cuda:0", - "transformer.h.2.attn.attn.weight": "cuda:0", - "transformer.h.2.attn.attn.bias": "cuda:0", + "transformer.h.2.attn.qkv.weight": "cuda:0", + "transformer.h.2.attn.qkv.bias": "cuda:0", "transformer.h.2.attn.proj.weight": "cuda:0", "transformer.h.2.attn.proj.bias": "cuda:0", "transformer.h.2.norm_2.weight": "cuda:0", @@ -223,8 +223,8 @@ def test_model_forward_hooks(): "transformer.h.2.mlp.proj.bias": "cuda:0", "transformer.h.3.norm_1.weight": "cuda:1", "transformer.h.3.norm_1.bias": "cuda:1", - "transformer.h.3.attn.attn.weight": "cuda:1", - "transformer.h.3.attn.attn.bias": "cuda:1", + "transformer.h.3.attn.qkv.weight": "cuda:1", + "transformer.h.3.attn.qkv.bias": "cuda:1", "transformer.h.3.attn.proj.weight": "cuda:1", "transformer.h.3.attn.proj.bias": "cuda:1", "transformer.h.3.norm_2.weight": "cuda:1", @@ -235,8 +235,8 @@ def test_model_forward_hooks(): "transformer.h.3.mlp.proj.bias": "cuda:1", "transformer.h.4.norm_1.weight": "cuda:1", "transformer.h.4.norm_1.bias": "cuda:1", - "transformer.h.4.attn.attn.weight": "cuda:1", - "transformer.h.4.attn.attn.bias": "cuda:1", + "transformer.h.4.attn.qkv.weight": "cuda:1", + "transformer.h.4.attn.qkv.bias": "cuda:1", "transformer.h.4.attn.proj.weight": "cuda:1", "transformer.h.4.attn.proj.bias": "cuda:1", "transformer.h.4.norm_2.weight": "cuda:1", @@ -247,8 +247,8 @@ def test_model_forward_hooks(): "transformer.h.4.mlp.proj.bias": "cuda:1", "transformer.h.5.norm_1.weight": "cuda:1", "transformer.h.5.norm_1.bias": "cuda:1", - "transformer.h.5.attn.attn.weight": "cuda:1", - "transformer.h.5.attn.attn.bias": "cuda:1", + "transformer.h.5.attn.qkv.weight": "cuda:1", + "transformer.h.5.attn.qkv.bias": "cuda:1", "transformer.h.5.attn.proj.weight": "cuda:1", "transformer.h.5.attn.proj.bias": "cuda:1", "transformer.h.5.norm_2.weight": "cuda:1", diff --git a/tests/test_lora.py b/tests/test_lora.py index 079d900d0b..c417d588a4 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -1,6 +1,7 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. import os from contextlib import redirect_stdout +from copy import deepcopy from io import StringIO from itertools import product from unittest import mock @@ -23,10 +24,19 @@ from litgpt.args import EvalArgs, TrainArgs from litgpt.data import Alpaca from litgpt.lora import GPT as LoRAGPT +from litgpt.lora import ( + CausalSelfAttention, + Config, + LoRALinear, + LoRAQKVLinear, + lora_filter, + mark_only_lora_as_trainable, + merge_lora_weights, +) from litgpt.lora import CausalSelfAttention as LoRACausalSelfAttention -from litgpt.lora import Config, LoRALinear, LoRAQKVLinear, lora_filter, mark_only_lora_as_trainable, merge_lora_weights from litgpt.model import GPT as BaseGPT from litgpt.scripts.convert_hf_checkpoint import copy_weights_gemma_2, copy_weights_hf_llama +from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from tests.conftest import RunIf @@ -100,11 +110,11 @@ def test_lora_mqa_gqa(): ) assert config.n_query_groups == config.n_head model = LoRAGPT(config) - attn = model.transformer.h[0].attn.attn + attn = model.transformer.h[0].attn.qkv for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) - lora_ind = [0, 1, 6, 7, 12, 13, 18, 19, 4, 5, 10, 11, 16, 17, 22, 23] + lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23] assert attn.linear.weight.shape == (24, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (16, 2) @@ -121,7 +131,7 @@ def test_lora_mqa_gqa(): # MQA config.n_query_groups = 1 model = LoRAGPT(config) - attn = model.transformer.h[0].attn.attn + attn = model.transformer.h[0].attn.qkv for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) @@ -142,11 +152,11 @@ def test_lora_mqa_gqa(): # GQA config.n_query_groups = 2 model = LoRAGPT(config) - attn = model.transformer.h[0].attn.attn + attn = model.transformer.h[0].attn.qkv for p in attn.linear.parameters(): torch.nn.init.zeros_(p) torch.nn.init.ones_(attn.lora_B) - lora_ind = [0, 1, 2, 3, 8, 9, 10, 11, 6, 7, 14, 15] + lora_ind = [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] assert attn.linear.weight.shape == (16, 8) assert attn.lora_A.shape == (4, 8) assert attn.lora_B.shape == (12, 2) @@ -169,12 +179,12 @@ def test_lora_filter(tmp_path): saved = torch.load(save_path)["model"] expected = { - "transformer.h.1.attn.attn.lora_B", - "transformer.h.2.attn.attn.lora_B", - "transformer.h.2.attn.attn.lora_A", - "transformer.h.1.attn.attn.lora_A", - "transformer.h.0.attn.attn.lora_A", - "transformer.h.0.attn.attn.lora_B", + "transformer.h.1.attn.qkv.lora_B", + "transformer.h.2.attn.qkv.lora_B", + "transformer.h.2.attn.qkv.lora_A", + "transformer.h.1.attn.qkv.lora_A", + "transformer.h.0.attn.qkv.lora_A", + "transformer.h.0.attn.qkv.lora_B", } assert set(saved) == expected @@ -665,7 +675,7 @@ def test_against_original_gemma_2(model_name): # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} - copy_weights_gemma_2(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = LoRAGPT(ours_config).to(device) ours_model.load_state_dict(state_dict) @@ -740,29 +750,29 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa dtype_to_name[str(layer.dtype)].add(name) assert dtype_to_name == { "torch.uint8": { - "transformer.h.0.attn.attn.linear.weight", + "transformer.h.0.attn.qkv.linear.weight", "transformer.h.0.attn.proj.linear.weight", "transformer.h.0.mlp.fc.linear.weight", "transformer.h.1.mlp.proj.linear.weight", "transformer.h.0.mlp.proj.linear.weight", - "transformer.h.1.attn.attn.linear.weight", + "transformer.h.1.attn.qkv.linear.weight", "lm_head.linear.weight", "transformer.h.1.attn.proj.linear.weight", "transformer.h.1.mlp.fc.linear.weight", }, "torch.float16": { - "transformer.h.0.attn.attn.lora_B", + "transformer.h.0.attn.qkv.lora_B", "transformer.h.0.norm_2.weight", "transformer.wte.weight", "transformer.wte.norm.weight", "transformer.wte.norm.bias", "transformer.h.1.mlp.fc.linear.bias", "transformer.ln_f.bias", - "transformer.h.1.attn.attn.lora_B", + "transformer.h.1.attn.qkv.lora_B", "transformer.h.1.attn.proj.linear.bias", "transformer.h.1.norm_1.weight", - "transformer.h.1.attn.attn.linear.bias", - "transformer.h.1.attn.attn.lora_A", + "transformer.h.1.attn.qkv.linear.bias", + "transformer.h.1.attn.qkv.lora_A", "transformer.h.1.norm_1.bias", "transformer.h.1.norm_2.bias", "transformer.h.0.attn.proj.linear.bias", @@ -771,11 +781,11 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa "transformer.h.0.mlp.fc.linear.bias", "transformer.h.0.norm_2.bias", "transformer.ln_f.weight", - "transformer.h.0.attn.attn.lora_A", + "transformer.h.0.attn.qkv.lora_A", "transformer.h.1.norm_2.weight", "transformer.h.1.mlp.proj.linear.bias", "transformer.h.0.norm_1.weight", - "transformer.h.0.attn.attn.linear.bias", + "transformer.h.0.attn.qkv.linear.bias", }, } @@ -787,10 +797,10 @@ def test_lora_bitsandbytes(monkeypatch, tmp_path, fake_checkpoint_dir, alpaca_pa dtype_to_name[str(layer.dtype)].add(name) assert dtype_to_name == { "torch.float16": { - "transformer.h.1.attn.attn.lora_A", - "transformer.h.0.attn.attn.lora_A", - "transformer.h.0.attn.attn.lora_B", - "transformer.h.1.attn.attn.lora_B", + "transformer.h.1.attn.qkv.lora_A", + "transformer.h.0.attn.qkv.lora_A", + "transformer.h.0.attn.qkv.lora_B", + "transformer.h.1.attn.qkv.lora_B", } } @@ -835,11 +845,12 @@ def test_lora_model_fsdp_init(): def test_zero_pad_cpu_and_mocked_mps(): - in_features = 128 - out_features = 384 head_size = 64 n_head = 12 n_query_groups = 3 + in_features = 128 + kv_embed_dim = in_features // (n_head // n_query_groups) + out_features = in_features + 2 * kv_embed_dim enable_lora = [True, False, True] r = 4 @@ -850,12 +861,12 @@ def test_zero_pad_cpu_and_mocked_mps(): n_head=n_head, n_query_groups=n_query_groups, r=r, - enable_lora=enable_lora + enable_lora=enable_lora, ) batch_size = 64 seq_len = 64 - embed_dim = 320 + embed_dim = 160 x = torch.randn(batch_size, seq_len, embed_dim) result_cpu = model.zero_pad(x) @@ -868,3 +879,29 @@ def test_zero_pad_cpu_and_mocked_mps(): assert result_cpu.shape == result_mps.shape, "Shape mismatch between CPU and MPS" assert torch.allclose(result_cpu, result_mps), "Tensor values mismatch between CPU and MPS" + + + +def test_load_legacy_state_dict(): + """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" + config = Config( + n_embd=32, + n_head=4, + head_size=8, + n_query_groups=4, + bias=True, + lora_r=8, + lora_alpha=16, + lora_dropout=0.1 + ) + + attention_1 = CausalSelfAttention(config=config, block_idx=0) + + # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` + # and make them interleaved + state_dict = deepcopy(attention_1.state_dict()) + state_dict["attn.linear.weight"] = make_qkv_interleaved(state_dict.pop("qkv.linear.weight"), config) + state_dict["attn.linear.bias"] = make_qkv_interleaved(state_dict.pop("qkv.linear.bias"), config) + + attention_2 = CausalSelfAttention(config=config, block_idx=0) + attention_2.load_state_dict(state_dict) diff --git a/tests/test_model.py b/tests/test_model.py index 9a21f0d34d..abd1a767bf 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -2,9 +2,9 @@ from copy import deepcopy from functools import partial +from unittest import mock import pytest -from unittest import mock import torch from lightning import Fabric from lightning.fabric.utilities.imports import _IS_WINDOWS @@ -31,8 +31,8 @@ from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM import litgpt.config as config_module -from litgpt.model import batched_index_copy_ from litgpt import GPT, Config +from litgpt.model import CausalSelfAttention, batched_index_copy_ from litgpt.scripts.convert_hf_checkpoint import ( copy_weights_falcon, copy_weights_gemma_2, @@ -41,6 +41,7 @@ copy_weights_phi, copy_weights_qwen_2_5, ) +from litgpt.scripts.convert_lit_checkpoint import qkv_reassemble as make_qkv_interleaved from tests.conftest import RunIf @@ -97,7 +98,7 @@ def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residua state_dict = {} theirs_model = GPTNeoXForCausalLM(theirs_config).to(device) # load the hf initialization into our model - copy_weights_gpt_neox(state_dict, theirs_model.state_dict()) + copy_weights_gpt_neox(ours_config, state_dict, theirs_model.state_dict()) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) @@ -152,7 +153,7 @@ def test_against_hf_falcon(kwargs, device, dtype): theirs_model = FalconForCausalLM(theirs_config).to(device) theirs_state_dict = theirs_model.state_dict() state_dict = {} - copy_weights_falcon(kwargs["name"], state_dict, theirs_state_dict) + copy_weights_falcon(ours_config, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) @@ -556,6 +557,7 @@ def test_against_hf_mixtral(model_name): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + @torch.inference_mode() @pytest.mark.parametrize("model_name", ("OLMo-1B-hf", "OLMo-7B-hf")) @pytest.mark.parametrize( @@ -614,6 +616,7 @@ def test_against_olmo(model_name, device, dtype): theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) + @torch.inference_mode() @pytest.mark.parametrize( ("device", "dtype"), @@ -779,7 +782,7 @@ def test_against_original_gemma_2(model_name, device, dtype): # Gemma weights are shipped without `lm_head.weight` theirs_state_dict.pop("lm_head.weight") state_dict = {} - copy_weights_gemma_2(ours_config, {}, state_dict, theirs_state_dict) + copy_weights_gemma_2({}, state_dict, theirs_state_dict) ours_model = GPT(ours_config).to(device) ours_model.load_state_dict(state_dict) @@ -1298,3 +1301,24 @@ def test_batched_index_copy_modes(): val_3_mps = val_3 batched_index_copy_(t3_mps, dim_3, idx_3_mps, val_3_mps) assert torch.allclose(t3_cpu, t3_mps), "Mismatch with negative dimension on mocked MPS" + +def test_load_legacy_state_dict(): + """Check that a legacy state dict (with an interleaved placement in QKV matrix) can be loaded into a model with CausalSelfAttention layers.""" + config = Config( + n_embd=32, + n_head=4, + head_size=8, + n_query_groups=4, + bias=True, + ) + + attention_1 = CausalSelfAttention(config=config, block_idx=0) + + # make weights to be as-like in a legacy checkpoint, with `attn.attn.weight` instead of `attn.qkv.weight` + # and make them interleaved + state_dict = deepcopy(attention_1.state_dict()) + state_dict["attn.weight"] = make_qkv_interleaved(state_dict.pop("qkv.weight"), config) + state_dict["attn.bias"] = make_qkv_interleaved(state_dict.pop("qkv.bias"), config) + + attention_2 = CausalSelfAttention(config=config, block_idx=0) + attention_2.load_state_dict(state_dict)