diff --git a/docs/source/bettertransformer/overview.mdx b/docs/source/bettertransformer/overview.mdx index 4e4fecfae1b..aca1b9797a1 100644 --- a/docs/source/bettertransformer/overview.mdx +++ b/docs/source/bettertransformer/overview.mdx @@ -38,12 +38,13 @@ The list of supported model below: - [MarkupLM](https://arxiv.org/abs/2110.08518) - [RoBERTa](https://arxiv.org/abs/1907.11692) - [Splinter](https://arxiv.org/abs/2101.00438) -- [XLMRoberta](https://arxiv.org/abs/1911.02116) -- [Whisper](https://cdn.openai.com/papers/whisper.pdf) +- [ViLT](https://arxiv.org/abs/2102.03334) - [ViT](https://arxiv.org/abs/2010.11929) - [ViT-MAE](https://arxiv.org/abs/2111.06377) - [ViT-MSN](https://arxiv.org/abs/2204.07141) - [Wav2Vec2](https://arxiv.org/abs/2006.11477) +- [Whisper](https://cdn.openai.com/papers/whisper.pdf) +- [XLMRoberta](https://arxiv.org/abs/1911.02116) - [YOLOS](https://arxiv.org/abs/2106.00666) Let us know by opening an issue in 🤗 Optimum if you want more models to be supported, or check out the contribution guideline if you want to add it by yourself! diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index 9dfe8dd863c..448425a268b 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -18,6 +18,7 @@ BartEncoderLayerBetterTransformer, BertLayerBetterTransformer, DistilBertLayerBetterTransformer, + ViltLayerBetterTransformer, ViTLayerBetterTransformer, Wav2Vec2EncoderLayerBetterTransformer, WhisperEncoderLayerBetterTransformer, @@ -65,6 +66,7 @@ "ViTMAELayer": ViTLayerBetterTransformer, "ViTMSNLayer": ViTLayerBetterTransformer, "YolosLayer": ViTLayerBetterTransformer, + "ViltLayer": ViltLayerBetterTransformer, } diff --git a/optimum/bettertransformer/models/encoder_models.py b/optimum/bettertransformer/models/encoder_models.py index c6723d43114..d6a80a4241d 100644 --- a/optimum/bettertransformer/models/encoder_models.py +++ b/optimum/bettertransformer/models/encoder_models.py @@ -644,6 +644,102 @@ def forward(self, hidden_states, *_, **__): return (hidden_states,) +class ViltLayerBetterTransformer(BetterTransformerBaseLayer): + def __init__(self, vilt_layer, config): + r""" + A simple conversion of the VilTLayer to its `BetterTransformer` implementation. + + Args: + vilt_layer (`torch.nn.Module`): + The original `VilTLayer` where the weights needs to be retrieved. + """ + super().__init__(config) + # In_proj layer + self.in_proj_weight = nn.Parameter( + torch.cat( + [ + vilt_layer.attention.attention.query.weight, + vilt_layer.attention.attention.key.weight, + vilt_layer.attention.attention.value.weight, + ] + ) + ) + self.in_proj_bias = nn.Parameter( + torch.cat( + [ + vilt_layer.attention.attention.query.bias, + vilt_layer.attention.attention.key.bias, + vilt_layer.attention.attention.value.bias, + ] + ) + ) + + # Out proj layer + self.out_proj_weight = vilt_layer.attention.output.dense.weight + self.out_proj_bias = vilt_layer.attention.output.dense.bias + + # Linear layer 1 + self.linear1_weight = vilt_layer.intermediate.dense.weight + self.linear1_bias = vilt_layer.intermediate.dense.bias + + # Linear layer 2 + self.linear2_weight = vilt_layer.output.dense.weight + self.linear2_bias = vilt_layer.output.dense.bias + + # Layer norm 1 + self.norm1_eps = vilt_layer.layernorm_before.eps + self.norm1_weight = vilt_layer.layernorm_before.weight + self.norm1_bias = vilt_layer.layernorm_before.bias + + # Layer norm 2 + self.norm2_eps = vilt_layer.layernorm_after.eps + self.norm2_weight = vilt_layer.layernorm_after.weight + self.norm2_bias = vilt_layer.layernorm_after.bias + + # Model hyper parameters + self.num_heads = vilt_layer.attention.attention.num_attention_heads + self.embed_dim = int(vilt_layer.attention.attention.attention_head_size * self.num_heads) + + # Last step: set the last layer to `False` -> this will be set to `True` when converting the model + self.is_last_layer = False + self.norm_first = True + + self.validate_bettertransformer() + + def forward(self, hidden_states, *_, **__): + r""" + This is just a wrapper around the forward function proposed in: + https://github.com/huggingface/transformers/pull/19553 + """ + super().forward_checker() + attention_mask = None + + hidden_states = torch._transformer_encoder_layer_fwd( + hidden_states, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj_weight, + self.out_proj_bias, + self.use_gelu, + self.norm_first, + self.norm1_eps, + self.norm1_weight, + self.norm1_bias, + self.norm2_weight, + self.norm2_bias, + self.linear1_weight, + self.linear1_bias, + self.linear2_weight, + self.linear2_bias, + attention_mask, + ) + if hidden_states.is_nested and self.is_last_layer: + hidden_states = hidden_states.to_padded_tensor(0.0) + return (hidden_states,) + + class Wav2Vec2EncoderLayerBetterTransformer(BetterTransformerBaseLayer): def __init__(self, wav2vec2_layer, config): r""" diff --git a/tests/bettertransformer/test_bettertransformer_vision.py b/tests/bettertransformer/test_bettertransformer_vision.py index 7c36cbf8e90..0f860e0a6c9 100644 --- a/tests/bettertransformer/test_bettertransformer_vision.py +++ b/tests/bettertransformer/test_bettertransformer_vision.py @@ -15,7 +15,7 @@ import unittest from PIL import Image -from transformers import AutoFeatureExtractor +from transformers import AutoFeatureExtractor, AutoProcessor import requests from testing_bettertransformer_utils import BetterTransformersTestMixin @@ -30,6 +30,11 @@ ] +ALL_VISION_TEXT_MODELS_TO_TEST = [ + "hf-internal-testing/tiny-vilt-random-vqa", +] + + class BetterTransformersVisionTest(BetterTransformersTestMixin, unittest.TestCase): r""" Testing suite for Vision Models - tests all the tests defined in `BetterTransformersTestMixin` @@ -44,3 +49,20 @@ def prepare_inputs_for_class(self, model_id=None): feature_extractor = AutoFeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-ViTModel") inputs = feature_extractor(images=image, return_tensors="pt") return inputs + + +class BetterTransformersViLTTest(BetterTransformersTestMixin, unittest.TestCase): + r""" + Testing suite for Vision and Text Models - tests all the tests defined in `BetterTransformersTestMixin` + """ + all_models_to_test = ALL_VISION_TEXT_MODELS_TO_TEST + + def prepare_inputs_for_class(self, model_id=None): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + text = "How many cats are there?" + + # Model takes image and text as input + processor = AutoProcessor.from_pretrained(model_id) + inputs = processor(image, text, return_tensors="pt") + return inputs