From dc55fc95933df9f484ff1982a13bb876f641a46e Mon Sep 17 00:00:00 2001 From: ZheyuYe Date: Tue, 14 Jul 2020 05:45:01 +0800 Subject: [PATCH] tiny update on run_squad --- scripts/question_answering/run_squad.py | 3 ++- src/gluonnlp/models/roberta.py | 10 ++++++---- src/gluonnlp/models/xlmr.py | 8 ++++---- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/scripts/question_answering/run_squad.py b/scripts/question_answering/run_squad.py index ef4697bbdc..069b8c2a75 100644 --- a/scripts/question_answering/run_squad.py +++ b/scripts/question_answering/run_squad.py @@ -808,9 +808,10 @@ def evaluate(args, last=True): args.comm_backend, args.gpus) # only evaluate once if rank != 0: + logging.info('Skipping node {}'.format(rank)) return ctx_l = parse_ctx(args.gpus) - logging.info('Srarting inference without horovod on the first node') + logging.info('Srarting inference without horovod on the first node on device {}'.format(str(ctx_l))) cfg, tokenizer, qa_net, use_segmentation = get_network( args.model_name, ctx_l, args.classifier_dropout) diff --git a/src/gluonnlp/models/roberta.py b/src/gluonnlp/models/roberta.py index 94b940a1fe..2eb6afec3f 100644 --- a/src/gluonnlp/models/roberta.py +++ b/src/gluonnlp/models/roberta.py @@ -62,8 +62,8 @@ 'merges': 'fairseq_roberta_large/gpt2-396d4d8e.merges', 'vocab': 'fairseq_roberta_large/gpt2-f1335494.vocab', 'params': 'fairseq_roberta_large/model-6b043b91.params', - 'mlm_params': 'fairseq_roberta_large/model_mlm-119f38e1.params' - }, + 'mlm_params': 'fairseq_roberta_large/model_mlm-119f38e1.params', + } } FILE_STATS = load_checksum_stats(os.path.join(get_model_zoo_checksum_dir(), 'roberta.txt')) @@ -524,11 +524,13 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base', """ assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format( model_name, list_pretrained_roberta()) - cfg_path = PRETRAINED_URL[model_name]['cfg'] + cfg_path = PRETRAINED_URL[model_name + ]['cfg'] merges_path = PRETRAINED_URL[model_name]['merges'] vocab_path = PRETRAINED_URL[model_name]['vocab'] params_path = PRETRAINED_URL[model_name]['params'] mlm_params_path = PRETRAINED_URL[model_name]['mlm_params'] + local_paths = dict() for k, path in [('cfg', cfg_path), ('vocab', vocab_path), ('merges', merges_path)]: @@ -541,7 +543,7 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base', sha1_hash=FILE_STATS[params_path]) else: local_params_path = None - if load_mlm: + if load_mlm and mlm_params_path is not None: local_mlm_params_path = download(url=get_repo_model_zoo_url() + mlm_params_path, path=os.path.join(root, mlm_params_path), sha1_hash=FILE_STATS[mlm_params_path]) diff --git a/src/gluonnlp/models/xlmr.py b/src/gluonnlp/models/xlmr.py index d89ab8daf0..5a40a1e461 100644 --- a/src/gluonnlp/models/xlmr.py +++ b/src/gluonnlp/models/xlmr.py @@ -44,19 +44,19 @@ 'cfg': 'fairseq_xlmr_base/model-b893d178.yml', 'sentencepiece.model': 'fairseq_xlmr_base/sentencepiece-18e17bae.model', 'params': 'fairseq_xlmr_base/model-3fa134e9.params', - 'mlm_params': 'model_mlm-86e37954.params' + 'mlm_params': 'model_mlm-86e37954.params', }, 'fairseq_xlmr_large': { 'cfg': 'fairseq_xlmr_large/model-01fc59fb.yml', 'sentencepiece.model': 'fairseq_xlmr_large/sentencepiece-18e17bae.model', 'params': 'fairseq_xlmr_large/model-b62b074c.params', - 'mlm_params': 'model_mlm-887506c2.params' + 'mlm_params': 'model_mlm-887506c2.params', } } FILE_STATS = load_checksum_stats(os.path.join(get_model_zoo_checksum_dir(), 'xlmr.txt')) -xlmr_cfg_reg = Registry('roberta_cfg') +xlmr_cfg_reg = Registry('xlmr_cfg') @xlmr_cfg_reg.register() @@ -139,7 +139,7 @@ def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base', sha1_hash=FILE_STATS[params_path]) else: local_params_path = None - if load_mlm: + if load_mlm and mlm_params_path is not None: local_mlm_params_path = download(url=get_repo_model_zoo_url() + mlm_params_path, path=os.path.join(root, mlm_params_path), sha1_hash=FILE_STATS[mlm_params_path])