Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request: NTK rope support #479

Closed
lucasjinreal opened this issue Jul 17, 2023 · 10 comments
Closed

Request: NTK rope support #479

lucasjinreal opened this issue Jul 17, 2023 · 10 comments

Comments

@lucasjinreal
Copy link

lucasjinreal commented Jul 17, 2023

Hi, there are some very sucessfull experiements shows that NTK based RoPE can obtain a good extrapolate ability without even finetune.

I have test as well, it works well, an 1024 trained model can have a very impressive long context ability with NTK RoPE.

Would consider support it as it doesn't requires many changes (maybe)?

However, the pos op implement baked in cu op kernel.

Currently I can using torch code to judge if context length bigger than 2048 then applying NTK, but isn't would be better if vllm can support it out of box?

@lucasjinreal
Copy link
Author

I have looked a little bit deeper, found actually this implementation is simple, no need to edit any cu files.

I have drafted a version to support ntk, see if it can works.

@lucasjinreal
Copy link
Author

I have tested with NTK support in vllm, it works, the extrapolate can up to 8k without any finetuning.

@lucasjinreal
Copy link
Author

Here was the main modification:

def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
        seq_len: int,
    ) -> torch.Tensor:
        """ PagedAttention forward pass with rotary embedding.

        Args:
            positions: shape = [num_tokens]
                        query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_heads * head_size]
            value: shape = [num_tokens, num_heads * head_size]
            key_cache: shape = [num_blocks, num_heads, head_size/x,
                block_size, x]
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """

        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
        if seq_len > self.max_seq_len_cached:
            print(f'debug dtypes: {value.dtype}, {query.dtype} {positions.device} {self.inv_freq.dtype}')
            t = torch.arange(seq_len, device=positions.device, dtype=self.inv_freq.dtype)
            inv_freq = self.inv_freq
            dim = self.dim
            alpha = seq_len / 1024 - 1
            base = self.base * alpha ** (dim / (dim - 2))
            inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(positions.device) / dim))

            freqs = torch.einsum("i,j->ij", t, inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
            cos = emb.cos()
            sin = emb.sin()
            cache = torch.cat((cos, sin), dim=-1).to(self.inv_freq.dtype)
            pos_encoding_ops.rotary_embedding_neox(
                positions,
                query,
                key,
                self.head_size,
                cache,
            )
        else:
            pos_encoding_ops.rotary_embedding_neox(
                positions,
                query,
                key,
                self.head_size,
                self.cos_sin_cache,
            )
        return super().forward(
            query,
            key,
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )

@81549361
Copy link

Great, can you please tell me how to use it?

@abarcovschi
Copy link

Do you know if this can be extended to a 16k context size? If so could you please provide the code necessary for this? @lucasjinreal

@ShadowTeamCN
Copy link

ShadowTeamCN commented Aug 1, 2023

I have tested with NTK support in vllm, it works, the extrapolate can up to 8k without any finetuning.

Here was the main modification:

def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
        seq_len: int,
    ) -> torch.Tensor:
        """ PagedAttention forward pass with rotary embedding.

        Args:
            positions: shape = [num_tokens]
                        query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_heads * head_size]
            value: shape = [num_tokens, num_heads * head_size]
            key_cache: shape = [num_blocks, num_heads, head_size/x,
                block_size, x]
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """

        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
        if seq_len > self.max_seq_len_cached:
            print(f'debug dtypes: {value.dtype}, {query.dtype} {positions.device} {self.inv_freq.dtype}')
            t = torch.arange(seq_len, device=positions.device, dtype=self.inv_freq.dtype)
            inv_freq = self.inv_freq
            dim = self.dim
            alpha = seq_len / 1024 - 1
            base = self.base * alpha ** (dim / (dim - 2))
            inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(positions.device) / dim))

            freqs = torch.einsum("i,j->ij", t, inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
            cos = emb.cos()
            sin = emb.sin()
            cache = torch.cat((cos, sin), dim=-1).to(self.inv_freq.dtype)
            pos_encoding_ops.rotary_embedding_neox(
                positions,
                query,
                key,
                self.head_size,
                cache,
            )
        else:
            pos_encoding_ops.rotary_embedding_neox(
                positions,
                query,
                key,
                self.head_size,
                self.cos_sin_cache,
            )
        return super().forward(
            query,
            key,
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )

does seq_len in this forward func equals to key.size(0)+ key_cache.size(0)?

@lucasjinreal
Copy link
Author

@ShadowTeamCN Am not sure, it should same as torch side len

@PaynatPierre
Copy link

Here was the main modification:

def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_cache: torch.Tensor,
        value_cache: torch.Tensor,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
        seq_len: int,
    ) -> torch.Tensor:
        """ PagedAttention forward pass with rotary embedding.

        Args:
            positions: shape = [num_tokens]
                        query: shape = [num_tokens, num_heads * head_size]
            key: shape = [num_tokens, num_heads * head_size]
            value: shape = [num_tokens, num_heads * head_size]
            key_cache: shape = [num_blocks, num_heads, head_size/x,
                block_size, x]
            value_cache: shape = [num_blocks, num_heads, head_size, block_size]
            input_metadata: metadata for paged attention.
            cache_event: event to wait for the cache operations to finish.

        Returns:
            shape = [num_tokens, num_heads * head_size]
        """

        # Apply rotary embedding to the query and key before passing them
        # to the attention op.
        if seq_len > self.max_seq_len_cached:
            print(f'debug dtypes: {value.dtype}, {query.dtype} {positions.device} {self.inv_freq.dtype}')
            t = torch.arange(seq_len, device=positions.device, dtype=self.inv_freq.dtype)
            inv_freq = self.inv_freq
            dim = self.dim
            alpha = seq_len / 1024 - 1
            base = self.base * alpha ** (dim / (dim - 2))
            inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(positions.device) / dim))

            freqs = torch.einsum("i,j->ij", t, inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(positions.device)
            cos = emb.cos()
            sin = emb.sin()
            cache = torch.cat((cos, sin), dim=-1).to(self.inv_freq.dtype)
            pos_encoding_ops.rotary_embedding_neox(
                positions,
                query,
                key,
                self.head_size,
                cache,
            )
        else:
            pos_encoding_ops.rotary_embedding_neox(
                positions,
                query,
                key,
                self.head_size,
                self.cos_sin_cache,
            )
        return super().forward(
            query,
            key,
            value,
            key_cache,
            value_cache,
            input_metadata,
            cache_event,
        )

In which file do you make this change exactly ?

@EricLingRui
Copy link

I passed in two samples in a batch with lengths of 6 and 8, respectively.
I print the positions value , its like:
[0,1,2,3,4,5,0,1,2,3,4,5,6,7]
If so, I guess pos_encoding_ops need to slice cos_sin_cache values internally separately.
so, I feel that it is difficult to implement NTK-aware without changing cu ops to adapt to batch infer.

@hmellor
Copy link
Collaborator

hmellor commented Mar 8, 2024

Closing as RoPE is now supported. If this is incorrect, feel free to re-open this issue.

@hmellor hmellor closed this as completed Mar 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants