diff --git a/docs/classes/models/xmod.rst b/docs/classes/models/xmod.rst new file mode 100644 index 0000000000..c42165661f --- /dev/null +++ b/docs/classes/models/xmod.rst @@ -0,0 +1,23 @@ +X-MOD +===== + +.. note:: + The X-MOD implementation integrated into Transformers already supports adapters. + To make this implementation compatible with Adapters, a few changes were necessary: + + - In Adapters, the X-MOD classes rely on the usual adapter methods instead of the custom methods introduced in Transformers, i.e.: + - ``set_active_adapters()`` instead of ``set_default_language()``. + - ``AdapterSetup`` context instead of ``lang_ids`` parameter. + - We provide dedicated model checkpoints converted for usage with Adapters + - e.g. ``facebook/xmod-base`` is available as ``AdapterHub/xmod-base`` with languages adapters split into separate repos (e.g. ``AdapterHub/xmod-base-af_ZA``) for on-demand loading. + +The abstract from the paper is the following: + +*Multilingual pre-trained models are known to suffer from the curse of multilinguality, which causes per-language performance to drop as they cover more languages. We address this issue by introducing language-specific modules, which allows us to grow the total capacity of the model, while keeping the total number of trainable parameters per language constant. In contrast with prior work that learns language-specific components post-hoc, we pre-train the modules of our Cross-lingual Modular (X-MOD) models from the start. Our experiments on natural language inference, named entity recognition and question answering show that our approach not only mitigates the negative interference between languages, but also enables positive transfer, resulting in improved monolingual and cross-lingual performance. Furthermore, our approach enables adding languages post-hoc with no measurable drop in performance, no longer limiting the model usage to the set of pre-trained languages.* + +XmodAdapterModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: adapters.XmodAdapterModel + :members: + :inherited-members: XmodPreTrainedModel diff --git a/docs/index.rst b/docs/index.rst index 3928193e5a..8ce684cf2f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -75,6 +75,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model classes/models/t5 classes/models/vit classes/models/xlmroberta + classes/models/xmod .. toctree:: :maxdepth: 2 diff --git a/docs/model_overview.md b/docs/model_overview.md index 64cd8be818..70f71c264b 100644 --- a/docs/model_overview.md +++ b/docs/model_overview.md @@ -30,6 +30,7 @@ The table below further shows which model architectures support which adaptation | [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [XLM-RoBERTa](classes/models/xlmroberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [X-MOD](classes/models/xmod.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | (*) If the used encoder and decoder model class are supported. diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index 9cb990f388..8cec0ee232 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -107,6 +107,7 @@ "models.t5": ["T5AdapterModel"], "models.vit": ["ViTAdapterModel"], "models.xlm_roberta": ["XLMRobertaAdapterModel"], + "models.xmod": ["XmodAdapterModel"], "trainer": ["AdapterTrainer", "Seq2SeqAdapterTrainer"], "training": [ "AdapterArguments", @@ -206,6 +207,7 @@ from .models.t5 import T5AdapterModel from .models.vit import ViTAdapterModel from .models.xlm_roberta import XLMRobertaAdapterModel + from .models.xmod import XmodAdapterModel from .trainer import AdapterTrainer, Seq2SeqAdapterTrainer from .training import AdapterArguments, setup_adapter_training from .utils import ( diff --git a/src/adapters/composition.py b/src/adapters/composition.py index 25e1b2b702..e3ee04925f 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -135,6 +135,7 @@ def __init__( "xlm-roberta", "bert-generation", "llama", + "xmod", ], } diff --git a/src/adapters/configuration/adapter_config.py b/src/adapters/configuration/adapter_config.py index 9d207cf249..3ae6ac3be1 100644 --- a/src/adapters/configuration/adapter_config.py +++ b/src/adapters/configuration/adapter_config.py @@ -162,9 +162,10 @@ class BnConfig(AdapterConfigBase): use_gating (:obj:`bool`, optional): Place a trainable gating module besides the added parameter module to control module activation. This is e.g. used for UniPELT. Defaults to False. - residual_before_ln (:obj:`bool`, optional): - If True, take the residual connection around the adapter bottleneck before the layer normalization. Only - applicable if :obj:`original_ln_before` is True. + residual_before_ln (:obj:`bool` or :obj:`str`, optional): + If True, take the residual connection around the adapter bottleneck before the layer normalization. + If set to "post_add", take the residual connection around the adapter bottleneck after the previous residual connection. + Only applicable if :obj:`original_ln_before` is True. adapter_residual_before_ln (:obj:`bool`, optional): If True, apply the residual connection around the adapter modules before the new layer normalization within the adapter. Only applicable if :obj:`ln_after` is True and :obj:`is_parallel` is False. @@ -225,7 +226,7 @@ class BnConfig(AdapterConfigBase): is_parallel: bool = False scaling: Union[float, str] = 1.0 use_gating: bool = False - residual_before_ln: bool = True + residual_before_ln: Union[bool, str] = True adapter_residual_before_ln: bool = False inv_adapter: Optional[str] = None inv_adapter_reduction_factor: Optional[float] = None @@ -267,7 +268,7 @@ class SeqBnConfig(BnConfig): original_ln_before: bool = True original_ln_after: bool = True - residual_before_ln: bool = True + residual_before_ln: Union[bool, str] = True adapter_residual_before_ln: bool = False ln_before: bool = False ln_after: bool = False @@ -306,7 +307,7 @@ class DoubleSeqBnConfig(BnConfig): original_ln_before: bool = False original_ln_after: bool = True - residual_before_ln: bool = True + residual_before_ln: Union[bool, str] = True adapter_residual_before_ln: bool = False ln_before: bool = False ln_after: bool = False diff --git a/src/adapters/head_utils.py b/src/adapters/head_utils.py index 8f6dab42b2..d1425f25d6 100644 --- a/src/adapters/head_utils.py +++ b/src/adapters/head_utils.py @@ -256,6 +256,61 @@ }, "layers": ["lm_head.dense", None, "lm_head.layer_norm", "lm_head.decoder"], }, + # Xmod + "XmodForSequenceClassification": { + "config": { + "head_type": "classification", + "layers": 2, + "activation_function": "tanh", + "use_pooler": False, + }, + "layers": [None, "classifier.dense", None, None, "classifier.out_proj"], + }, + "XmodForMultipleChoice": { + "config": { + "head_type": "multiple_choice", + "layers": 1, + "activation_function": None, + "use_pooler": True, + }, + "layers": [None, "classifier"], + }, + "XmodForTokenClassification": { + "config": { + "head_type": "tagging", + "layers": 1, + "activation_function": None, + }, + "layers": [None, "classifier"], + }, + "XmodForQuestionAnswering": { + "config": { + "head_type": "question_answering", + "layers": 1, + "activation_function": None, + }, + "layers": [None, "qa_outputs"], + }, + "XmodForMaskedLM": { + "config": { + "head_type": "masked_lm", + "layers": 2, + "activation_function": "gelu", + "layer_norm": True, + "bias": True, + }, + "layers": ["lm_head.dense", None, "lm_head.layer_norm", "lm_head.decoder"], + }, + "XmodForCausalLM": { + "config": { + "head_type": "causal_lm", + "layers": 2, + "activation_function": "gelu", + "layer_norm": True, + "bias": True, + }, + "layers": ["lm_head.dense", None, "lm_head.layer_norm", "lm_head.decoder"], + }, # BART "BartForSequenceClassification": { "config": { diff --git a/src/adapters/layer.py b/src/adapters/layer.py index 02b81d3781..32b57915a5 100644 --- a/src/adapters/layer.py +++ b/src/adapters/layer.py @@ -227,7 +227,13 @@ def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapt for param in self.adapter_fusion_layer[sub_setup.name].parameters(): param.requires_grad = True - def get_adapter(self, adapter_name): + def freeze_adapter(self, adapter_name: str, freeze: bool = True): + if adapter_name in self.adapters: + self.adapters[adapter_name].train(not freeze) + for param in self.adapters[adapter_name].parameters(): + param.requires_grad = not freeze + + def get_adapter(self, adapter_name: str): if adapter_name in self.adapters: return self.adapters[adapter_name] else: diff --git a/src/adapters/lora.py b/src/adapters/lora.py index 643e579c47..3549e7a8fc 100644 --- a/src/adapters/lora.py +++ b/src/adapters/lora.py @@ -173,6 +173,12 @@ def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapt for param in self.loras[name].parameters(): param.requires_grad = True + def freeze_adapter(self, adapter_name: str, freeze: bool = True): + if adapter_name in self.loras: + self.loras[adapter_name].train(not freeze) + for param in self.loras[adapter_name].parameters(): + param.requires_grad = not freeze + def get_adapter(self, adapter_name: str) -> nn.Module: if adapter_name in self.loras: return self.loras[adapter_name] diff --git a/src/adapters/modeling.py b/src/adapters/modeling.py index 3f7b3a03ec..b61419069e 100644 --- a/src/adapters/modeling.py +++ b/src/adapters/modeling.py @@ -145,7 +145,7 @@ def pre_forward( """ query = None - if self.residual_before_ln: + if self.residual_before_ln is True: residual = hidden_states if fusion_config is not None and fusion_config["query_before_ln"]: @@ -153,7 +153,10 @@ def pre_forward( if self.original_ln_before: if layer_norm: - hidden_states = layer_norm(hidden_states + input_tensor) + hidden_states = hidden_states + input_tensor + if self.residual_before_ln == "post_add": + residual = hidden_states + hidden_states = layer_norm(hidden_states) else: hidden_states = hidden_states + input_tensor diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index b62291b98f..4df8e98f32 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -19,6 +19,7 @@ from .llama.mixin_llama import LlamaModelAdapterMixin from .t5.mixin_t5 import T5BlockAdaptersMixin, T5ModelAdaptersMixin, T5ModelAdaptersWithHeadsMixin from .vit.mixin_vit import ViTIntermediateAdaptersMixin, ViTModelAdaptersMixin +from .xmod.mixin_xmod import XmodModelAdaptersMixin # IMPORTANT: Only add classes to this mapping that are not copied into the adapters package @@ -58,6 +59,8 @@ "ViTModel": ViTModelAdaptersMixin, "XLMRobertaLayer": BertLayerAdaptersMixin, "XLMRobertaModel": BertModelAdaptersMixin, + "XmodLayer": BertLayerAdaptersMixin, + "XmodModel": XmodModelAdaptersMixin, "DebertaModel": BertModelAdaptersMixin, "DebertaLayer": BertLayerAdaptersMixin, "DebertaV2Model": BertModelAdaptersMixin, diff --git a/src/adapters/models/auto/adapter_model.py b/src/adapters/models/auto/adapter_model.py index 40abf337c2..75a154c6c4 100644 --- a/src/adapters/models/auto/adapter_model.py +++ b/src/adapters/models/auto/adapter_model.py @@ -26,6 +26,7 @@ ("t5", "T5AdapterModel"), ("vit", "ViTAdapterModel"), ("xlm-roberta", "XLMRobertaAdapterModel"), + ("xmod", "XmodAdapterModel"), ] ) diff --git a/src/adapters/models/xmod/__init__.py b/src/adapters/models/xmod/__init__.py new file mode 100644 index 0000000000..7140f6f465 --- /dev/null +++ b/src/adapters/models/xmod/__init__.py @@ -0,0 +1,39 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2023 The Adapter-Hub Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "adapter_model": ["XmodAdapterModel"], +} + + +if TYPE_CHECKING: + from .adapter_model import XmodAdapterModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/adapters/models/xmod/adapter_model.py b/src/adapters/models/xmod/adapter_model.py new file mode 100644 index 0000000000..31ca7acd3b --- /dev/null +++ b/src/adapters/models/xmod/adapter_model.py @@ -0,0 +1,256 @@ +from typing import Optional + +import torch + +from transformers.models.xmod.modeling_xmod import ( + XMOD_INPUTS_DOCSTRING, + XMOD_START_DOCSTRING, + XmodModel, + XmodPreTrainedModel, +) +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward + +from ...context import AdapterSetup +from ...heads import ( + BertStyleMaskedLMHead, + BiaffineParsingHead, + CausalLMHead, + ClassificationHead, + ModelWithFlexibleHeadsAdaptersMixin, + MultiLabelClassificationHead, + MultipleChoiceHead, + QuestionAnsweringHead, + TaggingHead, +) +from ...model_mixin import EmbeddingAdaptersWrapperMixin +from ...wrappers import init + + +@add_start_docstrings( + """X-MOD Model transformer with the option to add multiple flexible heads on top.""", + XMOD_START_DOCSTRING, +) +class XmodAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, XmodPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.roberta = XmodModel(config) + init(self.roberta) + + self._init_head_modules() + + self.init_weights() + + @add_start_docstrings_to_model_forward(XMOD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + lang_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + head: Optional[str] = None, + output_adapter_gating_scores: Optional[bool] = False, + output_adapter_fusion_attentions: Optional[bool] = False, + **kwargs + ): + # Flatten for multiple choice tasks + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + lang_ids = lang_ids.repeat(input_ids.size(0) * input_ids.size(1)) if lang_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + lang_ids=lang_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + ) + # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads + if not return_dict: + head_inputs = (outputs[0],) + outputs[2:] + else: + head_inputs = outputs + pooled_output = outputs[1] + + if head or AdapterSetup.get_context_head_setup() or self.active_head: + head_outputs = self.forward_head( + head_inputs, + head_name=head, + attention_mask=attention_mask, + return_dict=return_dict, + pooled_output=pooled_output, + **kwargs, + ) + return head_outputs + else: + # in case no head is used just return the output of the base model (including pooler output) + return outputs + + # Copied from RobertaForCausalLM + def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past, + "adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False), + } + + head_types = { + "classification": ClassificationHead, + "multilabel_classification": MultiLabelClassificationHead, + "tagging": TaggingHead, + "multiple_choice": MultipleChoiceHead, + "question_answering": QuestionAnsweringHead, + "dependency_parsing": BiaffineParsingHead, + "masked_lm": BertStyleMaskedLMHead, + "causal_lm": CausalLMHead, + } + + def add_classification_head( + self, + head_name, + num_labels=2, + layers=2, + activation_function="tanh", + overwrite_ok=False, + multilabel=False, + id2label=None, + use_pooler=False, + ): + """ + Adds a sequence classification head on top of the model. + + Args: + head_name (str): The name of the head. + num_labels (int, optional): Number of classification labels. Defaults to 2. + layers (int, optional): Number of layers. Defaults to 2. + activation_function (str, optional): Activation function. Defaults to 'tanh'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + multilabel (bool, optional): Enable multilabel classification setup. Defaults to False. + """ + + if multilabel: + head = MultiLabelClassificationHead( + self, head_name, num_labels, layers, activation_function, id2label, use_pooler + ) + else: + head = ClassificationHead(self, head_name, num_labels, layers, activation_function, id2label, use_pooler) + self.add_prediction_head(head, overwrite_ok) + + def add_multiple_choice_head( + self, + head_name, + num_choices=2, + layers=2, + activation_function="tanh", + overwrite_ok=False, + id2label=None, + use_pooler=False, + ): + """ + Adds a multiple choice head on top of the model. + + Args: + head_name (str): The name of the head. + num_choices (int, optional): Number of choices. Defaults to 2. + layers (int, optional): Number of layers. Defaults to 2. + activation_function (str, optional): Activation function. Defaults to 'tanh'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = MultipleChoiceHead(self, head_name, num_choices, layers, activation_function, id2label, use_pooler) + self.add_prediction_head(head, overwrite_ok) + + def add_tagging_head( + self, head_name, num_labels=2, layers=1, activation_function="tanh", overwrite_ok=False, id2label=None + ): + """ + Adds a token classification head on top of the model. + + Args: + head_name (str): The name of the head. + num_labels (int, optional): Number of classification labels. Defaults to 2. + layers (int, optional): Number of layers. Defaults to 1. + activation_function (str, optional): Activation function. Defaults to 'tanh'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = TaggingHead(self, head_name, num_labels, layers, activation_function, id2label) + self.add_prediction_head(head, overwrite_ok) + + def add_qa_head( + self, head_name, num_labels=2, layers=1, activation_function="tanh", overwrite_ok=False, id2label=None + ): + head = QuestionAnsweringHead(self, head_name, num_labels, layers, activation_function, id2label) + self.add_prediction_head(head, overwrite_ok) + + def add_dependency_parsing_head(self, head_name, num_labels=2, overwrite_ok=False, id2label=None): + """ + Adds a biaffine dependency parsing head on top of the model. The parsing head uses the architecture described + in "Is Supervised Syntactic Parsing Beneficial for Language Understanding? An Empirical Investigation" (Glavaš + & Vulić, 2021) (https://arxiv.org/pdf/2008.06788.pdf). + + Args: + head_name (str): The name of the head. + num_labels (int, optional): Number of labels. Defaults to 2. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + id2label (dict, optional): Mapping from label ids to labels. Defaults to None. + """ + head = BiaffineParsingHead(self, head_name, num_labels, id2label) + self.add_prediction_head(head, overwrite_ok) + + def add_masked_lm_head(self, head_name, activation_function="gelu", overwrite_ok=False): + """ + Adds a masked language modeling head on top of the model. + + Args: + head_name (str): The name of the head. + activation_function (str, optional): Activation function. Defaults to 'gelu'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = BertStyleMaskedLMHead(self, head_name, activation_function=activation_function) + self.add_prediction_head(head, overwrite_ok=overwrite_ok) + + def add_causal_lm_head(self, head_name, activation_function="gelu", overwrite_ok=False): + """ + Adds a causal language modeling head on top of the model. + + Args: + head_name (str): The name of the head. + activation_function (str, optional): Activation function. Defaults to 'gelu'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = CausalLMHead( + self, head_name, layers=2, activation_function=activation_function, layer_norm=True, bias=True + ) + self.add_prediction_head(head, overwrite_ok=overwrite_ok) diff --git a/src/adapters/models/xmod/mixin_xmod.py b/src/adapters/models/xmod/mixin_xmod.py new file mode 100644 index 0000000000..4522f133de --- /dev/null +++ b/src/adapters/models/xmod/mixin_xmod.py @@ -0,0 +1,68 @@ +from typing import Callable, Iterable, Tuple + +import torch.nn as nn + +from transformers.utils import logging + +from ...composition import adjust_tensors_for_parallel_ +from ...model_mixin import EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin +from ...context import ForwardContext + + +logger = logging.get_logger(__name__) + + +class XmodModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelBaseAdaptersMixin): + """Adds adapters to the BertModel module.""" + + def init_adapters(self, model_config, adapters_config): + super().init_adapters(model_config, adapters_config) + + # Set hook for parallel composition + for _, layer in self.iter_layers(): + self._set_layer_hook_for_parallel(layer) + + # Delete original adapter modules + for _, layer in self.iter_layers(): + del layer.output.adapter_modules + + def _set_layer_hook_for_parallel(self, layer: nn.Module): + def hook(module, input): + # hook[1] is lang_ids tensor + adjust_tensors_for_parallel_(input[0], input[2]) + return input + + layer.register_forward_pre_hook(hook) + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + for i, layer in enumerate(self.encoder.layer): + yield i, layer + + def hook_after_embeddings(self, hook_fn: Callable): + return self.embeddings.register_forward_hook(hook_fn) + + @ForwardContext.wrap + def forward(self, *args, **kwargs): + if "lang_ids" in kwargs and kwargs["lang_ids"] is not None: + raise ValueError("XmodModel with adapters does not support `lang_ids` as an argument. Use `set_active_adapters` instead.") + else: + kwargs["lang_ids"] = 1 + return super().forward(*args, **kwargs) + + # Override adapter-specific methods in original implementation + + def set_default_language(self, language: str): + raise ValueError("`set_default_language` is not implemented for models using `adapters`. Use `set_active_adapters` instead.") + + def freeze_embeddings_and_language_adapters(self): + """ + Freeze the embeddings and language adapters of the model. Usually, this is applied before the model is + fine-tuned on a downstream task. + """ + # TODO: Replace this by a general method for `adapters`. + logger.info("Freezing embeddings") + for parameter in self.base_model.embeddings.parameters(): + parameter.requires_grad = False + logger.info("Freezing adapters") + for adapter_name in self.adapters_config: + self.apply_to_adapter_layers(lambda i, layer: layer.freeze_adapter(adapter_name)) diff --git a/src/adapters/models/xmod/modeling_xmod.py b/src/adapters/models/xmod/modeling_xmod.py new file mode 100644 index 0000000000..bd3d9a5ad0 --- /dev/null +++ b/src/adapters/models/xmod/modeling_xmod.py @@ -0,0 +1,168 @@ +# coding=utf-8 +# Copyright 2023 Meta AI Team and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch X-MOD model.""" + +import math +from typing import Optional, Tuple + +import torch +import torch.utils.checkpoint +from torch import nn + +from transformers.models.xmod.modeling_xmod import XmodOutput, XmodSelfAttention, XmodSelfOutput + +from ...composition import adjust_tensors_for_parallel +from ..bert.mixin_bert import BertOutputAdaptersMixin, BertSelfAttentionAdaptersMixin, BertSelfOutputAdaptersMixin + + +# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Xmod +class XmodSelfAttentionWithAdapters(BertSelfAttentionAdaptersMixin, XmodSelfAttention): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + key_layer, value_layer, attention_mask = self.prefix_tuning( + key_layer, value_layer, hidden_states, attention_mask + ) + (query_layer,) = adjust_tensors_for_parallel(key_layer, query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in XmodModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class XmodSelfOutputWithAdapters(BertSelfOutputAdaptersMixin, XmodSelfOutput): + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, None) + return hidden_states + + +class XmodOutputWithAdapters(BertOutputAdaptersMixin, XmodOutput): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if config.adapter_layer_norm: + self.adapter_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + else: + self.adapter_layer_norm = None + self.adapter_reuse_layer_norm = config.adapter_reuse_layer_norm + # Other adapter-specific modules of original module are not created here + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, lang_ids: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + if self.adapter_layer_norm is not None: + layer_norm = self.adapter_layer_norm + elif self.adapter_reuse_layer_norm: + layer_norm = self.LayerNorm + hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, layer_norm) + return hidden_states diff --git a/src/adapters/prefix_tuning.py b/src/adapters/prefix_tuning.py index 1e7e0cf179..af9c57e03f 100644 --- a/src/adapters/prefix_tuning.py +++ b/src/adapters/prefix_tuning.py @@ -350,6 +350,15 @@ def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapt for param in self.prefix_gates[prefix_tuning_name].parameters(): param.requires_grad = unfreeze_adapters + def freeze_adapter(self, adapter_name: str, freeze: bool = True): + if adapter_name in self.prefixes: + self.pool.get_prefix(adapter_name)[self.location_key].train(not freeze) + for param in self.pool.get_prefix(adapter_name)[self.location_key].parameters(): + param.requires_grad = not freeze + if adapter_name in self.prefix_gates: + for param in self.prefix_gates[adapter_name].parameters(): + param.requires_grad = not freeze + def get_adapter(self, adapter_name): return_dict = nn.ModuleDict() # Make sure to only return params once diff --git a/tests_adapters/models/test_xmod.py b/tests_adapters/models/test_xmod.py new file mode 100644 index 0000000000..fb4c95fdeb --- /dev/null +++ b/tests_adapters/models/test_xmod.py @@ -0,0 +1,12 @@ +# flake8: noqa: F403,F405 +from adapters import XmodAdapterModel +from tests.models.xmod.test_modeling_xmod import * +from transformers.testing_utils import require_torch + +from .base import AdapterModelTesterMixin + + +@require_torch +class XmodAdapterModelTest(AdapterModelTesterMixin, XmodModelTest): + all_model_classes = (XmodAdapterModel,) + fx_compatible = False diff --git a/tests_adapters/test_adapter_conversion.py b/tests_adapters/test_adapter_conversion.py index 6368cc14fc..ac57daa315 100644 --- a/tests_adapters/test_adapter_conversion.py +++ b/tests_adapters/test_adapter_conversion.py @@ -19,6 +19,7 @@ BertPreTrainedModel, RobertaPreTrainedModel, XLMRobertaPreTrainedModel, + XmodPreTrainedModel, ) from transformers.testing_utils import require_torch, torch_device @@ -56,7 +57,11 @@ def run_test(self, static_model, input_shape=None, label_dict=None): # HACK for bert-based models if isinstance(static_model, BertPreTrainedModel): unexpected_keys = [k for k in unexpected_keys if "cls.predictions.bias" not in k] - elif isinstance(static_model, RobertaPreTrainedModel) or isinstance(static_model, XLMRobertaPreTrainedModel): + elif ( + isinstance(static_model, RobertaPreTrainedModel) + or isinstance(static_model, XLMRobertaPreTrainedModel) + or isinstance(static_model, XmodPreTrainedModel) + ): unexpected_keys = [k for k in unexpected_keys if "lm_head.bias" not in k] elif isinstance(static_model, AlbertPreTrainedModel): unexpected_keys = [k for k in unexpected_keys if "predictions.bias" not in k] diff --git a/tests_adapters/test_xmod.py b/tests_adapters/test_xmod.py new file mode 100644 index 0000000000..450c84231d --- /dev/null +++ b/tests_adapters/test_xmod.py @@ -0,0 +1,61 @@ +import unittest + +from transformers import XmodConfig +from transformers.testing_utils import require_torch + +from .composition.test_parallel import ParallelAdapterInferenceTestMixin +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, +) +from .test_adapter import AdapterTestBase, make_config +from .test_adapter_backward_compability import CompabilityTestMixin +from .test_adapter_conversion import ModelClassConversionTestMixin +from .test_adapter_fusion_common import AdapterFusionModelTestMixin +from .test_adapter_heads import PredictionHeadModelTestMixin + + +class XmodAdapterTestBase(AdapterTestBase): + config_class = XmodConfig + config = make_config( + XmodConfig, + hidden_size=32, + num_hidden_layers=4, + num_attention_heads=4, + intermediate_size=37, + vocab_size=250002, + max_position_embeddings=512, + default_language="en_XX", + ) + tokenizer_name = "xlm-roberta-base" + + +@require_torch +class XmodAdapterTest( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, + AdapterFusionModelTestMixin, + CompabilityTestMixin, + PredictionHeadModelTestMixin, + ParallelAdapterInferenceTestMixin, + XmodAdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class XmodClassConversionTest( + ModelClassConversionTestMixin, + XmodAdapterTestBase, + unittest.TestCase, +): + pass diff --git a/utils/convert_xmod_checkpoint.py b/utils/convert_xmod_checkpoint.py new file mode 100644 index 0000000000..30ca0ede74 --- /dev/null +++ b/utils/convert_xmod_checkpoint.py @@ -0,0 +1,85 @@ +""" +This script can be used to convert an Xmod checkpoints (including adapters) from the HF format to the Adapters format. +""" +import argparse +import os +import re + +import torch + +from adapters import SeqBnConfig, XmodAdapterModel +from transformers import XmodModel + + +def convert_xmod_checkpoint(model_name: str, output_dir: str): + # Instantiate new model + orig_model = XmodModel.from_pretrained(model_name) + model_config = orig_model.config + new_model = XmodAdapterModel.from_pretrained(model_name) + for lang in model_config.languages: + adapter_config = SeqBnConfig( + reduction_factor=model_config.adapter_reduction_factor, + # selection between (shared) adapter LN and original LN is done in XmodOutput + original_ln_before=model_config.adapter_layer_norm or model_config.adapter_reuse_layer_norm, + original_ln_after=False, + residual_before_ln=False if model_config.ln_before_adapter else "post_add", + non_linearity=model_config.hidden_act, + ) + new_model.add_adapter(lang, adapter_config) + + # Convert state dict + new_state_dict = {} + for k, v in orig_model.state_dict().items(): + if match := re.match(r"(.+)\.adapter_modules\.(?P\w+)\.(?P\w+)\.(.+)", k): + prefix, suffix = match.group(1, 4) + lang = match.group("lang") + layer = match.group("layer") + if layer == "dense1": + new_layer = "adapter_down.0" + elif layer == "dense2": + new_layer = "adapter_up" + else: + raise ValueError(f"Unknown layer {layer}") + new_k = f"{new_model.base_model_prefix}.{prefix}.adapters.{lang}.{new_layer}.{suffix}" + new_state_dict[new_k] = v + else: + new_state_dict[f"{new_model.base_model_prefix}.{k}"] = v + missing_keys, unexpected_keys = new_model.load_state_dict(new_state_dict, strict=False) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + + # Check equal outputs + orig_model.eval() + new_model.eval() + inputs = orig_model.dummy_inputs + for lang in model_config.languages: + orig_model.set_default_language(lang) + orig_outputs = orig_model(**inputs) + new_model.set_active_adapters(lang) + new_outputs = new_model(**inputs) + all_close = torch.allclose(orig_outputs.last_hidden_state, new_outputs.last_hidden_state) + check_str = "OK" if all_close else "FAIL" + print(f"{lang:>6}: {check_str}") + + # Save new model & all adapters + os.makedirs(output_dir, exist_ok=True) + new_model.save_all_adapters(output_dir) + # Remove all adapters except for English + for lang in model_config.languages: + if lang != "en_XX": + new_model.delete_adapter(lang) + new_model.save_pretrained(output_dir) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-n", "--model_name", type=str, required=True) + parser.add_argument("-o", "--output_dir", type=str, required=True) + + args = parser.parse_args() + + convert_xmod_checkpoint(args.model_name, args.output_dir) + + +if __name__ == "__main__": + main()