From 2f2ecd9a4042c340bb4116e1700f6d6371c729f3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 4 Sep 2024 20:31:48 -0700 Subject: [PATCH] [Misc] Clean up RoPE forward_native (#8076) --- .../model_executor/layers/rotary_embedding.py | 95 ++++--------------- 1 file changed, 19 insertions(+), 76 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index c5a0278e485d4..d323f6cc432a2 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -28,7 +28,6 @@ import torch.nn as nn from vllm.model_executor.custom_op import CustomOp -from vllm.platforms import current_platform def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -48,21 +47,29 @@ def _apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, + is_neox_style: bool, ) -> torch.Tensor: """ Args: x: [num_tokens, num_heads, head_size] cos: [num_tokens, head_size // 2] sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. """ - orig_dtype = x.dtype - x = x.float() - x1, x2 = torch.chunk(x, 2, dim=-1) - cos = cos.unsqueeze(-2) - sin = sin.unsqueeze(-2) + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] o1 = x1 * cos - x2 * sin o2 = x2 * cos + x1 * sin - return torch.cat((o1, o2), dim=-1).to(orig_dtype) + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) class RotaryEmbedding(CustomOp): @@ -87,10 +94,9 @@ def __init__( cache = self._compute_cos_sin_cache() cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) - self.use_native2 = current_platform.is_tpu() and is_neox_style - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to @@ -119,59 +125,7 @@ def forward_native( key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - """A PyTorch-native implementation equivalent to forward(). - - This method mimics the implementation of the custom CUDA kernel - used in `forward_cuda()`. - """ - query = query.view(*query.shape[:-1], -1, self.head_size) - key = key.view(*key.shape[:-1], -1, self.head_size) - - query_rot = query[..., :self.rotary_dim] - key_rot = key[..., :self.rotary_dim] - if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim:] - key_pass = key[..., self.rotary_dim:] - - self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( - positions.device, dtype=query.dtype) - cos_sin = self.cos_sin_cache[torch.add(positions, offsets) - if offsets is not None else positions] - cos, sin = cos_sin.chunk(2, dim=-1) - if self.is_neox_style: - # NOTE(woosuk): Here we assume that the positions tensor has the - # shape [batch_size, seq_len]. - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) - else: - cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) - sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) - - rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj - query_rot = query_rot * cos + rotate_fn(query_rot) * sin - key_rot = key_rot * cos + rotate_fn(key_rot) * sin - - if self.rotary_dim < self.head_size: - query = torch.cat((query_rot, query_pass), dim=-1) - key = torch.cat((key_rot, key_pass), dim=-1) - else: - query = query_rot - key = key_rot - query = query.flatten(-2) - key = key.flatten(-2) - return query, key - - def forward_native2( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Another PyTorch-native implementation of forward(). - - This method might perform better than `forward_native()` when compiled. - """ + """A PyTorch-native implementation of forward().""" if offsets is not None: positions = positions + offsets positions = positions.flatten() @@ -183,14 +137,14 @@ def forward_native2( query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] - query_rot = _apply_rotary_emb(query_rot, cos, sin) + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] - key_rot = _apply_rotary_emb(key_rot, cos, sin) + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -203,7 +157,7 @@ def forward_cuda( ) -> Tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops - self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. @@ -240,17 +194,6 @@ def forward_xpu( self.cos_sin_cache, self.is_neox_style) return query, key - def forward_tpu( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - forward_fn = (self.forward_native2 - if self.use_native2 else self.forward_native) - return forward_fn(positions, query, key, offsets) - def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}"