Skip to content

Commit

Permalink
Add nllb (#58)
Browse files Browse the repository at this point in the history
* Add NLLB-200

* Improve readme and docs

* Fix tests

* Bump version

* Bump pytho nrequirements

* Add new module

* Improve tests and fix capitalization errors

* Change behvaior of _resolve_lang_codes to resolve one at the time

* Add new max token length default to 512, update tests

* Add demo
  • Loading branch information
xhluca authored Jul 18, 2023
1 parent d362ae0 commit 66b63b4
Show file tree
Hide file tree
Showing 16 changed files with 1,189 additions and 254 deletions.
43 changes: 33 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
*A deep learning-based translation library built on Huggingface `transformers`*

💻 [GitHub Repository](https://github.com/xhluca/dl-translate)<br>
📚 [Documentation](https://xhluca.github.io/dl-translate) / [Readthedocs](https://dl-translate.readthedocs.io)<br>
📚 [Documentation](https://xhluca.github.io/dl-translate)<br>
🐍 [PyPi project](https://pypi.org/project/dl-translate/)<br>
🧪 [Colab Demo](https://colab.research.google.com/github/xhluca/dl-translate/blob/main/demos/colab_demo.ipynb) / [Kaggle Demo](https://www.kaggle.com/xhlulu/dl-translate-demo/)

Expand Down Expand Up @@ -58,24 +58,34 @@ By default, the value will be `device="auto"`, which means it will use a GPU if

### Choosing a different model

Two model families are available at the moment: [m2m100](https://huggingface.co/transformers/model_doc/m2m_100.html) and [mBART-50 Large](https://huggingface.co/transformers/master/model_doc/mbart.html), which respective allow translation across over 100 languages and 50 languages. By default, the model will select `m2m100`, but you can also explicitly choose the model by specifying the shorthand (`"m2m100"` or `"mbart50"`) or the full repository name (e.g. `"facebook/m2m100_418M"`). For example:
By default, the `m2m100` model will be used. However, there are a few options:

* [mBART-50 Large](https://huggingface.co/transformers/master/model_doc/mbart.html): Allows translations across 50 languages.
* [m2m100](https://huggingface.co/transformers/model_doc/m2m_100.html): Allows translations across 100 languages.
* [nllb-200](https://huggingface.co/docs/transformers/model_doc/nllb) (New in v0.3): Allows translations across 200 languages, and is faster than m2m100 (On RTX A6000, we can see speed up of 3x).

Here's an example:
```python
# The following ways are equivalent
mt = dlt.TranslationModel("m2m100") # Default
mt = dlt.TranslationModel("facebook/m2m100_418M")
# The default approval
mt = dlt.TranslationModel("m2m100") # Shorthand
mt = dlt.TranslationModel("facebook/m2m100_418M") # Huggingface repo

# The following ways are equivalent
# If you want to use mBART-50 Large
mt = dlt.TranslationModel("mbart50")
mt = dlt.TranslationModel("facebook/mbart-large-50-many-to-many-mmt")

# Or NLLB-200 (faster and has 200 languages)
mt = dlt.TranslationModel("nllb200")
mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M")
```

Note that the language code will change depending on the model family. To find out the correct language codes, please read the doc page on available languages or run `mt.available_codes()`.

By default, `dlt.TranslationModel` will download the model from the huggingface repo for [mbart50](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt) or [m2m100](https://huggingface.co/facebook/m2m100_418M) and cache it. It's possible to load the model from a path or a model with a similar format, but you will need to specify the `model_family`:
By default, `dlt.TranslationModel` will download the model from the huggingface repo for [mbart50](https://huggingface.co/facebook/mbart-large-50-one-to-many-mmt), [m2m100](https://huggingface.co/facebook/m2m100_418M), or [nllb200](https://huggingface.co/facebook/nllb-200-distilled-600M) and cache it. It's possible to load the model from a path or a model with a similar format, but you will need to specify the `model_family`:
```python
mt = dlt.TranslationModel("/path/to/model/directory/", model_family="mbart50")
mt = dlt.TranslationModel("facebook/m2m100_1.2B", model_family="m2m100")
mt = dlt.TranslationModel("facebook/nllb-200-distilled-600M", model_family="nllb200")
```

Notes:
Expand Down Expand Up @@ -114,8 +124,8 @@ An alternative to `mt.available_languages()` is the `dlt.utils` module. You can

```python
print(dlt.utils.available_languages('mbart50')) # All languages that you can use
print(dlt.utils.available_codes('mbart50')) # Code corresponding to each language accepted
print(dlt.utils.get_lang_code_map('mbart50')) # Dictionary of lang -> code
print(dlt.utils.available_codes('m2m100')) # Code corresponding to each language accepted
print(dlt.utils.get_lang_code_map('nllb200')) # Dictionary of lang -> code
```

### Offline usage
Expand Down Expand Up @@ -159,7 +169,7 @@ If you have knowledge of PyTorch and Huggingface Transformers, you can access ad
* **Interacting with underlying model and tokenizer**: When initializing `model`, you can pass in arguments for the underlying BART model and tokenizer with `model_options` and `tokenizer_options` respectively. You can also access the underlying `transformers` with `mt.get_transformers_model()`.
* **Keyword arguments for the `generate()` method**: When running `mt.translate`, you can also give `generation_options` that is passed to the `generate()` method of the underlying transformer model.

For more information, please visit the [advanced section of the user guide](https://xhluca.github.io/dl-translate/#advanced) (also available in the [readthedocs version](https://dl-translate.readthedocs.io/en/latest/#advanced)).
For more information, please visit the [advanced section of the user guide](https://xhluca.github.io/dl-translate/#advanced).

## Acknowledgement

Expand All @@ -186,6 +196,19 @@ For more information, please visit the [advanced section of the user guide](http
}
```
3. The [no language left behind](https://arxiv.org/abs/2207.04672) model, which extends NMT to 200+ languages. You can cite it here:
```
@misc{nllbteam2022language,
title={No Language Left Behind: Scaling Human-Centered Machine Translation},
author={NLLB Team and Marta R. Costa-jussà and James Cross and Onur Çelebi and Maha Elbayad and Kenneth Heafield and Kevin Heffernan and Elahe Kalbassi and Janice Lam and Daniel Licht and Jean Maillard and Anna Sun and Skyler Wang and Guillaume Wenzek and Al Youngblood and Bapi Akula and Loic Barrault and Gabriel Mejia Gonzalez and Prangthip Hansanti and John Hoffman and Semarley Jarrett and Kaushik Ram Sadagopan and Dirk Rowe and Shannon Spruit and Chau Tran and Pierre Andrews and Necip Fazil Ayan and Shruti Bhosale and Sergey Edunov and Angela Fan and Cynthia Gao and Vedanuj Goswami and Francisco Guzmán and Philipp Koehn and Alexandre Mourachko and Christophe Ropers and Safiyyah Saleem and Holger Schwenk and Jeff Wang},
year={2022},
eprint={2207.04672},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```
`dlt` is a wrapper with useful `utils` to save you time. For huggingface's `transformers`, the following snippet is shown as an example:
```python
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
Expand Down
1 change: 1 addition & 0 deletions demos/nllb_demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells":[{"cell_type":"code","execution_count":1,"metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","execution":{"iopub.execute_input":"2023-07-18T05:15:13.999614Z","iopub.status.busy":"2023-07-18T05:15:13.999228Z","iopub.status.idle":"2023-07-18T05:15:31.978108Z","shell.execute_reply":"2023-07-18T05:15:31.976681Z","shell.execute_reply.started":"2023-07-18T05:15:13.999573Z"},"trusted":true},"outputs":[],"source":["!pip install dl-translate==3.* -q"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:15:31.982361Z","iopub.status.busy":"2023-07-18T05:15:31.981992Z","iopub.status.idle":"2023-07-18T05:16:23.731908Z","shell.execute_reply":"2023-07-18T05:16:23.730776Z","shell.execute_reply.started":"2023-07-18T05:15:31.982327Z"},"trusted":true},"outputs":[],"source":["import dl_translate as dlt\n","\n","mt = dlt.TranslationModel(\"nllb200\")"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:23.734336Z","iopub.status.busy":"2023-07-18T05:16:23.733295Z","iopub.status.idle":"2023-07-18T05:16:28.025038Z","shell.execute_reply":"2023-07-18T05:16:28.023933Z","shell.execute_reply.started":"2023-07-18T05:16:23.734293Z"},"trusted":true},"outputs":[],"source":["text = \"Meta AI has built a single AI model capable of translating across 200 different languages with state-of-the-art quality.\"\n","\n","# The new translation is much faster than before\n","%time print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.FRENCH))"]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:28.028919Z","iopub.status.busy":"2023-07-18T05:16:28.028286Z","iopub.status.idle":"2023-07-18T05:16:28.717521Z","shell.execute_reply":"2023-07-18T05:16:28.716343Z","shell.execute_reply.started":"2023-07-18T05:16:28.028882Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["मेटाएआई एकमेव एआई मॉडलं निर्मितवान्, यत् 200 भिन्नभाषायां अवधीतमतमतमगुणैः अनुवादं कर्तुं समर्थः अस्ति।\n"]}],"source":["# Sanskrit is now available (previously unavailable)\n","print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.SANSKRIT))"]},{"cell_type":"code","execution_count":5,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:28.719596Z","iopub.status.busy":"2023-07-18T05:16:28.719227Z","iopub.status.idle":"2023-07-18T05:16:29.443696Z","shell.execute_reply":"2023-07-18T05:16:29.442668Z","shell.execute_reply.started":"2023-07-18T05:16:28.719560Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["Meta AI hà custruitu un solu mudellu d'AI capace di tradurisce in 200 lingue sfarenti cù qualità di u statu di l'arte.\n"]}],"source":["# Sicilian is now available (previously unavailable)\n","print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.SICILIAN))"]},{"cell_type":"code","execution_count":6,"metadata":{"execution":{"iopub.execute_input":"2023-07-18T05:16:29.447147Z","iopub.status.busy":"2023-07-18T05:16:29.445331Z","iopub.status.idle":"2023-07-18T05:16:30.145637Z","shell.execute_reply":"2023-07-18T05:16:30.144623Z","shell.execute_reply.started":"2023-07-18T05:16:29.447108Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["基於Meta AI 建立咗一個 AI 模型 可以用最先端嘅質量翻譯到 200 個唔同語言\n"]}],"source":["# Yue Chinese is now available (previously unavailable)\n","print(mt.translate(text, source=dlt.lang.nllb200.ENGLISH, target=dlt.lang.nllb200.YUE_CHINESE))"]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.12"}},"nbformat":4,"nbformat_minor":4}
206 changes: 206 additions & 0 deletions dl_translate/_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,209 @@
("Galician", "gl_ES"),
("Slovene", "sl_SI"),
)
_PAIRS_NLLB200 = (
("Acehnese (Arabic script)", "ace_Arab"),
("Acehnese (Latin script)", "ace_Latn"),
("Mesopotamian Arabic", "acm_Arab"),
("Ta'izzi-Adeni Arabic", "acq_Arab"),
("Tunisian Arabic", "aeb_Arab"),
("Afrikaans", "afr_Latn"),
("South Levantine Arabic", "ajp_Arab"),
("Akan", "aka_Latn"),
("Amharic", "amh_Ethi"),
("North Levantine Arabic", "apc_Arab"),
("Modern Standard Arabic", "arb_Arab"),
("Modern Standard Arabic (Romanized)", "arb_Latn"),
("Najdi Arabic", "ars_Arab"),
("Moroccan Arabic", "ary_Arab"),
("Egyptian Arabic", "arz_Arab"),
("Assamese", "asm_Beng"),
("Asturian", "ast_Latn"),
("Awadhi", "awa_Deva"),
("Central Aymara", "ayr_Latn"),
("South Azerbaijani", "azb_Arab"),
("North Azerbaijani", "azj_Latn"),
("Bashkir", "bak_Cyrl"),
("Bambara", "bam_Latn"),
("Balinese", "ban_Latn"),
("Belarusian", "bel_Cyrl"),
("Bemba", "bem_Latn"),
("Bengali", "ben_Beng"),
("Bhojpuri", "bho_Deva"),
("Banjar (Arabic script)", "bjn_Arab"),
("Banjar (Latin script)", "bjn_Latn"),
("Standard Tibetan", "bod_Tibt"),
("Bosnian", "bos_Latn"),
("Buginese", "bug_Latn"),
("Bulgarian", "bul_Cyrl"),
("Catalan", "cat_Latn"),
("Cebuano", "ceb_Latn"),
("Czech", "ces_Latn"),
("Chokwe", "cjk_Latn"),
("Central Kurdish", "ckb_Arab"),
("Crimean Tatar", "crh_Latn"),
("Welsh", "cym_Latn"),
("Danish", "dan_Latn"),
("German", "deu_Latn"),
("Southwestern Dinka", "dik_Latn"),
("Dyula", "dyu_Latn"),
("Dzongkha", "dzo_Tibt"),
("Greek", "ell_Grek"),
("English", "eng_Latn"),
("Esperanto", "epo_Latn"),
("Estonian", "est_Latn"),
("Basque", "eus_Latn"),
("Ewe", "ewe_Latn"),
("Faroese", "fao_Latn"),
("Fijian", "fij_Latn"),
("Finnish", "fin_Latn"),
("Fon", "fon_Latn"),
("French", "fra_Latn"),
("Friulian", "fur_Latn"),
("Nigerian Fulfulde", "fuv_Latn"),
("Scottish Gaelic", "gla_Latn"),
("Irish", "gle_Latn"),
("Galician", "glg_Latn"),
("Guarani", "grn_Latn"),
("Gujarati", "guj_Gujr"),
("Haitian Creole", "hat_Latn"),
("Hausa", "hau_Latn"),
("Hebrew", "heb_Hebr"),
("Hindi", "hin_Deva"),
("Chhattisgarhi", "hne_Deva"),
("Croatian", "hrv_Latn"),
("Hungarian", "hun_Latn"),
("Armenian", "hye_Armn"),
("Igbo", "ibo_Latn"),
("Ilocano", "ilo_Latn"),
("Indonesian", "ind_Latn"),
("Icelandic", "isl_Latn"),
("Italian", "ita_Latn"),
("Javanese", "jav_Latn"),
("Japanese", "jpn_Jpan"),
("Kabyle", "kab_Latn"),
("Jingpho", "kac_Latn"),
("Kamba", "kam_Latn"),
("Kannada", "kan_Knda"),
("Kashmiri (Arabic script)", "kas_Arab"),
("Kashmiri (Devanagari script)", "kas_Deva"),
("Georgian", "kat_Geor"),
("Central Kanuri (Arabic script)", "knc_Arab"),
("Central Kanuri (Latin script)", "knc_Latn"),
("Kazakh", "kaz_Cyrl"),
("Kabiyè", "kbp_Latn"),
("Kabuverdianu", "kea_Latn"),
("Khmer", "khm_Khmr"),
("Kikuyu", "kik_Latn"),
("Kinyarwanda", "kin_Latn"),
("Kyrgyz", "kir_Cyrl"),
("Kimbundu", "kmb_Latn"),
("Northern Kurdish", "kmr_Latn"),
("Kikongo", "kon_Latn"),
("Korean", "kor_Hang"),
("Lao", "lao_Laoo"),
("Ligurian", "lij_Latn"),
("Limburgish", "lim_Latn"),
("Lingala", "lin_Latn"),
("Lithuanian", "lit_Latn"),
("Lombard", "lmo_Latn"),
("Latgalian", "ltg_Latn"),
("Luxembourgish", "ltz_Latn"),
("Luba-Kasai", "lua_Latn"),
("Ganda", "lug_Latn"),
("Luo", "luo_Latn"),
("Mizo", "lus_Latn"),
("Standard Latvian", "lvs_Latn"),
("Magahi", "mag_Deva"),
("Maithili", "mai_Deva"),
("Malayalam", "mal_Mlym"),
("Marathi", "mar_Deva"),
("Minangkabau (Arabic script)", "min_Arab"),
("Minangkabau (Latin script)", "min_Latn"),
("Macedonian", "mkd_Cyrl"),
("Plateau Malagasy", "plt_Latn"),
("Maltese", "mlt_Latn"),
("Meitei (Bengali script)", "mni_Beng"),
("Halh Mongolian", "khk_Cyrl"),
("Mossi", "mos_Latn"),
("Maori", "mri_Latn"),
("Burmese", "mya_Mymr"),
("Dutch", "nld_Latn"),
("Norwegian Nynorsk", "nno_Latn"),
("Norwegian Bokmål", "nob_Latn"),
("Nepali", "npi_Deva"),
("Northern Sotho", "nso_Latn"),
("Nuer", "nus_Latn"),
("Nyanja", "nya_Latn"),
("Occitan", "oci_Latn"),
("West Central Oromo", "gaz_Latn"),
("Odia", "ory_Orya"),
("Pangasinan", "pag_Latn"),
("Eastern Panjabi", "pan_Guru"),
("Papiamento", "pap_Latn"),
("Western Persian", "pes_Arab"),
("Polish", "pol_Latn"),
("Portuguese", "por_Latn"),
("Dari", "prs_Arab"),
("Southern Pashto", "pbt_Arab"),
("Ayacucho Quechua", "quy_Latn"),
("Romanian", "ron_Latn"),
("Rundi", "run_Latn"),
("Russian", "rus_Cyrl"),
("Sango", "sag_Latn"),
("Sanskrit", "san_Deva"),
("Santali", "sat_Olck"),
("Sicilian", "scn_Latn"),
("Shan", "shn_Mymr"),
("Sinhala", "sin_Sinh"),
("Slovak", "slk_Latn"),
("Slovenian", "slv_Latn"),
("Samoan", "smo_Latn"),
("Shona", "sna_Latn"),
("Sindhi", "snd_Arab"),
("Somali", "som_Latn"),
("Southern Sotho", "sot_Latn"),
("Spanish", "spa_Latn"),
("Tosk Albanian", "als_Latn"),
("Sardinian", "srd_Latn"),
("Serbian", "srp_Cyrl"),
("Swati", "ssw_Latn"),
("Sundanese", "sun_Latn"),
("Swedish", "swe_Latn"),
("Swahili", "swh_Latn"),
("Silesian", "szl_Latn"),
("Tamil", "tam_Taml"),
("Tatar", "tat_Cyrl"),
("Telugu", "tel_Telu"),
("Tajik", "tgk_Cyrl"),
("Tagalog", "tgl_Latn"),
("Thai", "tha_Thai"),
("Tigrinya", "tir_Ethi"),
("Tamasheq (Latin script)", "taq_Latn"),
("Tamasheq (Tifinagh script)", "taq_Tfng"),
("Tok Pisin", "tpi_Latn"),
("Tswana", "tsn_Latn"),
("Tsonga", "tso_Latn"),
("Turkmen", "tuk_Latn"),
("Tumbuka", "tum_Latn"),
("Turkish", "tur_Latn"),
("Twi", "twi_Latn"),
("Central Atlas Tamazight", "tzm_Tfng"),
("Uyghur", "uig_Arab"),
("Ukrainian", "ukr_Cyrl"),
("Umbundu", "umb_Latn"),
("Urdu", "urd_Arab"),
("Northern Uzbek", "uzn_Latn"),
("Venetian", "vec_Latn"),
("Vietnamese", "vie_Latn"),
("Waray", "war_Latn"),
("Wolof", "wol_Latn"),
("Xhosa", "xho_Latn"),
("Eastern Yiddish", "ydd_Hebr"),
("Yoruba", "yor_Latn"),
("Yue Chinese", "yue_Hant"),
("Chinese (Simplified)", "zho_Hans"),
("Chinese (Traditional)", "zho_Hant"),
("Standard Malay", "zsm_Latn"),
("Zulu", "zul_Latn"),
)
Loading

0 comments on commit 66b63b4

Please sign in to comment.