From 3c874575bf40e8b1fa2280371131a8f29ebb3e98 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Tue, 28 Jul 2020 18:03:21 -0700 Subject: [PATCH] Add layout + compute_layout support: TransformerNMT, BERT, ALBERT, ELECTRA, MobileBERT, RoBERTA, XLMR (#1258) * Add layout support * fix test * Update transformer.py * Update transformer.py * Update README.md * try to add set_layout * update test case * fix * update * update * update * Update bert.py * fix bug * update * Update test_models_bert.py * Update tokenizers.py * add compute layout * Update xlmr.py * Update test_models_bert.py * revise test cases * Update layers.py * move jieba to try import * fix * Update transformer.py * fix * Update bert.py * Update setup.py * Update test_models_bert.py * Update test_models_bert.py * fix * update * Revise * Update electra.py * Update electra.py * Update test_models_electra.py * fix * fix bug * Update test_models_albert.py * add more testcases * fix * Update albert.py * Update albert.py * fix bug * fix testcase * Update test_models_electra.py * Update bert.py * update * Update test_models_electra.py * Update mobilebert.py * Update mobilebert.py * update mobilebert * Update test_models_mobilebert.py * Update mobilebert.py * fix bug * Update roberta.py * fix roberta * update * update * fix import * fix bug * update * reduce test workloads * address comment * address comment --- README.md | 9 +- scripts/conversion_toolkits/README.md | 3 +- setup.py | 2 + src/gluonnlp/attention_cell.py | 79 +++-- src/gluonnlp/data/tokenizers.py | 20 +- src/gluonnlp/layers.py | 5 +- src/gluonnlp/models/albert.py | 318 ++++++++++++++----- src/gluonnlp/models/bert.py | 364 ++++++++++++++++----- src/gluonnlp/models/electra.py | 424 +++++++++++++++++++------ src/gluonnlp/models/mobilebert.py | 434 ++++++++++++++++++-------- src/gluonnlp/models/roberta.py | 309 +++++++++++------- src/gluonnlp/models/transformer.py | 430 ++++++++++++++++++------- src/gluonnlp/models/transformer_xl.py | 9 +- src/gluonnlp/models/xlmr.py | 52 +-- src/gluonnlp/utils/testing.py | 152 ++++++--- tests/test_attention_cell.py | 51 ++- tests/test_models_albert.py | 68 +++- tests/test_models_bert.py | 78 ++++- tests/test_models_electra.py | 59 +++- tests/test_models_mobilebert.py | 78 ++++- tests/test_models_roberta.py | 54 ++++ tests/test_models_transformer.py | 43 ++- tests/test_models_xlmr.py | 4 +- 23 files changed, 2280 insertions(+), 765 deletions(-) diff --git a/README.md b/README.md index 34fc069cbc..65b877451a 100644 --- a/README.md +++ b/README.md @@ -19,12 +19,17 @@ This is a work-in-progress. First of all, install the latest MXNet. You may use the following commands: ```bash +# Install the version with CUDA 10.0 +pip install -U --pre "mxnet-cu100>=2.0.0b20200716" -f https://dist.mxnet.io/python # Install the version with CUDA 10.1 -pip install -U --pre mxnet-cu101>=2.0.0b20200716 -f https://dist.mxnet.io/python +pip install -U --pre "mxnet-cu101>=2.0.0b20200716" -f https://dist.mxnet.io/python + +# Install the version with CUDA 10.2 +pip install -U --pre "mxnet-cu102>=2.0.0b20200716" -f https://dist.mxnet.io/python # Install the cpu-only version -pip install -U --pre mxnet>=2.0.0b20200716 -f https://dist.mxnet.io/python +pip install -U --pre "mxnet>=2.0.0b20200716" -f https://dist.mxnet.io/python ``` diff --git a/scripts/conversion_toolkits/README.md b/scripts/conversion_toolkits/README.md index be8bc8eff3..2c29e87db7 100644 --- a/scripts/conversion_toolkits/README.md +++ b/scripts/conversion_toolkits/README.md @@ -75,8 +75,7 @@ Notice: pleas set up the `--electra_path` with the cloned path or get this elect ```bash # Need to use TF 1.13.2 to use contrib layer -pip uninstall tensorflow -pip install tensorflow==1.13.2 +pip install tensorflow==1.13.2 --upgrade --force-reinstall # Actual conversion bash convert_electra.sh diff --git a/setup.py b/setup.py index 29cbc0c029..3de80f5695 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,8 @@ def find_version(*file_paths): 'scripts', )), package_dir={"": "src"}, + package_data={'': [os.path.join('models', 'model_zoo_checksums', '*.txt'), + os.path.join('cli', 'data', 'url_checksums', '*.txt')]}, zip_safe=True, include_package_data=True, install_requires=requirements, diff --git a/src/gluonnlp/attention_cell.py b/src/gluonnlp/attention_cell.py index c5288ae087..4773f81d46 100644 --- a/src/gluonnlp/attention_cell.py +++ b/src/gluonnlp/attention_cell.py @@ -33,7 +33,8 @@ def gen_self_attn_mask(F, data, valid_length=None, dtype: type = np.float32, - attn_type: str = 'full'): + attn_type: str = 'full', + layout: str = 'NT'): """Generate the mask used for the encoder, i.e, self-attention. In our implementation, 1 --> not masked, 0 --> masked @@ -100,25 +101,37 @@ def gen_self_attn_mask(F, data, Parameters ---------- - F : - data : - The data. Shape (batch_size, seq_length, C) - valid_length : + F + data + The data. + - layout = 'NT' + Shape (batch_size, seq_length, C) + - layout = 'TN' + Shape (seq_length, batch_size, C) + valid_length Shape (batch_size,) dtype Data type of the mask - attn_type : str + attn_type Can be 'full' or 'causal' + layout + The layout of the data Returns ------- mask Shape (batch_size, seq_length, seq_length) """ + if layout == 'NT': + batch_axis, time_axis = 0, 1 + elif layout == 'TN': + batch_axis, time_axis = 1, 0 + else: + raise NotImplementedError('Unsupported layout={}'.format(layout)) if attn_type == 'full': if valid_length is not None: valid_length = valid_length.astype(dtype) - steps = F.npx.arange_like(data, axis=1) # (seq_length,) + steps = F.npx.arange_like(data, axis=time_axis) # (seq_length,) mask1 = (F.npx.reshape(steps, (1, 1, -1)) < F.npx.reshape(valid_length, (-2, 1, 1))) mask2 = (F.npx.reshape(steps, (1, -1, 1)) @@ -126,12 +139,12 @@ def gen_self_attn_mask(F, data, mask = mask1 * mask2 else: # TODO(sxjscience) optimize - seq_len_ones = F.np.ones_like(F.npx.arange_like(data, axis=1)) # (seq_length,) - batch_ones = F.np.ones_like(F.npx.arange_like(data, axis=0)) # (batch_size,) + seq_len_ones = F.np.ones_like(F.npx.arange_like(data, axis=time_axis)) # (seq_length,) + batch_ones = F.np.ones_like(F.npx.arange_like(data, axis=batch_axis)) # (batch_size,) mask = batch_ones.reshape((-1, 1, 1)) * seq_len_ones.reshape((1, -1, 1))\ * seq_len_ones.reshape((1, 1, -1)) elif attn_type == 'causal': - steps = F.npx.arange_like(data, axis=1) + steps = F.npx.arange_like(data, axis=time_axis) # mask: (seq_length, seq_length) # batch_mask: (batch_size, seq_length) mask = (F.np.expand_dims(steps, axis=0) <= F.np.expand_dims(steps, axis=1)).astype(dtype) @@ -140,7 +153,8 @@ def gen_self_attn_mask(F, data, batch_mask = (F.np.expand_dims(steps, axis=0) < F.np.expand_dims(valid_length, axis=-1)).astype(dtype) mask = mask * F.np.expand_dims(batch_mask, axis=-1) else: - batch_ones = F.np.ones_like(F.npx.arange_like(data, axis=0), dtype=np.float32) # (batch_size,) + batch_ones = F.np.ones_like(F.npx.arange_like(data, axis=batch_axis), + dtype=dtype) # (batch_size,) mask = mask * batch_ones.reshape((-1, 1, 1)) else: raise NotImplementedError @@ -148,7 +162,8 @@ def gen_self_attn_mask(F, data, return mask -def gen_mem_attn_mask(F, mem, mem_valid_length, data, data_valid_length=None, dtype=np.float32): +def gen_mem_attn_mask(F, mem, mem_valid_length, data, data_valid_length=None, + dtype=np.float32, layout: str = 'NT'): """Generate the mask used for the decoder. All query slots are attended to the memory slots. In our implementation, 1 --> not masked, 0 --> masked @@ -183,34 +198,48 @@ def gen_mem_attn_mask(F, mem, mem_valid_length, data, data_valid_length=None, dt Parameters ---------- F : - mem : - Shape (batch_size, mem_length, C_mem) + mem + - layout = 'NT' + Shape (batch_size, mem_length, C_mem) + - layout = 'TN' + Shape (mem_length, batch_size, C_mem) mem_valid_length : Shape (batch_size,) - data : - Shape (batch_size, query_length, C_data) + data + - layout = 'NT' + Shape (batch_size, query_length, C_data) + - layout = 'TN' + Shape (query_length, batch_size, C_data) data_valid_length : Shape (batch_size,) - dtype : type + dtype Data type of the mask + layout + Layout of the data + mem tensor Returns ------- mask : Shape (batch_size, query_length, mem_length) """ + if layout == 'NT': + batch_axis, time_axis = 0, 1 + elif layout == 'TN': + batch_axis, time_axis = 1, 0 + else: + raise NotImplementedError('Unsupported layout={}'.format(layout)) mem_valid_length = mem_valid_length.astype(dtype) - mem_steps = F.npx.arange_like(mem, axis=1) # (mem_length,) + mem_steps = F.npx.arange_like(mem, axis=time_axis) # (mem_length,) + data_steps = F.npx.arange_like(data, axis=time_axis) # (query_length,) mem_mask = (F.npx.reshape(mem_steps, (1, 1, -1)) < F.npx.reshape(mem_valid_length, (-2, 1, 1))).astype(dtype) # (B, 1, mem_length) if data_valid_length is not None: data_valid_length = data_valid_length.astype(dtype) - data_steps = F.npx.arange_like(data, axis=1) # (query_length,) data_mask = (F.npx.reshape(data_steps, (1, -1, 1)) < F.npx.reshape(data_valid_length, (-2, 1, 1))).astype(dtype) # (B, query_length, 1) mask = mem_mask * data_mask else: - query_length_ones = F.np.ones_like(F.npx.arange_like(data, axis=1)) # (query_length,) + query_length_ones = F.np.ones_like(data_steps) mask = query_length_ones.reshape((1, -1, 1)) * mem_mask return mask @@ -594,6 +623,7 @@ def __init__(self, query_units=None, num_heads=None, attention_dropout=0.0, self._normalized = normalized self._eps = eps self._dtype = dtype + assert layout in ['NTK', 'NKT', 'TNK'] self._layout = layout self._use_einsum = use_einsum if self._query_units is not None: @@ -604,6 +634,10 @@ def __init__(self, query_units=None, num_heads=None, attention_dropout=0.0, else: self._query_head_units = None + @property + def layout(self): + return self._layout + def hybrid_forward(self, F, query, key, value, mask=None, edge_scores=None): return multi_head_dot_attn(F, query=query, key=key, value=value, mask=mask, edge_scores=edge_scores, @@ -764,6 +798,11 @@ def __init__(self, query_units, else: raise NotImplementedError('method="{}" is currently not supported!'.format(method)) + @property + def layout(self) -> str: + """Layout of the cell""" + return self._layout + def hybrid_forward(self, F, rel_positions, query=None): """ diff --git a/src/gluonnlp/data/tokenizers.py b/src/gluonnlp/data/tokenizers.py index a7aa40ee7b..d9579b2d55 100644 --- a/src/gluonnlp/data/tokenizers.py +++ b/src/gluonnlp/data/tokenizers.py @@ -26,21 +26,20 @@ import json from collections import OrderedDict import abc -import sys import warnings import itertools from typing import NewType import sacremoses -import jieba from uuid import uuid4 from .vocab import Vocab from ..registry import TOKENIZER_REGISTRY -from ..utils.lazy_imports import try_import_subword_nmt, \ - try_import_sentencepiece, \ - try_import_huggingface_tokenizers, \ - try_import_yttm, \ - try_import_spacy, \ - try_import_jieba +from ..utils.lazy_imports import try_import_subword_nmt,\ + try_import_sentencepiece,\ + try_import_huggingface_tokenizers,\ + try_import_yttm,\ + try_import_spacy,\ + try_import_jieba + SentencesType = NewType('SentencesType', Union[str, List[str]]) TokensType = NewType('TokensType', Union[List[str], List[List[str]]]) @@ -553,10 +552,10 @@ class JiebaTokenizer(BaseTokenizerWithVocab): """ - def __init__(self, ditionary=None, vocab: Optional[Vocab] = None): + def __init__(self, dictionary=None, vocab: Optional[Vocab] = None): self._vocab = vocab jieba = try_import_jieba() - self._tokenizer = jieba.Tokenizer(ditionary) + self._tokenizer = jieba.Tokenizer(dictionary) self._tokenizer.initialize(self._tokenizer.dictionary) def encode(self, sentences, output_type=str): @@ -626,6 +625,7 @@ def __getstate__(self): return d def __setstate__(self, state): + jieba = try_import_jieba() self._tokenizer = jieba.Tokenizer() for k, v in state.items(): setattr(self._tokenizer, k, v) diff --git a/src/gluonnlp/layers.py b/src/gluonnlp/layers.py index f19553fd5e..a6ea6b181e 100644 --- a/src/gluonnlp/layers.py +++ b/src/gluonnlp/layers.py @@ -356,9 +356,10 @@ def __init__(self, mode='erf'): def hybrid_forward(self, F, x): if self._mode == 'erf': - return x * 0.5 * (1.0 + F.npx.erf(x / math.sqrt(2.0))) + return F.npx.leaky_relu(x, act_type='gelu') elif self._mode == 'tanh': - return 0.5 * x * (1.0 + F.np.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * (x ** 3)))) + return 0.5 * x\ + * (1.0 + F.np.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * (x ** 3)))) elif self._mode == 'sigmoid': return x * F.npx.sigmoid(1.702 * x) else: diff --git a/src/gluonnlp/models/albert.py b/src/gluonnlp/models/albert.py index 1eb504c643..1b4efa16e2 100644 --- a/src/gluonnlp/models/albert.py +++ b/src/gluonnlp/models/albert.py @@ -25,7 +25,8 @@ """ __all__ = ['AlbertModel', 'AlbertForMLM', 'AlbertForPretrain', - 'list_pretrained_albert', 'get_pretrained_albert'] + 'list_pretrained_albert', 'get_pretrained_albert', + 'albert_cfg_reg'] import os from typing import Tuple @@ -38,16 +39,89 @@ from ..base import get_model_zoo_home_dir, get_repo_model_zoo_url, get_model_zoo_checksum_dir from ..utils.config import CfgNode as CN from ..utils.misc import load_checksum_stats, download +from ..utils.registry import Registry from ..initializer import TruncNorm from ..attention_cell import gen_self_attn_mask from ..layers import get_activation, PositionalEmbedding from ..op import select_vectors_by_position from ..data.tokenizers import SentencepieceTokenizer +albert_cfg_reg = Registry('albert_cfg') + + +@albert_cfg_reg.register() +def google_albert_base(): + cfg = CN() + # Model Parameters + cfg.MODEL = CN() + cfg.MODEL.vocab_size = 30000 + cfg.MODEL.embed_size = 128 + cfg.MODEL.units = 768 + cfg.MODEL.hidden_size = 3072 + cfg.MODEL.max_length = 512 + cfg.MODEL.num_heads = 12 + cfg.MODEL.num_layers = 12 + cfg.MODEL.pos_embed_type = 'learned' + cfg.MODEL.activation = 'gelu(tanh)' + cfg.MODEL.layer_norm_eps = 1E-12 + cfg.MODEL.num_groups = 1 + cfg.MODEL.num_token_types = 2 + cfg.MODEL.hidden_dropout_prob = 0.0 + cfg.MODEL.attention_dropout_prob = 0.0 + cfg.MODEL.dtype = 'float32' + cfg.MODEL.layout = 'NT' + cfg.MODEL.compute_layout = 'auto' + # Hyper-parameters of the Initializers + cfg.INITIALIZER = CN() + cfg.INITIALIZER.embed = ['truncnorm', 0, 0.02] + cfg.INITIALIZER.weight = ['truncnorm', 0, 0.02] # TruncNorm(0, 0.02) + cfg.INITIALIZER.bias = ['zeros'] + # Version of the model. This helps ensure backward compatibility. + # Also, we can not use string here due to https://github.com/rbgirshick/yacs/issues/26 + cfg.VERSION = 1 + cfg.freeze() + return cfg + + +@albert_cfg_reg.register() +def google_albert_large(): + cfg = google_albert_base() + cfg.defrost() + cfg.MODEL.hidden_size = 4096 + cfg.MODEL.num_heads = 16 + cfg.MODEL.num_layers = 24 + cfg.MODEL.units = 1024 + cfg.freeze() + return cfg + + +@albert_cfg_reg.register() +def google_albert_xlarge(): + cfg = google_albert_base() + cfg.defrost() + cfg.MODEL.hidden_size = 8192 + cfg.MODEL.num_heads = 32 + cfg.MODEL.num_layers = 24 + cfg.MODEL.units = 2048 + cfg.freeze() + return cfg + + +@albert_cfg_reg.register() +def google_albert_xxlarge(): + cfg = google_albert_base() + cfg.defrost() + cfg.MODEL.hidden_size = 16384 + cfg.MODEL.num_heads = 64 + cfg.MODEL.num_layers = 12 + cfg.MODEL.units = 4096 + cfg.freeze() + return cfg + PRETRAINED_URL = { 'google_albert_base_v2': { - 'cfg': 'google_albert_base_v2/model-8767fdc9.yml', + 'cfg': google_albert_base(), 'spm_model': 'google_albert_base_v2/spm-65999e5d.model', 'vocab': 'google_albert_base_v2/vocab-2ee53ae7.json', 'params': 'google_albert_base_v2/model-125be477.params', @@ -55,7 +129,7 @@ 'lowercase': True, }, 'google_albert_large_v2': { - 'cfg': 'google_albert_large_v2/model-e2e9b974.yml', + 'cfg': google_albert_large(), 'spm_model': 'google_albert_large_v2/spm-65999e5d.model', 'vocab': 'google_albert_large_v2/vocab-2ee53ae7.json', 'params': 'google_albert_large_v2/model-ad60bcd5.params', @@ -63,7 +137,7 @@ 'lowercase': True, }, 'google_albert_xlarge_v2': { - 'cfg': 'google_albert_xlarge_v2/model-8123bffd.yml', + 'cfg': google_albert_xlarge(), 'spm_model': 'google_albert_xlarge_v2/spm-65999e5d.model', 'vocab': 'google_albert_xlarge_v2/vocab-2ee53ae7.json', 'params': 'google_albert_xlarge_v2/model-4149c9e2.params', @@ -71,7 +145,7 @@ 'lowercase': True, }, 'google_albert_xxlarge_v2': { - 'cfg': 'google_albert_xxlarge_v2/model-07fbeebc.yml', + 'cfg': google_albert_xxlarge(), 'spm_model': 'google_albert_xxlarge_v2/spm-65999e5d.model', 'vocab': 'google_albert_xxlarge_v2/vocab-2ee53ae7.json', 'params': 'google_albert_xxlarge_v2/model-5601a0ed.params', @@ -97,7 +171,8 @@ def __init__(self, units=512, hidden_size=2048, layer_norm_eps=1E-12, weight_initializer=TruncNorm(stdev=0.02), bias_initializer='zeros', - activation='gelu'): + activation='gelu', + layout='NT'): super().__init__() assert units % num_heads == 0,\ 'In AlbertEncoder, The units should be divided exactly ' \ @@ -112,6 +187,8 @@ def __init__(self, units=512, hidden_size=2048, self._output_attention = output_attention self._output_all_encodings = output_all_encodings + self._layout = layout + self.all_encoder_groups = nn.HybridSequential() for group_idx in range(num_groups): @@ -124,7 +201,13 @@ def __init__(self, units=512, hidden_size=2048, layer_norm_eps=layer_norm_eps, weight_initializer=weight_initializer, bias_initializer=bias_initializer, - activation=activation)) + activation=activation, + dtype=dtype, + layout=layout)) + + @property + def layout(self): + return self._layout def hybrid_forward(self, F, data, valid_length): """ @@ -135,18 +218,26 @@ def hybrid_forward(self, F, data, valid_length): Parameters ---------- F - data : - Shape (batch_size, seq_length, C) + data + - layout = 'NT' + Shape (batch_size, seq_length, C) + - layout = 'TN' + Shape (seq_length, batch_size, C) valid_length : Shape (batch_size,) Returns ------- - out : - Shape (batch_size, seq_length, C_out) + out + - layout = 'NT' + Shape (batch_size, seq_length, C_out) + - layout = 'TN' + Shape (seq_length, batch_size, C) """ # 1. Embed the data - attn_mask = gen_self_attn_mask(F, data, valid_length, dtype=self._dtype, attn_type='full') + time_axis = 1 if self.layout == 'NT' else 0 + attn_mask = gen_self_attn_mask(F, data, valid_length, dtype=self._dtype, + attn_type='full', layout=self.layout) out = data all_encodings_outputs = [] additional_outputs = [] @@ -159,7 +250,8 @@ def hybrid_forward(self, F, data, valid_length): if self._output_all_encodings: out = F.npx.sequence_mask(out, sequence_length=valid_length, - use_sequence_length=True, axis=1) + use_sequence_length=True, + axis=time_axis) all_encodings_outputs.append(out) if self._output_attention: @@ -168,7 +260,8 @@ def hybrid_forward(self, F, data, valid_length): if not self._output_all_encodings: # if self._output_all_encodings, SequenceMask is already applied above out = F.npx.sequence_mask(out, sequence_length=valid_length, - use_sequence_length=True, axis=1) + use_sequence_length=True, + axis=time_axis) return out, additional_outputs else: return all_encodings_outputs, additional_outputs @@ -195,7 +288,9 @@ def __init__(self, weight_initializer=TruncNorm(stdev=0.02), bias_initializer='zeros', dtype='float32', - use_pooler=True): + use_pooler=True, + layout='NT', + compute_layout='auto'): super().__init__() self._dtype = dtype self.use_pooler = use_pooler @@ -210,6 +305,11 @@ def __init__(self, self.weight_initializer = weight_initializer self.bias_initializer = bias_initializer self.layer_norm_eps = layer_norm_eps + self._layout = layout + if compute_layout is None or compute_layout == 'auto': + self._compute_layout = layout + else: + self._compute_layout = compute_layout # Construct AlbertEncoder self.encoder = AlbertEncoder( units=units, @@ -226,6 +326,7 @@ def __init__(self, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=dtype, + layout=self._compute_layout ) self.encoder.hybridize() # Construct word embedding @@ -257,6 +358,10 @@ def __init__(self, weight_initializer=weight_initializer, bias_initializer=bias_initializer) + @property + def layout(self): + return self._layout + def hybrid_forward(self, F, inputs, token_types, valid_length=None): # pylint: disable=arguments-differ """Generate the representation given the inputs. @@ -266,10 +371,16 @@ def hybrid_forward(self, F, inputs, token_types, valid_length=None): Parameters ---------- F - inputs : - Shape (batch_size, seq_length) - token_types : - Shape (batch_size, seq_length) + inputs + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) + token_types + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. @@ -279,8 +390,11 @@ def hybrid_forward(self, F, inputs, token_types, valid_length=None): Returns ------- - contextual_embedding : - Shape (batch_size, seq_length, units). + contextual_embedding + - layout = 'NT' + Shape (batch_size, seq_length, units) + - layout = 'TN' + Shape (seq_length, batch_size, units) pooled_output : This is optional. Shape (batch_size, units) """ @@ -290,7 +404,13 @@ def hybrid_forward(self, F, inputs, token_types, valid_length=None): if self.embed_size != self.units: prev_out = self.embed_factorized_proj(prev_out) outputs = [] - contextual_embeddings, additional_outputs = self.encoder(prev_out, valid_length) + if self._compute_layout != self._layout: + # Swap input to reflect the compute_layout + contextual_embeddings, additional_outputs = self.encoder(F.np.swapaxes(prev_out, 0, 1), + valid_length) + contextual_embeddings = F.np.swapaxes(contextual_embeddings, 0, 1) + else: + contextual_embeddings, additional_outputs = self.encoder(prev_out, valid_length) outputs.append(contextual_embeddings) if self.use_pooler: pooled_out = self.apply_pooling(contextual_embeddings) @@ -304,24 +424,37 @@ def get_initial_embedding(self, F, inputs, token_types=None): ---------- F inputs - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) token_types - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' If None, it will be initialized as all zero Returns ------- embedding The initial embedding that will be fed into the encoder + - layout = 'NT' + Shape (batch_size, seq_length, C_embed) + - layout = 'TN' + Shape (seq_length, batch_size, C_embed) """ + if self.layout == 'NT': + batch_axis, time_axis = 0, 1 + else: + batch_axis, time_axis = 1, 0 embedding = self.word_embed(inputs) if token_types is None: token_types = F.np.zeros_like(inputs) type_embedding = self.token_type_embed(token_types) embedding = embedding + type_embedding if self.pos_embed_type is not None: - positional_embedding = self.token_pos_embed(F.npx.arange_like(inputs, axis=1)) - positional_embedding = F.np.expand_dims(positional_embedding, axis=0) + positional_embedding = self.token_pos_embed(F.npx.arange_like(inputs, axis=time_axis)) + positional_embedding = F.np.expand_dims(positional_embedding, axis=batch_axis) embedding = embedding + positional_embedding # Extra layer normalization plus dropout embedding = self.embed_layer_norm(embedding) @@ -334,50 +467,34 @@ def apply_pooling(self, sequence): This is used for pre-training or fine-tuning a Bert model. Get the first token of the whole sequence which is [CLS] - sequence: - Shape (batch_size, sequence_length, units) - return: + Parameters + ---------- + sequence + - layout = 'NT' + Shape (batch_size, sequence_length, units) + - layout = 'TN' + Shape (sequence_length, batch_size, units) + + Returns + ------- + pooled_out Shape (batch_size, units) """ - outputs = sequence[:, 0, :] + if self.layout == 'NT': + outputs = sequence[:, 0, :] + else: + outputs = sequence[0, :, :] return self.pooler(outputs) @staticmethod def get_cfg(key=None): - if key is None: - cfg = CN() - # Model Parameters - cfg.MODEL = CN() - cfg.MODEL.vocab_size = 30000 - cfg.MODEL.embed_size = 128 - cfg.MODEL.units = 768 - cfg.MODEL.hidden_size = 3072 - cfg.MODEL.max_length = 512 - cfg.MODEL.num_heads = 12 - cfg.MODEL.num_layers = 12 - cfg.MODEL.pos_embed_type = 'learned' - cfg.MODEL.activation = 'gelu' - cfg.MODEL.layer_norm_eps = 1E-12 - cfg.MODEL.num_groups = 1 - cfg.MODEL.num_token_types = 2 - cfg.MODEL.hidden_dropout_prob = 0.0 - cfg.MODEL.attention_dropout_prob = 0.0 - cfg.MODEL.dtype = 'float32' - # Hyper-parameters of the Initializers - cfg.INITIALIZER = CN() - cfg.INITIALIZER.embed = ['truncnorm', 0, 0.02] - cfg.INITIALIZER.weight = ['truncnorm', 0, 0.02] # TruncNorm(0, 0.02) - cfg.INITIALIZER.bias = ['zeros'] - # Version of the model. This helps ensure backward compatibility. - # Also, we can not use string here due to https://github.com/rbgirshick/yacs/issues/26 - cfg.VERSION = 1 + if key is not None: + return albert_cfg_reg.create(key) else: - raise NotImplementedError - cfg.freeze() - return cfg + return google_albert_base() @classmethod - def from_cfg(cls, cfg, use_pooler=True, dtype='float32') -> 'AlbertModel': + def from_cfg(cls, cfg, use_pooler=True, dtype=None) -> 'AlbertModel': """ Parameters @@ -385,6 +502,8 @@ def from_cfg(cls, cfg, use_pooler=True, dtype='float32') -> 'AlbertModel': cfg use_pooler Whether to use pooler + dtype + The dtype of the backbone model Returns ------- @@ -396,6 +515,8 @@ def from_cfg(cls, cfg, use_pooler=True, dtype='float32') -> 'AlbertModel': embed_initializer = mx.init.create(*cfg.INITIALIZER.embed) weight_initializer = mx.init.create(*cfg.INITIALIZER.weight) bias_initializer = mx.init.create(*cfg.INITIALIZER.bias) + if dtype is None: + dtype = cfg.MODEL.dtype return cls(vocab_size=cfg.MODEL.vocab_size, units=cfg.MODEL.units, hidden_size=cfg.MODEL.hidden_size, @@ -411,6 +532,7 @@ def from_cfg(cls, cfg, use_pooler=True, dtype='float32') -> 'AlbertModel': activation=cfg.MODEL.activation, layer_norm_eps=cfg.MODEL.layer_norm_eps, dtype=dtype, + layout=cfg.MODEL.layout, embed_initializer=embed_initializer, weight_initializer=weight_initializer, bias_initializer=bias_initializer, @@ -453,6 +575,10 @@ def __init__(self, backbone_cfg, self.mlm_decoder[-1].weight = self.backbone_model.word_embed.weight self.mlm_decoder.hybridize() + @property + def layout(self): + return self.backbone_model.layout + def hybrid_forward(self, F, inputs, token_types, valid_length, masked_positions): """Getting the scores of the masked positions. @@ -460,10 +586,16 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, Parameters ---------- F - inputs : - Shape (batch_size, seq_length) - token_types : - Shape (batch_size, seq_length) + inputs + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) + token_types + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) The type of the token. For example, if the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. valid_length : @@ -476,14 +608,21 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, Returns ------- contextual_embedding - Shape (batch_size, seq_length, units). + - layout = 'NT' + Shape (batch_size, seq_length, units) + - layout = 'TN' + Shape (seq_length, batch_size, units) pooled_out Shape (batch_size, units) mlm_scores : Shape (batch_size, num_masked_positions, vocab_size) """ contextual_embeddings, pooled_out = self.backbone_model(inputs, token_types, valid_length) - mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) + if self.layout == 'NT': + mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) + else: + mlm_features = select_vectors_by_position(F, F.np.swapaxes(contextual_embeddings, 0, 1), + masked_positions) mlm_scores = self.mlm_decoder(mlm_features) return contextual_embeddings, pooled_out, mlm_scores @@ -528,6 +667,10 @@ def __init__(self, backbone_cfg, self.mlm_decoder[-1].weight = self.backbone_model.word_embed.weight self.mlm_decoder.hybridize() + @property + def layout(self): + return self.backbone_model.layout + def hybrid_forward(self, F, inputs, token_types, valid_length, masked_positions): """Generate the representation given the inputs. @@ -537,10 +680,16 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, Parameters ---------- F - inputs : - Shape (batch_size, seq_length) + inputs + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) token_types : - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. @@ -554,7 +703,10 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, Returns ------- contextual_embedding - Shape (batch_size, seq_length, units). + - layout = 'NT' + Shape (batch_size, seq_length, units). + - layout = 'TN' + Shape (seq_length, batch_size, units). pooled_out Shape (batch_size, units) sop_score : @@ -564,7 +716,11 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, """ contextual_embeddings, pooled_out = self.backbone_model(inputs, token_types, valid_length) sop_score = self.sop_classifier(pooled_out) - mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) + if self.layout == 'NT': + mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) + else: + mlm_features = select_vectors_by_position(F, F.np.swapaxes(contextual_embeddings, 0, 1), + masked_positions) mlm_scores = self.mlm_decoder(mlm_features) return contextual_embeddings, pooled_out, sop_score, mlm_scores @@ -604,15 +760,22 @@ def get_pretrained_albert(model_name: str = 'google_albert_base_v2', assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format( model_name, list_pretrained_albert()) cfg_path = PRETRAINED_URL[model_name]['cfg'] + if isinstance(cfg_path, CN): + cfg = cfg_path + else: + cfg = None spm_model_path = PRETRAINED_URL[model_name]['spm_model'] vocab_path = PRETRAINED_URL[model_name]['vocab'] params_path = PRETRAINED_URL[model_name]['params'] mlm_params_path = PRETRAINED_URL[model_name]['mlm_params'] local_paths = dict() - for k, path in [('cfg', cfg_path), ('spm_model', spm_model_path), ('vocab', vocab_path)]: - local_paths[k] = download(url=get_repo_model_zoo_url() + path, - path=os.path.join(root, path), - sha1_hash=FILE_STATS[path]) + download_jobs = [('spm_model', spm_model_path), ('vocab', vocab_path)] + if cfg is None: + download_jobs.append(('cfg', cfg_path)) + for key, path in download_jobs: + local_paths[key] = download(url=get_repo_model_zoo_url() + path, + path=os.path.join(root, path), + sha1_hash=FILE_STATS[path]) if load_backbone: local_params_path = download(url=get_repo_model_zoo_url() + params_path, path=os.path.join(root, params_path), @@ -630,7 +793,8 @@ def get_pretrained_albert(model_name: str = 'google_albert_base_v2', tokenizer = SentencepieceTokenizer(local_paths['spm_model'], vocab=local_paths['vocab'], lowercase=do_lower) - cfg = AlbertModel.get_cfg().clone_merge(local_paths['cfg']) + if cfg is None: + cfg = AlbertModel.get_cfg().clone_merge(local_paths['cfg']) return cfg, tokenizer, local_params_path, local_mlm_params_path diff --git a/src/gluonnlp/models/bert.py b/src/gluonnlp/models/bert.py index fd53ae3b5c..84a1d5ee2e 100644 --- a/src/gluonnlp/models/bert.py +++ b/src/gluonnlp/models/bert.py @@ -39,16 +39,108 @@ from ..base import get_model_zoo_home_dir, get_repo_model_zoo_url, get_model_zoo_checksum_dir from ..utils.config import CfgNode as CN from ..utils.misc import load_checksum_stats, download +from ..utils.registry import Registry from ..initializer import TruncNorm from ..attention_cell import MultiHeadAttentionCell, gen_self_attn_mask from ..layers import get_activation, PositionalEmbedding, PositionwiseFFN, InitializerType from ..op import select_vectors_by_position from ..data.tokenizers import HuggingFaceWordPieceTokenizer +bert_cfg_reg = Registry('bert_cfg') + + +@bert_cfg_reg.register() +def google_en_uncased_bert_base(): + cfg = CN() + # Parameters for thr small model + cfg.MODEL = CN() + cfg.MODEL.vocab_size = 30522 + cfg.MODEL.units = 768 + cfg.MODEL.hidden_size = 3072 + cfg.MODEL.max_length = 512 + cfg.MODEL.num_heads = 12 + cfg.MODEL.num_layers = 12 + cfg.MODEL.pos_embed_type = 'learned' + cfg.MODEL.activation = 'gelu' + cfg.MODEL.layer_norm_eps = 1E-12 + cfg.MODEL.num_token_types = 2 + cfg.MODEL.hidden_dropout_prob = 0.1 + cfg.MODEL.attention_dropout_prob = 0.1 + cfg.MODEL.dtype = 'float32' + cfg.MODEL.layout = 'NT' + cfg.MODEL.compute_layout = 'auto' + # Hyper-parameters of the Initializers + cfg.INITIALIZER = CN() + cfg.INITIALIZER.embed = ['truncnorm', 0, 0.02] + cfg.INITIALIZER.weight = ['truncnorm', 0, 0.02] # TruncNorm(0, 0.02) + cfg.INITIALIZER.bias = ['zeros'] + # Version of the model. This helps ensure backward compatibility. + # Also, we can not use string here due to https://github.com/rbgirshick/yacs/issues/26 + cfg.VERSION = 1 + cfg.freeze() + return cfg + + +@bert_cfg_reg.register() +def google_en_uncased_bert_large(): + cfg = google_en_uncased_bert_base() + cfg.defrost() + cfg.MODEL.hidden_size = 4096 + cfg.MODEL.num_heads = 16 + cfg.MODEL.num_layers = 24 + cfg.MODEL.units = 1024 + cfg.freeze() + return cfg + + +@bert_cfg_reg.register() +def google_en_cased_bert_base(): + cfg = google_en_uncased_bert_base() + cfg.defrost() + cfg.MODEL.vocab_size = 28996 + cfg.freeze() + return cfg + + +@bert_cfg_reg.register() +def google_en_cased_bert_large(): + cfg = google_en_uncased_bert_large() + cfg.defrost() + cfg.MODEL.vocab_size = 28996 + cfg.freeze() + return cfg + + +@bert_cfg_reg.register() +def google_zh_bert_base(): + cfg = google_en_uncased_bert_base() + cfg.defrost() + cfg.MODEL.vocab_size = 21128 + cfg.freeze() + return cfg + + +@bert_cfg_reg.register() +def google_multi_cased_bert_base(): + cfg = google_en_uncased_bert_base() + cfg.defrost() + cfg.MODEL.vocab_size = 119547 + cfg.freeze() + return cfg + + +@bert_cfg_reg.register() +def google_multi_cased_bert_large(): + cfg = google_en_uncased_bert_large() + cfg.defrost() + cfg.MODEL.vocab_size = 119547 + cfg.freeze() + return cfg + PRETRAINED_URL = { 'google_en_cased_bert_base': { - 'cfg': 'google_en_cased_bert_base/model-5620839a.yml', + 'cfg': google_en_cased_bert_base(), 'vocab': 'google_en_cased_bert_base/vocab-c1defaaa.json', 'params': 'google_en_cased_bert_base/model-c566c289.params', 'mlm_params': 'google_en_cased_bert_base/model_mlm-bde14bee.params', @@ -56,49 +148,49 @@ }, 'google_en_uncased_bert_base': { - 'cfg': 'google_en_uncased_bert_base/model-4d8422ad.yml', + 'cfg': google_en_uncased_bert_base(), 'vocab': 'google_en_uncased_bert_base/vocab-e6d2b21d.json', 'params': 'google_en_uncased_bert_base/model-3712e50a.params', 'mlm_params': 'google_en_uncased_bert_base/model_mlm-04e88b58.params', 'lowercase': True, }, 'google_en_cased_bert_large': { - 'cfg': 'google_en_cased_bert_large/model-9e127fee.yml', + 'cfg': google_en_cased_bert_large(), 'vocab': 'google_en_cased_bert_large/vocab-c1defaaa.json', 'params': 'google_en_cased_bert_large/model-7aa93704.params', 'mlm_params': 'google_en_cased_bert_large/model_mlm-59ff3f6a.params', 'lowercase': False, }, 'google_en_uncased_bert_large': { - 'cfg': 'google_en_uncased_bert_large/model-d0c37dcc.yml', + 'cfg': google_en_uncased_bert_large(), 'vocab': 'google_en_uncased_bert_large/vocab-e6d2b21d.json', 'params': 'google_en_uncased_bert_large/model-e53bbc57.params', 'mlm_params': 'google_en_uncased_bert_large/model_mlm-44bc70c0.params', 'lowercase': True, }, 'google_zh_bert_base': { - 'cfg': 'google_zh_bert_base/model-9b16bda6.yml', + 'cfg': google_zh_bert_base(), 'vocab': 'google_zh_bert_base/vocab-711c13e4.json', 'params': 'google_zh_bert_base/model-2efbff63.params', 'mlm_params': 'google_zh_bert_base/model_mlm-75339658.params', 'lowercase': False, }, 'google_multi_cased_bert_base': { - 'cfg': 'google_multi_cased_bert_base/model-881ad607.yml', + 'cfg': google_multi_cased_bert_base(), 'vocab': 'google_multi_cased_bert_base/vocab-016e1169.json', 'params': 'google_multi_cased_bert_base/model-c2110078.params', 'mlm_params': 'google_multi_cased_bert_base/model_mlm-4611e7a3.params', 'lowercase': False, }, 'google_en_cased_bert_wwm_large': { - 'cfg': 'google_en_cased_bert_wwm_large/model-9e127fee.yml', + 'cfg': google_en_cased_bert_large(), 'vocab': 'google_en_cased_bert_wwm_large/vocab-c1defaaa.json', 'params': 'google_en_cased_bert_wwm_large/model-0fe841cf.params', 'mlm_params': None, 'lowercase': False, }, 'google_en_uncased_bert_wwm_large': { - 'cfg': 'google_en_uncased_bert_wwm_large/model-d0c37dcc.yml', + 'cfg': google_en_uncased_bert_large(), 'vocab': 'google_en_uncased_bert_wwm_large/vocab-e6d2b21d.json', 'params': 'google_en_uncased_bert_wwm_large/model-cb3ad3c2.params', 'mlm_params': None, @@ -124,7 +216,8 @@ def __init__(self, units: int = 512, layer_norm_eps: float = 1E-12, weight_initializer: InitializerType = TruncNorm(stdev=0.02), bias_initializer: InitializerType = 'zeros', - activation='gelu'): + activation='gelu', + layout='NT'): super().__init__() assert units % num_heads == 0,\ 'In BertTransformer, The units should be divided exactly ' \ @@ -135,6 +228,7 @@ def __init__(self, units: int = 512, self._num_layers = num_layers self._output_attention = output_attention self._output_all_encodings = output_all_encodings + self._layout = layout self.all_layers = nn.HybridSequential() for layer_idx in range(num_layers): @@ -147,7 +241,13 @@ def __init__(self, units: int = 512, layer_norm_eps=layer_norm_eps, weight_initializer=weight_initializer, bias_initializer=bias_initializer, - activation=activation)) + activation=activation, + layout=layout, + dtype=dtype)) + + @property + def layout(self): + return self._layout def hybrid_forward(self, F, data, valid_length): """ @@ -158,30 +258,41 @@ def hybrid_forward(self, F, data, valid_length): Parameters ---------- F - data : - Shape (batch_size, seq_length, C) - valid_length : + data + - layout = 'NT' + Shape (batch_size, seq_length, C) + - layout = 'TN' + Shape (seq_length, batch_size, C) + valid_length Shape (batch_size,) Returns ------- - out : - Shape (batch_size, seq_length, C_out) + out + - layout = 'NT' + Shape (batch_size, seq_length, C_out) + - layout = 'TN' + Shape (seq_length, batch_size, C_out) """ + if self.layout == 'NT': + time_axis, batch_axis = 1, 0 + else: + time_axis, batch_axis = 0, 1 # 1. Embed the data - attn_mask = gen_self_attn_mask(F, data, valid_length, dtype=self._dtype, attn_type='full') + attn_mask = gen_self_attn_mask(F, data, valid_length, dtype=self._dtype, + attn_type='full', layout=self.layout) out = data all_encodings_outputs = [] additional_outputs = [] for layer_idx in range(self._num_layers): layer = self.all_layers[layer_idx] out, attention_weights = layer(out, attn_mask) - # out : [batch_size, seq_len, units] + # out : [batch_size, seq_len, units] or [seq_len, batch_size, units] # attention_weights : [batch_size, num_heads, seq_len, seq_len] if self._output_all_encodings: out = F.npx.sequence_mask(out, sequence_length=valid_length, - use_sequence_length=True, axis=1) + use_sequence_length=True, axis=time_axis) all_encodings_outputs.append(out) if self._output_attention: @@ -190,7 +301,7 @@ def hybrid_forward(self, F, data, valid_length): if not self._output_all_encodings: # if self._output_all_encodings, SequenceMask is already applied above out = F.npx.sequence_mask(out, sequence_length=valid_length, - use_sequence_length=True, axis=1) + use_sequence_length=True, axis=time_axis) return out, additional_outputs else: return all_encodings_outputs, additional_outputs @@ -215,7 +326,9 @@ def __init__(self, weight_initializer=TruncNorm(stdev=0.02), bias_initializer='zeros', dtype='float32', - use_pooler=True): + use_pooler=True, + layout='NT', + compute_layout='auto'): super().__init__() self._dtype = dtype self.use_pooler = use_pooler @@ -229,6 +342,11 @@ def __init__(self, self.weight_initializer = weight_initializer self.bias_initializer = bias_initializer self.layer_norm_eps = layer_norm_eps + self._layout = layout + if compute_layout is None or compute_layout == 'auto': + self._compute_layout = layout + else: + self._compute_layout = compute_layout # Construct BertTransformer self.encoder = BertTransformer( units=units, @@ -244,6 +362,7 @@ def __init__(self, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=dtype, + layout=self._compute_layout ) self.encoder.hybridize() # Construct word embedding @@ -270,6 +389,10 @@ def __init__(self, weight_initializer=weight_initializer, bias_initializer=bias_initializer) + @property + def layout(self): + return self._layout + def hybrid_forward(self, F, inputs, token_types, valid_length): # pylint: disable=arguments-differ """Generate the representation given the inputs. @@ -279,10 +402,16 @@ def hybrid_forward(self, F, inputs, token_types, valid_length): Parameters ---------- F - inputs : - Shape (batch_size, seq_length) - token_types : - Shape (batch_size, seq_length) + inputs + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) + token_types + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (batch_size, seq_length) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. @@ -292,16 +421,24 @@ def hybrid_forward(self, F, inputs, token_types, valid_length): Returns ------- - contextual_embedding : - Shape (batch_size, seq_length, units). + contextual_embedding + - layout = 'NT' + Shape (batch_size, seq_length, units). + - layout = 'TN' + Shape (seq_length, batch_size, units). pooled_output : This is optional. Shape (batch_size, units) """ initial_embedding = self.get_initial_embedding(F, inputs, token_types) prev_out = initial_embedding outputs = [] - - contextual_embeddings, additional_outputs = self.encoder(prev_out, valid_length) + if self._compute_layout != self._layout: + # Swap the axes if the compute_layout and layout mismatch + contextual_embeddings, additional_outputs = self.encoder(F.np.swapaxes(prev_out, 0, 1), + valid_length) + contextual_embeddings = F.np.swapaxes(contextual_embeddings, 0, 1) + else: + contextual_embeddings, additional_outputs = self.encoder(prev_out, valid_length) outputs.append(contextual_embeddings) if self.use_pooler: pooled_out = self.apply_pooling(contextual_embeddings) @@ -315,24 +452,38 @@ def get_initial_embedding(self, F, inputs, token_types=None): ---------- F inputs - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) token_types - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) If None, it will be initialized as all zero Returns ------- embedding The initial embedding that will be fed into the encoder + - layout = 'NT' + Shape (batch_size, seq_length, C_emb) + - layout = 'TN' + Shape (seq_length, batch_size, C_emb) """ + if self.layout == 'NT': + time_axis, batch_axis = 1, 0 + else: + time_axis, batch_axis = 0, 1 embedding = self.word_embed(inputs) if token_types is None: token_types = F.np.zeros_like(inputs) type_embedding = self.token_type_embed(token_types) embedding = embedding + type_embedding if self.pos_embed_type is not None: - positional_embedding = self.token_pos_embed(F.npx.arange_like(inputs, axis=1)) - positional_embedding = F.np.expand_dims(positional_embedding, axis=0) + positional_embedding = self.token_pos_embed(F.npx.arange_like(inputs, axis=time_axis)) + positional_embedding = F.np.expand_dims(positional_embedding, axis=batch_axis) embedding = embedding + positional_embedding # Extra layer normalization plus dropout embedding = self.embed_layer_norm(embedding) @@ -345,53 +496,52 @@ def apply_pooling(self, sequence): This is used for pre-training or fine-tuning a bert model. Get the first token of the whole sequence which is [CLS] - sequence: - Shape (batch_size, sequence_length, units) + sequence + - layout = 'NT' + Shape (batch_size, sequence_length, units) + - layout = 'TN' + Shape (sequence_length, batch_size, units) return: Shape (batch_size, units) """ - outputs = sequence[:, 0, :] + if self.layout == 'NT': + outputs = sequence[:, 0, :] + else: + outputs = sequence[0, :, :] return self.pooler(outputs) @staticmethod def get_cfg(key=None): - if key is None: - cfg = CN() - # Parameters for thr small model - cfg.MODEL = CN() - cfg.MODEL.vocab_size = 30000 - cfg.MODEL.units = 256 - cfg.MODEL.hidden_size = 1024 - cfg.MODEL.max_length = 512 - cfg.MODEL.num_heads = 4 - cfg.MODEL.num_layers = 12 - cfg.MODEL.pos_embed_type = 'learned' - cfg.MODEL.activation = 'gelu' - cfg.MODEL.layer_norm_eps = 1E-12 - cfg.MODEL.num_token_types = 2 - cfg.MODEL.hidden_dropout_prob = 0.1 - cfg.MODEL.attention_dropout_prob = 0.1 - cfg.MODEL.dtype = 'float32' - # Hyper-parameters of the Initializers - cfg.INITIALIZER = CN() - cfg.INITIALIZER.embed = ['truncnorm', 0, 0.02] - cfg.INITIALIZER.weight = ['truncnorm', 0, 0.02] # TruncNorm(0, 0.02) - cfg.INITIALIZER.bias = ['zeros'] - # Version of the model. This helps ensure backward compatibility. - # Also, we can not use string here due to https://github.com/rbgirshick/yacs/issues/26 - cfg.VERSION = 1 + if key is not None: + return bert_cfg_reg.create(key) else: - raise NotImplementedError - cfg.freeze() - return cfg + return google_en_uncased_bert_base() @classmethod - def from_cfg(cls, cfg, use_pooler=True, dtype='float32') -> 'BertModel': + def from_cfg(cls, cfg, use_pooler=True, dtype=None) -> 'BertModel': + """ + + Parameters + ---------- + cfg + Configuration + use_pooler + Whether to output the pooled feature + dtype + data type of the model + + Returns + ------- + ret + The constructed BertModel + """ cfg = BertModel.get_cfg().clone_merge(cfg) assert cfg.VERSION == 1, 'Wrong version!' embed_initializer = mx.init.create(*cfg.INITIALIZER.embed) weight_initializer = mx.init.create(*cfg.INITIALIZER.weight) bias_initializer = mx.init.create(*cfg.INITIALIZER.bias) + if dtype is None: + dtype = cfg.MODEL.dtype return cls(vocab_size=cfg.MODEL.vocab_size, units=cfg.MODEL.units, hidden_size=cfg.MODEL.hidden_size, @@ -408,7 +558,9 @@ def from_cfg(cls, cfg, use_pooler=True, dtype='float32') -> 'BertModel': embed_initializer=embed_initializer, weight_initializer=weight_initializer, bias_initializer=bias_initializer, - use_pooler=use_pooler) + use_pooler=use_pooler, + layout=cfg.MODEL.layout, + compute_layout=cfg.MODEL.compute_layout) @use_np @@ -447,6 +599,10 @@ def __init__(self, backbone_cfg, self.mlm_decoder[-1].weight = self.backbone_model.word_embed.weight self.mlm_decoder.hybridize() + @property + def layout(self): + return self.backbone_model.layout + def hybrid_forward(self, F, inputs, token_types, valid_length, masked_positions): """Getting the scores of the masked positions. @@ -454,10 +610,16 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, Parameters ---------- F - inputs : - Shape (batch_size, seq_length) - token_types : - Shape (batch_size, seq_length) + inputs + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) + token_types + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. @@ -471,14 +633,21 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, Returns ------- contextual_embedding - Shape (batch_size, seq_length, units). - pooled_out + - layout = 'NT' + Shape (batch_size, seq_length, units). + - layout = 'TN' + Shape (seq_length, batch_size, units) + cfg.MODEL.compute_layout = 'auto' Shape (batch_size, units) mlm_scores : Shape (batch_size, num_masked_positions, vocab_size) """ contextual_embeddings, pooled_out = self.backbone_model(inputs, token_types, valid_length) - mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) + if self.layout == 'NT': + mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) + else: + mlm_features = select_vectors_by_position(F, F.np.swapaxes(contextual_embeddings, 0, 1), + masked_positions) mlm_scores = self.mlm_decoder(mlm_features) return contextual_embeddings, pooled_out, mlm_scores @@ -523,6 +692,10 @@ def __init__(self, backbone_cfg, self.mlm_decoder[-1].weight = self.backbone_model.word_embed.weight self.mlm_decoder.hybridize() + @property + def layout(self): + return self.backbone_model.layout + def hybrid_forward(self, F, inputs, token_types, valid_length, masked_positions): """Generate the representation given the inputs. @@ -532,24 +705,33 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, Parameters ---------- F - inputs : - Shape (batch_size, seq_length) - token_types : - Shape (batch_size, seq_length) + inputs + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) + token_types + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. - valid_length : + valid_length The valid length of each sequence Shape (batch_size,) - masked_positions : + masked_positions The masked position of the sequence Shape (batch_size, num_masked_positions). Returns ------- contextual_embedding - Shape (batch_size, seq_length, units). + - layout = 'NT' + Shape (batch_size, seq_length, units). + - layout = 'TN' + Shape (seq_length, batch_size, units). pooled_out Shape (batch_size, units) nsp_score : @@ -559,7 +741,11 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, """ contextual_embeddings, pooled_out = self.backbone_model(inputs, token_types, valid_length) nsp_score = self.nsp_classifier(pooled_out) - mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) + if self.layout == 'NT': + mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) + else: + mlm_features = select_vectors_by_position(F, F.np.swapaxes(contextual_embeddings, 0, 1), + masked_positions) mlm_scores = self.mlm_decoder(mlm_features) return contextual_embeddings, pooled_out, nsp_score, mlm_scores @@ -599,14 +785,21 @@ def get_pretrained_bert(model_name: str = 'google_en_cased_bert_base', assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format( model_name, list_pretrained_bert()) cfg_path = PRETRAINED_URL[model_name]['cfg'] + if isinstance(cfg_path, CN): + cfg = cfg_path + else: + cfg = None vocab_path = PRETRAINED_URL[model_name]['vocab'] params_path = PRETRAINED_URL[model_name]['params'] mlm_params_path = PRETRAINED_URL[model_name]['mlm_params'] local_paths = dict() - for k, path in [('cfg', cfg_path), ('vocab', vocab_path)]: - local_paths[k] = download(url=get_repo_model_zoo_url() + path, - path=os.path.join(root, path), - sha1_hash=FILE_STATS[path]) + download_jobs = [('vocab', vocab_path)] + if cfg is None: + download_jobs.append(('cfg', cfg_path)) + for key, path in download_jobs: + local_paths[key] = download(url=get_repo_model_zoo_url() + path, + path=os.path.join(root, path), + sha1_hash=FILE_STATS[path]) if load_backbone: local_params_path = download(url=get_repo_model_zoo_url() + params_path, path=os.path.join(root, params_path), @@ -629,7 +822,8 @@ def get_pretrained_bert(model_name: str = 'google_en_cased_bert_base', sep_token='[SEP]', mask_token='[MASK]', lowercase=do_lower) - cfg = BertModel.get_cfg().clone_merge(local_paths['cfg']) + if cfg is None: + cfg = BertModel.get_cfg().clone_merge(local_paths['cfg']) return cfg, tokenizer, local_params_path, local_mlm_params_path diff --git a/src/gluonnlp/models/electra.py b/src/gluonnlp/models/electra.py index a56d7879dc..b8d4e44029 100644 --- a/src/gluonnlp/models/electra.py +++ b/src/gluonnlp/models/electra.py @@ -43,9 +43,12 @@ from ..initializer import TruncNorm from ..utils.config import CfgNode as CN from ..utils.misc import load_checksum_stats, download +from ..utils.registry import Registry from ..attention_cell import gen_self_attn_mask from ..data.tokenizers import HuggingFaceWordPieceTokenizer +electra_cfg_reg = Registry('electra_cfg') + def get_generator_cfg(model_config): """ @@ -66,9 +69,73 @@ def get_generator_cfg(model_config): return generator_cfg +@electra_cfg_reg.register() +def google_electra_small(): + cfg = CN() + # Model + cfg.MODEL = CN() + cfg.MODEL.vocab_size = 30522 + cfg.MODEL.embed_size = 128 + cfg.MODEL.units = 256 + cfg.MODEL.hidden_size = 1024 + cfg.MODEL.max_length = 512 + cfg.MODEL.num_heads = 4 + cfg.MODEL.num_layers = 12 + cfg.MODEL.pos_embed_type = 'learned' + cfg.MODEL.activation = 'gelu' + cfg.MODEL.layer_norm_eps = 1E-12 + cfg.MODEL.num_token_types = 2 + # Dropout regularization + cfg.MODEL.hidden_dropout_prob = 0.1 + cfg.MODEL.attention_dropout_prob = 0.1 + cfg.MODEL.dtype = 'float32' + # Layout flags + cfg.MODEL.layout = 'NT' + cfg.MODEL.compute_layout = 'auto' + # Generator hyper-parameters + cfg.MODEL.generator_layers_scale = 1.0 + cfg.MODEL.generator_units_scale = 1.0 + # Initializer + cfg.INITIALIZER = CN() + cfg.INITIALIZER.embed = ['truncnorm', 0, 0.02] + cfg.INITIALIZER.weight = ['truncnorm', 0, 0.02] # TruncNorm(0, 0.02) + cfg.INITIALIZER.bias = ['zeros'] + cfg.VERSION = 1 + cfg.freeze() + return cfg + + +@electra_cfg_reg.register() +def google_electra_base(): + cfg = google_electra_small() + cfg.defrost() + cfg.MODEL.embed_size = 768 + cfg.MODEL.units = 768 + cfg.MODEL.hidden_size = 3072 + cfg.MODEL.num_heads = 12 + cfg.MODEL.num_layers = 12 + cfg.MODEL.generator_units_scale = 0.33333 + cfg.freeze() + return cfg + + +@electra_cfg_reg.register() +def google_electra_large(): + cfg = google_electra_small() + cfg.defrost() + cfg.MODEL.embed_size = 1024 + cfg.MODEL.units = 1024 + cfg.MODEL.hidden_size = 4096 + cfg.MODEL.num_heads = 16 + cfg.MODEL.num_layers = 24 + cfg.MODEL.generator_units_scale = 0.25 + cfg.freeze() + return cfg + + PRETRAINED_URL = { 'google_electra_small': { - 'cfg': 'google_electra_small/model-9ffb21c8.yml', + 'cfg': google_electra_small(), 'vocab': 'google_electra_small/vocab-e6d2b21d.json', 'params': 'google_electra_small/model-2654c8b4.params', 'disc_model': 'google_electra_small/disc_model-137714b6.params', @@ -76,7 +143,7 @@ def get_generator_cfg(model_config): 'lowercase': True, }, 'google_electra_base': { - 'cfg': 'google_electra_base/model-5b35ca0b.yml', + 'cfg': google_electra_base(), 'vocab': 'google_electra_base/vocab-e6d2b21d.json', 'params': 'google_electra_base/model-31c235cc.params', 'disc_model': 'google_electra_base/disc_model-514bd353.params', @@ -84,7 +151,7 @@ def get_generator_cfg(model_config): 'lowercase': True, }, 'google_electra_large': { - 'cfg': 'google_electra_large/model-31b7dfdd.yml', + 'cfg': google_electra_large(), 'vocab': 'google_electra_large/vocab-e6d2b21d.json', 'params': 'google_electra_large/model-9baf9ff5.params', 'disc_model': 'google_electra_large/disc_model-5b820c02.params', @@ -96,6 +163,7 @@ def get_generator_cfg(model_config): FILE_STATS = load_checksum_stats(os.path.join(get_model_zoo_checksum_dir(), 'electra.txt')) +# TODO(sxjscience) Use BertTransformer @use_np class ElectraEncoder(HybridBlock): def __init__(self, units=512, @@ -110,7 +178,35 @@ def __init__(self, units=512, layer_norm_eps=1E-12, weight_initializer=TruncNorm(stdev=0.02), bias_initializer='zeros', - activation='gelu'): + activation='gelu', + layout='NT'): + """ + + Parameters + ---------- + units + The number of units + hidden_size + The hidden size + num_layers + Number of layers + num_heads + Number of heads + attention_dropout_prob + Dropout probability of the attention layer + hidden_dropout_prob + Dropout probability + output_attention + Whether to output the attention weights + dtype + Data type of the weights + output_all_encodings + layer_norm_eps + weight_initializer + bias_initializer + activation + layout + """ super().__init__() assert units % num_heads == 0, \ 'In ElectraEncoder, The units should be divisible ' \ @@ -118,6 +214,7 @@ def __init__(self, units=512, .format(units, num_heads) self._dtype = dtype + self._layout = layout self._num_layers = num_layers self._output_attention = output_attention @@ -134,7 +231,13 @@ def __init__(self, units=512, layer_norm_eps=layer_norm_eps, weight_initializer=weight_initializer, bias_initializer=bias_initializer, - activation=activation)) + activation=activation, + dtype=dtype, + layout=layout)) + + @property + def layout(self): + return self._layout def hybrid_forward(self, F, data, valid_length): """ @@ -145,18 +248,31 @@ def hybrid_forward(self, F, data, valid_length): Parameters ---------- F - data : - Shape (batch_size, seq_length, C) - valid_length : + data + - layout = 'NT' + Shape (batch_size, seq_length, C) + - layout = 'TN' + Shape (seq_length, batch_size, C) + valid_length Shape (batch_size,) Returns ------- - out : - Shape (batch_size, seq_length, C_out) + out + - layout = 'NT' + Shape (batch_size, seq_length, C_out) + - layout = 'TN' + Shape (seq_length, batch_size, C_out) """ + if self.layout == 'NT': + time_axis, batch_axis = 1, 0 + else: + time_axis, batch_axis = 0, 1 # 1. Embed the data - attn_mask = gen_self_attn_mask(F, data, valid_length, dtype=self._dtype, attn_type='full') + attn_mask = gen_self_attn_mask(F, data, valid_length, + dtype=self._dtype, + layout=self._layout, + attn_type='full') out = data all_encodings_outputs = [] additional_outputs = [] @@ -168,7 +284,8 @@ def hybrid_forward(self, F, data, valid_length): if self._output_all_encodings: out = F.npx.sequence_mask(out, sequence_length=valid_length, - use_sequence_length=True, axis=1) + use_sequence_length=True, + axis=time_axis) all_encodings_outputs.append(out) if self._output_attention: @@ -177,7 +294,7 @@ def hybrid_forward(self, F, data, valid_length): if not self._output_all_encodings: # if self._output_all_encodings, SequenceMask is already applied above out = F.npx.sequence_mask(out, sequence_length=valid_length, - use_sequence_length=True, axis=1) + use_sequence_length=True, axis=time_axis) return out, additional_outputs else: return all_encodings_outputs, additional_outputs @@ -208,7 +325,9 @@ def __init__(self, weight_initializer=TruncNorm(stdev=0.02), bias_initializer='zeros', dtype='float32', - use_pooler=True): + use_pooler=True, + layout='NT', + compute_layout='auto'): super().__init__() self._dtype = dtype self.use_pooler = use_pooler @@ -223,6 +342,11 @@ def __init__(self, self.weight_initializer = weight_initializer self.bias_initializer = bias_initializer self.layer_norm_eps = layer_norm_eps + self._layout = layout + if compute_layout is None or compute_layout == 'auto': + self._compute_layout = layout + else: + self._compute_layout = compute_layout # Construct ElectraEncoder self.encoder = ElectraEncoder( units=units, @@ -238,6 +362,7 @@ def __init__(self, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=dtype, + layout=self._compute_layout, ) self.encoder.hybridize() @@ -262,6 +387,10 @@ def __init__(self, weight_initializer=weight_initializer, bias_initializer=bias_initializer) + @property + def layout(self): + return self._layout + def hybrid_forward(self, F, inputs, token_types, valid_length=None): # pylint: disable=arguments-differ """Generate the representation given the inputs. @@ -271,22 +400,31 @@ def hybrid_forward(self, F, inputs, token_types, valid_length=None): Parameters ---------- F - inputs : - Shape (batch_size, seq_length) - token_types : - Shape (batch_size, seq_length) + inputs + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) + token_types + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. - valid_length : + valid_length The valid length of each sequence Shape (batch_size,) Returns ------- - contextual_embedding : - Shape (batch_size, seq_length, units). - pooled_output : + contextual_embedding + - layout = 'NT' + Shape (batch_size, seq_length, units). + - layout = 'TN' + Shape (seq_length, batch_size, units). + pooled_output This is optional. Shape (batch_size, units) """ initial_embedding = self.get_initial_embedding(F, inputs, token_types) @@ -295,17 +433,27 @@ def hybrid_forward(self, F, inputs, token_types, valid_length=None): if self.embed_size != self.units: prev_out = self.embed_factorized_proj(prev_out) outputs = [] - contextual_embeddings, additional_outputs = self.encoder(prev_out, valid_length) + if self._compute_layout != self._layout: + # Swap the axes if the compute_layout and layout mismatch + contextual_embeddings, additional_outputs = self.encoder(F.np.swapaxes(prev_out, 0, 1), + valid_length) + contextual_embeddings = F.np.swapaxes(contextual_embeddings, 0, 1) + else: + contextual_embeddings, additional_outputs = self.encoder(prev_out, valid_length) outputs.append(contextual_embeddings) if self.use_pooler: # Here we just get the first token ([CLS]) without any pooling strategy, - # which is slightly different between bert model with the pooled_out + # which is slightly different from bert model with the pooled_out # the attribute name is keeping the same as bert and albert model with defualt # use_pooler=True - pooled_out = contextual_embeddings[:, 0, :] + if self._layout == 'NT': + pooled_out = contextual_embeddings[:, 0, :] + else: + pooled_out = contextual_embeddings[0, :, :] outputs.append(pooled_out) return tuple(outputs) if len(outputs) > 1 else outputs[0] + #TODO(sxjscience) Move to a `common.py` def get_initial_embedding(self, F, inputs, token_types=None): """Get the initial token embeddings that considers the token type and positional embeddings @@ -313,24 +461,38 @@ def get_initial_embedding(self, F, inputs, token_types=None): ---------- F inputs - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) token_types - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) If None, it will be initialized as all zero Returns ------- embedding The initial embedding that will be fed into the encoder + - layout = 'NT' + Shape (batch_size, seq_length, C_embed) + - layout = 'TN' + Shape (seq_length, batch_size, C_embed) """ + if self.layout == 'NT': + time_axis, batch_axis = 1, 0 + else: + time_axis, batch_axis = 0, 1 embedding = self.word_embed(inputs) if token_types is None: token_types = F.np.zeros_like(inputs) type_embedding = self.token_type_embed(token_types) embedding = embedding + type_embedding if self.pos_embed_type is not None: - positional_embedding = self.token_pos_embed(F.npx.arange_like(inputs, axis=1)) - positional_embedding = F.np.expand_dims(positional_embedding, axis=0) + positional_embedding = self.token_pos_embed(F.npx.arange_like(inputs, axis=time_axis)) + positional_embedding = F.np.expand_dims(positional_embedding, axis=batch_axis) embedding = embedding + positional_embedding # Extra layer normalization plus dropout embedding = self.embed_layer_norm(embedding) @@ -339,48 +501,20 @@ def get_initial_embedding(self, F, inputs, token_types=None): @staticmethod def get_cfg(key=None): - if key is None: - cfg = CN() - # Model Parameters for the electra small - cfg.MODEL = CN() - cfg.MODEL.vocab_size = 30522 - cfg.MODEL.embed_size = 128 - cfg.MODEL.units = 256 - cfg.MODEL.hidden_size = 1024 - cfg.MODEL.max_length = 512 - cfg.MODEL.num_heads = 4 - cfg.MODEL.num_layers = 12 - cfg.MODEL.pos_embed_type = 'learned' - # Unlike BERT and ALBERT, which ues gelu(tanh), the gelu(erf) is used in Electra. - cfg.MODEL.activation = 'gelu' - cfg.MODEL.layer_norm_eps = 1E-12 - cfg.MODEL.num_token_types = 2 - cfg.MODEL.hidden_dropout_prob = 0.1 - cfg.MODEL.attention_dropout_prob = 0.1 - cfg.MODEL.dtype = 'float32' - cfg.MODEL.generator_layers_scale = 1.0 - # multiplier for units, hidden_size, and num_heads - cfg.MODEL.generator_units_scale = 1.0 - # Hyper-parameters of the Initializers - cfg.INITIALIZER = CN() - cfg.INITIALIZER.embed = ['truncnorm', 0, 0.02] - cfg.INITIALIZER.weight = ['truncnorm', 0, 0.02] # TruncNorm(0, 0.02) - cfg.INITIALIZER.bias = ['zeros'] - # Version of the model. This helps ensure backward compatibility. - # Also, we can not use string here due to https://github.com/rbgirshick/yacs/issues/26 - cfg.VERSION = 1 - cfg.freeze() + if key is not None: + return electra_cfg_reg.create(key) else: - raise NotImplementedError - return cfg + return google_electra_base() @classmethod - def from_cfg(cls, cfg, use_pooler=True, dtype='float32') -> 'ElectraModel': + def from_cfg(cls, cfg, use_pooler=True, dtype=None) -> 'ElectraModel': cfg = ElectraModel.get_cfg().clone_merge(cfg) assert cfg.VERSION == 1, 'Wrong version!' embed_initializer = mx.init.create(*cfg.INITIALIZER.embed) weight_initializer = mx.init.create(*cfg.INITIALIZER.weight) bias_initializer = mx.init.create(*cfg.INITIALIZER.bias) + if dtype is None: + dtype = cfg.MODEL.dtype return cls(vocab_size=cfg.MODEL.vocab_size, units=cfg.MODEL.units, hidden_size=cfg.MODEL.hidden_size, @@ -398,7 +532,9 @@ def from_cfg(cls, cfg, use_pooler=True, dtype='float32') -> 'ElectraModel': embed_initializer=embed_initializer, weight_initializer=weight_initializer, bias_initializer=bias_initializer, - use_pooler=use_pooler) + use_pooler=use_pooler, + layout=cfg.MODEL.layout, + compute_layout=cfg.MODEL.compute_layout) @use_np @@ -447,25 +583,37 @@ def hybrid_forward(self, F, inputs, token_types, valid_length): Parameters ---------- F - inputs : - Shape (batch_size, seq_length) - token_types : - Shape (batch_size, seq_length) + inputs + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) + token_types + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. - valid_length : + valid_length The valid length of each sequence Shape (batch_size,) Returns ------- contextual_embedding - Shape (batch_size, seq_length, units). + - layout = 'NT' + Shape (batch_size, seq_length, units). + - layout = 'TN' + Shape (seq_length, batch_size, units). pooled_out Shape (batch_size, units) rtd_scores - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) """ contextual_embeddings, pooled_out = self.backbone_model(inputs, token_types, valid_length) rtd_scores = self.rtd_encoder(contextual_embeddings).squeeze(-1) @@ -515,8 +663,21 @@ def __init__(self, backbone_cfg, self.mlm_decoder[-1].weight = self.backbone_model.word_embed.weight self.mlm_decoder.hybridize() - def tie_embeddings(self, word_embed_params=None, token_type_embed_params=None, - token_pos_embed_params=None, embed_layer_norm_params=None): + # TODO(sxjscience,zheyu) Should design a better API + def tie_embeddings(self, word_embed_params=None, + token_type_embed_params=None, + token_pos_embed_params=None, + embed_layer_norm_params=None): + """Tie the embedding layers between the backbone and the MLM decoder + + Parameters + ---------- + word_embed_params + token_type_embed_params + token_pos_embed_params + embed_layer_norm_params + + """ self.backbone_model.word_embed.share_parameters(word_embed_params) self.mlm_decoder[-1].share_parameters(word_embed_params) self.backbone_model.token_type_embed.share_parameters(token_type_embed_params) @@ -529,10 +690,16 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, masked_positions) Parameters ---------- F - inputs : - Shape (batch_size, seq_length) - token_types : - Shape (batch_size, seq_length) + inputs + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) + token_types + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. @@ -546,14 +713,21 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, masked_positions) Returns ------- contextual_embedding - Shape (batch_size, seq_length, units). + - layout = 'NT' + Shape (batch_size, seq_length, units). + - layout = 'TN' + Shape (seq_length, batch_size, units). pooled_out Shape (batch_size, units) mlm_scores : Shape (batch_size, num_masked_positions, vocab_size) """ contextual_embeddings, pooled_out = self.backbone_model(inputs, token_types, valid_length) - mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) + if self.backbone_model.layout == 'NT': + mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) + else: + mlm_features = select_vectors_by_position(F, F.np.swapaxes(contextual_embeddings, 0, 1), + masked_positions) mlm_scores = self.mlm_decoder(mlm_features) return contextual_embeddings, pooled_out, mlm_scores @@ -561,7 +735,7 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, masked_positions) @use_np class ElectraForPretrain(HybridBlock): """ - A integrated model combined with a generator and a discriminator. Generator here + An integrated model combined with a generator and a discriminator. Generator here produces a corrupted tokens playing as fake data to fool a discriminator whose objective is to distinguish whether each token in the input sentence it accepts is the same as the original. It is a classification task instead of prediction @@ -612,11 +786,15 @@ def __init__(self, self.disc_cfg = disc_cfg self.vocab_size = disc_cfg.MODEL.vocab_size self.gen_cfg = get_generator_cfg(disc_cfg) - self.discriminator = ElectraDiscriminator(disc_cfg) + self.discriminator = ElectraDiscriminator(disc_cfg, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) self.disc_backbone = self.discriminator.backbone_model if not uniform_generator and not tied_generator: - self.generator = ElectraGenerator(self.gen_cfg) + self.generator = ElectraGenerator(self.gen_cfg, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) if tied_embeddings: self.generator.tie_embeddings(self.disc_backbone.word_embed.collect_params(), self.disc_backbone.token_type_embed.collect_params(), @@ -626,7 +804,10 @@ def __init__(self, elif tied_generator: # Reuse the weight of the discriminator backbone model - self.generator = ElectraGenerator(self.gen_cfg) + self.generator = ElectraGenerator(self.gen_cfg, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer) + # TODO(sxjscience, zheyu) Verify self.generator.backbone_model = self.disc_backbone self.generator.hybridize() elif uniform_generator: @@ -650,18 +831,24 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, Parameters ---------- F - inputs : + inputs The masked input - Shape (batch_size, seq_length) - token_types : - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) + token_types + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. - valid_length : + valid_length The valid length of each sequence Shape (batch_size,) - unmasked_tokens : + unmasked_tokens The original tokens that appear in the unmasked input sequence Shape (batch_size, num_masked_positions). masked_positions : @@ -670,20 +857,26 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, Returns ------- - mlm_scores : + mlm_scores Shape (batch_size, num_masked_positions, vocab_size) - rtd_scores : - Shape (batch_size, seq_length) + rtd_scores + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) replaced_inputs : Shape (batch_size, num_masked_positions) - labels : - Shape (batch_size, seq_length) + labels + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) """ if self._uniform_generator: # generate the corrupt tokens randomly with a mlm_scores vector whose value is all 0 - zero_logits = F.np.zeros(self.vocab_size) - zero_logits = F.np.expand_dims(F.np.expand_dims(zero_logits, axis=0), axis=0) - mlm_scores = F.np.expand_dims(F.np.zeros_like(masked_positions), axis=-1) + zero_logits = F.np.zeros((1, 1, self.vocab_size), dtype=self._dtype) + mlm_scores = F.np.expand_dims(F.np.zeros_like(masked_positions, dtype=self._dtype), + axis=-1) mlm_scores = mlm_scores + zero_logits else: _, _, mlm_scores = self.generator(inputs, token_types, valid_length, masked_positions) @@ -698,12 +891,16 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, def get_corrupted_tokens(self, F, inputs, unmasked_tokens, masked_positions, logits): """ Sample from the generator to create corrupted input. + Parameters ---------- F inputs The masked input - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) unmasked_tokens The original tokens that appear in the unmasked input sequence Shape (batch_size, num_masked_positions). @@ -715,10 +912,18 @@ def get_corrupted_tokens(self, F, inputs, unmasked_tokens, masked_positions, log Returns ------- + corrupted_tokens + The corrupted tokens fake_data - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) labels - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) """ if self._disallow_correct: @@ -734,6 +939,8 @@ def get_corrupted_tokens(self, F, inputs, unmasked_tokens, masked_positions, log use_np_gumbel=False) corrupted_tokens = F.np.argmax(prob, axis=-1).astype(np.int32) + if self.disc_backbone.layout == 'TN': + inputs = inputs.T # Following the Official electra to deal with duplicate positions as # https://github.com/google-research/electra/issues/41 original_data, updates_mask = updated_vectors_by_position(F, @@ -742,7 +949,10 @@ def get_corrupted_tokens(self, F, inputs, unmasked_tokens, masked_positions, log inputs, corrupted_tokens, masked_positions) labels = updates_mask * F.np.not_equal(fake_data, original_data) - return corrupted_tokens, fake_data, labels + if self.disc_backbone.layout == 'TN': + return corrupted_tokens, fake_data.T, labels.T + else: + return corrupted_tokens, fake_data, labels def list_pretrained_electra(): @@ -787,13 +997,20 @@ def get_pretrained_electra(model_name: str = 'google_electra_small', assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format( model_name, list_pretrained_electra()) cfg_path = PRETRAINED_URL[model_name]['cfg'] + if isinstance(cfg_path, CN): + cfg = cfg_path + else: + cfg = None vocab_path = PRETRAINED_URL[model_name]['vocab'] params_path = PRETRAINED_URL[model_name]['params'] disc_params_path = PRETRAINED_URL[model_name]['disc_model'] gen_params_path = PRETRAINED_URL[model_name]['gen_model'] local_paths = dict() - for k, path in [('cfg', cfg_path), ('vocab', vocab_path)]: + download_jobs = [('vocab', vocab_path)] + if cfg is None: + download_jobs.append(('cfg', cfg_path)) + for k, path in download_jobs: local_paths[k] = download(url=get_repo_model_zoo_url() + path, path=os.path.join(root, path), sha1_hash=FILE_STATS[path]) @@ -827,7 +1044,8 @@ def get_pretrained_electra(model_name: str = 'google_electra_small', sep_token='[SEP]', mask_token='[MASK]', lowercase=do_lower) - cfg = ElectraModel.get_cfg().clone_merge(local_paths['cfg']) + if cfg is None: + cfg = ElectraModel.get_cfg().clone_merge(local_paths['cfg']) return cfg, tokenizer, local_params_path, (local_disc_params_path, local_gen_params_path) diff --git a/src/gluonnlp/models/mobilebert.py b/src/gluonnlp/models/mobilebert.py index 502d7f4750..5a81de7c64 100644 --- a/src/gluonnlp/models/mobilebert.py +++ b/src/gluonnlp/models/mobilebert.py @@ -41,6 +41,7 @@ from ..initializer import TruncNorm from ..utils.config import CfgNode as CN from ..utils.misc import load_checksum_stats, download +from ..utils.registry import Registry from ..registry import BACKBONE_REGISTRY from ..attention_cell import MultiHeadAttentionCell, gen_self_attn_mask from ..data.tokenizers import HuggingFaceWordPieceTokenizer @@ -48,9 +49,51 @@ __all__ = ['MobileBertModel', 'MobileBertForMLM', 'MobileBertForPretrain', 'list_pretrained_mobilebert', 'get_pretrained_mobilebert'] +mobilebert_cfg_reg = Registry('mobilebert_cfg') + + +@mobilebert_cfg_reg.register() +def google_uncased_mobilebert(): + cfg = CN() + cfg.MODEL = CN() + cfg.MODEL.vocab_size = 30522 + cfg.MODEL.units = 512 + cfg.MODEL.embed_size = 128 + cfg.MODEL.inner_size = 128 + cfg.MODEL.hidden_size = 512 + cfg.MODEL.max_length = 512 + cfg.MODEL.num_heads = 4 + cfg.MODEL.num_layers = 24 + + cfg.MODEL.use_bottleneck = True # Whether to use bottleneck + cfg.MODEL.trigram_embed = True # Trigram embedding + cfg.MODEL.classifier_activation = False # Whether to use an additional pooling layer + cfg.MODEL.bottleneck_strategy = 'qk_sharing' + cfg.MODEL.num_stacked_ffn = 4 + cfg.MODEL.pos_embed_type = 'learned' + cfg.MODEL.activation = 'relu' + cfg.MODEL.num_token_types = 2 + cfg.MODEL.hidden_dropout_prob = 0.0 + cfg.MODEL.attention_dropout_prob = 0.1 + cfg.MODEL.normalization = 'no_norm' + cfg.MODEL.layer_norm_eps = 1E-12 + cfg.MODEL.dtype = 'float32' + # Layout flags + cfg.MODEL.layout = 'NT' + cfg.MODEL.compute_layout = 'auto' + # Initializer + cfg.INITIALIZER = CN() + cfg.INITIALIZER.embed = ['truncnorm', 0, 0.02] + cfg.INITIALIZER.weight = ['truncnorm', 0, 0.02] # TruncNorm(0, 0.02) + cfg.INITIALIZER.bias = ['zeros'] + cfg.VERSION = 1 + cfg.freeze() + return cfg + + PRETRAINED_URL = { 'google_uncased_mobilebert': { - 'cfg': 'google_uncased_mobilebert/model-1c33216b.yml', + 'cfg': google_uncased_mobilebert(), 'vocab': 'google_uncased_mobilebert/vocab-e6d2b21d.json', 'params': 'google_uncased_mobilebert/model-c8346cf2.params', 'mlm_params': 'google_uncased_mobilebert/model_mlm-53948e82.params', @@ -66,7 +109,7 @@ class MobileBertEncoderLayer(HybridBlock): """The Transformer Encoder Layer in Mobile Bert""" # TODO(zheyuye), use stacked groups for single ffn layer in transformer.TransformerEncoderLayer - # and revise the other models and scripts, masking sure their are compatible. + # and revise the other models and scripts, making sure they are compatible. def __init__(self, use_bottleneck: bool = True, @@ -85,12 +128,14 @@ def __init__(self, use_qkv_bias: bool = True, weight_initializer: Optional[InitializerType] = None, bias_initializer: Optional[InitializerType] = 'zeros', - dtype='float32'): + dtype='float32', + layout='NT'): """ Parameters ---------- use_bottleneck + Whether to use the bottleneck layer. units size of inter-bottleneck real_units @@ -110,6 +155,9 @@ def __init__(self, weight_initializer bias_initializer dtype + Data type of the block + layout + Layout of the input + output """ super().__init__() self._use_bottleneck = use_bottleneck @@ -119,6 +167,7 @@ def __init__(self, self._num_stacked_ffn = num_stacked_ffn self._bottleneck_strategy = bottleneck_strategy self._dtype = dtype + self._layout = layout assert real_units % num_heads == 0, 'units must be divisive by the number of heads' self.dropout_layer = nn.Dropout(hidden_dropout_prob) if use_bottleneck: @@ -159,24 +208,47 @@ def __init__(self, bias_initializer=bias_initializer, dtype=self._dtype) # The in_units of qkv varies according to the sharing strategy + if self._use_bottleneck: + if self._bottleneck_strategy == 'qk_sharing': + attn_query_in_units = real_units + attn_key_in_units = real_units + attn_value_in_units = units + elif self._bottleneck_strategy == 'from_bottleneck': + attn_query_in_units = real_units + attn_key_in_units = real_units + attn_value_in_units = real_units + elif self._bottleneck_strategy == 'from_input': + attn_query_in_units = units + attn_key_in_units = units + attn_value_in_units = units + else: + raise NotImplementedError + else: + attn_query_in_units = units + attn_key_in_units = units + attn_value_in_units = units self.attn_query = nn.Dense(units=real_units, + in_units=attn_query_in_units, flatten=False, use_bias=use_qkv_bias, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=self._dtype) self.attn_key = nn.Dense(units=real_units, + in_units=attn_key_in_units, flatten=False, use_bias=use_qkv_bias, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=self._dtype) self.attn_value = nn.Dense(units=real_units, + in_units=attn_value_in_units, flatten=False, use_bias=use_qkv_bias, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=self._dtype) + attention_layout = 'NTK' if self._layout == 'NT' else 'TNK' self.attention_cell = \ MultiHeadAttentionCell( query_units=real_units, @@ -184,7 +256,7 @@ def __init__(self, attention_dropout=attention_dropout_prob, scaled=True, dtype=self._dtype, - layout='NTK' + layout=attention_layout ) self.layer_norm = get_layer_norm(normalization=normalization, in_channels=real_units, @@ -209,26 +281,35 @@ def __init__(self, layer_norm_eps=layer_norm_eps, dtype=self._dtype)) + @property + def layout(self): + return self._layout + def hybrid_forward(self, F, data, attn_mask): """ Parameters ---------- F - data : - Shape (batch_size, seq_length, C_in) - attn_mask : + data + - layout = 'NT' + Shape (batch_size, seq_length, C_in) + - layout = 'TN' + Shape (seq_length, batch_size, C_in) + attn_mask + The attention mask Shape (batch_size, seq_length, seq_length) Returns ------- - out : - Shape (batch_size, seq_length, C_out) - attn_weight : + out + - layout = 'NT' + Shape (batch_size, seq_length, C_out) + - layout = 'TN' + Shape (seq_length, batch_size, C_out) + attn_weight Shape (batch_size, seq_length, seq_length) """ - # TODO(sxjscience) Cannot use negative axis due to - # https://github.com/apache/incubator-mxnet/issues/18132 if self._use_bottleneck: bn_proj = self.in_bottleneck_proj(data) bn_proj = self.in_bottleneck_ln(bn_proj) @@ -241,7 +322,7 @@ def hybrid_forward(self, F, data, attn_mask): key = qk_shared value = data elif self._bottleneck_strategy == 'from_bottleneck': - # for Mobile mobile bert Tiny + # for Mobile Bert Tiny query = bn_proj key = bn_proj value = bn_proj @@ -298,12 +379,14 @@ def __init__(self, layer_norm_eps: float = 1E-12, weight_initializer: InitializerType = TruncNorm(stdev=0.02), bias_initializer: InitializerType = 'zeros', - dtype='float32'): + dtype='float32', + layout='NT'): super().__init__() self._dtype = dtype self._num_layers = num_layers self._output_attention = output_attention self._output_all_encodings = output_all_encodings + self._layout = layout assert bottleneck_strategy in ['qk_sharing', 'from_bottleneck', 'from_input'], \ 'The bottleneck strategy={} is not supported.'.format(bottleneck_strategy) @@ -329,7 +412,12 @@ def __init__(self, weight_initializer=weight_initializer, bias_initializer=bias_initializer, normalization=normalization, - activation=activation)) + activation=activation, + layout=layout)) + + @property + def layout(self): + return self._layout def hybrid_forward(self, F, data, valid_length): """ @@ -340,18 +428,34 @@ def hybrid_forward(self, F, data, valid_length): Parameters ---------- F - data : - Shape (batch_size, seq_length, C) - valid_length : + data + - layout = 'NT' + Shape (batch_size, seq_length, C) + - layout = 'TN' + Shape (seq_length, batch_size, C) + valid_length Shape (batch_size,) Returns ------- - out : - Shape (batch_size, seq_length, C_out) + out + - layout = 'NT' + Shape (batch_size, seq_length, C_out) + - layout = 'TN' + Shape (seq_length, batch_size, C_out) """ + if self._layout == 'NT': + batch_axis, time_axis = 0, 1 + elif self._layout == 'TN': + batch_axis, time_axis = 1, 0 + else: + raise NotImplementedError('Received layout="{}". ' + 'Only "NT" and "TN" are supported.'.format(self._layout)) # 1. Embed the data - attn_mask = gen_self_attn_mask(F, data, valid_length, dtype=self._dtype, attn_type='full') + attn_mask = gen_self_attn_mask(F, data, valid_length, + dtype=self._dtype, + layout=self._layout, + attn_type='full') out = data all_encodings_outputs = [] additional_outputs = [] @@ -364,7 +468,8 @@ def hybrid_forward(self, F, data, valid_length): if self._output_all_encodings: out = F.npx.sequence_mask(out, sequence_length=valid_length, - use_sequence_length=True, axis=1) + use_sequence_length=True, + axis=time_axis) all_encodings_outputs.append(out) if self._output_attention: @@ -373,7 +478,8 @@ def hybrid_forward(self, F, data, valid_length): if not self._output_all_encodings: # if self._output_all_encodings, SequenceMask is already applied above out = F.npx.sequence_mask(out, sequence_length=valid_length, - use_sequence_length=True, axis=1) + use_sequence_length=True, + axis=time_axis) return out, additional_outputs else: return all_encodings_outputs, additional_outputs @@ -406,7 +512,9 @@ def __init__(self, trigram_embed=True, use_pooler=True, classifier_activation=False, - dtype='float32'): + dtype='float32', + layout='NT', + compute_layout='auto'): super().__init__() self._dtype = dtype self.use_bottleneck = use_bottleneck @@ -428,6 +536,12 @@ def __init__(self, self.weight_initializer = weight_initializer self.bias_initializer = bias_initializer self.layer_norm_eps = layer_norm_eps + self._layout = layout + if compute_layout == 'auto' or compute_layout is None: + self._compute_layout = layout + else: + assert compute_layout in ['TN', 'NT'] + self._compute_layout = compute_layout # Construct MobileBertTransformer self.encoder = MobileBertTransformer( units=units, @@ -447,6 +561,7 @@ def __init__(self, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=dtype, + layout=self._compute_layout, ) self.encoder.hybridize() # Construct word embedding @@ -455,7 +570,12 @@ def __init__(self, weight_initializer=embed_initializer, dtype=dtype) if trigram_embed or embed_size != units: + if trigram_embed: + in_units = 3 * embed_size + else: + in_units = embed_size self.embed_factorized_proj = nn.Dense(units=units, + in_units=in_units, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer) @@ -467,7 +587,8 @@ def __init__(self, # Construct token type embedding self.token_type_embed = nn.Embedding(input_dim=num_token_types, output_dim=units, - weight_initializer=weight_initializer) + weight_initializer=weight_initializer, + dtype=self._dtype) self.token_pos_embed = PositionalEmbedding(units=units, max_length=max_length, dtype=self._dtype, @@ -478,9 +599,18 @@ def __init__(self, in_units=units, flatten=False, activation='tanh', + dtype=self._dtype, weight_initializer=weight_initializer, bias_initializer=bias_initializer) + @property + def layout(self): + return self._layout + + @property + def dtype(self): + return self._dtype + def hybrid_forward(self, F, inputs, token_types, valid_length): # pylint: disable=arguments-differ """Generate the representation given the inputs. @@ -490,11 +620,16 @@ def hybrid_forward(self, F, inputs, token_types, valid_length): Parameters ---------- F - inputs : - Shape (batch_size, seq_length) - token_types : - Shape (batch_size, seq_length) - + inputs + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) + token_types + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. valid_length : @@ -510,24 +645,34 @@ def hybrid_forward(self, F, inputs, token_types, valid_length): """ embedding = self.get_initial_embedding(F, inputs, token_types) - contextual_embeddings, additional_outputs = self.encoder(embedding, valid_length) - outputs = [] - outputs.append(contextual_embeddings) + if self._compute_layout != self._layout: + contextual_embeddings, additional_outputs = self.encoder(F.np.swapaxes(embedding, 0, 1), + valid_length) + contextual_embeddings = F.np.swapaxes(contextual_embeddings, 0, 1) + else: + contextual_embeddings, additional_outputs = self.encoder(embedding, valid_length) if self.use_pooler: pooled_out = self.apply_pooling(contextual_embeddings) - outputs.append(pooled_out) - return tuple(outputs) if len(outputs) > 1 else outputs[0] + return contextual_embeddings, pooled_out + else: + return contextual_embeddings - def get_initial_embedding(self, F, inputs, token_types=None, trigram_embed=True): + def get_initial_embedding(self, F, inputs, token_types=None): """Get the initial token embeddings that considers the token type and positional embeddings Parameters ---------- F inputs - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) token_types - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) If None, it will be initialized as all zero Returns @@ -535,24 +680,39 @@ def get_initial_embedding(self, F, inputs, token_types=None, trigram_embed=True) embedding The initial embedding that will be fed into the encoder """ + if self._layout == 'NT': + batch_axis, time_axis = 0, 1 + elif self._layout == 'TN': + batch_axis, time_axis = 1, 0 + else: + raise NotImplementedError word_embedding = self.word_embed(inputs) - if trigram_embed: - word_embedding = F.np.concatenate( - [F.np.pad(word_embedding[:, 1:], ((0, 0), (0, 1), (0, 0))), - word_embedding, - F.np.pad(word_embedding[:, :-1], ((0, 0), (1, 0), (0, 0)))], axis=-1) + if self.trigram_embed: + if self._layout == 'NT': + word_embedding = F.np.concatenate( + [F.np.pad(word_embedding[:, 1:], ((0, 0), (0, 1), (0, 0))), + word_embedding, + F.np.pad(word_embedding[:, :-1], ((0, 0), (1, 0), (0, 0)))], axis=-1) + elif self._layout == 'TN': + word_embedding = F.np.concatenate( + [F.np.pad(word_embedding[1:, :], ((0, 1), (0, 0), (0, 0))), + word_embedding, + F.np.pad(word_embedding[:-1, :], ((1, 0), (0, 0), (0, 0)))], axis=-1) + else: + raise NotImplementedError # Projecting the embedding into units only for word embedding - if trigram_embed or self.embed_size != self.units: - embedding = self.embed_factorized_proj(word_embedding) + if self.trigram_embed or self.embed_size != self.units: + word_embedding = self.embed_factorized_proj(word_embedding) if token_types is None: - token_types = F.np.zeros_like(embedding) + token_types = F.np.zeros_like(inputs) type_embedding = self.token_type_embed(token_types) - embedding = embedding + type_embedding + embedding = word_embedding + type_embedding if self.pos_embed_type is not None: - positional_embedding = self.token_pos_embed(F.npx.arange_like(embedding, axis=1)) - positional_embedding = F.np.expand_dims(positional_embedding, axis=0) + positional_embedding =\ + self.token_pos_embed(F.npx.arange_like(embedding, axis=time_axis)) + positional_embedding = F.np.expand_dims(positional_embedding, axis=batch_axis) embedding = embedding + positional_embedding # Extra layer normalization plus dropout embedding = self.embed_layer_norm(embedding) @@ -565,12 +725,23 @@ def apply_pooling(self, sequence): This is used for pre-training or fine-tuning a mobile bert model. Get the first token of the whole sequence which is [CLS] - sequence: - Shape (batch_size, sequence_length, units) - return: + Parameters + ---------- + sequence + - layout = 'NT' + Shape (batch_size, sequence_length, units) + - layout = 'TN' + Shape (sequence_length, batch_size, units) + + Returns + ------- + outputs Shape (batch_size, units) """ - outputs = sequence[:, 0, :] + if self._layout == 'NT': + outputs = sequence[:, 0, :] + else: + outputs = sequence[0, :, :] if self.classifier_activation: return self.pooler(outputs) else: @@ -578,53 +749,23 @@ def apply_pooling(self, sequence): @staticmethod def get_cfg(key=None): - if key is None: - cfg = CN() - cfg.MODEL = CN() - cfg.MODEL.vocab_size = 30522 - cfg.MODEL.embed_size = 128 - cfg.MODEL.units = 512 - cfg.MODEL.hidden_size = 512 - cfg.MODEL.inner_size = 128 - cfg.MODEL.max_length = 512 - cfg.MODEL.num_heads = 4 - cfg.MODEL.num_layers = 12 - cfg.MODEL.num_stacked_ffn = 4 - cfg.MODEL.pos_embed_type = 'learned' - cfg.MODEL.activation = 'relu' - cfg.MODEL.normalization = 'no_norm' - cfg.MODEL.layer_norm_eps = 1E-12 - cfg.MODEL.bottleneck_strategy = 'qk_sharing' - cfg.MODEL.num_token_types = 2 - cfg.MODEL.hidden_dropout_prob = 0.0 - cfg.MODEL.attention_dropout_prob = 0.1 - cfg.MODEL.dtype = 'float32' - # Hyper-parameters of the Initializers - cfg.INITIALIZER = CN() - cfg.INITIALIZER.embed = ['truncnorm', 0, 0.02] - cfg.INITIALIZER.weight = ['truncnorm', 0, 0.02] # TruncNorm(0, 0.02) - cfg.INITIALIZER.bias = ['zeros'] - # Version of the model. This helps ensure backward compatibility. - # Also, we can not use string here due to https://github.com/rbgirshick/yacs/issues/26 - cfg.VERSION = 1 + if key is not None: + return mobilebert_cfg_reg.create(key) else: - raise NotImplementedError - cfg.freeze() - return cfg + return google_uncased_mobilebert() @classmethod def from_cfg(cls, cfg, use_pooler=True, - dtype='float32', - use_bottleneck=True, - trigram_embed=True, - classifier_activation=False) -> 'MobileBertModel': + dtype=None) -> 'MobileBertModel': cfg = MobileBertModel.get_cfg().clone_merge(cfg) assert cfg.VERSION == 1, 'Wrong version!' embed_initializer = mx.init.create(*cfg.INITIALIZER.embed) weight_initializer = mx.init.create(*cfg.INITIALIZER.weight) bias_initializer = mx.init.create(*cfg.INITIALIZER.bias) + if dtype is None: + dtype = cfg.MODEL.dtype return cls(vocab_size=cfg.MODEL.vocab_size, units=cfg.MODEL.units, hidden_size=cfg.MODEL.hidden_size, @@ -646,17 +787,17 @@ def from_cfg(cls, embed_initializer=embed_initializer, weight_initializer=weight_initializer, bias_initializer=bias_initializer, - use_bottleneck=use_bottleneck, - trigram_embed=trigram_embed, + use_bottleneck=cfg.MODEL.use_bottleneck, + trigram_embed=cfg.MODEL.trigram_embed, use_pooler=use_pooler, - classifier_activation=classifier_activation) + classifier_activation=cfg.MODEL.classifier_activation, + layout=cfg.MODEL.layout, + compute_layout=cfg.MODEL.compute_layout) @use_np class MobileBertForMLM(HybridBlock): def __init__(self, backbone_cfg, - use_bottleneck=True, - trigram_embed=True, weight_initializer=None, bias_initializer=None): """ @@ -668,9 +809,7 @@ def __init__(self, backbone_cfg, bias_initializer """ super().__init__() - self.backbone_model = MobileBertModel.from_cfg(backbone_cfg, - use_bottleneck=use_bottleneck, - trigram_embed=trigram_embed) + self.backbone_model = MobileBertModel.from_cfg(backbone_cfg) if weight_initializer is None: weight_initializer = self.backbone_model.weight_initializer if bias_initializer is None: @@ -680,7 +819,8 @@ def __init__(self, backbone_cfg, self.mlm_decoder.add(nn.Dense(units=self.backbone_model.units, flatten=False, weight_initializer=weight_initializer, - bias_initializer=bias_initializer)) + bias_initializer=bias_initializer, + dtype=self.backbone_model.dtype)) self.mlm_decoder.add(get_activation(self.backbone_model.activation)) # use basic layer normalization for pretaining self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps)) @@ -692,14 +832,14 @@ def __init__(self, backbone_cfg, units=self.backbone_model.vocab_size, in_units=self.backbone_model.embed_size, flatten=False, + dtype=self.backbone_model.dtype, bias_initializer=bias_initializer) self.embedding_table.weight = self.backbone_model.word_embed.weight if self.backbone_model.embed_size != self.backbone_model.units: self.extra_table = nn.Dense( units=self.backbone_model.vocab_size, use_bias=False, - in_units=self.backbone_model.units - - self.backbone_model.embed_size, + in_units=self.backbone_model.units - self.backbone_model.embed_size, flatten=False) def hybrid_forward(self, F, inputs, token_types, valid_length, @@ -709,30 +849,43 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, Parameters ---------- F - inputs : - Shape (batch_size, seq_length) - token_types : - Shape (batch_size, seq_length) + inputs + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) + token_types + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) The type of the token. For example, if the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. - valid_length : + valid_length The valid length of each sequence Shape (batch_size,) - masked_positions : + masked_positions The masked position of the sequence Shape (batch_size, num_masked_positions). Returns ------- contextual_embedding - Shape (batch_size, seq_length, units). + - layout = 'NT' + Shape (batch_size, seq_length, units). + - layout = 'TN' + Shape (seq_length, batch_size, units). pooled_out Shape (batch_size, units) - mlm_scores : + mlm_scores Shape (batch_size, num_masked_positions, vocab_size) """ contextual_embeddings, pooled_out = self.backbone_model(inputs, token_types, valid_length) - mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) + if self.backbone_model.layout == 'TN': + mlm_features = select_vectors_by_position(F, F.np.swapaxes(contextual_embeddings, 0, 1), + masked_positions) + else: + mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) intermediate_output = self.mlm_decoder(mlm_features) if self.backbone_model.embed_size != self.backbone_model.units: scores = self.embedding_table( @@ -748,8 +901,6 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, @use_np class MobileBertForPretrain(HybridBlock): def __init__(self, backbone_cfg, - use_bottleneck=True, - trigram_embed=True, weight_initializer=None, bias_initializer=None): """ @@ -762,22 +913,22 @@ def __init__(self, backbone_cfg, bias_initializer """ super().__init__() - self.backbone_model = MobileBertModel.from_cfg(backbone_cfg, - use_bottleneck=use_bottleneck, - trigram_embed=trigram_embed) + self.backbone_model = MobileBertModel.from_cfg(backbone_cfg) if weight_initializer is None: weight_initializer = self.backbone_model.weight_initializer if bias_initializer is None: bias_initializer = self.backbone_model.bias_initializer # Construct nsp_classifier for next sentence prediction self.nsp_classifier = nn.Dense(units=2, - weight_initializer=weight_initializer) + weight_initializer=weight_initializer, + dtype=self.backbone_model.dtype) self.mlm_decoder = nn.HybridSequential() # Extra non-linear layer self.mlm_decoder.add(nn.Dense(units=self.backbone_model.units, flatten=False, weight_initializer=weight_initializer, - bias_initializer=bias_initializer)) + bias_initializer=bias_initializer, + dtype=self.backbone_model.dtype)) self.mlm_decoder.add(get_activation(self.backbone_model.activation)) # use basic layer normalization for pretaining self.mlm_decoder.add(nn.LayerNorm(epsilon=self.backbone_model.layer_norm_eps)) @@ -789,7 +940,8 @@ def __init__(self, backbone_cfg, units=self.backbone_model.vocab_size, in_units=self.backbone_model.embed_size, flatten=False, - bias_initializer=bias_initializer) + bias_initializer=bias_initializer, + dtype=self.backbone_model.dtype) self.embedding_table.weight = self.backbone_model.word_embed.weight if self.backbone_model.embed_size != self.backbone_model.units: self.extra_table = nn.Dense( @@ -798,7 +950,8 @@ def __init__(self, backbone_cfg, self.backbone_model.embed_size, flatten=False, use_bias=False, - bias_initializer=bias_initializer) + bias_initializer=bias_initializer, + dtype=self.backbone_model.dtype) def hybrid_forward(self, F, inputs, token_types, valid_length, masked_positions): @@ -809,34 +962,47 @@ def hybrid_forward(self, F, inputs, token_types, valid_length, Parameters ---------- F - inputs : - Shape (batch_size, seq_length) - token_types : - Shape (batch_size, seq_length) + inputs + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) + token_types + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) If the inputs contain two sequences, we will set different token types for the first sentence and the second sentence. - valid_length : + valid_length The valid length of each sequence Shape (batch_size,) - masked_positions : + masked_positions The masked position of the sequence Shape (batch_size, num_masked_positions). Returns ------- contextual_embedding - Shape (batch_size, seq_length, units). + - layout = 'NT' + Shape (batch_size, seq_length, units). + - layout = 'TN' + Shape (seq_length, batch_size, units). pooled_out Shape (batch_size, units) - nsp_score : + nsp_score Shape (batch_size, 2) - mlm_scores : + mlm_scores Shape (batch_size, num_masked_positions, vocab_size) """ contextual_embeddings, pooled_out = self.backbone_model(inputs, token_types, valid_length) nsp_score = self.nsp_classifier(pooled_out) - mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) + if self.backbone_model.layout == 'NT': + mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) + else: + mlm_features = select_vectors_by_position(F, F.np.swapaxes(contextual_embeddings, 0, 1), + masked_positions) intermediate_output = self.mlm_decoder(mlm_features) if self.backbone_model.embed_size != self.backbone_model.units: scores = self.embedding_table( @@ -884,11 +1050,18 @@ def get_pretrained_mobilebert(model_name: str = 'google_uncased_mobilebert', assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format( model_name, list_pretrained_mobilebert()) cfg_path = PRETRAINED_URL[model_name]['cfg'] + if isinstance(cfg_path, CN): + cfg = cfg_path + else: + cfg = None vocab_path = PRETRAINED_URL[model_name]['vocab'] params_path = PRETRAINED_URL[model_name]['params'] mlm_params_path = PRETRAINED_URL[model_name]['mlm_params'] local_paths = dict() - for k, path in [('cfg', cfg_path), ('vocab', vocab_path)]: + download_jobs = [('vocab', vocab_path)] + if cfg is None: + download_jobs.append(('cfg', cfg_path)) + for k, path in download_jobs: local_paths[k] = download(url=get_repo_model_zoo_url() + path, path=os.path.join(root, path), sha1_hash=FILE_STATS[path]) @@ -914,7 +1087,8 @@ def get_pretrained_mobilebert(model_name: str = 'google_uncased_mobilebert', sep_token='[SEP]', mask_token='[MASK]', lowercase=do_lower) - cfg = MobileBertModel.get_cfg().clone_merge(local_paths['cfg']) + if cfg is None: + cfg = MobileBertModel.get_cfg().clone_merge(local_paths['cfg']) return cfg, tokenizer, local_params_path, local_mlm_params_path diff --git a/src/gluonnlp/models/roberta.py b/src/gluonnlp/models/roberta.py index 8400f89fbd..b9af04dafd 100644 --- a/src/gluonnlp/models/roberta.py +++ b/src/gluonnlp/models/roberta.py @@ -42,31 +42,13 @@ from ..layers import PositionalEmbedding, get_activation from ..registry import BACKBONE_REGISTRY from ..utils.misc import download, load_checksum_stats +from ..utils.registry import Registry from .transformer import TransformerEncoderLayer from ..initializer import TruncNorm from ..utils.config import CfgNode as CN from ..attention_cell import gen_self_attn_mask -from ..utils.registry import Registry from ..data.tokenizers import HuggingFaceByteBPETokenizer -PRETRAINED_URL = { - 'fairseq_roberta_base': { - 'cfg': 'fairseq_roberta_base/model-565d1db7.yml', - 'merges': 'fairseq_roberta_base/gpt2-396d4d8e.merges', - 'vocab': 'fairseq_roberta_base/gpt2-f1335494.vocab', - 'params': 'fairseq_roberta_base/model-09a1520a.params', - 'mlm_params': 'fairseq_roberta_base/model_mlm-29889e2b.params', - 'lowercase': False, - }, - 'fairseq_roberta_large': { - 'cfg': 'fairseq_roberta_large/model-6e66dc4a.yml', - 'merges': 'fairseq_roberta_large/gpt2-396d4d8e.merges', - 'vocab': 'fairseq_roberta_large/gpt2-f1335494.vocab', - 'params': 'fairseq_roberta_large/model-6b043b91.params', - 'mlm_params': 'fairseq_roberta_large/model_mlm-119f38e1.params', - 'lowercase': False, - } -} FILE_STATS = load_checksum_stats(os.path.join(get_model_zoo_checksum_dir(), 'roberta.txt')) roberta_cfg_reg = Registry('roberta_cfg') @@ -90,6 +72,10 @@ def roberta_base(): cfg.MODEL.hidden_dropout_prob = 0.1 cfg.MODEL.attention_dropout_prob = 0.1 cfg.MODEL.dtype = 'float32' + # Layout + cfg.MODEL.layout = 'NT' + cfg.MODEL.compute_layout = 'auto' + # Initialization method cfg.INITIALIZER = CN() cfg.INITIALIZER.embed = ['truncnorm', 0, 0.02] cfg.INITIALIZER.weight = ['truncnorm', 0, 0.02] @@ -111,6 +97,97 @@ def roberta_large(): return cfg +PRETRAINED_URL = { + 'fairseq_roberta_base': { + 'cfg': roberta_base(), + 'merges': 'fairseq_roberta_base/gpt2-396d4d8e.merges', + 'vocab': 'fairseq_roberta_base/gpt2-f1335494.vocab', + 'params': 'fairseq_roberta_base/model-09a1520a.params', + 'mlm_params': 'fairseq_roberta_base/model_mlm-29889e2b.params', + 'lowercase': False, + }, + 'fairseq_roberta_large': { + 'cfg': roberta_large(), + 'merges': 'fairseq_roberta_large/gpt2-396d4d8e.merges', + 'vocab': 'fairseq_roberta_large/gpt2-f1335494.vocab', + 'params': 'fairseq_roberta_large/model-6b043b91.params', + 'mlm_params': 'fairseq_roberta_large/model_mlm-119f38e1.params', + 'lowercase': False, + } +} + + +@use_np +class RobertaEncoder(HybridBlock): + def __init__(self, + units=768, + hidden_size=3072, + num_layers=12, + num_heads=12, + attention_dropout_prob=0.1, + hidden_dropout_prob=0.1, + layer_norm_eps=1E-5, + weight_initializer=TruncNorm(stdev=0.02), + bias_initializer='zeros', + activation='gelu', + dtype='float32', + output_all_encodings=False, + output_attention=False, + layout='NT'): + super().__init__() + self.units = units + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_heads = num_heads + self.attention_dropout_prob = attention_dropout_prob + self.hidden_dropout_prob = hidden_dropout_prob + self.layer_norm_eps = layer_norm_eps + self.activation = activation + self._dtype = dtype + self._layout = layout + self._output_all_encodings = output_all_encodings + self._output_attention = output_attention + self.all_layers = nn.HybridSequential() + for layer_idx in range(self.num_layers): + self.all_layers.add( + TransformerEncoderLayer( + units=self.units, + hidden_size=self.hidden_size, + num_heads=self.num_heads, + attention_dropout_prob=self.attention_dropout_prob, + hidden_dropout_prob=self.hidden_dropout_prob, + layer_norm_eps=self.layer_norm_eps, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, + activation=self.activation, + dtype=self._dtype, + layout=layout) + ) + + @property + def layout(self): + return self._layout + + def hybrid_forward(self, F, x, valid_length): + atten_mask = gen_self_attn_mask(F, x, valid_length, + layout=self._layout, + dtype=self._dtype, attn_type='full') + all_encodings_outputs = [x] + additional_outputs = [] + for layer_idx in range(self.num_layers): + layer = self.all_layers[layer_idx] + x, attention_weights = layer(x, atten_mask) + if self._output_all_encodings: + all_encodings_outputs.append(x) + if self._output_attention: + additional_outputs.append(attention_weights) + # sequence_mask is not necessary here because masking could be performed in downstream tasks + if self._output_all_encodings: + return all_encodings_outputs, additional_outputs + else: + return x, additional_outputs + + @use_np class RobertaModel(HybridBlock): def __init__(self, @@ -133,7 +210,9 @@ def __init__(self, use_pooler=True, classifier_activation=False, encoder_normalize_before=True, - output_all_encodings=False): + output_all_encodings=False, + layout='NT', + compute_layout='auto'): """ Parameters @@ -159,7 +238,13 @@ def __init__(self, classifier_activation Whether to use classification head encoder_normalize_before + Whether to normalize before the output_all_encodings + Whether to output all encodings + layout + The layout + compute_layout + The computation layout """ super().__init__() self._dtype = dtype @@ -181,7 +266,11 @@ def __init__(self, self.encoder_normalize_before = encoder_normalize_before self.weight_initializer = weight_initializer self.bias_initializer = bias_initializer - + self._layout = layout + if compute_layout == 'auto' or compute_layout is None: + self._compute_layout = layout + else: + self._compute_layout = compute_layout self.word_embed = nn.Embedding( input_dim=self.vocab_size, output_dim=self.units, @@ -211,7 +300,8 @@ def __init__(self, bias_initializer=bias_initializer, activation=self.activation, dtype=self._dtype, - output_all_encodings=self._output_all_encodings + output_all_encodings=self._output_all_encodings, + layout=self._compute_layout, ) self.encoder.hybridize() @@ -224,20 +314,26 @@ def __init__(self, weight_initializer=weight_initializer, bias_initializer=bias_initializer) + @property + def layout(self): + return self._layout + def hybrid_forward(self, F, tokens, valid_length): - outputs = [] embedding = self.get_initial_embedding(F, tokens) - - contextual_embeddings, additional_outputs = self.encoder(embedding, valid_length) - outputs.append(contextual_embeddings) - if self._output_all_encodings: - contextual_embeddings = contextual_embeddings[-1] - + if self._layout != self._compute_layout: + contextual_embeddings, additional_outputs = self.encoder(F.np.swapaxes(embedding, 0, 1), + valid_length) + contextual_embeddings = F.np.swapaxes(contextual_embeddings, 0, 1) + else: + contextual_embeddings, additional_outputs = self.encoder(embedding, valid_length) if self.use_pooler: - pooled_out = self.apply_pooling(contextual_embeddings) - outputs.append(pooled_out) - - return tuple(outputs) if len(outputs) > 1 else outputs[0] + if isinstance(contextual_embeddings, list): + pooled_out = self.apply_pooling(contextual_embeddings[-1]) + else: + pooled_out = self.apply_pooling(contextual_embeddings) + return contextual_embeddings, pooled_out + else: + return contextual_embeddings def get_initial_embedding(self, F, inputs): """Get the initial token embeddings that considers the token type and positional embeddings @@ -246,17 +342,28 @@ def get_initial_embedding(self, F, inputs): ---------- F inputs - Shape (batch_size, seq_length) + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) Returns ------- embedding The initial embedding that will be fed into the encoder + - layout = 'NT' + Shape (batch_size, seq_length, C) + - layout = 'TN' + Shape (seq_length, batch_size, C) """ + if self._layout == 'NT': + batch_axis, time_axis = 0, 1 + else: + batch_axis, time_axis = 1, 0 embedding = self.word_embed(inputs) if self.pos_embed_type: - positional_embedding = self.pos_embed(F.npx.arange_like(inputs, axis=1)) - positional_embedding = F.np.expand_dims(positional_embedding, axis=0) + positional_embedding = self.pos_embed(F.npx.arange_like(inputs, axis=time_axis)) + positional_embedding = F.np.expand_dims(positional_embedding, axis=batch_axis) embedding = embedding + positional_embedding if self.encoder_normalize_before: embedding = self.embed_ln(embedding) @@ -270,12 +377,25 @@ def apply_pooling(self, sequence): This is used for pre-training or fine-tuning a mobile bert model. Get the first token of the whole sequence which is [CLS] - sequence: - Shape (batch_size, sequence_length, units) - return: + Parameters + ---------- + sequence + - layout = 'NT' + Shape (batch_size, sequence_length, units) + - layout = 'TN' + Shape (sequence_length, batch_size, units) + + Returns + ------- + ret Shape (batch_size, units) """ - outputs = sequence[:, 0, :] + if self._layout == 'NT': + outputs = sequence[:, 0, :] + elif self._layout == 'TN': + outputs = sequence[0, :, :] + else: + raise NotImplementedError if self.classifier_activation: return self.pooler(outputs) else: @@ -283,7 +403,7 @@ def apply_pooling(self, sequence): @staticmethod def get_cfg(key=None): - if key: + if key is not None: return roberta_cfg_reg.create(key) else: return roberta_base() @@ -292,14 +412,14 @@ def get_cfg(key=None): def from_cfg(cls, cfg, use_pooler=True, - dtype='float32', - classifier_activation=False, - encoder_normalize_before=True, + dtype=None, output_all_encodings=False) -> 'RobertaModel': cfg = RobertaModel.get_cfg().clone_merge(cfg) embed_initializer = mx.init.create(*cfg.INITIALIZER.embed) weight_initializer = mx.init.create(*cfg.INITIALIZER.weight) bias_initializer = mx.init.create(*cfg.INITIALIZER.bias) + if dtype is None: + dtype = cfg.MODEL.dtype return cls(vocab_size=cfg.MODEL.vocab_size, units=cfg.MODEL.units, hidden_size=cfg.MODEL.hidden_size, @@ -317,71 +437,9 @@ def from_cfg(cls, bias_initializer=bias_initializer, dtype=dtype, use_pooler=use_pooler, - encoder_normalize_before=encoder_normalize_before, - output_all_encodings=output_all_encodings) - - -@use_np -class RobertaEncoder(HybridBlock): - def __init__(self, - units=768, - hidden_size=3072, - num_layers=12, - num_heads=12, - attention_dropout_prob=0.1, - hidden_dropout_prob=0.1, - layer_norm_eps=1E-5, - weight_initializer=TruncNorm(stdev=0.02), - bias_initializer='zeros', - activation='gelu', - dtype='float32', - output_all_encodings=False, - output_attention=False): - super().__init__() - self.units = units - self.hidden_size = hidden_size - self.num_layers = num_layers - self.num_heads = num_heads - self.attention_dropout_prob = attention_dropout_prob - self.hidden_dropout_prob = hidden_dropout_prob - self.layer_norm_eps = layer_norm_eps - self.activation = activation - self._dtype = dtype - self._output_all_encodings = output_all_encodings - self._output_attention = output_attention - self.all_layers = nn.HybridSequential() - for layer_idx in range(self.num_layers): - self.all_layers.add( - TransformerEncoderLayer( - units=self.units, - hidden_size=self.hidden_size, - num_heads=self.num_heads, - attention_dropout_prob=self.attention_dropout_prob, - hidden_dropout_prob=self.hidden_dropout_prob, - layer_norm_eps=self.layer_norm_eps, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - activation=self.activation, - dtype=self._dtype) - ) - - def hybrid_forward(self, F, x, valid_length): - atten_mask = gen_self_attn_mask(F, x, valid_length, - dtype=self._dtype, attn_type='full') - all_encodings_outputs = [x] - additional_outputs = [] - for layer_idx in range(self.num_layers): - layer = self.all_layers[layer_idx] - x, attention_weights = layer(x, atten_mask) - if self._output_all_encodings: - all_encodings_outputs.append(x) - if self._output_attention: - additional_outputs.append(attention_weights) - # sequence_mask is not necessary here because masking could be performed in downstream tasks - if self._output_all_encodings: - return all_encodings_outputs, additional_outputs - else: - return x, additional_outputs + output_all_encodings=output_all_encodings, + layout=cfg.MODEL.layout, + compute_layout=cfg.MODEL.compute_layout) @use_np @@ -432,19 +490,25 @@ def hybrid_forward(self, F, inputs, valid_length, masked_positions): Parameters ---------- F - inputs : - Shape (batch_size, seq_length) - valid_length : + inputs + - layout = 'NT' + Shape (batch_size, seq_length) + - layout = 'TN' + Shape (seq_length, batch_size) + valid_length The valid length of each sequence Shape (batch_size,) - masked_positions : + masked_positions The masked position of the sequence Shape (batch_size, num_masked_positions). Returns ------- contextual_embedding - Shape (batch_size, seq_length, units). + - layout = 'NT' + Shape (batch_size, seq_length, units). + - layout = 'TN' + Shape (seq_length, batch_size, units). pooled_out Shape (batch_size, units) mlm_scores : @@ -456,6 +520,8 @@ def hybrid_forward(self, F, inputs, valid_length, masked_positions): contextual_embeddings = all_encodings_outputs[-1] else: contextual_embeddings = all_encodings_outputs + if self.backbone_model.layout == 'TN': + contextual_embeddings = F.np.swapaxes(contextual_embeddings, 0, 1) mlm_features = select_vectors_by_position(F, contextual_embeddings, masked_positions) mlm_scores = self.mlm_decoder(mlm_features) return all_encodings_outputs, pooled_out, mlm_scores @@ -469,7 +535,7 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base', root: str = get_model_zoo_home_dir(), load_backbone: bool = True, load_mlm: bool = False) \ - -> Tuple[CN, HuggingFaceByteBPETokenizer, str]: + -> Tuple[CN, HuggingFaceByteBPETokenizer, str, str]: """Get the pretrained RoBERTa weights Parameters @@ -497,14 +563,20 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base', assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format( model_name, list_pretrained_roberta()) cfg_path = PRETRAINED_URL[model_name]['cfg'] + if isinstance(cfg_path, CN): + cfg = cfg_path + else: + cfg = None merges_path = PRETRAINED_URL[model_name]['merges'] vocab_path = PRETRAINED_URL[model_name]['vocab'] params_path = PRETRAINED_URL[model_name]['params'] mlm_params_path = PRETRAINED_URL[model_name]['mlm_params'] local_paths = dict() - for k, path in [('cfg', cfg_path), ('vocab', vocab_path), - ('merges', merges_path)]: + download_jobs = [('vocab', vocab_path), ('merges', merges_path)] + if cfg is None: + download_jobs.append(('cfg', cfg_path)) + for k, path in download_jobs: local_paths[k] = download(url=get_repo_model_zoo_url() + path, path=os.path.join(root, path), sha1_hash=FILE_STATS[path]) @@ -526,7 +598,8 @@ def get_pretrained_roberta(model_name: str = 'fairseq_roberta_base', merges_file=local_paths['merges'], vocab_file=local_paths['vocab'], lowercase=do_lower) - cfg = RobertaModel.get_cfg().clone_merge(local_paths['cfg']) + if cfg is None: + cfg = RobertaModel.get_cfg().clone_merge(local_paths['cfg']) return cfg, tokenizer, local_params_path, local_mlm_params_path diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index 1d0f7c2eb1..da18447f07 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -1,3 +1,5 @@ +from abc import ABC + import numpy as np import mxnet as mx from mxnet import use_np @@ -31,6 +33,7 @@ def transformer_nmt_base(): cfg.MODEL.attention_dropout = 0.0 cfg.MODEL.activation_dropout = 0.0 cfg.MODEL.dropout = 0.1 + cfg.MODEL.layout = 'NT' cfg.MODEL.dtype = 'float32' # Parameters for the encoder @@ -53,10 +56,6 @@ def transformer_nmt_base(): cfg.MODEL.DECODER.activation = 'relu' cfg.MODEL.DECODER.pre_norm = False - # Parameters for mixture of models - cfg.MODEL.method = 'hMoElp' - cfg.MODEL.num_experts = 3 - # Parameters for the initializer cfg.INITIALIZER = CN() cfg.INITIALIZER.embed = ['xavier', 'gaussian', 'in', 1.0] @@ -141,7 +140,8 @@ def __init__(self, weight_initializer: Optional[InitializerType] = None, bias_initializer: Optional[InitializerType] = 'zeros', activation: str = 'relu', - dtype='float32'): + dtype='float32', + layout='NT'): """ Parameters @@ -165,6 +165,7 @@ def __init__(self, bias_initializer activation dtype + layout """ super().__init__() self._units = units @@ -175,6 +176,9 @@ def __init__(self, self._activation_dropout_prob = activation_dropout_prob self._pre_norm = pre_norm self._dtype = dtype + self._layout = layout + assert layout in ['TN', 'NT'], 'Invalid layout received = {}. ' \ + 'Only "TN" and "NT" are accepted!'.format(layout) assert self._units % self._num_heads == 0, 'units must be divisive by the number of heads' self.dropout_layer = nn.Dropout(hidden_dropout_prob) self.attn_qkv = nn.Dense(3 * units, @@ -191,6 +195,7 @@ def __init__(self, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=self._dtype) + attention_layout = 'NTK' if self._layout == 'NT' else 'TNK' self.attention_cell =\ MultiHeadAttentionCell( query_units=self._units, @@ -198,7 +203,7 @@ def __init__(self, attention_dropout=self._attention_dropout_prob, scaled=True, dtype=self._dtype, - layout='NTK' + layout=attention_layout ) self.layer_norm = nn.LayerNorm(epsilon=layer_norm_eps, in_channels=units) @@ -213,6 +218,10 @@ def __init__(self, pre_norm=pre_norm, dtype=self._dtype) + @property + def layout(self) -> str: + return self._layout + def hybrid_forward(self, F, data, attn_mask): """ @@ -220,19 +229,23 @@ def hybrid_forward(self, F, data, attn_mask): ---------- F data : - Shape (batch_size, seq_length, C_in) + If layout == 'NT' + Shape (batch_size, seq_length, C_in) + Else + Shape (seq_length, batch_size, C_in) attn_mask : Shape (batch_size, seq_length, seq_length) Returns ------- out : - Shape (batch_size, seq_length, C_out) + If layout == 'NT' + Shape (batch_size, seq_length, C_out) + Else + Shape (seq_length, batch_size, C_out) attn_weight : Shape (batch_size, seq_length, seq_length) """ - # TODO(sxjscience) Cannot use negative axis due to - # https://github.com/apache/incubator-mxnet/issues/18132 if self._pre_norm: data = self.layer_norm(data) query, key, value = F.np.split(self.attn_qkv(data), 3, axis=-1) @@ -256,7 +269,7 @@ def __init__(self, num_layers=6, recurrent=False, activation_dropout=0.0, dropout=0.1, attention_dropout=0.1, layer_norm_eps=1E-5, data_norm=False, pre_norm=False, weight_initializer=None, bias_initializer='zeros', - activation='relu', dtype='float32'): + activation='relu', dtype='float32', layout='NT'): """ Parameters @@ -277,6 +290,8 @@ def __init__(self, num_layers=6, recurrent=False, weight_initializer bias_initializer activation + dtype + layout """ super().__init__() self._dtype = dtype @@ -284,6 +299,9 @@ def __init__(self, num_layers=6, recurrent=False, self._recurrent = recurrent self._data_norm = data_norm self._pre_norm = pre_norm + self._layout = layout + assert layout in ['TN', 'NT'], 'Invalid layout received = {}. ' \ + 'Only "TN" and "NT" are accepted!'.format(layout) self.dropout_layer = nn.Dropout(dropout) if self._pre_norm: self.ln_final = nn.LayerNorm(epsilon=layer_norm_eps, @@ -307,8 +325,13 @@ def __init__(self, num_layers=6, recurrent=False, bias_initializer=bias_initializer, pre_norm=pre_norm, activation=activation, + layout=self._layout, dtype=dtype)) + @property + def layout(self) -> str: + return self._layout + def hybrid_forward(self, F, data, valid_length): """ @@ -316,18 +339,26 @@ def hybrid_forward(self, F, data, valid_length): ---------- F data : - Shape (batch_size, seq_length, C) + - layout = 'NT' + Shape (batch_size, seq_length, C) + - layout = 'TN' + Shape (seq_length, batch_size, C) valid_length : Shape (batch_size,) Returns ------- out : - Shape (batch_size, seq_length, C_out) + - layout = 'NT' + Shape (batch_size, seq_length, C_out) + - layout = 'TN' + Shape (seq_length, batch_size, C_out) """ # 1. Embed the data attn_mask = gen_self_attn_mask(F, data, valid_length, - dtype=self._dtype, attn_type='full') + dtype=self._dtype, + layout=self.layout, + attn_type='full') out = self.dropout_layer(data) if self._data_norm: out = self.ln_data(out) @@ -356,7 +387,8 @@ def __init__(self, units: int = 512, pre_norm: bool = False, weight_initializer=None, bias_initializer='zeros', - dtype='float32'): + dtype='float32', + layout='NT'): """ Parameters @@ -377,6 +409,9 @@ def __init__(self, units: int = 512, weight_initializer bias_initializer dtype + Data type + layout + Layout of the input """ super().__init__() self._dtype = dtype @@ -388,6 +423,10 @@ def __init__(self, units: int = 512, self._num_heads = num_heads self._attention_dropout = attention_dropout self._dtype = dtype + self._layout = layout + assert layout in ['TN', 'NT'], 'Invalid layout received = {}. ' \ + 'Only "TN" and "NT" are accepted!'.format(layout) + attention_layout = 'NTK' if layout == 'NT' else 'TNK' self.dropout_layer = nn.Dropout(dropout) if units % num_heads: raise ValueError('In Transformer, units should be divided exactly by the number of ' @@ -402,7 +441,7 @@ def __init__(self, units: int = 512, num_heads=num_heads, attention_dropout=self._attention_dropout, dtype=dtype, - layout='NTK') + layout=attention_layout) self.proj_in = nn.Dense(units=units, in_units=units, flatten=False, use_bias=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer, @@ -430,7 +469,7 @@ def __init__(self, units: int = 512, num_heads=num_heads, attention_dropout=self._attention_dropout, dtype=dtype, - layout='NTK') + layout=attention_layout) self.proj_inter = nn.Dense(units=units, in_units=units, flatten=False, use_bias=False, weight_initializer=weight_initializer, @@ -449,6 +488,10 @@ def __init__(self, units: int = 512, pre_norm=pre_norm, dtype=dtype) + @property + def layout(self) -> str: + return self._layout + def hybrid_forward(self, F, data, mem, self_causal_mask, mem_attn_mask): """ @@ -456,9 +499,15 @@ def hybrid_forward(self, F, data, mem, self_causal_mask, mem_attn_mask): ---------- F data : - Shape (batch_size, seq_length, C_in) + - layout = 'NT' + Shape (batch_size, seq_length, C_in) + - layout = 'TN' + Shape (seq_length, batch_size, C_in) mem : - Shape (batch_size, mem_length, C_mem) + - layout = 'NT' + Shape (batch_size, mem_length, C_mem) + - layout = 'TN' + Shape (mem_length, batch_size, C_mem) self_causal_mask : Shape (batch_size, seq_length, seq_length) Mask for the causal self-attention. @@ -485,11 +534,11 @@ def hybrid_forward(self, F, data, mem, self_causal_mask, mem_attn_mask): Returns ------- out : - Shape (batch_size, seq_length, C_out) + - layout = 'NT' + Shape (batch_size, seq_length, C_out) + - layout = 'TN' + Shape (seq_length, batch_size, C_out) """ - # TODO(szhengac) - # Try the architecture in the "[ECCV2016] Identity Mappings in Deep Residual Networks". - # Shuai proposed to switch the order of the activation layer. # 1. Get the causal self-attention value if self._pre_norm: data = self.ln_in(data) @@ -525,22 +574,37 @@ def hybrid_forward(self, F, data, mem, self_causal_mask, mem_attn_mask): @property def state_batch_axis(self): - return 0, 0 + if self.layout == 'NT': + return 0, 0 + else: + return 1, 1 def init_states(self, batch_size, ctx, dtype='float32'): """Initialize the states required for incremental decoding Returns ------- - init_key : - Shape (batch_size, 0, N, C_key) + init_key + - layout = 'NT' + Shape (batch_size, 0, N, C_key) + - layout = 'TN' + Shape (0, batch_size, N, C_key) init_value : - Shape (batch_size, 0, N, C_value) + - layout = 'NT' + Shape (batch_size, 0, N, C_value) + - layout = 'TN' + Shape (0, batch_size, N, C_value) """ - init_key = mx.np.zeros(shape=(batch_size, 0, self._num_heads, - self._units // self._num_heads), ctx=ctx, dtype=dtype) - init_value = mx.np.zeros(shape=(batch_size, 0, self._num_heads, - self._units // self._num_heads), ctx=ctx, dtype=dtype) + if self.layout == 'NT': + init_key = mx.np.zeros(shape=(batch_size, 0, self._num_heads, + self._units // self._num_heads), ctx=ctx, dtype=dtype) + init_value = mx.np.zeros(shape=(batch_size, 0, self._num_heads, + self._units // self._num_heads), ctx=ctx, dtype=dtype) + else: + init_key = mx.np.zeros(shape=(0, batch_size, self._num_heads, + self._units // self._num_heads), ctx=ctx, dtype=dtype) + init_value = mx.np.zeros(shape=(0, batch_size, self._num_heads, + self._units // self._num_heads), ctx=ctx, dtype=dtype) return init_key, init_value def incremental_decode(self, F, data, states, mem, mem_valid_length, mem_attn_mask=None): @@ -550,16 +614,25 @@ def incremental_decode(self, F, data, states, mem, mem_valid_length, mem_attn_ma ---------- F data - Shape (batch_size, 1, C_in) + Shape (batch_size, C_in) states The previous states, contains - - prev_multi_key - Shape (batch_size, prev_seq_length, num_heads, C_key) - - prev_multi_value - Shape (batch_size, prev_seq_length, num_heads, C_value) + 1. layout = 'NT': + - prev_multi_key + Shape (batch_size, prev_seq_length, num_heads, C_key) + - prev_multi_value + Shape (batch_size, prev_seq_length, num_heads, C_value) + 2. layout = 'TN' + - prev_multi_key + Shape (prev_seq_length, batch_size, num_heads, C_key) + - prev_multi_value + Shape (prev_seq_length, batch_size, num_heads, C_value) mem The memory - Shape (batch_size, mem_length, C_mem) + 1. layout = 'NT': + Shape (batch_size, mem_length, C_mem) + 2. layout = 'TN' + Shape (mem_length, batch_size, C_mem) mem_valid_length Valid length of the memory Shape (batch_size,) @@ -570,7 +643,7 @@ def incremental_decode(self, F, data, states, mem, mem_valid_length, mem_attn_ma Returns ------- out - Shape (batch_size, 1, C_out) + Shape (batch_size, C_out) updated_states - new_key Shape (batch_size, prev_seq_length + 1, num_heads, C_key) @@ -579,19 +652,28 @@ def incremental_decode(self, F, data, states, mem, mem_valid_length, mem_attn_ma """ if self._pre_norm: data = self.ln_in(data) - prev_key, prev_value = states # Shape (B, prev_L, #Head, C_K), (B, prev_L, #Head, C_V) + if self.layout == 'NT': + time_axis = 1 + else: + time_axis = 0 + data = F.np.expand_dims(data, axis=time_axis) + # Shape (B, prev_L, #Head, C_K), (B, prev_L, #Head, C_V) + # or (prev_L, B, #Head, C_K), (prev_L, B, #Head, C_V) + prev_key, prev_value = states if mem_attn_mask is None: mem_attn_mask = gen_mem_attn_mask(F, mem, mem_valid_length, data, None, - dtype=self._dtype) + dtype=self._dtype, layout=self.layout) # 1. Get the causal self-attention value, we need to attend to both the current data # and the previous stored key/values - step_qkv = self.attn_in_qkv(data) # Shape (B, 1, 3 * num_heads * C_key) + # Shape (B, 1, 3 * num_heads * C_key) + # or (1, B, 3 * num_heads * C_key) + step_qkv = self.attn_in_qkv(data) step_query, step_key, step_value = F.np.split(step_qkv, 3, axis=-1) step_query = F.npx.reshape(step_query, (-2, -2, self._num_heads, -1)) step_key = F.npx.reshape(step_key, (-2, -2, self._num_heads, -1)) step_value = F.npx.reshape(step_value, (-2, -2, self._num_heads, -1)) - new_key = F.np.concatenate([prev_key, step_key], axis=1) - new_value = F.np.concatenate([prev_value, step_value], axis=1) + new_key = F.np.concatenate([prev_key, step_key], axis=time_axis) + new_value = F.np.concatenate([prev_value, step_value], axis=time_axis) out, _ = self.self_attention(step_query, new_key, new_value, None) out = self.proj_in(out) out = self.dropout_layer(out) @@ -616,6 +698,7 @@ def incremental_decode(self, F, data, states, mem, mem_valid_length, mem_attn_ma out = self.ln_inter(out) # 3. Encode the output via an FFN layer out = self.ffn(out) + out = F.npx.reshape(out, (-5, -1)) return out, (new_key, new_value) @@ -626,7 +709,8 @@ def __init__(self, num_layers=6, recurrent=False, num_heads=8, max_shift=None, rel_pos_embed=False, activation_dropout=0.0, dropout=0.1, attention_dropout=0.1, layer_norm_eps=1E-5, data_norm=False, pre_norm=False, weight_initializer=None, bias_initializer=None, - activation='relu', dtype='float32'): + activation='relu', dtype='float32', + layout='NT'): super().__init__() self._dtype = dtype self._units = units @@ -637,6 +721,9 @@ def __init__(self, num_layers=6, recurrent=False, self.rel_pos_embed = rel_pos_embed self._data_norm = data_norm self._pre_norm = pre_norm + self._layout = layout + assert layout in ['TN', 'NT'], 'Invalid layout received = {}. ' \ + 'Only "TN" and "NT" are accepted!'.format(layout) self.dropout_layer = nn.Dropout(dropout) if self._data_norm: self.ln_data = nn.LayerNorm(epsilon=layer_norm_eps, @@ -660,35 +747,53 @@ def __init__(self, num_layers=6, recurrent=False, bias_initializer=bias_initializer, activation=activation, pre_norm=pre_norm, + layout=layout, dtype=dtype)) + @property + def layout(self) -> str: + return self._layout + def hybrid_forward(self, F, data, valid_length, mem_data, mem_valid_length): """ Parameters ---------- F - data : - Shape (batch_size, seq_length, C_in) - valid_length : + data + - layout = 'NT' + Shape (batch_size, seq_length, C_in) + - layout = 'TN' + Shape (seq_length, batch_size, C_in) + valid_length Shape (batch_size,) - mem_data : - Shape (batch_size, mem_length, C_mem) - mem_valid_length : + mem_data + - layout = 'NT' + Shape (batch_size, mem_length, C_mem) + - layout = 'TN' + Shape (mem_length, batch_size, C_mem) + mem_valid_length Shape (batch_size,) + Returns ------- - out : - Shape (batch_size, seq_length, C_out) + out + - layout = 'NT' + Shape (batch_size, seq_length, C_out) + - layout = 'TN' + Shape (seq_length, batch_size, C_out) """ # 1. Embed the data out = self.dropout_layer(data) if self._data_norm: out = self.ln_data(out) self_causal_mask = gen_self_attn_mask(F, data, valid_length, - dtype=self._dtype, attn_type='causal') + dtype=self._dtype, + attn_type='causal', + layout=self._layout) mem_attn_mask = gen_mem_attn_mask(F, mem_data, mem_valid_length, data, valid_length, - dtype=self._dtype) + dtype=self._dtype, + layout=self._layout) for i in range(self.num_layers): if self.recurrent: layer = self.layers[0] @@ -710,15 +815,19 @@ def state_batch_axis(self): ret.append(layer.state_batch_axis) return ret - def init_states(self, batch_size, ctx, dtype): + def init_states(self, batch_size, ctx, dtype='float32'): """Initialize the states required for incremental decoding Returns ------- - init_key : - Shape (batch_size, 0, N, C_key) - init_value : - Shape (batch_size, 0, N, C_value) + states + A list of states, each includes: + - init_key : + layout = 'NT': + Shape (batch_size, 0, N, C_key) + - init_value : + layout = 'TN': + Shape (0, batch_size, N, C_value) """ states = [] for i in range(self.num_layers): @@ -738,16 +847,25 @@ def incremental_decode(self, F, data, states, mem, mem_valid_length): ---------- F data - Shape (batch_size, 1, C_in) + Shape (batch_size, C_in) states The previous states, contain a list of - - prev_multi_key - Shape (batch_size, prev_seq_length, num_heads, C_key) - - prev_multi_value - Shape (batch_size, prev_seq_length, num_heads, C_value) + 1. layout = 'NT' + - prev_multi_key + Shape (batch_size, prev_seq_length, num_heads, C_key) + - prev_multi_value + Shape (batch_size, prev_seq_length, num_heads, C_value) + 2. layout = 'TN' + - prev_multi_key + Shape (prev_seq_length, batch_size, num_heads, C_key) + - prev_multi_value + Shape (prev_seq_length, batch_size, num_heads, C_value) mem The memory - Shape (batch_size, mem_length, C_mem) + 1. layout = 'NT' + Shape (batch_size, mem_length, C_mem) + 2. layout = 'TN' + Shape (mem_length, batch_size, C_mem) mem_valid_length Valid length of the memory Shape (batch_size,) @@ -755,20 +873,27 @@ def incremental_decode(self, F, data, states, mem, mem_valid_length): Returns ------- out - Shape (batch_size, 1, C_out) + Shape (batch_size, C_out) new_states The updated states, contain a list of - - new_key - Shape (batch_size, prev_seq_length + 1, num_heads, C_key) - - new_value - Shape (batch_size, prev_seq_length + 1, num_heads, C_value) + 1. layout = 'NT' + - new_key + Shape (batch_size, prev_seq_length + 1, num_heads, C_key) + 2. layout = 'TN' + - new_value + Shape (prev_seq_length + 1, batch_size, num_heads, C_value) """ # 1. Embed the data out = self.dropout_layer(data) if self._data_norm: out = self.ln_data(out) - mem_attn_mask = gen_mem_attn_mask(F, mem, mem_valid_length, data, None, - dtype=self._dtype) + time_axis = 0 if self.layout == 'TN' else 1 + # Generate the mem_attn_mask + time_steps = F.npx.arange_like(mem, axis=time_axis) # (mem_length,) + mem_attn_mask = F.np.reshape(time_steps, (1, 1, -1))\ + < F.np.reshape(mem_valid_length, (-1, 1, 1)) + # TODO(sxjscience) Try with boolean masking + mem_attn_mask = mem_attn_mask.astype(self._dtype) new_states = [] for i in range(self.num_layers): if self.recurrent: @@ -815,7 +940,8 @@ def __init__(self, src_vocab_size: int, embed_initializer=mx.init.Xavier('gaussian', 'in', 1), weight_initializer=mx.init.Xavier('uniform', 'avg', 3), bias_initializer='zeros', - dtype='float32'): + dtype='float32', + layout='NT'): """ Parameters @@ -884,6 +1010,8 @@ def __init__(self, src_vocab_size: int, Initializer of the bias dtype Data type of the weights + layout + The layout of the input + target """ super().__init__() assert src_vocab_size > 0 and tgt_vocab_size > 0,\ @@ -900,6 +1028,9 @@ def __init__(self, src_vocab_size: int, self.scaled_embed = scale_embed self.enc_units = enc_units self.dec_units = dec_units + self._layout = layout + assert layout in ['TN', 'NT'], 'Invalid layout received = {}. ' \ + 'Only "TN" and "NT" are accepted!'.format(layout) if max_src_length is not None and max_src_length < 0: max_src_length = None if max_tgt_length is not None and max_tgt_length < 0: @@ -941,7 +1072,8 @@ def __init__(self, src_vocab_size: int, activation=enc_activation, data_norm=data_norm, pre_norm=enc_pre_norm, - dtype=self._dtype) + dtype=self._dtype, + layout=layout) self.decoder = TransformerDecoder(num_layers=dec_num_layers, recurrent=dec_recurrent, units=dec_units, @@ -957,7 +1089,8 @@ def __init__(self, src_vocab_size: int, activation=dec_activation, data_norm=data_norm, pre_norm=dec_pre_norm, - dtype=self._dtype) + dtype=self._dtype, + layout=layout) if tie_weights: self.tgt_final_layer =\ nn.Dense(tgt_vocab_size, flatten=False, @@ -976,6 +1109,10 @@ def __init__(self, src_vocab_size: int, self.encoder.hybridize() self.decoder.hybridize() + @property + def layout(self) -> str: + return self._layout + @property def src_vocab_size(self): return self._src_vocab_size @@ -992,21 +1129,31 @@ def encode(self, F, src_data, src_valid_length): Parameters ---------- F - src_data : - Shape (batch_size, src_length) - src_valid_length : + src_data + - layout = 'NT' + Shape (batch_size, src_length) + - layout = 'TN' + Shape (src_length, batch_size) + src_valid_length Shape (batch_size,) Returns ------- - enc_out : - Shape (batch_size, src_length, C_out) + enc_out + - layout = 'NT' + Shape (batch_size, src_length, C_out) + - layout = 'TN' + Shape (src_length, batch_size, C_out) """ src_data = self.src_embed_layer(src_data) if self.scaled_embed: src_data = src_data * np.sqrt(self.enc_units) if self.pos_embed_type is not None: - src_data = src_data + self.src_pos_embed_layer(F.npx.arange_like(src_data, axis=1)) + if self.layout == 'NT': + src_data = src_data + self.src_pos_embed_layer(F.npx.arange_like(src_data, axis=1)) + else: + src_data = src_data + F.np.expand_dims(self.src_pos_embed_layer( + F.npx.arange_like(src_data, axis=0)), axis=1) enc_out = self.encoder(src_data, src_valid_length) return enc_out @@ -1016,26 +1163,39 @@ def decode_seq(self, F, tgt_data, tgt_valid_length, mem_data, mem_valid_length): Parameters ---------- F - tgt_data : - Shape (batch_size, tgt_length) - tgt_valid_length : + tgt_data + - layout = 'NT' + Shape (batch_size, tgt_length) + - layout = 'TN' + Shape (tgt_length, batch_size) + tgt_valid_length Shape (batch_size,) - mem_data : - Shape (batch_size, src_length, C_out) + mem_data + - layout = 'NT' + Shape (batch_size, src_length, C_out) + - layout = 'TN' + Shape (src_length, batch_size, C_out) mem_valid_length : Shape (batch_size,) Returns ------- - dec_out : - Shape (batch_size, tgt_length, tgt_vocab_size) + dec_out + - layout = 'NT' + Shape (batch_size, tgt_length, tgt_vocab_size) + - layout = 'TN' + Shape (tgt_length, batch_size, tgt_vocab_size) """ tgt_data = self.tgt_embed_layer(tgt_data) if self.scaled_embed: tgt_data = tgt_data * np.sqrt(self.dec_units) if self.pos_embed_type is not None: - tgt_data = tgt_data + self.tgt_pos_embed_layer( - F.npx.arange_like(tgt_data, axis=1)) + if self.layout == 'NT': + tgt_data = tgt_data + self.tgt_pos_embed_layer( + F.npx.arange_like(tgt_data, axis=1)) + else: + tgt_data = tgt_data + F.np.expand_dims(self.tgt_pos_embed_layer( + F.npx.arange_like(tgt_data, axis=0)), axis=1) dec_out = self.decoder(tgt_data, tgt_valid_length, mem_data, mem_valid_length) dec_out = self.tgt_final_layer(dec_out) return dec_out @@ -1046,19 +1206,28 @@ def hybrid_forward(self, F, src_data, src_valid_length, tgt_data, tgt_valid_leng Parameters ---------- F - src_data : - Shape (batch_size, src_length) - src_valid_length : + src_data + - layout = 'NT' + Shape (batch_size, src_length) + - layout = 'TN' + Shape (src_length, batch_size) + src_valid_length Shape (batch_size,) - tgt_data : - Shape (batch_size, tgt_length) - tgt_valid_length : + tgt_data + - layout = 'NT' + Shape (batch_size, tgt_length) + - layout = 'TN' + Shape (tgt_length, batch_size) + tgt_valid_length Shape (batch_size,) Returns ------- - out : - Shape (batch_size, tgt_length, tgt_vocab_size) + out + - layout = 'NT' + Shape (batch_size, tgt_length, tgt_vocab_size) + - layout = 'TN' + Shape (tgt_length, batch_size, tgt_vocab_size) """ enc_out = self.encode(F, src_data, src_valid_length) dec_out = self.decode_seq(F, tgt_data, tgt_valid_length, enc_out, src_valid_length) @@ -1073,11 +1242,13 @@ def get_cfg(cls, key=None): return transformer_nmt_cfg_reg.create(key) @classmethod - def from_cfg(cls, cfg): + def from_cfg(cls, cfg, dtype=None): cfg = cls.get_cfg().clone_merge(cfg) embed_initializer = mx.init.create(*cfg.INITIALIZER.embed) weight_initializer = mx.init.create(*cfg.INITIALIZER.weight) bias_initializer = mx.init.create(*cfg.INITIALIZER.bias) + if dtype is None: + dtype = cfg.MODEL.dtype return cls(src_vocab_size=cfg.MODEL.src_vocab_size, tgt_vocab_size=cfg.MODEL.tgt_vocab_size, max_src_length=cfg.MODEL.max_src_length, @@ -1103,10 +1274,11 @@ def from_cfg(cls, cfg): dec_recurrent=cfg.MODEL.DECODER.recurrent, dec_activation=cfg.MODEL.DECODER.activation, dec_pre_norm=cfg.MODEL.DECODER.pre_norm, + layout=cfg.MODEL.layout, embed_initializer=embed_initializer, weight_initializer=weight_initializer, bias_initializer=bias_initializer, - dtype=cfg.MODEL.dtype) + dtype=dtype) @use_np @@ -1140,33 +1312,45 @@ def state_batch_axis(self) -> Tuple[int, int, int, List]: position_batch_axis : int dec_layer_batch_axis : list """ - return 0, 0, 0, self.model.decoder.state_batch_axis + if self.model.layout == 'NT': + return 0, 0, 0, self.model.decoder.state_batch_axis + else: + return 1, 0, 0, self.model.decoder.state_batch_axis def init_states(self, src_data, src_valid_length): # TODO(sxjscience) Revisit here, support auxiliary states? """Initialize the states required for sequence sampling Parameters ---------- - src_data : - Shape (batch_size, src_length) - src_valid_length : + src_data + - layout = 'NT' + Shape (batch_size, src_length) + - layout = 'TN' + Shape (src_length, batch_size) + src_valid_length Shape (batch_size,) Returns ------- - enc_out : - Shape (batch_size, src_length, C_mem) - src_valid_length : + enc_out + - layout = 'NT' + Shape (batch_size, src_length, C_mem) + - layout = 'TN' + Shape (src_length, batch_size, C_mem) + src_valid_length Shape (batch_size,) - position : + position Shape (batch_size,) dec_states: list The states of the decoder """ - batch_size = src_data.shape[0] + if self.model.layout == 'NT': + batch_size = src_data.shape[0] + else: + batch_size = src_data.shape[1] ctx = src_data.ctx - enc_out = self.model.encode(mx.nd, src_data, src_valid_length) - position = mx.np.zeros((batch_size, 1), dtype=np.int32, ctx=ctx) + enc_out = self.model.encode(mx, src_data, src_valid_length) + position = mx.np.zeros((batch_size,), dtype=np.int32, ctx=ctx) dtype = enc_out.dtype dec_states = self.model.decoder.init_states(batch_size, ctx, dtype) return enc_out, src_valid_length, position, dec_states @@ -1176,24 +1360,29 @@ def hybrid_forward(self, F, step_data, states): Parameters ---------- - step_data : + step_data Shape (batch_size,) - states : tuple + states It includes : - mem_data : (batch_size, src_length, C_mem) - mem_valid_length : (batch_size,) - position : (batch_size,) - dec_states : list + - layout = 'NT' + mem_data : (batch_size, src_length, C_mem) + mem_valid_length : (batch_size,) + position : (batch_size,) + dec_states : list + - layout = 'TN' + mem_data : (src_length, batch_size, C_mem) + mem_valid_length : (batch_size,) + position : (batch_size,) + dec_states : list Returns ------- - out : + out Shape (batch_size, C) - new_states : tuple + new_states Has the same structure as the states """ mem_data, mem_valid_length, position, dec_states = states # 1. Get the embedding - step_data = F.np.expand_dims(step_data, axis=1) step_data = self.model.tgt_embed_layer(step_data) if self.model.scaled_embed: step_data = step_data * np.sqrt(self.model.dec_units) @@ -1203,5 +1392,4 @@ def hybrid_forward(self, F, step_data, states): self.model.decoder.incremental_decode(F, step_data, dec_states, mem_data, mem_valid_length) out = self.model.tgt_final_layer(out) - out = F.npx.reshape(out, (-2, -1)) return out, (mem_data, mem_valid_length, position + 1, new_states) diff --git a/src/gluonnlp/models/transformer_xl.py b/src/gluonnlp/models/transformer_xl.py index a232ec8c37..b6ff44c5df 100644 --- a/src/gluonnlp/models/transformer_xl.py +++ b/src/gluonnlp/models/transformer_xl.py @@ -81,6 +81,10 @@ def __init__(self, units: int = 512, pre_norm=pre_norm, dtype=dtype) + @property + def layout(self): + return self._layout + def hybrid_forward(self, F, data, mem, rel_positions, mask, query_r_bias, query_k_bias): """ @@ -118,7 +122,10 @@ def hybrid_forward(self, F, data, mem, rel_positions, mask, query_r_bias, query_ Returns ------- out - Shape (batch_size, query_length, units) + - layout = 'NT' + Shape (batch_size, query_length, units) + - layout = 'TN' + Shape (query_length, batch_size, units) """ if self._layout == 'NT': context = F.np.concatenate([mem, data], axis=1) diff --git a/src/gluonnlp/models/xlmr.py b/src/gluonnlp/models/xlmr.py index b433d34157..66a3784557 100644 --- a/src/gluonnlp/models/xlmr.py +++ b/src/gluonnlp/models/xlmr.py @@ -39,23 +39,6 @@ from ..data.tokenizers import SentencepieceTokenizer -PRETRAINED_URL = { - 'fairseq_xlmr_base': { - 'cfg': 'fairseq_xlmr_base/model-b893d178.yml', - 'sentencepiece.model': 'fairseq_xlmr_base/sentencepiece-18e17bae.model', - 'params': 'fairseq_xlmr_base/model-3fa134e9.params', - 'mlm_params': 'fairseq_xlmr_base/model_mlm-86e37954.params', - 'lowercase': False, - }, - 'fairseq_xlmr_large': { - 'cfg': 'fairseq_xlmr_large/model-01fc59fb.yml', - 'sentencepiece.model': 'fairseq_xlmr_large/sentencepiece-18e17bae.model', - 'params': 'fairseq_xlmr_large/model-b62b074c.params', - 'mlm_params': 'fairseq_xlmr_large/model_mlm-887506c2.params', - 'lowercase': False, - } -} - FILE_STATS = load_checksum_stats(os.path.join(get_model_zoo_checksum_dir(), 'xlmr.txt')) xlmr_cfg_reg = Registry('xlmr_cfg') @@ -86,10 +69,31 @@ def get_cfg(key=None): return xlmr_cfg_reg.create(key) else: return xlmr_base() + + +PRETRAINED_URL = { + 'fairseq_xlmr_base': { + 'cfg': xlmr_base(), + 'sentencepiece.model': 'fairseq_xlmr_base/sentencepiece-18e17bae.model', + 'params': 'fairseq_xlmr_base/model-3fa134e9.params', + 'mlm_params': 'fairseq_xlmr_base/model_mlm-86e37954.params', + 'lowercase': False, + }, + 'fairseq_xlmr_large': { + 'cfg': xlmr_large(), + 'sentencepiece.model': 'fairseq_xlmr_large/sentencepiece-18e17bae.model', + 'params': 'fairseq_xlmr_large/model-b62b074c.params', + 'mlm_params': 'fairseq_xlmr_large/model_mlm-887506c2.params', + 'lowercase': False, + } +} + + @use_np class XLMRForMLM(RobertaForMLM): pass + def list_pretrained_xlmr(): return sorted(list(PRETRAINED_URL.keys())) @@ -98,7 +102,7 @@ def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base', root: str = get_model_zoo_home_dir(), load_backbone: bool = True, load_mlm: bool = False) \ - -> Tuple[CN, SentencepieceTokenizer, str]: + -> Tuple[CN, SentencepieceTokenizer, str, str]: """Get the pretrained XLM-R weights Parameters @@ -126,11 +130,18 @@ def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base', assert model_name in PRETRAINED_URL, '{} is not found. All available are {}'.format( model_name, list_pretrained_xlmr()) cfg_path = PRETRAINED_URL[model_name]['cfg'] + if isinstance(cfg_path, CN): + cfg = cfg_path + else: + cfg = None sp_model_path = PRETRAINED_URL[model_name]['sentencepiece.model'] params_path = PRETRAINED_URL[model_name]['params'] mlm_params_path = PRETRAINED_URL[model_name]['mlm_params'] local_paths = dict() - for k, path in [('cfg', cfg_path), ('sentencepiece.model', sp_model_path)]: + download_jobs = [('sentencepiece.model', sp_model_path)] + if cfg is None: + download_jobs.append(('cfg', cfg_path)) + for k, path in download_jobs: local_paths[k] = download(url=get_repo_model_zoo_url() + path, path=os.path.join(root, path), sha1_hash=FILE_STATS[path]) @@ -152,7 +163,8 @@ def get_pretrained_xlmr(model_name: str = 'fairseq_xlmr_base', tokenizer = SentencepieceTokenizer( model_path=local_paths['sentencepiece.model'], lowercase=do_lower) - cfg = XLMRModel.get_cfg().clone_merge(local_paths['cfg']) + if cfg is None: + cfg = XLMRModel.get_cfg().clone_merge(local_paths['cfg']) return cfg, tokenizer, local_params_path, local_mlm_params_path diff --git a/src/gluonnlp/utils/testing.py b/src/gluonnlp/utils/testing.py index 00b2d4901d..abae1a804e 100644 --- a/src/gluonnlp/utils/testing.py +++ b/src/gluonnlp/utils/testing.py @@ -3,19 +3,56 @@ from mxnet.util import use_np +def is_match_states_batch_size(states, states_batch_axis, batch_size) -> bool: + """Test whether the generated states have the specified batch size + + Parameters + ---------- + states + The states structure + states_batch_axis + The states batch axis structure + batch_size + The batch size + + Returns + ------- + ret + """ + if states_batch_axis is None: + return True + if isinstance(states_batch_axis, int): + if states.shape[states_batch_axis] == batch_size: + return True + for ele_states_batch_axis, ele_states in zip(states_batch_axis, states): + ret = is_match_states_batch_size(ele_states, ele_states_batch_axis, batch_size) + if ret is False: + return False + return True + + @use_np -def verify_nmt_model(model, batch_size=4, src_seq_length=5, tgt_seq_length=10, - atol=1E-5, rtol=1E-5): +def verify_nmt_model(model, batch_size: int = 4, + src_seq_length: int = 5, + tgt_seq_length: int = 10, + atol: float = 1E-4, + rtol: float = 1E-4): """Verify the correctness of an NMT model. Raise error message if it detects problems. Parameters ---------- - model : - batch_size : - src_seq_length : - tgt_seq_length : - atol : - rtol : + model + The machine translation model + batch_size + The batch size to test the nmt model + src_seq_length + Length of the source sequence + tgt_seq_length + Length of the target sequence + atol + Absolute tolerance. + rtol + Relative tolerance. """ src_word_sequence = mx.np.random.randint(0, model.src_vocab_size, (batch_size, src_seq_length)) @@ -23,7 +60,13 @@ def verify_nmt_model(model, batch_size=4, src_seq_length=5, tgt_seq_length=10, src_valid_length = mx.np.random.randint(1, src_seq_length, (batch_size,)) min_tgt_seq_length = max(1, tgt_seq_length - 5) tgt_valid_length = mx.np.random.randint(min_tgt_seq_length, tgt_seq_length, (batch_size,)) - full_out = model(src_word_sequence, src_valid_length, tgt_word_sequence, tgt_valid_length) + + if model.layout == 'NT': + full_out = model(src_word_sequence, src_valid_length, tgt_word_sequence, tgt_valid_length) + else: + full_out = model(src_word_sequence.T, src_valid_length, + tgt_word_sequence.T, tgt_valid_length) + full_out = mx.np.swapaxes(full_out, 0, 1) if full_out.shape != (batch_size, tgt_seq_length, model.tgt_vocab_size): raise AssertionError('The output of NMT model does not match the expected output.' ' Model output shape = {}, Expected (B, T, V) = {}' @@ -31,11 +74,19 @@ def verify_nmt_model(model, batch_size=4, src_seq_length=5, tgt_seq_length=10, (batch_size, tgt_seq_length, model.tgt_vocab_size))) for partial_batch_size in range(1, batch_size + 1): for i in range(1, min_tgt_seq_length): - partial_out = model(src_word_sequence[:partial_batch_size, :], - src_valid_length[:partial_batch_size], - tgt_word_sequence[:partial_batch_size, :(-i)], - tgt_valid_length[:partial_batch_size] - - mx.np.array(i, dtype=tgt_valid_length.dtype)) + if model.layout == 'NT': + partial_out = model(src_word_sequence[:partial_batch_size, :], + src_valid_length[:partial_batch_size], + tgt_word_sequence[:partial_batch_size, :(-i)], + tgt_valid_length[:partial_batch_size] + - mx.np.array(i, dtype=tgt_valid_length.dtype)) + else: + partial_out = model(src_word_sequence[:partial_batch_size, :].T, + src_valid_length[:partial_batch_size], + tgt_word_sequence[:partial_batch_size, :(-i)].T, + tgt_valid_length[:partial_batch_size] + - mx.np.array(i, dtype=tgt_valid_length.dtype)) + partial_out = mx.np.swapaxes(partial_out, 0, 1) # Verify that the partial output matches the full output for b in range(partial_batch_size): partial_vl = tgt_valid_length.asnumpy()[b] - i @@ -45,37 +96,66 @@ def verify_nmt_model(model, batch_size=4, src_seq_length=5, tgt_seq_length=10, @use_np def verify_nmt_inference(train_model, inference_model, - batch_size=4, src_seq_length=5, tgt_seq_length=10, atol=1E-5, rtol=1E-5): + batch_size=4, src_seq_length=5, + tgt_seq_length=10, atol=1E-4, rtol=1E-4): """Verify the correctness of an NMT inference model. Raise error message if it detects any problems. Parameters ---------- - train_model : - inference_model : - batch_size : - src_seq_length : - tgt_seq_length : - atol : - rtol : + train_model + inference_model + batch_size + src_seq_length + tgt_seq_length + atol + Absolute tolerance + rtol + Relative tolerance """ - src_word_sequences = mx.np.random.randint(0, train_model.src_vocab_size, - (batch_size, src_seq_length)) - tgt_word_sequences = mx.np.random.randint(0, train_model.tgt_vocab_size, - (batch_size, tgt_seq_length)) + if train_model.layout == 'NT': + src_word_sequences = mx.np.random.randint(0, train_model.src_vocab_size, + (batch_size, src_seq_length)) + tgt_word_sequences = mx.np.random.randint(0, train_model.tgt_vocab_size, + (batch_size, tgt_seq_length)) + else: + src_word_sequences = mx.np.random.randint(0, train_model.src_vocab_size, + (src_seq_length, batch_size)) + tgt_word_sequences = mx.np.random.randint(0, train_model.tgt_vocab_size, + (tgt_seq_length, batch_size)) src_valid_length = mx.np.random.randint(1, src_seq_length, (batch_size,)) min_tgt_seq_length = max(1, tgt_seq_length - 5) tgt_valid_length = mx.np.random.randint(min_tgt_seq_length, tgt_seq_length, (batch_size,)) full_out = train_model(src_word_sequences, src_valid_length, tgt_word_sequences, tgt_valid_length) - for partial_batch_size in range(1, batch_size + 1): - step_out_l = [] - states = inference_model.init_states(src_word_sequences[:partial_batch_size, :], - src_valid_length[:partial_batch_size]) - for i in range(min_tgt_seq_length): - step_out, states = inference_model(tgt_word_sequences[:partial_batch_size, i], states) - step_out_l.append(step_out) - partial_out = mx.np.stack(step_out_l, axis=1) - npt.assert_allclose(full_out[:partial_batch_size, :min_tgt_seq_length].asnumpy(), - partial_out[:partial_batch_size, :].asnumpy(), atol, rtol) + if train_model.layout == 'NT': + for partial_batch_size in range(1, batch_size + 1): + step_out_l = [] + states = inference_model.init_states(src_word_sequences[:partial_batch_size, :], + src_valid_length[:partial_batch_size]) + assert is_match_states_batch_size(states, inference_model.state_batch_axis, + partial_batch_size) + for i in range(min_tgt_seq_length): + step_out, states = inference_model(tgt_word_sequences[:partial_batch_size, i], + states) + step_out_l.append(step_out) + partial_out = mx.np.stack(step_out_l, axis=1) + npt.assert_allclose(full_out[:partial_batch_size, :min_tgt_seq_length].asnumpy(), + partial_out[:partial_batch_size, :].asnumpy(), atol, rtol) + elif train_model.layout == 'TN': + for partial_batch_size in range(1, batch_size + 1): + step_out_l = [] + states = inference_model.init_states(src_word_sequences[:, :partial_batch_size], + src_valid_length[:partial_batch_size]) + assert is_match_states_batch_size(states, inference_model.state_batch_axis, + partial_batch_size) + for i in range(min_tgt_seq_length): + step_out, states = inference_model(tgt_word_sequences[i, :partial_batch_size], + states) + step_out_l.append(step_out) + partial_out = mx.np.stack(step_out_l, axis=0) + npt.assert_allclose(full_out[:min_tgt_seq_length, :partial_batch_size].asnumpy(), + partial_out[:, :partial_batch_size].asnumpy(), atol, rtol) + else: + raise NotImplementedError diff --git a/tests/test_attention_cell.py b/tests/test_attention_cell.py index 489f566beb..3b874b0d55 100644 --- a/tests/test_attention_cell.py +++ b/tests/test_attention_cell.py @@ -173,23 +173,27 @@ def test_dot_product_attention(scaled, normalized): @pytest.mark.seed(123) def test_gen_attn_mask(): class GenSelfAttnMask(HybridBlock): - def __init__(self, dtype, attn_type): + def __init__(self, dtype, layout, attn_type): super().__init__() self._dtype = dtype + self._layout = layout self._attn_type = attn_type def hybrid_forward(self, F, data, valid_length): return gen_self_attn_mask(F, data, valid_length, - dtype=self._dtype, attn_type=self._attn_type) + dtype=self._dtype, + layout=self._layout, + attn_type=self._attn_type) class GenMemAttnMask(HybridBlock): - def __init__(self, dtype): + def __init__(self, dtype, layout): super().__init__() self._dtype = dtype + self._layout = layout def hybrid_forward(self, F, mem, mem_valid_length, data, valid_length): return gen_mem_attn_mask(F, mem, mem_valid_length, data, valid_length, - dtype=self._dtype) + dtype=self._dtype, layout=self._layout) batch_size = 4 query_length = 8 @@ -203,11 +207,17 @@ def hybrid_forward(self, F, mem, mem_valid_length, data, valid_length): for hybridize in [False, True]: # Test Full Attention Mask - mask_gen = GenSelfAttnMask(dtype=np.float32, attn_type='full') + mask_gen_nt = GenSelfAttnMask(dtype=np.float32, layout='NT', attn_type='full') + mask_gen_tn = GenSelfAttnMask(dtype=np.float32, layout='TN', attn_type='full') if hybridize: - mask_gen.hybridize() - mask = mask_gen(data, valid_length) - mask = mask.asnumpy() + mask_gen_nt.hybridize() + mask_gen_tn.hybridize() + mask_nt = mask_gen_nt(data, valid_length) + mask_nt = mask_nt.asnumpy() + mask_tn = mask_gen_tn(mx.np.swapaxes(data, 0, 1), valid_length) + mask_tn = mask_tn.asnumpy() + mask = mask_nt + assert_allclose(mask_nt, mask_tn) for b in range(batch_size): v_l = valid_length.asnumpy()[b] for i in range(v_l): @@ -217,11 +227,15 @@ def hybrid_forward(self, F, mem, mem_valid_length, data, valid_length): assert (mask[b, i, :] == 0).all() # Test Causal Attention Mask - mask_gen = GenSelfAttnMask(dtype=np.float32, attn_type='causal') + mask_gen_nt = GenSelfAttnMask(dtype=np.float32, layout='NT', attn_type='causal') + mask_gen_tn = GenSelfAttnMask(dtype=np.float32, layout='TN', attn_type='causal') if hybridize: - mask_gen.hybridize() - mask = mask_gen(data, valid_length) - mask = mask.asnumpy() + mask_gen_nt.hybridize() + mask_gen_tn.hybridize() + mask_nt = mask_gen_nt(data, valid_length) + mask_tn = mask_gen_tn(mx.np.swapaxes(data, 0, 1), valid_length) + assert_allclose(mask_nt.asnumpy(), mask_tn.asnumpy()) + mask = mask_nt.asnumpy() for b in range(batch_size): v_l = valid_length.asnumpy()[b] for i in range(v_l): @@ -231,11 +245,16 @@ def hybrid_forward(self, F, mem, mem_valid_length, data, valid_length): assert (mask[b, i, :] == 0).all() # Test Mem Attention Mask - mask_gen = GenMemAttnMask(dtype=np.float32) + mask_gen_nt = GenMemAttnMask(dtype=np.float32, layout='NT') + mask_gen_tn = GenMemAttnMask(dtype=np.float32, layout='TN') if hybridize: - mask_gen.hybridize() - mask = mask_gen(mem, mem_valid_length, data, valid_length) - mask = mask.asnumpy() + mask_gen_nt.hybridize() + mask_gen_tn.hybridize() + mask_nt = mask_gen_nt(mem, mem_valid_length, data, valid_length) + mask_tn = mask_gen_tn(mx.np.swapaxes(mem, 0, 1), mem_valid_length, + mx.np.swapaxes(data, 0, 1), valid_length) + mask = mask_nt.asnumpy() + assert_allclose(mask_nt.asnumpy(), mask_tn.asnumpy()) for b in range(batch_size): data_v_l = valid_length.asnumpy()[b] mem_v_l = mem_valid_length.asnumpy()[b] diff --git a/tests/test_models_albert.py b/tests/test_models_albert.py index 2fd7bbdba5..f428a85569 100644 --- a/tests/test_models_albert.py +++ b/tests/test_models_albert.py @@ -30,17 +30,36 @@ def get_test_cfg(): return cfg -def test_albert_backbone(): +@pytest.mark.parametrize('static_alloc,static_shape', [(False, False), + (True, True)]) +@pytest.mark.parametrize('compute_layout', ['auto', 'NT', 'TN']) +def test_albert_backbone(static_alloc, static_shape, compute_layout): batch_size = 3 cfg = get_test_cfg() + cfg.defrost() + cfg.MODEL.compute_layout = compute_layout + cfg.freeze() model = AlbertModel.from_cfg(cfg, use_pooler=True) model.initialize() - model.hybridize(static_alloc=True, static_shape=True) + model.hybridize(static_alloc=static_alloc, static_shape=static_shape) + cfg_tn = cfg.clone() + cfg_tn.defrost() + cfg_tn.MODEL.layout = 'TN' + cfg_tn.freeze() + model_tn = AlbertModel.from_cfg(cfg_tn, use_pooler=True) + model_tn.share_parameters(model.collect_params()) + model_tn.hybridize(static_alloc=static_alloc, static_shape=static_shape) + for seq_length in [64, 96]: valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,)) inputs = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length)) token_types = mx.np.random.randint(0, cfg.MODEL.num_token_types, (batch_size, seq_length)) contextual_embedding, pooled_out = model(inputs, token_types, valid_length) + contextual_embedding_tn, pooled_out_tn = model_tn(inputs.T, token_types.T, valid_length) + # Verify layout + assert_allclose(np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1), + contextual_embedding.asnumpy(), 1E-4, 1E-4) + assert_allclose(pooled_out_tn.asnumpy(), pooled_out.asnumpy(), 1E-4, 1E-4) assert contextual_embedding.shape == (batch_size, seq_length, cfg.MODEL.units) assert pooled_out.shape == (batch_size, cfg.MODEL.units) # Ensure the embeddings that exceed valid_length are masked @@ -65,35 +84,72 @@ def test_albert_backbone(): assert_allclose(new_pooled_out_np, pooled_out_np, 1E-4, 1E-4) -def test_albert_for_mlm_model(): +@pytest.mark.parametrize('compute_layout', ['auto', 'NT', 'TN']) +def test_albert_for_mlm_model(compute_layout): batch_size = 3 cfg = get_test_cfg() + cfg.defrost() + cfg.MODEL.compute_layout = compute_layout + cfg.freeze() albert_mlm_model = AlbertForMLM(backbone_cfg=cfg) albert_mlm_model.initialize() albert_mlm_model.hybridize() + cfg_tn = cfg.clone() + cfg_tn.defrost() + cfg_tn.MODEL.layout = 'TN' + cfg_tn.freeze() + albert_mlm_tn_model = AlbertForMLM(backbone_cfg=cfg_tn) + albert_mlm_tn_model.share_parameters(albert_mlm_model.collect_params()) + albert_mlm_tn_model.hybridize() + num_mask = 16 seq_length = 64 inputs = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length)) token_types = mx.np.random.randint(0, cfg.MODEL.num_token_types, (batch_size, seq_length)) valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,)) masked_positions = mx.np.random.randint(0, seq_length // 2, (batch_size, num_mask)) - _, _, mlm_scores = albert_mlm_model(inputs, token_types, valid_length, masked_positions) + contextual_embeddings, pooled_out, mlm_scores = albert_mlm_model(inputs, token_types, valid_length, masked_positions) + contextual_embeddings_tn, pooled_out_tn, mlm_scores_tn = albert_mlm_tn_model(inputs.T, token_types.T, valid_length, masked_positions) + assert_allclose(np.swapaxes(contextual_embeddings_tn.asnumpy(), 0, 1), + contextual_embeddings.asnumpy(), 1E-4, 1E-4) + assert_allclose(pooled_out_tn.asnumpy(), pooled_out.asnumpy(), 1E-4, 1E-4) + assert_allclose(mlm_scores_tn.asnumpy(), mlm_scores.asnumpy(), 1E-4, 1E-4) assert mlm_scores.shape == (batch_size, num_mask, cfg.MODEL.vocab_size) -def test_albert_for_pretrain_model(): +@pytest.mark.parametrize('compute_layout', ['auto', 'NT', 'TN']) +def test_albert_for_pretrain_model(compute_layout): batch_size = 3 cfg = get_test_cfg() + cfg.defrost() + cfg.MODEL.compute_layout = compute_layout + cfg.freeze() albert_pretrain_model = AlbertForPretrain(backbone_cfg=cfg) albert_pretrain_model.initialize() albert_pretrain_model.hybridize() + cfg_tn = cfg.clone() + cfg_tn.defrost() + cfg_tn.MODEL.layout = 'TN' + cfg_tn.freeze() + albert_pretrain_model_tn = AlbertForPretrain(backbone_cfg=cfg_tn) + albert_pretrain_model_tn.share_parameters(albert_pretrain_model.collect_params()) + albert_pretrain_model_tn.hybridize() + num_mask = 16 seq_length = 64 inputs = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length)) token_types = mx.np.random.randint(0, cfg.MODEL.num_token_types, (batch_size, seq_length)) valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,)) masked_positions = mx.np.random.randint(0, seq_length // 2, (batch_size, num_mask)) - _, _, sop_score, mlm_scores = albert_pretrain_model(inputs, token_types, valid_length, masked_positions) + contextual_embeddings, pooled_out, sop_score, mlm_scores =\ + albert_pretrain_model(inputs, token_types, valid_length, masked_positions) + contextual_embeddings_tn, pooled_out_tn, sop_score_tn, mlm_scores_tn = \ + albert_pretrain_model_tn(inputs.T, token_types.T, valid_length, masked_positions) + assert_allclose(np.swapaxes(contextual_embeddings_tn.asnumpy(), 0, 1), + contextual_embeddings.asnumpy(), 1E-4, 1E-4) + assert_allclose(pooled_out_tn.asnumpy(), pooled_out.asnumpy(), 1E-4, 1E-4) + assert_allclose(sop_score.asnumpy(), sop_score_tn.asnumpy(), 1E-4, 1E-4) + assert_allclose(mlm_scores.asnumpy(), mlm_scores_tn.asnumpy(), 1E-4, 1E-4) assert mlm_scores.shape == (batch_size, num_mask, cfg.MODEL.vocab_size) assert sop_score.shape == (batch_size, 2) diff --git a/tests/test_models_bert.py b/tests/test_models_bert.py index cb1feedc66..a0d9a8d742 100644 --- a/tests/test_models_bert.py +++ b/tests/test_models_bert.py @@ -1,5 +1,4 @@ import pytest -import numpy as np from numpy.testing import assert_allclose import mxnet as mx import tempfile @@ -12,6 +11,83 @@ def test_list_pretrained_bert(): assert len(list_pretrained_bert()) > 0 +@pytest.mark.parametrize('compute_layout', ['auto', 'NT', 'TN']) +def test_bert_small_cfg(compute_layout): + cfg = BertModel.get_cfg() + cfg.defrost() + cfg.MODEL.vocab_size = 100 + cfg.MODEL.units = 12 * 8 + cfg.MODEL.hidden_size = 64 + cfg.MODEL.num_layers = 2 + cfg.MODEL.num_heads = 2 + cfg.MODEL.compute_layout = compute_layout + cfg.freeze() + + # Generate TN layout + cfg_tn = cfg.clone() + cfg_tn.defrost() + cfg_tn.MODEL.layout = 'TN' + cfg_tn.freeze() + + # Sample data + batch_size = 4 + sequence_length = 16 + num_mask = 3 + inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length)) + token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length)) + valid_length = mx.np.random.randint(3, sequence_length, (batch_size,)) + masked_positions = mx.np.random.randint(0, 3, (batch_size, num_mask)) + + # Test for BertModel + bert_model = BertModel.from_cfg(cfg) + bert_model.initialize() + bert_model.hybridize() + contextual_embedding, pooled_out = bert_model(inputs, token_types, valid_length) + bert_model_tn = BertModel.from_cfg(cfg_tn) + bert_model_tn.share_parameters(bert_model.collect_params()) + bert_model_tn.hybridize() + contextual_embedding_tn, pooled_out_tn = bert_model_tn(inputs.T, token_types.T, valid_length) + assert_allclose(contextual_embedding.asnumpy(), + mx.np.swapaxes(contextual_embedding_tn, 0, 1).asnumpy(), + 1E-4, 1E-4) + assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-4, 1E-4) + + # Test for BertForMLM + bert_mlm_model = BertForMLM(cfg) + bert_mlm_model.initialize() + bert_mlm_model.hybridize() + contextual_embedding, pooled_out, mlm_score = bert_mlm_model(inputs, token_types, + valid_length, masked_positions) + bert_mlm_model_tn = BertForMLM(cfg_tn) + bert_mlm_model_tn.share_parameters(bert_mlm_model.collect_params()) + bert_mlm_model_tn.hybridize() + contextual_embedding_tn, pooled_out_tn, mlm_score_tn =\ + bert_mlm_model_tn(inputs.T, token_types.T, valid_length, masked_positions) + assert_allclose(contextual_embedding.asnumpy(), + mx.np.swapaxes(contextual_embedding_tn, 0, 1).asnumpy(), + 1E-4, 1E-4) + assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-4, 1E-4) + assert_allclose(mlm_score.asnumpy(), mlm_score_tn.asnumpy(), 1E-4, 1E-4) + + # Test for BertForPretrain + bert_pretrain_model = BertForPretrain(cfg) + bert_pretrain_model.initialize() + bert_pretrain_model.hybridize() + contextual_embedding, pooled_out, nsp_score, mlm_scores =\ + bert_pretrain_model(inputs, token_types, valid_length, masked_positions) + bert_pretrain_model_tn = BertForPretrain(cfg_tn) + bert_pretrain_model_tn.share_parameters(bert_pretrain_model.collect_params()) + bert_pretrain_model_tn.hybridize() + contextual_embedding_tn, pooled_out_tn, nsp_score_tn, mlm_scores_tn = \ + bert_pretrain_model_tn(inputs.T, token_types.T, valid_length, masked_positions) + assert_allclose(contextual_embedding.asnumpy(), + mx.np.swapaxes(contextual_embedding_tn, 0, 1).asnumpy(), + 1E-4, 1E-4) + assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-4, 1E-4) + assert_allclose(nsp_score.asnumpy(), nsp_score_tn.asnumpy(), 1E-4, 1E-4) + assert_allclose(mlm_score.asnumpy(), mlm_score_tn.asnumpy(), 1E-4, 1E-4) + + @pytest.mark.remote_required @pytest.mark.parametrize('model_name', list_pretrained_bert()) def test_bert_get_pretrained(model_name): diff --git a/tests/test_models_electra.py b/tests/test_models_electra.py index 8866cd7921..17f9420a07 100644 --- a/tests/test_models_electra.py +++ b/tests/test_models_electra.py @@ -3,14 +3,68 @@ from numpy.testing import assert_allclose import mxnet as mx import tempfile -from gluonnlp.models.electra import ElectraModel, ElectraDiscriminator, ElectraGenerator,\ +from gluonnlp.models.electra import ElectraModel, ElectraDiscriminator,\ + ElectraGenerator,\ list_pretrained_electra, get_pretrained_electra, get_generator_cfg mx.npx.set_np() +def test_list_pretrained_electra(): + assert len(list_pretrained_electra()) > 0 + + +def get_test_cfg(): + cfg = ElectraModel.get_cfg() + cfg.defrost() + cfg.MODEL.vocab_size = 100 + cfg.MODEL.units = 12 * 8 + cfg.MODEL.hidden_size = 128 + cfg.MODEL.num_heads = 2 + cfg.MODEL.num_layers = 2 + cfg.freeze() + return cfg + + +@pytest.mark.parametrize('compute_layout', ['auto', 'NT', 'TN']) +def test_electra_model(compute_layout): + cfg = get_test_cfg() + cfg.defrost() + cfg.MODEL.compute_layout = compute_layout + cfg.freeze() + + # Generate TN layout + cfg_tn = cfg.clone() + cfg_tn.defrost() + cfg_tn.MODEL.layout = 'TN' + cfg_tn.freeze() + + # Sample data + batch_size = 4 + sequence_length = 16 + num_mask = 3 + inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length)) + token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length)) + valid_length = mx.np.random.randint(3, sequence_length, (batch_size,)) + masked_positions = mx.np.random.randint(0, 3, (batch_size, num_mask)) + + electra_model = ElectraModel.from_cfg(cfg) + electra_model.initialize() + electra_model.hybridize() + contextual_embedding, pooled_out = electra_model(inputs, token_types, valid_length) + electra_model_tn = ElectraModel.from_cfg(cfg_tn) + electra_model_tn.share_parameters(electra_model.collect_params()) + electra_model_tn.hybridize() + contextual_embedding_tn, pooled_out_tn = electra_model_tn(inputs.T, token_types.T, valid_length) + assert_allclose(contextual_embedding.asnumpy(), + np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1), + 1E-4, 1E-4) + assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), + 1E-4, 1E-4) + + @pytest.mark.remote_required @pytest.mark.parametrize('model_name', list_pretrained_electra()) -def test_bert_get_pretrained(model_name): +def test_electra_get_pretrained(model_name): assert len(list_pretrained_electra()) > 0 with tempfile.TemporaryDirectory() as root: cfg, tokenizer, backbone_params_path, (disc_params_path, gen_params_path) =\ @@ -34,6 +88,5 @@ def test_bert_get_pretrained(model_name): electra_disc_model.backbone_model.token_pos_embed.collect_params(), electra_disc_model.backbone_model.embed_layer_norm.collect_params()) - electra_gen_model = ElectraGenerator(cfg) electra_gen_model.backbone_model.load_parameters(backbone_params_path) diff --git a/tests/test_models_mobilebert.py b/tests/test_models_mobilebert.py index bfd1e3d882..d7f22ac533 100644 --- a/tests/test_models_mobilebert.py +++ b/tests/test_models_mobilebert.py @@ -12,9 +12,85 @@ def test_list_pretrained_mobilebert(): assert len(list_pretrained_mobilebert()) > 0 +@pytest.mark.parametrize('compute_layout', ['auto', 'TN', 'NT']) +def test_mobilebert_model_small_cfg(compute_layout): + cfg = MobileBertModel.get_cfg() + cfg.defrost() + cfg.MODEL.vocab_size = 100 + cfg.MODEL.num_layers = 2 + cfg.MODEL.hidden_size = 128 + cfg.MODEL.num_heads = 2 + cfg.MODEL.compute_layout = compute_layout + cfg.freeze() + + # Generate TN layout + cfg_tn = cfg.clone() + cfg_tn.defrost() + cfg_tn.MODEL.layout = 'TN' + cfg_tn.freeze() + + batch_size = 4 + sequence_length = 16 + num_mask = 3 + inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length)) + token_types = mx.np.random.randint(0, 2, (batch_size, sequence_length)) + valid_length = mx.np.random.randint(3, sequence_length, (batch_size,)) + masked_positions = mx.np.random.randint(0, 3, (batch_size, num_mask)) + + mobile_bert_model = MobileBertModel.from_cfg(cfg) + mobile_bert_model.initialize() + mobile_bert_model.hybridize() + mobile_bert_model_tn = MobileBertModel.from_cfg(cfg_tn) + mobile_bert_model_tn.share_parameters(mobile_bert_model.collect_params()) + mobile_bert_model_tn.hybridize() + contextual_embedding, pooled_out = mobile_bert_model(inputs, token_types, valid_length) + contextual_embedding_tn, pooled_out_tn = mobile_bert_model_tn(inputs.T, + token_types.T, valid_length) + assert_allclose(contextual_embedding.asnumpy(), + np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1), + 1E-4, 1E-4) + assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-4, 1E-4) + + # Test for MobileBertForMLM + mobile_bert_mlm_model = MobileBertForMLM(cfg) + mobile_bert_mlm_model.initialize() + mobile_bert_mlm_model.hybridize() + mobile_bert_mlm_model_tn = MobileBertForMLM(cfg_tn) + mobile_bert_mlm_model_tn.share_parameters(mobile_bert_mlm_model.collect_params()) + mobile_bert_model_tn.hybridize() + contextual_embedding, pooled_out, mlm_scores = mobile_bert_mlm_model(inputs, token_types, + valid_length, + masked_positions) + contextual_embedding_tn, pooled_out_tn, mlm_scores_tn =\ + mobile_bert_mlm_model_tn(inputs.T, token_types.T, valid_length, masked_positions) + assert_allclose(contextual_embedding.asnumpy(), + np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1), + 1E-4, 1E-4) + assert_allclose(pooled_out_tn.asnumpy(), pooled_out.asnumpy(), 1E-4, 1E-4) + assert_allclose(mlm_scores_tn.asnumpy(), mlm_scores.asnumpy(), 1E-4, 1E-4) + + # Test for MobileBertForPretrain + mobile_bert_pretrain_model = MobileBertForPretrain(cfg) + mobile_bert_pretrain_model.initialize() + mobile_bert_pretrain_model.hybridize() + mobile_bert_pretrain_model_tn = MobileBertForPretrain(cfg_tn) + mobile_bert_pretrain_model_tn.share_parameters(mobile_bert_pretrain_model.collect_params()) + mobile_bert_pretrain_model_tn.hybridize() + contextual_embedding, pooled_out, nsp_score, mlm_scores =\ + mobile_bert_pretrain_model(inputs, token_types, valid_length, masked_positions) + contextual_embedding_tn, pooled_out_tn, nsp_score_tn, mlm_scores_tn = \ + mobile_bert_pretrain_model_tn(inputs.T, token_types.T, valid_length, masked_positions) + assert_allclose(contextual_embedding.asnumpy(), + np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1), + 1E-4, 1E-4) + assert_allclose(pooled_out.asnumpy(), pooled_out_tn.asnumpy(), 1E-4, 1E-4) + assert_allclose(nsp_score.asnumpy(), nsp_score_tn.asnumpy(), 1E-4, 1E-4) + assert_allclose(mlm_scores.asnumpy(), mlm_scores_tn.asnumpy(), 1E-4, 1E-4) + + @pytest.mark.remote_required @pytest.mark.parametrize('model_name', list_pretrained_mobilebert()) -def test_bert_get_pretrained(model_name): +def test_mobilebert_get_pretrained(model_name): with tempfile.TemporaryDirectory() as root: cfg, tokenizer, backbone_params_path, mlm_params_path =\ get_pretrained_mobilebert(model_name, load_backbone=True, load_mlm=True, root=root) diff --git a/tests/test_models_roberta.py b/tests/test_models_roberta.py index 9511c51472..bedf85f027 100644 --- a/tests/test_models_roberta.py +++ b/tests/test_models_roberta.py @@ -2,6 +2,7 @@ import numpy as np import mxnet as mx import tempfile +from numpy.testing import assert_allclose from gluonnlp.models.roberta import RobertaModel, RobertaForMLM, \ list_pretrained_roberta, get_pretrained_roberta from gluonnlp.loss import LabelSmoothCrossEntropyLoss @@ -13,6 +14,59 @@ def test_list_pretrained_roberta(): assert len(list_pretrained_roberta()) > 0 +@pytest.mark.parametrize('compute_layout', ['auto', 'TN', 'NT']) +def test_robert_small_config(compute_layout): + cfg = RobertaModel.get_cfg() + cfg.defrost() + cfg.MODEL.vocab_size = 1000 + cfg.MODEL.num_layers = 2 + cfg.MODEL.hidden_size = 128 + cfg.MODEL.num_heads = 2 + cfg.MODEL.compute_layout = compute_layout + cfg.freeze() + + # Generate TN layout + cfg_tn = cfg.clone() + cfg_tn.defrost() + cfg_tn.MODEL.layout = 'TN' + cfg_tn.freeze() + + batch_size = 4 + sequence_length = 16 + num_mask = 3 + inputs = mx.np.random.randint(0, 10, (batch_size, sequence_length)) + valid_length = mx.np.random.randint(3, sequence_length, (batch_size,)) + masked_positions = mx.np.random.randint(0, 3, (batch_size, num_mask)) + + roberta_model = RobertaModel.from_cfg(cfg) + roberta_model.initialize() + roberta_model.hybridize() + contextual_embeddings, pooled_out = roberta_model(inputs, valid_length) + roberta_model_tn = RobertaModel.from_cfg(cfg_tn) + roberta_model_tn.share_parameters(roberta_model.collect_params()) + roberta_model_tn.hybridize() + contextual_embeddings_tn, pooled_out_tn = roberta_model_tn(inputs.T, valid_length) + assert_allclose(np.swapaxes(contextual_embeddings_tn.asnumpy(), 0, 1), + contextual_embeddings.asnumpy(), 1E-4, 1E-4) + assert_allclose(pooled_out_tn.asnumpy(), pooled_out.asnumpy(), 1E-4, 1E-4) + + # Test for RobertaForMLM + roberta_mlm_model = RobertaForMLM(cfg) + roberta_mlm_model.initialize() + roberta_mlm_model.hybridize() + contextual_embedding, pooled_out, mlm_scores = roberta_mlm_model(inputs, valid_length, + masked_positions) + roberta_mlm_model_tn = RobertaForMLM(cfg_tn) + roberta_mlm_model_tn.share_parameters(roberta_mlm_model.collect_params()) + roberta_mlm_model_tn.hybridize() + contextual_embedding_tn, pooled_out_tn, mlm_scores_tn =\ + roberta_mlm_model_tn(inputs.T, valid_length.T, masked_positions) + assert_allclose(np.swapaxes(contextual_embedding_tn.asnumpy(), 0, 1), + contextual_embedding.asnumpy(), 1E-4, 1E-4) + assert_allclose(pooled_out_tn.asnumpy(), pooled_out.asnumpy(), 1E-4, 1E-4) + assert_allclose(mlm_scores_tn.asnumpy(), mlm_scores.asnumpy(), 1E-4, 1E-4) + + @pytest.mark.remote_required @pytest.mark.parametrize('model_name', list_pretrained_roberta()) def test_roberta(model_name): diff --git a/tests/test_models_transformer.py b/tests/test_models_transformer.py index b1e772ce73..e9b1cd6184 100644 --- a/tests/test_models_transformer.py +++ b/tests/test_models_transformer.py @@ -33,6 +33,23 @@ def test_transformer_encoder_decoder(pre_norm, num_enc_layers, num_dec_layers): encoded_mem = enc(src_data, src_valid_length) full_decode_out = dec(dst_data, dst_valid_length, encoded_mem, src_valid_length) + # Test for the TN layout + enc_tn = TransformerEncoder(units=units, hidden_size=64, num_layers=num_enc_layers, num_heads=4, + dropout=0.0, pre_norm=pre_norm, layout='TN') + enc_tn.share_parameters(enc.collect_params()) + dec_tn = TransformerDecoder(units=units, hidden_size=64, num_layers=num_dec_layers, num_heads=4, + dropout=0.0, pre_norm=pre_norm, layout='TN') + dec_tn.share_parameters(dec.collect_params()) + enc_tn.hybridize() + dec_tn.hybridize() + encoded_mem_tn = enc_tn(mx.np.swapaxes(src_data, 0, 1), src_valid_length) + full_decode_out_tn = dec_tn(mx.np.swapaxes(dst_data, 0, 1), dst_valid_length, + encoded_mem_tn, src_valid_length) + assert_allclose(encoded_mem_tn.asnumpy(), + mx.np.swapaxes(encoded_mem, 0, 1).asnumpy(), 1E-5, 1E-5) + assert_allclose(full_decode_out_tn.asnumpy(), + mx.np.swapaxes(full_decode_out, 0, 1).asnumpy(), 1E-5, 1E-5) + # Test the consistency via shifting the data and the valid_length for i in range(1, dst_valid_length.asnumpy().min()): for partial_decode_out in [dec(dst_data[:, :(-i), :], @@ -52,11 +69,11 @@ def test_transformer_encoder_decoder(pre_norm, num_enc_layers, num_dec_layers): states = dec.layers[0].init_states(batch_size, h_out.ctx, h_out.dtype) h_out_from_incremental = [] for i in range(tgt_seq_length): - ele_h_out, states = dec.layers[0].incremental_decode(mx, dst_data[:, i:(i + 1), :], states, + ele_h_out, states = dec.layers[0].incremental_decode(mx, dst_data[:, i, :], states, encoded_mem, src_valid_length, enc_mem_attn_mask) h_out_from_incremental.append(ele_h_out) - h_out_from_incremental = mx.np.concatenate(h_out_from_incremental, axis=1) + h_out_from_incremental = mx.np.stack(h_out_from_incremental, axis=1) for i in range(batch_size): val_length = dst_valid_length[i].asnumpy() @@ -66,10 +83,10 @@ def test_transformer_encoder_decoder(pre_norm, num_enc_layers, num_dec_layers): states = dec.init_states(batch_size, src_data.ctx, src_data.dtype) final_out_from_incremental = [] for i in range(tgt_seq_length): - ele_final_out, states = dec.incremental_decode(mx, dst_data[:, i:(i + 1), :], + ele_final_out, states = dec.incremental_decode(mx, dst_data[:, i, :], states, encoded_mem, src_valid_length) final_out_from_incremental.append(ele_final_out) - final_out_from_incremental = mx.np.concatenate(final_out_from_incremental, axis=1) + final_out_from_incremental = mx.np.stack(final_out_from_incremental, axis=1) for i in range(batch_size): val_length = dst_valid_length[i].asnumpy() assert_allclose(final_out_from_incremental[i, :val_length, :].asnumpy(), @@ -85,12 +102,13 @@ def test_transformer_encoder_decoder(pre_norm, num_enc_layers, num_dec_layers): (2, 3, 16, 24)]) @pytest.mark.parametrize('enc_recurrent', [False, True]) @pytest.mark.parametrize('dec_recurrent', [False, True]) -@pytest.mark.parametrize('tie_weights', [False, True]) +@pytest.mark.parametrize('tie_weights,layout', [(False, 'NT'), (True, 'NT'), (True, 'TN')]) def test_transformer_nmt_model(train_hybridize, inference_hybridize, enc_pre_norm, dec_pre_norm, enc_units, dec_units, enc_num_layers, dec_num_layers, - enc_recurrent, dec_recurrent, tie_weights): + enc_recurrent, dec_recurrent, tie_weights, + layout): src_seq_length = 20 tgt_seq_length = 15 src_vocab_size = 32 @@ -117,7 +135,8 @@ def test_transformer_nmt_model(train_hybridize, inference_hybridize, dec_recurrent=dec_recurrent, shared_embed=shared_embed, tie_weights=tie_weights, - dropout=0.0) + dropout=0.0, + layout=layout) inference_model = TransformerNMTInference(model=model) model.initialize() if train_hybridize: @@ -136,10 +155,16 @@ def test_transformer_cfg_registry(): def test_transformer_cfg(cfg_key): cfg = TransformerNMTModel.get_cfg(cfg_key) cfg.defrost() - cfg.MODEL.src_vocab_size = 1000 - cfg.MODEL.tgt_vocab_size = 1000 + cfg.MODEL.src_vocab_size = 32 + cfg.MODEL.tgt_vocab_size = 32 cfg.freeze() model = TransformerNMTModel.from_cfg(cfg) model.initialize() model.hybridize() + cfg.defrost() + cfg.MODEL.layout = 'TN' + cfg.freeze() + model_tn = TransformerNMTModel.from_cfg(cfg) + model_tn.share_parameters(model.collect_params()) + model_tn.hybridize() mx.npx.waitall() diff --git a/tests/test_models_xlmr.py b/tests/test_models_xlmr.py index f8f9ec76fe..ff9c41fdfd 100644 --- a/tests/test_models_xlmr.py +++ b/tests/test_models_xlmr.py @@ -2,7 +2,7 @@ import numpy as np import mxnet as mx import tempfile -from gluonnlp.models.xlmr import XLMRModel, XLMRForMLM, \ +from gluonnlp.models.xlmr import XLMRModel, \ list_pretrained_xlmr, get_pretrained_xlmr from gluonnlp.loss import LabelSmoothCrossEntropyLoss @@ -29,7 +29,7 @@ def test_xlmr(): # test forward batch_size = 1 - seq_length = 8 + seq_length = 4 vocab_size = len(tokenizer.vocab) input_ids = mx.np.array( np.random.randint(