Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Xavier: Share KV cache between VLLM replicas #2732

Merged
merged 18 commits into from
Jan 10, 2025

Conversation

ChengjieLi28
Copy link
Contributor

@ChengjieLi28 ChengjieLi28 commented Jan 3, 2025

Xavier: Share KV cache between VLLM replicas

Naming

It is derived from Professor X (Charles Francis Xavier) in the Marvel Comics X-Men series. The project name starts with "X," and like Professor X, who possesses a powerful mind that controls information, this metaphorically refers to the project managing the data scheduling in vllm.

Purpose

In vllm with multiple replicas, some long prompts have a lengthy prefill time. If other replicas have already computed the results, they can be directly transferred and used.

Usage

Simply add the parameter enable_xavier=True when starting the vllm model.

Test

Using this script to generate a long prompt for LLM (about 9k+ prompt token):

from faker import Faker
import pandas as pd


def gen_data(lines: int):
    faker = Faker()
    data = {
        "Name": [faker.name() for _ in range(lines)],
        "Age": [faker.random_int(min=15, max=80) for _ in range(lines)],
        "Occupation": [faker.job() for _ in range(lines)],
        "Country": [faker.country() for _ in range(lines)],
        "Email": [faker.email() for _ in range(lines)],
        "Address": [faker.address() for _ in range(lines)],
        "Phone Number": [faker.phone_number() for _ in range(lines)]
    }
    df = pd.DataFrame(data)
    markdown_table = df.to_markdown(index=False)
    return markdown_table

LONG_PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as follows.\n# Table\n" + f"""
{gen_data(100)}
"""
q1 = "Question: What is the name and country of ID 23? Your answer: The name and country of ID 23 are "
q2 = "Question: What is the name and country of ID 96? Your answer: The name and country of ID 96 are "

Use LONG_PROMPT+q1 and LONG_PROMPT+q2 as prompts to interact with the model separately for each query.

Test Results:

  • Env two RTX 3090TI with nvlink
  • Qwen2.5-instruct 7B with 2 replicas (one replica on one card)

First query (without cache, just calculating) E2E time:
LONG_PROMPT+q1: ~2.96 s
Second query (with transferring) E2E time:
LONG_PROMPT+q2: ~1.33 s

Limitations

  • Rollback for xinference is not currently supported (it will be supported in the future)
  • Enabling Xavier means enabling vllm's enable_prefix_caching. The vllm version needs to be >= 0.6.5
  • Gloo cannot recognize the 0.0.0.0 address, so when starting xinference, you need to use the actual IP address, for example: xinference-local -H 192.168.xx.xx.

@XprobeBot XprobeBot added this to the v1.x milestone Jan 3, 2025
@ChengjieLi28 ChengjieLi28 changed the title FEAT: [WIP] Xavier: Share KV cache between VLLM replicas FEAT: Xavier: Share KV cache between VLLM replicas Jan 9, 2025
@ChengjieLi28 ChengjieLi28 marked this pull request as ready for review January 9, 2025 11:12
Copy link
Contributor

@qinxuye qinxuye left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ChengjieLi28 ChengjieLi28 merged commit 545ee12 into xorbitsai:main Jan 10, 2025
12 of 13 checks passed
@codingl2k1
Copy link
Contributor

A Corner Case:

  • Model eviction of the block.
  • Query the block from the block tracker.
  • Unregister the evicted block from the block tracker.
  • The block is evicted.
  • Transfer the block.

When transferring blocks, the block may be evicted or replaced by a new one. It's better to use a block hash during transfers. If the block is evicted or there is a block hash mismatch, we can simply handle it as a cache miss.

@qinxuye
Copy link
Contributor

qinxuye commented Jan 10, 2025

How can we produce the corner case?

@codingl2k1
Copy link
Contributor

How can we produce the corner case?

We can add mock logic to produce it. For example, call evict or modify (to simulate the block replacement) on the model block while querying the block.

@qinxuye
Copy link
Contributor

qinxuye commented Jan 10, 2025

OK, how about opening a new issue to track this?

@codingl2k1
Copy link
Contributor

codingl2k1 commented Jan 10, 2025

OK, how about opening a new issue to track this?

Let me open an issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants