Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sharing of sampling params in multiple seq groups #33

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

maxdebayser
Copy link
Contributor

The sampling params object cannot be shared between multiple sequence groups even though the configuration is identical because some logits processors are stateful.

  enum ResponseFormat {
    // Plain text, no constraints
    TEXT = 0;
    // Valid json
    JSON = 1;
  }

  message StringChoices {
    repeated string choices = 1;
  }

  // Mutually-exclusive guided decoding options
  oneof guided {
    // Output will be in the specified format
    ResponseFormat format = 3;
    // Output will follow the provided JSON schema
    string json_schema = 4;
    // Output will follow the provided regex pattern
    string regex = 5;
    // Output will be exactly one of the specified choices
    StringChoices choice = 6;
    // Output will follow the provided context free grammar
    string grammar = 7;
  }

Signed-off-by: Nick Hill <nickhill@us.ibm.com>
@maxdebayser maxdebayser requested a review from njhill May 27, 2024 20:05
@njhill
Copy link
Member

njhill commented May 27, 2024

Thanks @maxdebayser for investigating this! I'm not sure that this is the best fix though. I think we'd want to avoid validating the same set of sampling parameters multiple times.

The sampling params are actually also already defensively copied per sequence group inside the engine here, but logits processors are intentionally omitted since they can be arbitrary and the cost of cloning on every request can be prohibitively high (depending on the processor implementation).

I feel like what should be passed in is actually something more like a logit processor factory. Which itself is stateless and can be included in a sampling params object that's reused. Then internally when each seq group is created it would use the factory to create the logits processor.

@maxdebayser
Copy link
Contributor Author

maxdebayser commented May 28, 2024

Thanks for pointing that out, I wasn't aware of that. The repetitive construction of the sampling params will happen anyway in unary requests, but it makes sense not to pay that cost in batch requests.

I went back to issue vllm-project/vllm#3087 to understand the history of the problem and what is kind of weird is that the code snippet of the reporter looks like it should also reproduce the new problem but doesn't.

from vllm import LLM, SamplingParams
from outlines.serve.vllm import JSONLogitsProcessor
from pydantic import BaseModel, conlist
import datetime as dt
import os
MODEL_NAME=os.getenv("MODEL_NAME")

class Output(BaseModel):
    names: conlist(str, max_length=5)
    organizations: conlist(str, max_length=5)
    locations: conlist(str, max_length=5)
    miscellanous: conlist(str, max_length=5)

llm = LLM(MODEL_NAME, max_model_len=2048, gpu_memory_utilization=0.9)
logits_processor = JSONLogitsProcessor(schema=Output, llm=llm.llm_engine)
logits_processor.fsm.vocabulary = list(logits_processor.fsm.vocabulary)
prompt = """
Locate all the names, organizations, locations and other miscellaneous entities in the following sentence: 
"Charles went and saw Anna at the coffee shop Starbucks, which was based in a small town in Germany called Essen."
"""
sampling_params = SamplingParams(max_tokens=128, temperature=0, logits_processors=[logits_processor])

t0 = dt.datetime.now()
llm.generate([prompt] * 4, sampling_params=sampling_params)
time_elapsed = (dt.datetime.now() - t0).total_seconds()
print(f"Generation took {time_elapsed:,} seconds.")

@maxdebayser
Copy link
Contributor Author

Actually the JSONLogitsProcessor in the code above uses an entirely different code path. What is causing the error in our case is the CFGLogitsProcessor in vllm. It has a CFGFSM that rebuilds internal RegexFSMs based on the string it has seen so far.

This allows the implementation of stateful logit processors.
By passing the sequence ID, the decision on how to manage the state
is the responsibility of the author of the processor.
Without a sequence ID we would have to defensively create a copy
of the processor for each sequence just to safeguard against
accidental state sharing.

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
@maxdebayser maxdebayser force-pushed the guided_decoding_concurrency_grpc branch from 00c830f to 2758017 Compare May 28, 2024 18:54
@maxdebayser
Copy link
Contributor Author

@njhill, here goes another fix idea: instead of writing a generic code that instantiates processors for each sequence just in case there might be a stateful one, we can make this state explicit with an extra sequence id argument to the logits processors. In this way the code stays as optimal as before and the author of a logits processor can decide how to optimize for his particular case.

I've overwritten my previous commit in this branch with this alternative implementation.

Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
@njhill
Copy link
Member

njhill commented May 29, 2024

Thanks @maxdebayser I'm not sure about changing the logitsprocessor interface for this. I think there may be other reasons that we may want to change it, like passing token ids and/or working on batches. I don't think that individual logits processor impls should have to manage a dict like this.

instead of writing a generic code that instantiates processors for each sequence just in case there might be a stateful one

Actually I was suggesting that we add logit processor factories as an additional samplingparams field, so they can be passed as alternative to existing logitsprocessors. This can be an abstract class with a method e.g. get_processor(), In the stateless case the factory can just hold and return a constant logitsprocessor. This may be better than a callable since we could later add a method return_proceessor() to allow for pooling.

Base automatically changed from guided to main May 30, 2024 00:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants