From 4bc9cb36b70bc90e3ecbea750a68ec2a95853633 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Wed, 1 May 2024 12:33:12 +0500 Subject: [PATCH] Fix Marian model conversion (#30173) * fix marian model coversion * uncomment that line * remove unnecessary code * revert tie_weights, doesn't hurt --- .../models/marian/convert_marian_tatoeba_to_pytorch.py | 7 +++++-- .../models/marian/convert_marian_to_pytorch.py | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/marian/convert_marian_tatoeba_to_pytorch.py b/src/transformers/models/marian/convert_marian_tatoeba_to_pytorch.py index f6b548c2b07f..40ad3294097c 100644 --- a/src/transformers/models/marian/convert_marian_tatoeba_to_pytorch.py +++ b/src/transformers/models/marian/convert_marian_tatoeba_to_pytorch.py @@ -34,7 +34,6 @@ DEFAULT_REPO = "Tatoeba-Challenge" DEFAULT_MODEL_DIR = os.path.join(DEFAULT_REPO, "models") -LANG_CODE_URL = "https://datahub.io/core/language-codes/r/language-codes-3b2.csv" ISO_URL = "https://cdn-datasets.huggingface.co/language_codes/iso-639-3.csv" ISO_PATH = "lang_code_data/iso-639-3.csv" LANG_CODE_PATH = "lang_code_data/language-codes-3b2.csv" @@ -277,13 +276,17 @@ def write_model_card(self, model_dict, dry_run=False) -> str: json.dump(metadata, writeobj) def download_lang_info(self): + global LANG_CODE_PATH Path(LANG_CODE_PATH).parent.mkdir(exist_ok=True) import wget + from huggingface_hub import hf_hub_download if not os.path.exists(ISO_PATH): wget.download(ISO_URL, ISO_PATH) if not os.path.exists(LANG_CODE_PATH): - wget.download(LANG_CODE_URL, LANG_CODE_PATH) + LANG_CODE_PATH = hf_hub_download( + repo_id="huggingface/language_codes_marianMT", filename="language-codes-3b2.csv", repo_type="dataset" + ) def parse_metadata(self, model_name, repo_path=DEFAULT_MODEL_DIR, method="best"): p = Path(repo_path) / model_name diff --git a/src/transformers/models/marian/convert_marian_to_pytorch.py b/src/transformers/models/marian/convert_marian_to_pytorch.py index 79afd50955dd..593162ffe674 100644 --- a/src/transformers/models/marian/convert_marian_to_pytorch.py +++ b/src/transformers/models/marian/convert_marian_to_pytorch.py @@ -622,6 +622,10 @@ def load_marian_model(self) -> MarianMTModel: bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias)) model.model.decoder.embed_tokens.weight = decoder_wemb_tensor + # handle tied embeddings, otherwise "from_pretrained" loads them incorrectly + if self.cfg["tied-embeddings"]: + model.lm_head.weight.data = model.model.decoder.embed_tokens.weight.data.clone() + model.final_logits_bias = bias_tensor if "Wpos" in state_dict: