From db308ba1f1aaf7cc2a56753f19d2cb33a4255f9b Mon Sep 17 00:00:00 2001 From: Matthias Seeger Date: Tue, 24 Dec 2024 20:24:35 +0100 Subject: [PATCH] Small fixes and refactoring (#1861) Co-authored-by: Andrei-Aksionov --- extensions/thunder/unsloth/executor.py | 2 +- litgpt/adapter.py | 4 +- litgpt/adapter_v2.py | 4 +- litgpt/config.py | 53 ++++++++++------- litgpt/generate/base.py | 2 +- litgpt/lora.py | 4 +- litgpt/model.py | 80 +++++++++++++++++--------- tests/test_generate.py | 8 ++- tests/test_generate_adapter.py | 10 +++- 9 files changed, 108 insertions(+), 59 deletions(-) diff --git a/extensions/thunder/unsloth/executor.py b/extensions/thunder/unsloth/executor.py index a0ed54598a..1779daf8ee 100644 --- a/extensions/thunder/unsloth/executor.py +++ b/extensions/thunder/unsloth/executor.py @@ -240,7 +240,7 @@ def unsloth_apply_rope_meta( Q: TensorProxy, cos: TensorProxy, sin: TensorProxy ) -> Tuple[TensorProxy, TensorProxy, TensorProxy, int, int, int]: batch, n_heads, seq_len, head_dim = Q.shape - assert seq_len <= cos.shape[0] + assert seq_len <= cos.shape[-2] BLOCK_SIZE, num_warps = kernels.calculate_settings(head_dim // 2) div, mod = divmod(n_heads, kernels.rope_embedding.ROPE_GROUP_SIZE) n_groups = div + (mod != 0) diff --git a/litgpt/adapter.py b/litgpt/adapter.py index 8523cec814..628217b61c 100644 --- a/litgpt/adapter.py +++ b/litgpt/adapter.py @@ -132,8 +132,8 @@ def __init__(self, config: Config, block_idx: int) -> None: self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None self.block_idx = block_idx self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_placing == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index 1ad3d40b9d..7c94a8d630 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -179,8 +179,8 @@ def __init__(self, config: Config, block_idx: int) -> None: self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None self.block_idx = block_idx self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_placing == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config diff --git a/litgpt/config.py b/litgpt/config.py index af4098c2d3..a4a70c8238 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -15,23 +15,23 @@ class Config: name: str = "" hf_config: dict = field(default_factory=dict) - scale_embeddings: bool = False - attention_scores_scalar: Optional[int] = None + # General size parameters block_size: int = 4096 - sliding_window_size: Optional[int] = None - sliding_window_layer_placing: Optional[Literal["all", "interleaved"]] = None + n_layer: int = 16 + n_embd: int = 4096 vocab_size: int = 50254 padding_multiple: int = 512 padded_vocab_size: Optional[int] = None - n_layer: int = 16 + # Transformer block (structure, normalizations) + norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" + norm_eps: float = 1e-5 + post_attention_norm: bool = False + post_mlp_norm: bool = False + parallel_residual: bool = True + shared_attention_norm: bool = False + # Transformer block (self-attention) n_head: int = 32 head_size: Optional[int] = None - n_embd: int = 4096 - rotary_percentage: float = 0.25 - parallel_residual: bool = True - bias: bool = True - lm_head_bias: bool = False - attn_bias: bool = False # to use multi-head attention (MHA), set this to `n_head` (default) # to use multi-query attention (MQA), set this to 1 # to use grouped-query attention (GQA), set this to a value in between @@ -53,20 +53,29 @@ class Config: # # credit https://arxiv.org/pdf/2305.13245.pdf n_query_groups: Optional[int] = None - shared_attention_norm: bool = False - norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" - post_attention_norm: bool = False - post_mlp_norm: bool = False - norm_eps: float = 1e-5 - mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" - gelu_approximate: str = "none" - intermediate_size: Optional[int] = None - rope_condense_ratio: int = 1 + attn_bias: bool = False + attention_scores_scalar: Optional[int] = None + sliding_window_size: Optional[int] = None + sliding_window_layer_placing: Optional[Literal["all", "interleaved"]] = None + # if `attention_logit_softcapping` is used, cannot use optimized + # `torch.nn.functional.scaled_dot_product_attention` (which implements + # Flash attention), may result in higher memory and runtime footprint. + attention_logit_softcapping: Optional[float] = None + # Rotary position embedding (RoPE) rope_base: int = 10000 + rotary_percentage: float = 0.25 + rope_condense_ratio: int = 1 rope_adjustments: Optional[dict] = None + # Transformer block (MLP) + intermediate_size: Optional[int] = None + bias: bool = True + mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" + gelu_approximate: str = "none" n_expert: int = 0 n_expert_per_token: int = 0 - attention_logit_softcapping: Optional[float] = None + # GPT before/after blocks + scale_embeddings: bool = False + lm_head_bias: bool = False final_logit_softcapping: Optional[float] = None def __post_init__(self): @@ -99,7 +108,7 @@ def __post_init__(self): self.rope_n_elem = int(self.rotary_percentage * self.head_size) if self.sliding_window_size is not None: - self.sliding_window_layer_placing = ( + self.sliding_window_layer_stride = ( 1 if (self.sliding_window_layer_placing is None or self.sliding_window_layer_placing == "all") else 2 ) diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index d349502489..866947beea 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -230,7 +230,7 @@ def batched_generate_fn( Args: model: The model to use. prompts: A 2D tensor of shape [batch_size, prompt_length]. - max_returned_tokens: The maximum number of new tokens to return. Does not include the prompt tokens. + max_returned_tokens: The maximum number of tokens to return, including the prompt tokens. sample_args: The dictionary of kwargs to pass to sample() for each each token for each index in the batch. stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens. include_prompt: Whether to output the prompt tokens. diff --git a/litgpt/lora.py b/litgpt/lora.py index 18a472337b..db48175eac 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -628,8 +628,8 @@ def __init__(self, config: Config, block_idx: int) -> None: # disabled by default self.kv_cache: Optional[KVCache] = None self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_placing == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config diff --git a/litgpt/model.py b/litgpt/model.py index 17b3b4ab04..643ba59a71 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -72,11 +72,30 @@ def _init_weights(self, module: nn.Module) -> None: torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Args: + idx (torch.Tensor): Input token indices, shape `(B, T)` + input_pos (torch.Tensor, optional): Contains input positions, + either with shape `(T,)` or `(B, T)`, if provided. This is used + for generative inference, where a KV cache is required. By + default, this assumes `input_dim == arange(T)` with all inputs + up to `T` provided upfront. + + Returns: + torch.Tensor: Output (logits), shape `(B, T, config.padded_vocab_size)` + """ + if idx.dim() != 2: + raise ValueError(f"idx must have 2 dimensions, idx.shape = {idx.shape}") T = idx.size(1) if self.max_seq_length < T: raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") if input_pos is not None: # use the kv cache + if input_pos.dim() > 2: + # otherwise, things go wrong in `apply_rope` + raise ValueError(f"input_pos must have 1 or 2 dimensions, input_pos.shape = {input_pos.shape}") + if input_pos.shape[-1] != T: + raise ValueError(f"input_pos.shape[-1] = {input_pos.shape[-1]} != {T} = idx.shape[1], must be the same") cos = batched_index_select(self.cos, 0, input_pos) sin = batched_index_select(self.sin, 0, input_pos) if self.mask_cache is None: @@ -87,20 +106,22 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) - # we get if input_pos has a batch dimension mask = mask.squeeze(1) else: - cos = self.cos[:T] - sin = self.sin[:T] - mask = None + # unsqueeze to have a batch dimension + cos = self.cos[:T].unsqueeze(0) + sin = self.sin[:T].unsqueeze(0) + # `cos`, `sin` have shape (1, T, config.rope_n_elem) + mask = None # defaults to causal mask - x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd) if self.config.scale_embeddings: x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype) for block in self.transformer.h: x = block(x, cos, sin, mask, input_pos) x = self.transformer.ln_f(x) - x = self.lm_head(x) # (b, t, vocab_size) + x = self.lm_head(x) # (B, T, padded_vocab_size) if self.config.final_logit_softcapping is not None: - x = torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping + x = do_softcapping(x, self.config.final_logit_softcapping) return x @classmethod @@ -122,10 +143,8 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso elif num_params_present == 4: # These parameters should always be used together so that we don't interfere with standard rope extra_config = { - "original_max_seq_len": self.config.rope_adjustments["original_max_seq_len"], - "factor": self.config.rope_adjustments["factor"], - "low_freq_factor": self.config.rope_adjustments["low_freq_factor"], - "high_freq_factor": self.config.rope_adjustments["high_freq_factor"], + name: self.config.rope_adjustments[name] + for name in adjusted_params_required } else: # Some but not all parameters are specified; raise an error @@ -231,12 +250,13 @@ def forward( attention_output = self.post_attention_norm(attention_output) if self.config.parallel_residual: - x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x) - x = self.mlp(x_normed) + attention_output + x + if not self.config.shared_attention_norm: + x_normed = self.norm_2(x) + x = attention_output + x else: x = attention_output + x - x = self.post_mlp_norm(self.mlp(self.norm_2(x))) + x - return x + x_normed = self.norm_2(x) + return self.post_mlp_norm(self.mlp(x_normed)) + x class CausalSelfAttention(nn.Module): @@ -251,8 +271,8 @@ def __init__(self, config: Config, block_idx: int) -> None: # disabled by default self.kv_cache: Optional[KVCache] = None self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_placing == 0 + config.sliding_window_size is not None and + block_idx % config.sliding_window_layer_stride == 0 ) self.config = config @@ -275,15 +295,17 @@ def forward( qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - # split batched computation into three + # split batched computation into three: + # q: (B, n_query_groups, q_per_kv, T, hs) + # k, v: (B, n_query_groups, 1, T, hs) q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) # maybe repeat k and v if for the non multi-head attention cases # training: flash attention requires it # inference: multi-query would require a full kv cache so avoid it to limit its memory usage if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): - k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) - v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size) + k = k.expand(*q.shape) + v = v.expand(*q.shape) q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) @@ -331,11 +353,8 @@ def scaled_dot_product_attention( # with softcapping we cannot use SDPA if self.config.attention_logit_softcapping is not None: - scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size) scores = q @ k.mT * scale - scores = ( - torch.tanh(scores / self.config.attention_logit_softcapping) * self.config.attention_logit_softcapping - ) + scores = do_softcapping(scores, self.config.attention_logit_softcapping) if mask is None: mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1) mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min) @@ -496,10 +515,11 @@ def batched_index_select(t, dim, idx): res = torch.index_select(t, dim, idx.reshape(-1)) # flat index # split out single batch idx res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :]) - # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors - dims = [dim] + list(range(res.dim())) - del dims[dim + 1] - res = res.permute(dims) + if dim > 0: + # move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors + dims = [dim] + list(range(res.dim())) + del dims[dim + 1] + res = res.permute(dims) # unflatten batch dims res = res.view(*batch_shape, *res.shape[1:]) return res @@ -556,6 +576,8 @@ def batched_index_copy_(t, dim, idx, val): def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + # x: (B, nh, T, hs) + # sin, cos: (B, T, hs) or (1, T, hs) head_size = x.size(-1) x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) @@ -571,6 +593,10 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T return roped.to(dtype=x.dtype) +def do_softcapping(x: torch.Tensor, thresh: float) -> torch.Tensor: + return torch.tanh(x / thresh) * thresh + + class KVCache(nn.Module): def __init__( self, diff --git a/tests/test_generate.py b/tests/test_generate.py index 6fc561b945..592f2c3acf 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -93,7 +93,13 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like): pattern = rf".*^{re.escape(expected_output.strip())}$.*" assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE) - assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue() + err_value = err.getvalue() + expected_parts = [ + "'padded_vocab_size': 512", + "'n_layer': 2", + "'n_head': 4", + ] + assert all(part in err_value for part in expected_parts) def test_cli(): diff --git a/tests/test_generate_adapter.py b/tests/test_generate_adapter.py index 6e57ff0c5e..a40672d03e 100644 --- a/tests/test_generate_adapter.py +++ b/tests/test_generate_adapter.py @@ -55,7 +55,15 @@ def test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like): pattern = rf".*^{re.escape(expected_output.strip())}$.*" assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE) - assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4, 'head_size': 2, 'n_embd': 8" in err.getvalue() + err_value = err.getvalue() + expected_parts = [ + "'padded_vocab_size': 512", + "'n_layer': 2", + "'n_head': 4", + "'head_size': 2", + "'n_embd': 8", + ] + assert all(part in err_value for part in expected_parts) @pytest.mark.parametrize("version", ("", "_v2"))