Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline for Text Generation: GenerationPipeline #3758

Merged
merged 82 commits into from
Apr 22, 2020
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
55526e4
Add GenerationPipeline
enzoampil Apr 11, 2020
b80a70b
Fix parameter names
enzoampil Apr 11, 2020
b0c63ef
Correct parameter __call__ parameters
enzoampil Apr 11, 2020
e4b2921
Add model type attribute and correct function calls for prepare_input
enzoampil Apr 11, 2020
c012658
Take out trailing commas from init attributes
enzoampil Apr 11, 2020
11caefb
Remove unnecessary tokenization line
enzoampil Apr 11, 2020
3fea3c5
Implement support for multiple text inputs
enzoampil Apr 11, 2020
c947d37
Apply generation support for multiple input text prompts
enzoampil Apr 11, 2020
1a7130e
Take out tensor coersion
enzoampil Apr 11, 2020
bc7b8ac
Take out batch index
enzoampil Apr 11, 2020
84f51f9
Add text prompt to return sequence
enzoampil Apr 11, 2020
0ab3cc9
Squeeze token tensore before decoding
enzoampil Apr 11, 2020
0d62796
Return only a single list of sequences if only one prompt was used
enzoampil Apr 11, 2020
caff86e
Correct results variable name
enzoampil Apr 11, 2020
67aa288
Add GenerationPipeline to SUPPORTED_TASKS with the alias , initalized…
enzoampil Apr 11, 2020
9fff3ad
Registedred AutoModelWithLMHead for both pt and t
enzoampil Apr 11, 2020
d126b01
Update docstring for GenerationPipeline
enzoampil Apr 11, 2020
b2b22c5
Add kwargs parameter to mode.generate
enzoampil Apr 11, 2020
6ffbcd5
Take out kwargs parameter after all
enzoampil Apr 11, 2020
7dc8d1e
Add generation pipeline example in pipeline docstring
enzoampil Apr 11, 2020
d5a8ca2
Fix max length by squeezing tokens tensor
enzoampil Apr 12, 2020
7839811
Apply ensure_tensor_on_device to pytorch tensor
enzoampil Apr 12, 2020
6d7ed48
Include generation step in torch.no_grad
enzoampil Apr 12, 2020
c649625
Take out input from prepare_xlm_input and set 'en' as default xlm_lan…
enzoampil Apr 12, 2020
fa957b2
Apply framework specific encoding during prepare_input
enzoampil Apr 12, 2020
644f6b7
Merge branch 'master' into generation_pipeline
enzoampil Apr 12, 2020
9034314
Format w make style
enzoampil Apr 12, 2020
d15878f
Merge branch 'generation_pipeline' of https://github.com/enzoampil/tr…
enzoampil Apr 12, 2020
fcb4644
Move GenerationPipeline import to follow proper import sorting
enzoampil Apr 12, 2020
7146869
Take out training comma from generation dict
enzoampil Apr 12, 2020
38d7935
Apply requested changes
enzoampil Apr 18, 2020
0378fb6
Change name to TextGenerationPipeline
enzoampil Apr 18, 2020
1eec760
Apply TextGenerationPipeline rename to __init___
enzoampil Apr 18, 2020
59554b7
Changing alias to
enzoampil Apr 18, 2020
1ac7eb5
Set input mapping as input to ensure_tensor_on_device
enzoampil Apr 18, 2020
cb42392
Fix assertion placement
enzoampil Apr 18, 2020
9c03121
Add test_text_generation
enzoampil Apr 18, 2020
e98b9ae
Add TextGenerationPipeline to PipelineCommonTests
enzoampil Apr 18, 2020
89b191c
Take out whitespace
enzoampil Apr 18, 2020
51d1748
Format __init__ w black
enzoampil Apr 18, 2020
900474d
Merge branch 'master' into generation_pipeline
enzoampil Apr 18, 2020
1148e77
Fix __init__ style
enzoampil Apr 18, 2020
5388473
Merge branch 'generation_pipeline' of https://github.com/enzoampil/tr…
enzoampil Apr 18, 2020
f75ed8a
Forman __init___
enzoampil Apr 18, 2020
215e429
Add line to end of __init__
enzoampil Apr 18, 2020
be15f38
Correct model tokenizer set for test_text_generation
enzoampil Apr 18, 2020
b7505cc
Ensure to return list of list, not list of string (to pass test)
enzoampil Apr 18, 2020
9180e3f
Limit test models to only 3 to limit runtime to address circleCI time…
enzoampil Apr 18, 2020
547cf8a
Update src/transformers/pipelines.py
enzoampil Apr 20, 2020
5c31869
Update src/transformers/pipelines.py
enzoampil Apr 20, 2020
9bb048a
Update src/transformers/pipelines.py
enzoampil Apr 20, 2020
e73e1e5
Update src/transformers/pipelines.py
enzoampil Apr 20, 2020
e921bb3
Update src/transformers/pipelines.py
enzoampil Apr 20, 2020
013e448
Update tests/test_pipelines.py
enzoampil Apr 20, 2020
4a66740
Update src/transformers/pipelines.py
enzoampil Apr 20, 2020
8af21d4
Update src/transformers/pipelines.py
enzoampil Apr 20, 2020
76fc370
Update src/transformers/pipelines.py
enzoampil Apr 20, 2020
13b3b49
Remove argument docstring, __init__, add additional __call__ argument…
enzoampil Apr 21, 2020
f43fe34
Fix blank result list
enzoampil Apr 21, 2020
99b7a7e
Add TextGenerationPipeline to pipelines.rst
enzoampil Apr 21, 2020
ddf76b1
Update src/transformers/pipelines.py
enzoampil Apr 21, 2020
dfc548f
Update src/transformers/pipelines.py
enzoampil Apr 21, 2020
ff4ff2b
Fix typos from adding PADDING_TEXT_TOKEN_LENGTH
enzoampil Apr 21, 2020
182ff34
Fix incorrectly moved result list
enzoampil Apr 21, 2020
29ce6d8
Update src/transformers/pipelines.py
enzoampil Apr 21, 2020
8ed1aa2
Update src/transformers/pipelines.py
patrickvonplaten Apr 21, 2020
65912f2
Update src/transformers/pipelines.py
patrickvonplaten Apr 21, 2020
755ba43
Update src/transformers/pipelines.py
patrickvonplaten Apr 21, 2020
a390257
Update src/transformers/pipelines.py
patrickvonplaten Apr 21, 2020
c8bab99
Update src/transformers/pipelines.py
patrickvonplaten Apr 21, 2020
5f3cd78
Update src/transformers/pipelines.py
patrickvonplaten Apr 21, 2020
b24c091
Update src/transformers/pipelines.py
patrickvonplaten Apr 21, 2020
de369b1
Update src/transformers/pipelines.py
patrickvonplaten Apr 21, 2020
5aad433
Update src/transformers/pipelines.py
patrickvonplaten Apr 21, 2020
83be569
Update src/transformers/pipelines.py
patrickvonplaten Apr 21, 2020
8877fa3
Update src/transformers/pipelines.py
patrickvonplaten Apr 21, 2020
4360886
Update src/transformers/pipelines.py
enzoampil Apr 21, 2020
5cdb831
Add back generation line and make style
enzoampil Apr 22, 2020
adf7c94
Take out blank whitespace
enzoampil Apr 22, 2020
11b1d8b
Apply new alis, text-generation, to test_pipelines
enzoampil Apr 22, 2020
8f8c79b
Fix text generation alias in test
enzoampil Apr 22, 2020
38a0d1e
Update src/transformers/pipelines.py
julien-c Apr 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
QuestionAnsweringPipeline,
SummarizationPipeline,
TextClassificationPipeline,
TextGenerationPipeline,
TokenClassificationPipeline,
TranslationPipeline,
pipeline,
Expand Down
89 changes: 89 additions & 0 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,89 @@ def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs).tolist()


