Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-process RESTful API #328

Merged
merged 3 commits into from
Nov 28, 2023
Merged

Multi-process RESTful API #328

merged 3 commits into from
Nov 28, 2023

Conversation

mrwyattii
Copy link
Contributor

@mrwyattii mrwyattii commented Nov 27, 2023

Resolves #325, #324, #323, #314

Our RESTful API currently only processes a single request at a time. For example, if we run the following script on our current main:

import mii
import json
import subprocess
import time

# Stand up a MII deployment
model = "meta-llama/Llama-2-7b-hf"
client = mii.serve(
    model,
    deployment_name="test-dep",
    tensor_parallel=1,
    enable_restful_api=True,
    restful_api_port=8000,
)

# Define some queries
queries = [
    "Hello world!",
    "My name is",
    "DeepSpeed is",
    "Seattle is",
    "One day",
    "I like to",
    "My favorite food is",
    "The world is",
]

# Run with Python API
gen_tokens = 0
start_time = time.time()
outputs = client.generate(queries, ignore_eos=True, max_length=128)
end_time = time.time()
python_time = end_time - start_time
for output in outputs:
    gen_tokens += output.generated_length


# Run with RESTful API
procs = []
start_time = time.time()
for i in range(len(queries)):
    p = subprocess.Popen(
        [
            "curl",
            "-X",
            "POST",
            "-H",
            "Content-Type: application/json",
            "-d",
            f'{{"prompts": "{queries[i]}", "ignore_eos": true, "max_length": 128}}',
            "http://localhost:8000/mii/test-dep",
        ],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
    )
    procs.append(p)

# Check the outputs, verify we have generated text
rest_gen_tokens = 0
for p in procs:
    output, error = p.communicate()
    output = json.loads(output.decode("utf-8"))
    assert "generated_text" in output[0], "No generated text"
    rest_gen_tokens += output[0]["generated_length"]
end_time = time.time()
rest_time = end_time - start_time

assert rest_gen_tokens == gen_tokens, "RESTful API generated different number of tokens"

# Print results
print("Python API Results:")
print(f"\tTotal Time: {python_time:0.2f} seconds")
print(f"\tTotal Generated Tokens: {gen_tokens}")
print(f"\tTokens per second: {gen_tokens/python_time:0.2f}")
print("RESTful API Results:")
print(f"\tTotal Time: {rest_time:0.2f} seconds")
print(f"\tTotal Generated Tokens: {gen_tokens}")
print(f"\tTokens per second: {gen_tokens/rest_time:0.2f}")

client.terminate_server()

We see the following output:

Python API Results:
        Total Time: 3.14 seconds
        Total Generated Tokens: 993
        Tokens per second: 316.61
RESTful API Results:
        Total Time: 21.45 seconds
        Total Generated Tokens: 993
        Tokens per second: 46.29

With this PR, the RESTful API performance matches the Python API:

Python API Results:
        Total Time: 3.13 seconds
        Total Generated Tokens: 993
        Tokens per second: 316.92
RESTful API Results:
        Total Time: 3.19 seconds
        Total Generated Tokens: 993
        Tokens per second: 311.39

We use a default of 32 processes for serving the RESTful API. This can be changed with mii.serve(..., restful_processes=8). We found that more than 32 did not provide improved performance in our benchmarks.

This PR also fixes a bug with older versions of Flask, where the returned object from the RESTful API could not be parsed into a python dict with json.loads.

@mrwyattii mrwyattii changed the title Multi-threaded RESTful API Multi-process RESTful API Nov 27, 2023
@mrwyattii mrwyattii marked this pull request as ready for review November 27, 2023 23:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Low throughput (0.61 reqs/sec) when served with RESTful API
2 participants