Skip to content

Commit

Permalink
Small fixes and refactoring (#1861)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrei-Aksionov <aksionau.andrei@gmail.com>
  • Loading branch information
mseeger and Andrei-Aksionov authored Dec 24, 2024
1 parent 5670d46 commit db308ba
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 59 deletions.
2 changes: 1 addition & 1 deletion extensions/thunder/unsloth/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def unsloth_apply_rope_meta(
Q: TensorProxy, cos: TensorProxy, sin: TensorProxy
) -> Tuple[TensorProxy, TensorProxy, TensorProxy, int, int, int]:
batch, n_heads, seq_len, head_dim = Q.shape
assert seq_len <= cos.shape[0]
assert seq_len <= cos.shape[-2]
BLOCK_SIZE, num_warps = kernels.calculate_settings(head_dim // 2)
div, mod = divmod(n_heads, kernels.rope_embedding.ROPE_GROUP_SIZE)
n_groups = div + (mod != 0)
Expand Down
4 changes: 2 additions & 2 deletions litgpt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def __init__(self, config: Config, block_idx: int) -> None:
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
self.block_idx = block_idx
self.apply_sliding_window_attention = (
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_placing == 0
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_stride == 0
)
self.config = config

Expand Down
4 changes: 2 additions & 2 deletions litgpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ def __init__(self, config: Config, block_idx: int) -> None:
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
self.block_idx = block_idx
self.apply_sliding_window_attention = (
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_placing == 0
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_stride == 0
)

self.config = config
Expand Down
53 changes: 31 additions & 22 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@
class Config:
name: str = ""
hf_config: dict = field(default_factory=dict)
scale_embeddings: bool = False
attention_scores_scalar: Optional[int] = None
# General size parameters
block_size: int = 4096
sliding_window_size: Optional[int] = None
sliding_window_layer_placing: Optional[Literal["all", "interleaved"]] = None
n_layer: int = 16
n_embd: int = 4096
vocab_size: int = 50254
padding_multiple: int = 512
padded_vocab_size: Optional[int] = None
n_layer: int = 16
# Transformer block (structure, normalizations)
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
norm_eps: float = 1e-5
post_attention_norm: bool = False
post_mlp_norm: bool = False
parallel_residual: bool = True
shared_attention_norm: bool = False
# Transformer block (self-attention)
n_head: int = 32
head_size: Optional[int] = None
n_embd: int = 4096
rotary_percentage: float = 0.25
parallel_residual: bool = True
bias: bool = True
lm_head_bias: bool = False
attn_bias: bool = False
# to use multi-head attention (MHA), set this to `n_head` (default)
# to use multi-query attention (MQA), set this to 1
# to use grouped-query attention (GQA), set this to a value in between
Expand All @@ -53,20 +53,29 @@ class Config:
#
# credit https://arxiv.org/pdf/2305.13245.pdf
n_query_groups: Optional[int] = None
shared_attention_norm: bool = False
norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
post_attention_norm: bool = False
post_mlp_norm: bool = False
norm_eps: float = 1e-5
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
gelu_approximate: str = "none"
intermediate_size: Optional[int] = None
rope_condense_ratio: int = 1
attn_bias: bool = False
attention_scores_scalar: Optional[int] = None
sliding_window_size: Optional[int] = None
sliding_window_layer_placing: Optional[Literal["all", "interleaved"]] = None
# if `attention_logit_softcapping` is used, cannot use optimized
# `torch.nn.functional.scaled_dot_product_attention` (which implements
# Flash attention), may result in higher memory and runtime footprint.
attention_logit_softcapping: Optional[float] = None
# Rotary position embedding (RoPE)
rope_base: int = 10000
rotary_percentage: float = 0.25
rope_condense_ratio: int = 1
rope_adjustments: Optional[dict] = None
# Transformer block (MLP)
intermediate_size: Optional[int] = None
bias: bool = True
mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP"
gelu_approximate: str = "none"
n_expert: int = 0
n_expert_per_token: int = 0
attention_logit_softcapping: Optional[float] = None
# GPT before/after blocks
scale_embeddings: bool = False
lm_head_bias: bool = False
final_logit_softcapping: Optional[float] = None

def __post_init__(self):
Expand Down Expand Up @@ -99,7 +108,7 @@ def __post_init__(self):
self.rope_n_elem = int(self.rotary_percentage * self.head_size)

if self.sliding_window_size is not None:
self.sliding_window_layer_placing = (
self.sliding_window_layer_stride = (
1 if (self.sliding_window_layer_placing is None or self.sliding_window_layer_placing == "all") else 2
)

Expand Down
2 changes: 1 addition & 1 deletion litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def batched_generate_fn(
Args:
model: The model to use.
prompts: A 2D tensor of shape [batch_size, prompt_length].
max_returned_tokens: The maximum number of new tokens to return. Does not include the prompt tokens.
max_returned_tokens: The maximum number of tokens to return, including the prompt tokens.
sample_args: The dictionary of kwargs to pass to sample() for each each token for each index in the batch.
stop_tokens: A tuple of stop sequences. If any of the sequences are generated, the generation stops early before max_returned_tokens.
include_prompt: Whether to output the prompt tokens.
Expand Down
4 changes: 2 additions & 2 deletions litgpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,8 +628,8 @@ def __init__(self, config: Config, block_idx: int) -> None:
# disabled by default
self.kv_cache: Optional[KVCache] = None
self.apply_sliding_window_attention = (
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_placing == 0
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_stride == 0
)

self.config = config
Expand Down
80 changes: 53 additions & 27 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,30 @@ def _init_weights(self, module: nn.Module) -> None:
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Args:
idx (torch.Tensor): Input token indices, shape `(B, T)`
input_pos (torch.Tensor, optional): Contains input positions,
either with shape `(T,)` or `(B, T)`, if provided. This is used
for generative inference, where a KV cache is required. By
default, this assumes `input_dim == arange(T)` with all inputs
up to `T` provided upfront.
Returns:
torch.Tensor: Output (logits), shape `(B, T, config.padded_vocab_size)`
"""
if idx.dim() != 2:
raise ValueError(f"idx must have 2 dimensions, idx.shape = {idx.shape}")
T = idx.size(1)
if self.max_seq_length < T:
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")

if input_pos is not None: # use the kv cache
if input_pos.dim() > 2:
# otherwise, things go wrong in `apply_rope`
raise ValueError(f"input_pos must have 1 or 2 dimensions, input_pos.shape = {input_pos.shape}")
if input_pos.shape[-1] != T:
raise ValueError(f"input_pos.shape[-1] = {input_pos.shape[-1]} != {T} = idx.shape[1], must be the same")
cos = batched_index_select(self.cos, 0, input_pos)
sin = batched_index_select(self.sin, 0, input_pos)
if self.mask_cache is None:
Expand All @@ -87,20 +106,22 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -
# we get if input_pos has a batch dimension
mask = mask.squeeze(1)
else:
cos = self.cos[:T]
sin = self.sin[:T]
mask = None
# unsqueeze to have a batch dimension
cos = self.cos[:T].unsqueeze(0)
sin = self.sin[:T].unsqueeze(0)
# `cos`, `sin` have shape (1, T, config.rope_n_elem)
mask = None # defaults to causal mask

x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd)
if self.config.scale_embeddings:
x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype)

for block in self.transformer.h:
x = block(x, cos, sin, mask, input_pos)
x = self.transformer.ln_f(x)
x = self.lm_head(x) # (b, t, vocab_size)
x = self.lm_head(x) # (B, T, padded_vocab_size)
if self.config.final_logit_softcapping is not None:
x = torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping
x = do_softcapping(x, self.config.final_logit_softcapping)
return x

@classmethod
Expand All @@ -122,10 +143,8 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso
elif num_params_present == 4:
# These parameters should always be used together so that we don't interfere with standard rope
extra_config = {
"original_max_seq_len": self.config.rope_adjustments["original_max_seq_len"],
"factor": self.config.rope_adjustments["factor"],
"low_freq_factor": self.config.rope_adjustments["low_freq_factor"],
"high_freq_factor": self.config.rope_adjustments["high_freq_factor"],
name: self.config.rope_adjustments[name]
for name in adjusted_params_required
}
else:
# Some but not all parameters are specified; raise an error
Expand Down Expand Up @@ -231,12 +250,13 @@ def forward(
attention_output = self.post_attention_norm(attention_output)

if self.config.parallel_residual:
x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x)
x = self.mlp(x_normed) + attention_output + x
if not self.config.shared_attention_norm:
x_normed = self.norm_2(x)
x = attention_output + x
else:
x = attention_output + x
x = self.post_mlp_norm(self.mlp(self.norm_2(x))) + x
return x
x_normed = self.norm_2(x)
return self.post_mlp_norm(self.mlp(x_normed)) + x


class CausalSelfAttention(nn.Module):
Expand All @@ -251,8 +271,8 @@ def __init__(self, config: Config, block_idx: int) -> None:
# disabled by default
self.kv_cache: Optional[KVCache] = None
self.apply_sliding_window_attention = (
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_placing == 0
config.sliding_window_size is not None and
block_idx % config.sliding_window_layer_stride == 0
)

self.config = config
Expand All @@ -275,15 +295,17 @@ def forward(
qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)

# split batched computation into three
# split batched computation into three:
# q: (B, n_query_groups, q_per_kv, T, hs)
# k, v: (B, n_query_groups, 1, T, hs)
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)

# maybe repeat k and v if for the non multi-head attention cases
# training: flash attention requires it
# inference: multi-query would require a full kv cache so avoid it to limit its memory usage
if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1):
k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
k = k.expand(*q.shape)
v = v.expand(*q.shape)

q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
Expand Down Expand Up @@ -331,11 +353,8 @@ def scaled_dot_product_attention(

# with softcapping we cannot use SDPA
if self.config.attention_logit_softcapping is not None:
scale = 1.0 / math.sqrt(self.config.attention_scores_scalar or self.config.head_size)
scores = q @ k.mT * scale
scores = (
torch.tanh(scores / self.config.attention_logit_softcapping) * self.config.attention_logit_softcapping
)
scores = do_softcapping(scores, self.config.attention_logit_softcapping)
if mask is None:
mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1)
mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min)
Expand Down Expand Up @@ -496,10 +515,11 @@ def batched_index_select(t, dim, idx):
res = torch.index_select(t, dim, idx.reshape(-1)) # flat index
# split out single batch idx
res = res.view(*t.shape[:dim], -1, idx_size, *t.shape[dim + 1 :])
# move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors
dims = [dim] + list(range(res.dim()))
del dims[dim + 1]
res = res.permute(dims)
if dim > 0:
# move batch dim to front, this is np.rollaxis(res, dim, 0) for tensors
dims = [dim] + list(range(res.dim()))
del dims[dim + 1]
res = res.permute(dims)
# unflatten batch dims
res = res.view(*batch_shape, *res.shape[1:])
return res
Expand Down Expand Up @@ -556,6 +576,8 @@ def batched_index_copy_(t, dim, idx, val):


def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
# x: (B, nh, T, hs)
# sin, cos: (B, T, hs) or (1, T, hs)
head_size = x.size(-1)
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
Expand All @@ -571,6 +593,10 @@ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.T
return roped.to(dtype=x.dtype)


def do_softcapping(x: torch.Tensor, thresh: float) -> torch.Tensor:
return torch.tanh(x / thresh) * thresh


class KVCache(nn.Module):
def __init__(
self,
Expand Down
8 changes: 7 additions & 1 deletion tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,13 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like):
pattern = rf".*^{re.escape(expected_output.strip())}$.*"
assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE)

assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue()
err_value = err.getvalue()
expected_parts = [
"'padded_vocab_size': 512",
"'n_layer': 2",
"'n_head': 4",
]
assert all(part in err_value for part in expected_parts)


def test_cli():
Expand Down
10 changes: 9 additions & 1 deletion tests/test_generate_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,15 @@ def test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like):
pattern = rf".*^{re.escape(expected_output.strip())}$.*"
assert re.match(pattern, out.getvalue().strip(), re.DOTALL | re.MULTILINE)

assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4, 'head_size': 2, 'n_embd': 8" in err.getvalue()
err_value = err.getvalue()
expected_parts = [
"'padded_vocab_size': 512",
"'n_layer': 2",
"'n_head': 4",
"'head_size': 2",
"'n_embd': 8",
]
assert all(part in err_value for part in expected_parts)


@pytest.mark.parametrize("version", ("", "_v2"))
Expand Down

0 comments on commit db308ba

Please sign in to comment.