class TextGenerationPipeline(Pipeline):
"""
Language generation pipeline using any ModelWithLMHead head. This pipeline predicts the words that will follow a specified text prompt.

This language generation pipeline can currently be loaded from the :func:`~transformers.pipeline` method using
the following task identifier(s):

- "generation", for generating text from a specified prompt.
julien-c marked this conversation as resolved.
Show resolved Hide resolved

The models that this pipeline can use are models that have been trained with an autoregressive language modeling objective,
which includes the uni-directional models in the library (e.g. gpt2).
See the list of available community models on
`huggingface.co/models <https://huggingface.co/models?search=&filter=lm-head>`__.
"""

# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
a young Grigori Rasputin is asked by his father and a group of men to perform magic.
Rasputin has a vision and denounces one of the men as a horse thief. Although his
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""

enzoampil marked this conversation as resolved.
Show resolved Hide resolved
def __call__(
self, *texts, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
):
text_inputs = self._args_parser(*texts)

results = []
for prompt_text in text_inputs:
# Manage correct placement of the tensors
with self.device_placement():
if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]:
prompt_text = self.PADDING_TEXT + prompt_text
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
inputs = self._parse_and_tokenize(prompt_text)
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

if self.framework == "pt":
enzoampil marked this conversation as resolved.
Show resolved Hide resolved
inputs = self.ensure_tensor_on_device(**inputs)

