Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Back translation transformation #534

Merged
merged 7 commits into from
Oct 15, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tests/test_augment_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,16 @@ def test_deletion_augmenter():
augmented_text_list = augmenter.augment(s)
augmented_s = "United States"
assert augmented_s in augmented_text_list


def test_back_translation_augmenter():
from textattack.augmentation import Augmenter
from textattack.transformations.sentence_transformations import back_translation

augmenter = Augmenter(
transformation=back_translation.BackTranslation(), transformations_per_example=1
)
s = "What on earth are you doing?"
augmented_text_list = augmenter.augment(s)
augmented_s = "What the hell are you doing?"
assert augmented_s in augmented_text_list
9 changes: 5 additions & 4 deletions textattack/shared/utils/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def s3_url(uri):


def download_from_s3(folder_name, skip_if_cached=True):
"""Folder name will be saved as `<cache_dir>/textattack/<folder_name>`. If it
doesn't exist on disk, the zip file will be downloaded and extracted.
"""Folder name will be saved as `<cache_dir>/textattack/<folder_name>`. If
it doesn't exist on disk, the zip file will be downloaded and extracted.

Args:
folder_name (str): path to folder or file in cache
Expand Down Expand Up @@ -68,8 +68,9 @@ def download_from_s3(folder_name, skip_if_cached=True):


def download_from_url(url, save_path, skip_if_cached=True):
"""Downloaded file will be saved under `<cache_dir>/textattack/<save_path>`. If it
doesn't exist on disk, the zip file will be downloaded and extracted.
"""Downloaded file will be saved under
`<cache_dir>/textattack/<save_path>`. If it doesn't exist on disk, the zip
file will be downloaded and extracted.

Args:
url (str): URL path from which to download.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .sentence_transformation import SentenceTransformation
from .back_translation import BackTranslation
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import random

from transformers import MarianMTModel, MarianTokenizer

from textattack.shared import AttackedText

from .sentence_transformation import SentenceTransformation


class BackTranslation(SentenceTransformation):
"""A type of sentence level transformation that takes in a text input,
translates it into target language and translates it back to source
language.

letters_to_insert (string): letters allowed for insertion into words
(used by some char-based transformations)

src_lang (string): source language
target_lang (string): target language, for the list of supported language check bottom of this page
src_model: translation model from huggingface that translates from source language to target language
target_model: translation model from huggingface that translates from target language to source language
chained_back_translation: run back translation in a chain for more perturbation (for example, en-es-en-fr-en)
"""

def __init__(
self,
src_lang="en",
target_lang="es",
src_model="Helsinki-NLP/opus-mt-ROMANCE-en",
target_model="Helsinki-NLP/opus-mt-en-ROMANCE",
chained_back_translation=0,
):
self.src_lang = src_lang
self.target_lang = target_lang
self.target_model = MarianMTModel.from_pretrained(target_model)
self.target_tokenizer = MarianTokenizer.from_pretrained(target_model)
self.src_model = MarianMTModel.from_pretrained(src_model)
self.src_tokenizer = MarianTokenizer.from_pretrained(src_model)
self.chained_back_translation = chained_back_translation

def translate(self, input, model, tokenizer, lang="es"):
# change the text to model's format
src_texts = []
if lang == "en":
src_texts.append(input[0])
else:
if ">>" and "<<" not in lang:
lang = ">>" + lang + "<<"
src_texts.append(lang + input[0])

# tokenize the input
encoded_input = tokenizer.prepare_seq2seq_batch(src_texts, return_tensors="pt")

# translate the input
translated = model.generate(**encoded_input)
translated_input = tokenizer.batch_decode(translated, skip_special_tokens=True)
return translated_input

def _get_transformations(self, current_text, indices_to_modify):
transformed_texts = []
current_text = current_text.text

# to perform chained back translation, a random list of target languages are selected from the provided model
if self.chained_back_translation:
list_of_target_lang = random.sample(
self.target_tokenizer.supported_language_codes,
self.chained_back_translation,
)
for target_lang in list_of_target_lang:
target_language_text = self.translate(
[current_text],
self.target_model,
self.target_tokenizer,
target_lang,
)
src_language_text = self.translate(
target_language_text,
self.src_model,
self.src_tokenizer,
self.src_lang,
)
current_text = src_language_text[0]
return [AttackedText(current_text)]

# translates source to target language and back to source language (single back translation)
target_language_text = self.translate(
[current_text], self.target_model, self.target_tokenizer, self.target_lang
)
src_language_text = self.translate(
target_language_text, self.src_model, self.src_tokenizer, self.src_lang
)
transformed_texts.append(AttackedText(src_language_text[0]))
return transformed_texts


"""
List of supported languages
['fr',
'es',
'it',
'pt',
'pt_br',
'ro',
'ca',
'gl',
'pt_BR<<',
'la<<',
'wa<<',
'fur<<',
'oc<<',
'fr_CA<<',
'sc<<',
'es_ES',
'es_MX',
'es_AR',
'es_PR',
'es_UY',
'es_CL',
'es_CO',
'es_CR',
'es_GT',
'es_HN',
'es_NI',
'es_PA',
'es_PE',
'es_VE',
'es_DO',
'es_EC',
'es_SV',
'an',
'pt_PT',
'frp',
'lad',
'vec',
'fr_FR',
'co',
'it_IT',
'lld',
'lij',
'lmo',
'nap',
'rm',
'scn',
'mwl']
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""https://github.com/makcedward/nlpaug."""

from textattack.transformations import Transformation


class SentenceTransformation(Transformation):
def _get_transformations(self, current_text, indices_to_modify):
raise NotImplementedError()