Skip to content

Commit

Permalink
Fix Marian model conversion (#30173)
Browse files Browse the repository at this point in the history
* fix marian model coversion

* uncomment that line

* remove unnecessary code

* revert tie_weights, doesn't hurt
  • Loading branch information
zucchini-nlp authored May 1, 2024
1 parent 38a4bf7 commit 4bc9cb3
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@

DEFAULT_REPO = "Tatoeba-Challenge"
DEFAULT_MODEL_DIR = os.path.join(DEFAULT_REPO, "models")
LANG_CODE_URL = "https://datahub.io/core/language-codes/r/language-codes-3b2.csv"
ISO_URL = "https://cdn-datasets.huggingface.co/language_codes/iso-639-3.csv"
ISO_PATH = "lang_code_data/iso-639-3.csv"
LANG_CODE_PATH = "lang_code_data/language-codes-3b2.csv"
Expand Down Expand Up @@ -277,13 +276,17 @@ def write_model_card(self, model_dict, dry_run=False) -> str:
json.dump(metadata, writeobj)

def download_lang_info(self):
global LANG_CODE_PATH
Path(LANG_CODE_PATH).parent.mkdir(exist_ok=True)
import wget
from huggingface_hub import hf_hub_download

if not os.path.exists(ISO_PATH):
wget.download(ISO_URL, ISO_PATH)
if not os.path.exists(LANG_CODE_PATH):
wget.download(LANG_CODE_URL, LANG_CODE_PATH)
LANG_CODE_PATH = hf_hub_download(
repo_id="huggingface/language_codes_marianMT", filename="language-codes-3b2.csv", repo_type="dataset"
)

def parse_metadata(self, model_name, repo_path=DEFAULT_MODEL_DIR, method="best"):
p = Path(repo_path) / model_name
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/marian/convert_marian_to_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,10 @@ def load_marian_model(self) -> MarianMTModel:
bias_tensor = nn.Parameter(torch.FloatTensor(self.final_bias))
model.model.decoder.embed_tokens.weight = decoder_wemb_tensor

# handle tied embeddings, otherwise "from_pretrained" loads them incorrectly
if self.cfg["tied-embeddings"]:
model.lm_head.weight.data = model.model.decoder.embed_tokens.weight.data.clone()

model.final_logits_bias = bias_tensor

if "Wpos" in state_dict:
Expand Down

0 comments on commit 4bc9cb3

Please sign in to comment.