-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix sharing of sampling params in multiple seq groups #33
base: main
Are you sure you want to change the base?
Conversation
enum ResponseFormat { // Plain text, no constraints TEXT = 0; // Valid json JSON = 1; } message StringChoices { repeated string choices = 1; } // Mutually-exclusive guided decoding options oneof guided { // Output will be in the specified format ResponseFormat format = 3; // Output will follow the provided JSON schema string json_schema = 4; // Output will follow the provided regex pattern string regex = 5; // Output will be exactly one of the specified choices StringChoices choice = 6; // Output will follow the provided context free grammar string grammar = 7; } Signed-off-by: Nick Hill <nickhill@us.ibm.com>
Thanks @maxdebayser for investigating this! I'm not sure that this is the best fix though. I think we'd want to avoid validating the same set of sampling parameters multiple times. The sampling params are actually also already defensively copied per sequence group inside the engine here, but logits processors are intentionally omitted since they can be arbitrary and the cost of cloning on every request can be prohibitively high (depending on the processor implementation). I feel like what should be passed in is actually something more like a logit processor factory. Which itself is stateless and can be included in a sampling params object that's reused. Then internally when each seq group is created it would use the factory to create the logits processor. |
Thanks for pointing that out, I wasn't aware of that. The repetitive construction of the sampling params will happen anyway in unary requests, but it makes sense not to pay that cost in batch requests. I went back to issue vllm-project/vllm#3087 to understand the history of the problem and what is kind of weird is that the code snippet of the reporter looks like it should also reproduce the new problem but doesn't. from vllm import LLM, SamplingParams
from outlines.serve.vllm import JSONLogitsProcessor
from pydantic import BaseModel, conlist
import datetime as dt
import os
MODEL_NAME=os.getenv("MODEL_NAME")
class Output(BaseModel):
names: conlist(str, max_length=5)
organizations: conlist(str, max_length=5)
locations: conlist(str, max_length=5)
miscellanous: conlist(str, max_length=5)
llm = LLM(MODEL_NAME, max_model_len=2048, gpu_memory_utilization=0.9)
logits_processor = JSONLogitsProcessor(schema=Output, llm=llm.llm_engine)
logits_processor.fsm.vocabulary = list(logits_processor.fsm.vocabulary)
prompt = """
Locate all the names, organizations, locations and other miscellaneous entities in the following sentence:
"Charles went and saw Anna at the coffee shop Starbucks, which was based in a small town in Germany called Essen."
"""
sampling_params = SamplingParams(max_tokens=128, temperature=0, logits_processors=[logits_processor])
t0 = dt.datetime.now()
llm.generate([prompt] * 4, sampling_params=sampling_params)
time_elapsed = (dt.datetime.now() - t0).total_seconds()
print(f"Generation took {time_elapsed:,} seconds.") |
Actually the JSONLogitsProcessor in the code above uses an entirely different code path. What is causing the error in our case is the CFGLogitsProcessor in vllm. It has a CFGFSM that rebuilds internal RegexFSMs based on the string it has seen so far. |
This allows the implementation of stateful logit processors. By passing the sequence ID, the decision on how to manage the state is the responsibility of the author of the processor. Without a sequence ID we would have to defensively create a copy of the processor for each sequence just to safeguard against accidental state sharing. Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
00c830f
to
2758017
Compare
@njhill, here goes another fix idea: instead of writing a generic code that instantiates processors for each sequence just in case there might be a stateful one, we can make this state explicit with an extra sequence id argument to the logits processors. In this way the code stays as optimal as before and the author of a logits processor can decide how to optimize for his particular case. I've overwritten my previous commit in this branch with this alternative implementation. |
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Thanks @maxdebayser I'm not sure about changing the logitsprocessor interface for this. I think there may be other reasons that we may want to change it, like passing token ids and/or working on batches. I don't think that individual logits processor impls should have to manage a dict like this.
Actually I was suggesting that we add logit processor factories as an additional samplingparams field, so they can be passed as alternative to existing logitsprocessors. This can be an abstract class with a method e.g. get_processor(), In the stateless case the factory can just hold and return a constant logitsprocessor. This may be better than a callable since we could later add a method return_proceessor() to allow for pooling. |
The sampling params object cannot be shared between multiple sequence groups even though the configuration is identical because some logits processors are stateful.