Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BetterTransformer] Add MobileBERT support for BetterTransformer #506

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
BartEncoderLayerBetterTransformer,
raghavanone marked this conversation as resolved.
Show resolved Hide resolved
BertLayerBetterTransformer,
DistilBertLayerBetterTransformer,
MobileBertLayer,
raghavanone marked this conversation as resolved.
Show resolved Hide resolved
ViTLayerBetterTransformer,
Wav2Vec2EncoderLayerBetterTransformer,
WhisperEncoderLayerBetterTransformer,
Expand All @@ -31,6 +32,7 @@
"Data2VecTextLayer": BertLayerBetterTransformer,
"CamembertLayer": BertLayerBetterTransformer,
"MarkupLMLayer": BertLayerBetterTransformer,
"MobileBertLayer": MobileBertLayer,
raghavanone marked this conversation as resolved.
Show resolved Hide resolved
"RobertaLayer": BertLayerBetterTransformer,
"SplinterLayer": BertLayerBetterTransformer,
"ErnieLayer": BertLayerBetterTransformer,
Expand Down
108 changes: 108 additions & 0 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,3 +751,111 @@ def forward(self, hidden_states, attention_mask, **__):
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)


class MobileBertLayer(BetterTransformerBaseLayer):
raghavanone marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, mobilebert_layer, config):
r"""
A simple conversion of the MobileBERT layer to its `BetterTransformer` implementation.

Args:
MobileBert_layer (`torch.nn.Module`):
The original MobileBERT Layer where the weights needs to be retrieved.
"""
super().__init__(config)
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
mobilebert_layer.attention.self.query.weight,
mobilebert_layer.attention.self.key.weight,
mobilebert_layer.attention.self.value.weight,
]
)
)
self.in_proj_bias = nn.Parameter(
torch.cat(
[
mobilebert_layer.attention.self.query.bias,
mobilebert_layer.attention.self.key.bias,
mobilebert_layer.attention.self.value.bias,
]
)
)

# Out proj layer
self.out_proj_weight = mobilebert_layer.attention.output.dense.weight
self.out_proj_bias = mobilebert_layer.attention.output.dense.bias

# Linear layer 1
self.linear1_weight = mobilebert_layer.intermediate.dense.weight
self.linear1_bias = mobilebert_layer.intermediate.dense.bias

# Linear layer 2
self.linear2_weight = mobilebert_layer.output.dense.weight
self.linear2_bias = mobilebert_layer.output.dense.bias

# Layer norm 1
self.norm1_eps = mobilebert_layer.attention.output.LayerNorm.eps
self.norm1_weight = mobilebert_layer.attention.output.LayerNorm.weight
self.norm1_bias = mobilebert_layer.attention.output.LayerNorm.bias

# Layer norm 2
self.norm2_eps = mobilebert_layer.output.LayerNorm.eps
self.norm2_weight = mobilebert_layer.output.LayerNorm.weight
self.norm2_bias = mobilebert_layer.output.LayerNorm.bias

# Model hyper parameters
self.num_heads = mobilebert_layer.attention.self.num_attention_heads
self.embed_dim = mobilebert_layer.attention.self.all_head_size

# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, *_):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of *_ in the end of the arguments?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not sure, I saw *_ in every other forward call definition.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, maybe we should also add **__ or something to handle potential keyword arguments?
@younesbelkada

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw both some forward def having both *_ and **_ . We should possibly add more information about it in the guide published.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes basically it is to "handle" potential arguments that were passed to the original forward function, by allowing this forward function to get them, but ignoring them right after.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great remark, the guide will be updated ;)

r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()

if hidden_states.is_nested:
attention_mask = None

if attention_mask is not None:
# attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
# 0->false->keep this token -inf->true->mask this token
attention_mask = attention_mask.bool()
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
seqlen = attention_mask.shape[1]
lengths = torch.sum(~attention_mask, 1)
if not all([l == seqlen for l in lengths]):
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)