diff --git a/examples/paraphraser/README.md b/examples/paraphraser/README.md new file mode 100644 index 0000000000..3810311f30 --- /dev/null +++ b/examples/paraphraser/README.md @@ -0,0 +1,46 @@ +# Paraphrasing with round-trip translation and mixture of experts + +Machine translation models can be used to paraphrase text by translating it to +an intermediate language and back (round-trip translation). + +This example shows how to paraphrase text by first passing it to an +English-French translation model, followed by a French-English [mixture of +experts translation model](/examples/translation_moe). + +##### 0. Setup + +Clone fairseq from source and install necessary dependencies: +```bash +git clone https://github.com/pytorch/fairseq.git +cd fairseq +pip install --editable . +pip install sacremoses sentencepiece +``` + +##### 1. Download models +```bash +wget https://dl.fbaipublicfiles.com/fairseq/models/paraphraser.en-fr.tar.gz +wget https://dl.fbaipublicfiles.com/fairseq/models/paraphraser.fr-en.hMoEup.tar.gz +tar -xzvf paraphraser.en-fr.tar.gz +tar -xzvf paraphraser.fr-en.hMoEup.tar.gz +``` + +##### 2. Paraphrase +```bash +python examples/paraphraser/paraphrase.py \ + --en2fr paraphraser.en-fr \ + --fr2en paraphraser.fr-en.hMoEup +# Example input: +# The new date for the Games, postponed for a year in response to the coronavirus pandemic, gives athletes time to recalibrate their training schedules. +# Example outputs: +# Delayed one year in response to the coronavirus pandemic, the new date of the Games gives athletes time to rebalance their training schedule. +# The new date of the Games, which was rescheduled one year in response to the coronavirus (CV) pandemic, gives athletes time to rebalance their training schedule. +# The new date of the Games, postponed one year in response to the coronavirus pandemic, provides athletes with time to rebalance their training schedule. +# The Games' new date, postponed one year in response to the coronavirus pandemic, gives athletes time to rebalance their training schedule. +# The new Games date, postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their training schedule. +# The new date of the Games, which was postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their training schedule. +# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives athletes time to rebalance their training schedule. +# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives athletes time to re-balance their training schedule. +# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their schedule of training. +# The new date of the Games, postponed one year in response to the pandemic of coronavirus, gives the athletes time to rebalance their training schedule. +``` diff --git a/examples/paraphraser/paraphrase.py b/examples/paraphraser/paraphrase.py new file mode 100644 index 0000000000..405df296b3 --- /dev/null +++ b/examples/paraphraser/paraphrase.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 -u + +import argparse +import fileinput +import logging +import os +import sys + +from fairseq.models.transformer import TransformerModel + + +logging.getLogger().setLevel(logging.INFO) + + +def main(): + parser = argparse.ArgumentParser(description='') + parser.add_argument('--en2fr', required=True, + help='path to en2fr model') + parser.add_argument('--fr2en', required=True, + help='path to fr2en mixture of experts model') + parser.add_argument('--user-dir', + help='path to fairseq examples/translation_moe/src directory') + parser.add_argument('--num-experts', type=int, default=10, + help='(keep at 10 unless using a different model)') + parser.add_argument('files', nargs='*', default=['-'], + help='input files to paraphrase; "-" for stdin') + args = parser.parse_args() + + if args.user_dir is None: + args.user_dir = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/ + 'translation_moe', + 'src', + ) + if os.path.exists(args.user_dir): + logging.info('found user_dir:' + args.user_dir) + else: + raise RuntimeError( + 'cannot find fairseq examples/translation_moe/src ' + '(tried looking here: {})'.format(args.user_dir) + ) + + logging.info('loading en2fr model from:' + args.en2fr) + en2fr = TransformerModel.from_pretrained( + model_name_or_path=args.en2fr, + tokenizer='moses', + bpe='sentencepiece', + ).eval() + + logging.info('loading fr2en model from:' + args.fr2en) + fr2en = TransformerModel.from_pretrained( + model_name_or_path=args.fr2en, + tokenizer='moses', + bpe='sentencepiece', + user_dir=args.user_dir, + task='translation_moe', + ).eval() + + def gen_paraphrases(en): + fr = en2fr.translate(en) + return [ + fr2en.translate(fr, inference_step_args={'expert': i}) + for i in range(args.num_experts) + ] + + logging.info('Type the input sentence and press return:') + for line in fileinput.input(args.files): + line = line.strip() + if len(line) == 0: + continue + for paraphrase in gen_paraphrases(line): + print(paraphrase) + + +if __name__ == '__main__': + main() diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index 92a4c9092f..1b1091c881 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -145,6 +145,7 @@ def generate( beam: int = 5, verbose: bool = False, skip_invalid_size_inputs=False, + inference_step_args=None, **kwargs ) -> List[List[Dict[str, torch.Tensor]]]: if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1: @@ -159,10 +160,13 @@ def generate( setattr(gen_args, k, v) generator = self.task.build_generator(self.models, gen_args) + inference_step_args = inference_step_args or {} results = [] for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs): batch = utils.apply_to_sample(lambda t: t.to(self.device), batch) - translations = self.task.inference_step(generator, self.models, batch) + translations = self.task.inference_step( + generator, self.models, batch, **inference_step_args + ) for id, hypos in zip(batch["id"].tolist(), translations): results.append((id, hypos))