input_ids = inputs["input_ids"]

# Ensure that batch size = 1 (batch generation not allowed for now)
assert (
input_ids.shape[0] == 1
), "Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."
output_sequences = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
enzoampil marked this conversation as resolved.
Show resolved Hide resolved

enzoampil marked this conversation as resolved.
Show resolved Hide resolved
result = []

for generated_sequence in output_sequences:
generated_sequence = generated_sequence.tolist()
record = {}
if return_tensors:
record["generated_token_ids"] = generated_sequence
if return_text:
# Decode text
text = self.tokenizer.decode(
generated_sequence,
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)

# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
if self.model.__class__.__name__ in ["XLNetLMHeadModel", "TransfoXLLMHeadModel"]:
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
text = text[len(self.PADDING_TEXT) :]
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
record["generated_text"] = text
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

result.append(record)
results += [result]

if len(results) == 1:
return results[0]

return results


class TextClassificationPipeline(Pipeline):
"""
Text classification pipeline using ModelForSequenceClassification head. See the
Expand Down Expand Up @@ -1459,6 +1542,12 @@ def __call__(
"tokenizer": ("t5-base", {"use_fast": False}),
},
},
"text_generation": {
enzoampil marked this conversation as resolved.
Show resolved Hide resolved
"impl": TextGenerationPipeline,
"tf": TFAutoModelWithLMHead if is_tf_available() else None,
"pt": AutoModelWithLMHead if is_torch_available() else None,
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}, "config": None, "tokenizer": "gpt2"},
enzoampil marked this conversation as resolved.
Show resolved Hide resolved
},
}


Expand Down
16 changes: 16 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@
)
}

TEXT_GENERATION_FINETUNED_MODELS = {
("gpt2", "gpt2"),
("xlnet-base-cased", "xlnet-base-cased"),
}

FILL_MASK_FINETUNED_MODELS = [
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
]
Expand Down Expand Up @@ -293,6 +298,16 @@ def test_tf_translation(self):
nlp, valid_inputs, invalid_inputs, mandatory_keys,
)

@require_torch
def test_text_generation(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [None]
for model, tokenizer in TEXT_GENERATION_FINETUNED_MODELS:
nlp = pipeline(task="text_generation", model=model, tokenizer=tokenizer, framework="pt")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, {},
)


class MultiColumnInputTestCase(unittest.TestCase):
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
Expand Down Expand Up @@ -371,6 +386,7 @@ class PipelineCommonTests(unittest.TestCase):
"translation_en_to_fr",
"translation_en_to_de",
"translation_en_to_ro",
"text_generation",
enzoampil marked this conversation as resolved.
Show resolved Hide resolved
)

@slow
Expand Down