We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
transformers
No response
examples
The following code breaks:
import torch import transformers from transformers import GenerationConfig from transformers import AutoConfig def generate_inputs_for_model(model_cls, model): eval_context = torch.randint(0, model.config.vocab_size, (4, 2048)).to("cuda") return {"input_ids": eval_context} config = AutoConfig.from_pretrained("t5-small") model_cls = getattr(transformers, "AutoModelForSeq2SeqLM") model = model_cls.from_config(config).to("cuda") example_inputs = generate_inputs_for_model(model_cls, model) example_inputs = (example_inputs["input_ids"],) generation_config = GenerationConfig( max_new_tokens=256, pad_token_id=0, eos_token_id=None, do_sample=False, num_beams=1, use_cache=True, ) class GenerationWrapper(torch.nn.Module): def __init__(self, model, generation_config): super().__init__() self.model = model self.generation_config = generation_config def forward(self, inputs): return self.model.generate(inputs, self.generation_config) model = GenerationWrapper(model, generation_config) # torch.compile repro model_opt = torch.compile(model) output = model_opt(*example_inputs) # torch.export repro torch.export.export(model, args=example_inputs, strict=False)
With the following error:
ValueError: `decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation.
If I manually add decoder_start_token_id=0 to the GenerationConfig. Then both compile and export work, although very slow.
decoder_start_token_id=0
Expected generate to work like before without manually specifying decoder_start_token_id or bos_token_id in the GenerationConfig.
decoder_start_token_id
bos_token_id
GenerationConfig
The text was updated successfully, but these errors were encountered:
Thanks for the issue! cc @ArthurZucker
Sorry, something went wrong.
#33221 seems like it is required quite a lot. cc @gante let's fix the generate issues!
generate
cc @zucchini-nlp, who will be converting BART and T5 to be compile-compatible (using EncoderDecoderCache, like we did on Whisper)
compile
EncoderDecoderCache
Whisper
zucchini-nlp
No branches or pull requests
System Info
transformers
version: 4.44.2Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
The following code breaks:
With the following error:
If I manually add
decoder_start_token_id=0
to the GenerationConfig. Then both compile and export work, although very slow.Expected behavior
Expected generate to work like before without manually specifying
decoder_start_token_id
orbos_token_id
in theGenerationConfig
.The text was updated successfully, but these errors were encountered: