Skip to content

Commit

Permalink
Delete RoCBertLayerBetterTransformer
Browse files Browse the repository at this point in the history
Signed-off-by: Shogo Hida <shogo.hida@gmail.com>
  • Loading branch information
shogohida committed Dec 5, 2022
1 parent 747d969 commit cd94151
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 110 deletions.
3 changes: 1 addition & 2 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
DistilBertLayerBetterTransformer,
FSMTEncoderLayerBetterTransformer,
MBartEncoderLayerBetterTransformer,
RoCBertLayerBetterTransformer,
ViltLayerBetterTransformer,
ViTLayerBetterTransformer,
Wav2Vec2EncoderLayerBetterTransformer,
Expand All @@ -42,7 +41,7 @@
"LayoutLMLayer": BertLayerBetterTransformer,
"BertGenerationLayer": BertLayerBetterTransformer,
"XLMRobertaLayer": BertLayerBetterTransformer,
"RoCBertLayer": RoCBertLayerBetterTransformer,
"RoCBertLayer": BertLayerBetterTransformer,
# Albert Family
"AlbertLayer": AlbertLayerBetterTransformer,
# Bart family
Expand Down
108 changes: 0 additions & 108 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,111 +1068,3 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states, attention_mask)


class RoCBertLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, rocbert_layer, config):
r"""
A simple conversion of the RoCBERT layer to its `BetterTransformer` implementation.
Args:
rocbert_layer (`torch.nn.Module`):
The original RoCBERT Layer where the weights needs to be retrieved.
"""
super().__init__(config)
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
rocbert_layer.attention.self.query.weight,
rocbert_layer.attention.self.key.weight,
rocbert_layer.attention.self.value.weight,
]
)
)
self.in_proj_bias = nn.Parameter(
torch.cat(
[
rocbert_layer.attention.self.query.bias,
rocbert_layer.attention.self.key.bias,
rocbert_layer.attention.self.value.bias,
]
)
)

# Out proj layer
self.out_proj_weight = rocbert_layer.attention.output.dense.weight
self.out_proj_bias = rocbert_layer.attention.output.dense.bias

# Linear layer 1
self.linear1_weight = rocbert_layer.intermediate.dense.weight
self.linear1_bias = rocbert_layer.intermediate.dense.bias

# Linear layer 2
self.linear2_weight = rocbert_layer.output.dense.weight
self.linear2_bias = rocbert_layer.output.dense.bias

# Layer norm 1
self.norm1_eps = rocbert_layer.attention.output.LayerNorm.eps
self.norm1_weight = rocbert_layer.attention.output.LayerNorm.weight
self.norm1_bias = rocbert_layer.attention.output.LayerNorm.bias

# Layer norm 2
self.norm2_eps = rocbert_layer.output.LayerNorm.eps
self.norm2_weight = rocbert_layer.output.LayerNorm.weight
self.norm2_bias = rocbert_layer.output.LayerNorm.bias

# Model hyper parameters
self.num_heads = rocbert_layer.attention.self.num_attention_heads
self.embed_dim = rocbert_layer.attention.self.all_head_size

# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, *_):
r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()

if hidden_states.is_nested:
attention_mask = None

if attention_mask is not None:
# attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
# 0->false->keep this token -inf->true->mask this token
attention_mask = attention_mask.bool()
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
seqlen = attention_mask.shape[1]
lengths = torch.sum(~attention_mask, 1)
if not all([l == seqlen for l in lengths]):
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)

0 comments on commit cd94151

Please sign in to comment.