-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pipeline for Text Generation: GenerationPipeline #3758
Conversation
…ansformers into generation_pipeline
@enzoampil - Sorry for fiddling in your code so much :D |
Maybe we should also add an optional |
Ok added it to the config of Transfo-XL and XLNet @LysandreJik @thomwolf, we also might want to discuss the default generation params for each model. I think it might e.g. be better to set |
I don't have any strong opinions on whether we should sample or not; However, I think whatever the choice we should make sure that it is explicit in the pipeline documentation that we may control it from the pipeline directly. Maybe a link linking to the |
Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com>
@patrickvonplaten Ran Also, not sure if this is specific to this PR, but there are tests that are suddenly returning an error for the lines that contain Sample error:
|
Those test are probably falling because the new Pytorch version was released. Can you just tense your branch in master?:
(Assuming that you added the master branch as a remote branch "upstream"). The test should then pass :-) |
@patrickvonplaten Apologies, I'm having issues with the rebase suggested above. I initially tried it but ended up showing up as a co-committer with the rebased commits, which explains why I performed a May I please ask for some assistance / advice with this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean implementation, congrats @enzoampil and thanks @patrickvonplaten for the great review
* Add GenerationPipeline * Fix parameter names * Correct parameter __call__ parameters * Add model type attribute and correct function calls for prepare_input * Take out trailing commas from init attributes * Remove unnecessary tokenization line * Implement support for multiple text inputs * Apply generation support for multiple input text prompts * Take out tensor coersion * Take out batch index * Add text prompt to return sequence * Squeeze token tensore before decoding * Return only a single list of sequences if only one prompt was used * Correct results variable name * Add GenerationPipeline to SUPPORTED_TASKS with the alias , initalized w GPT2 * Registedred AutoModelWithLMHead for both pt and t * Update docstring for GenerationPipeline * Add kwargs parameter to mode.generate * Take out kwargs parameter after all * Add generation pipeline example in pipeline docstring * Fix max length by squeezing tokens tensor * Apply ensure_tensor_on_device to pytorch tensor * Include generation step in torch.no_grad * Take out input from prepare_xlm_input and set 'en' as default xlm_language * Apply framework specific encoding during prepare_input * Format w make style * Move GenerationPipeline import to follow proper import sorting * Take out training comma from generation dict * Apply requested changes * Change name to TextGenerationPipeline * Apply TextGenerationPipeline rename to __init___ * Changing alias to * Set input mapping as input to ensure_tensor_on_device * Fix assertion placement * Add test_text_generation * Add TextGenerationPipeline to PipelineCommonTests * Take out whitespace * Format __init__ w black * Fix __init__ style * Forman __init___ * Add line to end of __init__ * Correct model tokenizer set for test_text_generation * Ensure to return list of list, not list of string (to pass test) * Limit test models to only 3 to limit runtime to address circleCI timeout error * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Remove argument docstring, __init__, add additional __call__ arguments, and reformat results to list of dict * Fix blank result list * Add TextGenerationPipeline to pipelines.rst * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Fix typos from adding PADDING_TEXT_TOKEN_LENGTH * Fix incorrectly moved result list * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Add back generation line and make style * Take out blank whitespace * Apply new alis, text-generation, to test_pipelines * Fix text generation alias in test * Update src/transformers/pipelines.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Julien Chaumond <chaumond@gmail.com>
Once again, thanks so much! Looking forward to contributing more in the future 😄@patrickvonplaten @julien-c |
This PR implements a text generation pipeline,
GenerationPipeline
, which works on anyModelWithLMHead
head, and resolves issue #3728This pipeline predicts the words that will follow a specified text prompt for autoregressive language models. I've registered it to the pipeline function using
gpt2
as the defaultmodel_type
.The implementation is based on the approach taken in run_generation.py, which means the forward pass uses the
PreTrainedModel.generate()
method in modeling_utils.py, as recommended to me by @julien-c and @patrickvonplaten .Sample code:
Google Colab tutorial here for running GenerationPipeline for the following LM models:
For context, I also plan to use the above
GenerationPipeline
for my Humor Generation Bot (issue).I'm very keen to get feedback for the above, so please let me know if I should change anything, or perform additional steps to bring its quality to an acceptable level.