From 8ed8a721395f302df03e473dd4393cffcb9748af Mon Sep 17 00:00:00 2001 From: ZheyuYe Date: Mon, 13 Jul 2020 22:28:13 +0800 Subject: [PATCH] re-upload roberta --- .../convert_fairseq_roberta.py | 12 +------ .../convert_fairseq_xlmr.py | 35 ++++++++++++------- .../models/model_zoo_checksums/roberta.txt | 22 ++++++------ src/gluonnlp/models/roberta.py | 14 +++----- src/gluonnlp/models/xlmr.py | 8 +++-- 5 files changed, 43 insertions(+), 48 deletions(-) diff --git a/scripts/conversion_toolkits/convert_fairseq_roberta.py b/scripts/conversion_toolkits/convert_fairseq_roberta.py index 2fd40fa0b9..d3923e7104 100644 --- a/scripts/conversion_toolkits/convert_fairseq_roberta.py +++ b/scripts/conversion_toolkits/convert_fairseq_roberta.py @@ -322,17 +322,7 @@ def test_model(fairseq_model, gluon_model, gpu): 1E-3, 1E-3 ) - - gl_mlm_scores = gl_mlm_scores.asnumpy() - fs_mlm_scores = fs_mlm_scores.transpose(0, 1) - fs_mlm_scores = fs_mlm_scores.detach().cpu().numpy() - for j in range(batch_size): - assert_allclose( - gl_mlm_scores[j, :valid_length[j], :], - fs_mlm_scores[j, :valid_length[j], :], - 1E-3, - 1E-3 - ) + #TODO(zheyuye), checking the masking scores def rename(save_dir): """Rename converted files with hash""" diff --git a/scripts/conversion_toolkits/convert_fairseq_xlmr.py b/scripts/conversion_toolkits/convert_fairseq_xlmr.py index 334dda2f1b..34b6cf5341 100644 --- a/scripts/conversion_toolkits/convert_fairseq_xlmr.py +++ b/scripts/conversion_toolkits/convert_fairseq_xlmr.py @@ -6,7 +6,7 @@ import mxnet as mx from gluonnlp.utils.misc import logging_config -from gluonnlp.models.xlmr import XLMRModel as gluon_XLMRModel +from gluonnlp.models.xlmr import XLMRModel, XLMRForMLM from gluonnlp.third_party import sentencepiece_model_pb2 from fairseq.models.roberta import XLMRModel as fairseq_XLMRModel from convert_fairseq_roberta import rename, test_model, test_vocab, convert_config, convert_params @@ -88,23 +88,32 @@ def convert_fairseq_model(args): vocab_size = convert_vocab(args, fairseq_xlmr) gluon_cfg = convert_config(fairseq_xlmr.args, vocab_size, - gluon_XLMRModel.get_cfg().clone()) + XLMRModel.get_cfg().clone()) with open(os.path.join(args.save_dir, 'model.yml'), 'w') as of: of.write(gluon_cfg.dump()) ctx = mx.gpu(args.gpu) if args.gpu is not None else mx.cpu() - gluon_xlmr = convert_params(fairseq_xlmr, - gluon_cfg, - gluon_XLMRModel, - ctx, - gluon_prefix='xlmr_') - if args.test: - test_model(fairseq_xlmr, gluon_xlmr, args.gpu) + for is_mlm in [False, True]: + gluon_xlmr = convert_params(fairseq_roberta, + gluon_cfg, + ctx, + is_mlm=is_mlm, + gluon_prefix='roberta_') + + if is_mlm: + if args.test: + test_model(fairseq_roberta, gluon_xlmr, args.gpu) + + gluon_xlmr.save_parameters(os.path.join(args.save_dir, 'model_mlm.params'), deduplicate=True) + logging.info('Convert the RoBERTa MLM model in {} to {}'. + format(os.path.join(args.fairseq_model_path, 'model.pt'), \ + os.path.join(args.save_dir, 'model_mlm.params'))) + else: + gluon_xlmr.save_parameters(os.path.join(args.save_dir, 'model.params'), deduplicate=True) + logging.info('Convert the RoBERTa backbone model in {} to {}'. + format(os.path.join(args.fairseq_model_path, 'model.pt'), \ + os.path.join(args.save_dir, 'model.params'))) - gluon_xlmr.save_parameters(os.path.join(args.save_dir, 'model.params'), deduplicate=True) - logging.info('Convert the XLM-R model in {} to {}'. - format(os.path.join(args.fairseq_model_path, 'model.pt'), \ - os.path.join(args.save_dir, 'model.params'))) logging.info('Conversion finished!') logging.info('Statistics:') rename(args.save_dir) diff --git a/src/gluonnlp/models/model_zoo_checksums/roberta.txt b/src/gluonnlp/models/model_zoo_checksums/roberta.txt index 4e4f9efe6d..6de6e8ce5f 100644 --- a/src/gluonnlp/models/model_zoo_checksums/roberta.txt +++ b/src/gluonnlp/models/model_zoo_checksums/roberta.txt @@ -1,12 +1,10 @@ -fairseq_roberta_base/model-565d1db7.yml 565d1db71b0452fa2c28f155b8e9d90754f4f40a 401 -fairseq_roberta_base/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 -fairseq_roberta_base/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231 -fairseq_roberta_base/model-98b4532f.params 98b4532fe59e6fd755422057fde4601b3eb8fbf0 498792661 -fairseq_roberta_large/model-6e66dc4a.yml 6e66dc4a450560a93aaf3d0ba9e0d447495d778a 402 -fairseq_roberta_large/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 -fairseq_roberta_large/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231 -fairseq_roberta_large/model-e3f578dc.params e3f578dc669cf36fa5b6730b0bbee77c980276d7 1421659773 -fairseq_roberta_large_mnli/model-6e66dc4a.yml 6e66dc4a450560a93aaf3d0ba9e0d447495d778a 402 -fairseq_roberta_large_mnli/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 -fairseq_roberta_large_mnli/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231 -fairseq_roberta_large_mnli/model-5288bb09.params 5288bb09db89b7900e85c9d673686f748f0abd56 1421659773 \ No newline at end of file +fairseq_roberta_base/model-565d1db7.yml 565d1db71b0452fa2c28f155b8e9d90754f4f40a 401 +fairseq_roberta_base/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 +fairseq_roberta_base/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231 +fairseq_roberta_base/model-09a1520a.params 09a1520adf652468c07e43a6ed28908418fa58a7 496222787 +fairseq_roberta_base/model_mlm-29889e2b.params 29889e2b4ef20676fda117bb7b754e1693d0df25 498794868 +fairseq_roberta_large/model-6b043b91.params 6b043b91a6a781a12ea643d0644d32300db38ec8 1417251819 +fairseq_roberta_large/gpt2-396d4d8e.merges 396d4d8ec90cb02f4d56e049e0e4add868bcd943 456318 +fairseq_roberta_large/model-6e66dc4a.yml 6e66dc4a450560a93aaf3d0ba9e0d447495d778a 402 +fairseq_roberta_large/gpt2-f1335494.vocab f1335494f47917829e3b1d08e579ff2c3fe4fd60 558231 +fairseq_roberta_large/model_mlm-119f38e1.params 119f38e1249bd28bea7dd2e90c09b8f4b879fa19 1421664140 diff --git a/src/gluonnlp/models/roberta.py b/src/gluonnlp/models/roberta.py index 6f2ae2c6f8..aab99047cb 100644 --- a/src/gluonnlp/models/roberta.py +++ b/src/gluonnlp/models/roberta.py @@ -27,7 +27,7 @@ } """ -__all__ = ['RobertaModel', 'list_pretrained_roberta', 'get_pretrained_roberta'] +__all__ = ['RobertaModel', 'RobertaForMLM', 'list_pretrained_roberta', 'get_pretrained_roberta'] import os from typing import Tuple @@ -54,20 +54,16 @@ 'cfg': 'fairseq_roberta_base/model-565d1db7.yml', 'merges': 'fairseq_roberta_base/gpt2-396d4d8e.merges', 'vocab': 'fairseq_roberta_base/gpt2-f1335494.vocab', - 'params': 'fairseq_roberta_base/model-98b4532f.params' + 'params': 'fairseq_roberta_base/model-09a1520a.params' + 'mlm_params': 'google_uncased_mobilebert/model_mlm-29889e2b.params', }, 'fairseq_roberta_large': { 'cfg': 'fairseq_roberta_large/model-6e66dc4a.yml', 'merges': 'fairseq_roberta_large/gpt2-396d4d8e.merges', 'vocab': 'fairseq_roberta_large/gpt2-f1335494.vocab', - 'params': 'fairseq_roberta_large/model-e3f578dc.params' + 'params': 'fairseq_roberta_large/model-6b043b91.params', + 'mlm_params': 'fairseq_roberta_large/model_mlm-119f38e1.params' }, - 'fairseq_roberta_large_mnli': { - 'cfg': 'fairseq_roberta_large_mnli/model-6e66dc4a.yml', - 'merges': 'fairseq_roberta_large_mnli/gpt2-396d4d8e.merges', - 'vocab': 'fairseq_roberta_large_mnli/gpt2-f1335494.vocab', - 'params': 'fairseq_roberta_large_mnli/model-5288bb09.params' - } } FILE_STATS = load_checksum_stats(os.path.join(get_model_zoo_checksum_dir(), 'roberta.txt')) diff --git a/src/gluonnlp/models/xlmr.py b/src/gluonnlp/models/xlmr.py index 7afca7a1ff..b6f5b156d8 100644 --- a/src/gluonnlp/models/xlmr.py +++ b/src/gluonnlp/models/xlmr.py @@ -25,12 +25,12 @@ } """ -__all__ = ['XLMRModel', 'list_pretrained_xlmr', 'get_pretrained_xlmr'] +__all__ = ['XLMRModel', 'XLMRForMLM', 'list_pretrained_xlmr', 'get_pretrained_xlmr'] from typing import Tuple import os from mxnet import use_np -from .roberta import RobertaModel, roberta_base, roberta_large +from .roberta import RobertaModel, RobertaForMLM roberta_base, roberta_large from ..base import get_model_zoo_home_dir, get_repo_model_zoo_url, get_model_zoo_checksum_dir from ..utils.config import CfgNode as CN from ..utils.registry import Registry @@ -82,7 +82,9 @@ def get_cfg(key=None): return xlmr_cfg_reg.create(key) else: return xlmr_base() - +@use_np +class XLMRForMLM(RobertaForMLM): + super().__init__() def list_pretrained_xlmr(): return sorted(list(PRETRAINED_URL.keys()))