Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pipeline for Text Generation: GenerationPipeline (huggingface#3758)
* Add GenerationPipeline * Fix parameter names * Correct parameter __call__ parameters * Add model type attribute and correct function calls for prepare_input * Take out trailing commas from init attributes * Remove unnecessary tokenization line * Implement support for multiple text inputs * Apply generation support for multiple input text prompts * Take out tensor coersion * Take out batch index * Add text prompt to return sequence * Squeeze token tensore before decoding * Return only a single list of sequences if only one prompt was used * Correct results variable name * Add GenerationPipeline to SUPPORTED_TASKS with the alias , initalized w GPT2 * Registedred AutoModelWithLMHead for both pt and t * Update docstring for GenerationPipeline * Add kwargs parameter to mode.generate * Take out kwargs parameter after all * Add generation pipeline example in pipeline docstring * Fix max length by squeezing tokens tensor * Apply ensure_tensor_on_device to pytorch tensor * Include generation step in torch.no_grad * Take out input from prepare_xlm_input and set 'en' as default xlm_language * Apply framework specific encoding during prepare_input * Format w make style * Move GenerationPipeline import to follow proper import sorting * Take out training comma from generation dict * Apply requested changes * Change name to TextGenerationPipeline * Apply TextGenerationPipeline rename to __init___ * Changing alias to * Set input mapping as input to ensure_tensor_on_device * Fix assertion placement * Add test_text_generation * Add TextGenerationPipeline to PipelineCommonTests * Take out whitespace * Format __init__ w black * Fix __init__ style * Forman __init___ * Add line to end of __init__ * Correct model tokenizer set for test_text_generation * Ensure to return list of list, not list of string (to pass test) * Limit test models to only 3 to limit runtime to address circleCI timeout error * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update tests/test_pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Remove argument docstring, __init__, add additional __call__ arguments, and reformat results to list of dict * Fix blank result list * Add TextGenerationPipeline to pipelines.rst * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Fix typos from adding PADDING_TEXT_TOKEN_LENGTH * Fix incorrectly moved result list * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py * Update src/transformers/pipelines.py Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * Add back generation line and make style * Take out blank whitespace * Apply new alis, text-generation, to test_pipelines * Fix text generation alias in test * Update src/transformers/pipelines.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Julien Chaumond <chaumond@gmail.com>
- Loading branch information