diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 2148d0ce910f..7c418beae9f4 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -20,6 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import re from itertools import cycle from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -1558,81 +1559,71 @@ def get_layers(self, blocks, linear_layers, mamba_layers): layers = [] self._tied_weights_keys = [] self.first_transformer_layer_id = 0 + for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": if self.first_transformer_layer_id == 0: self.first_transformer_layer_id = layer_id + block = next(blocks) + + # If there are multiple shared blocks, we gather tied keys if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1: - prefix_name = f"layers.{layer_id}." - tied_keys = [ - "shared_transformer.self_attn.q_proj.weight", - "shared_transformer.self_attn.k_proj.weight", - "shared_transformer.self_attn.v_proj.weight", - "shared_transformer.self_attn.o_proj.weight", - "shared_transformer.feed_forward.gate_up_proj.weight", - "shared_transformer.feed_forward.down_proj.weight", - "shared_transformer.input_layernorm.weight", - "shared_transformer.pre_ff_layernorm.weight", - ] - self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + # We will incorporate the 'layers.{layer_id}.' prefix into the patterns. + prefix_pattern = rf"^layers\.{layer_id}\.shared_transformer\." + + # 1) Main shared keys (q/k/v/o_proj, gate_up_proj, down_proj, layernorms) + # combined into one pattern. You can separate these into multiple regex + # entries if you prefer finer granularity. + main_keys_pattern = re.compile( + prefix_pattern + + r"(?:" + + r"self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|" + + r"feed_forward\.(?:gate_up_proj|down_proj)\.weight|" + + r"(?:input_layernorm|pre_ff_layernorm)\.weight" + + r")$" + ) + self._tied_weights_keys.append(main_keys_pattern) + + # 2) If using shared MLP adapter layers, create regex patterns for those. if self.config.use_shared_mlp_adapter: - tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: + # Only add keys for the relevant adapter_id / block_id if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." + # gate_up_proj_adapter_list.X.[0|1].weight + # Instead of storing multiple strings, store a single combined regex + adapter_pattern = re.compile( + r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." - + str(adapter_id) - + ".1.weight" + + r"\.(?:0|1)\.weight$" ) + self._tied_weights_keys.append(adapter_pattern) adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] + + # 3) If using shared Attention adapter layers, create regex patterns for those. if self.config.use_shared_attention_adapter: - tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." + # linear_q_adapter_list.X.[0|1].weight + # linear_k_adapter_list.X.[0|1].weight + # linear_v_adapter_list.X.[0|1].weight + # We'll combine them, but if you want separate patterns, split accordingly. + attn_adapter_pattern = re.compile( + r"^shared_transformer\.self_attn\." + + r"(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\." + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." - + str(adapter_id) - + ".1.weight" + + r"\.(?:0|1)\.weight$" ) + self._tied_weights_keys.append(attn_adapter_pattern) adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] + + # Construct the actual layer layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) + return layers diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 8a4e379b3ce0..64d8cd8087cf 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import math +import re from itertools import cycle from typing import Callable, Optional, Tuple, Union @@ -976,81 +977,71 @@ def get_layers(self, blocks, linear_layers, mamba_layers): layers = [] self._tied_weights_keys = [] self.first_transformer_layer_id = 0 + for layer_id, layer_type in enumerate(self.layers_block_type): if layer_type == "hybrid": if self.first_transformer_layer_id == 0: self.first_transformer_layer_id = layer_id + block = next(blocks) + + # If there are multiple shared blocks, we gather tied keys if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1: - prefix_name = f"layers.{layer_id}." - tied_keys = [ - "shared_transformer.self_attn.q_proj.weight", - "shared_transformer.self_attn.k_proj.weight", - "shared_transformer.self_attn.v_proj.weight", - "shared_transformer.self_attn.o_proj.weight", - "shared_transformer.feed_forward.gate_up_proj.weight", - "shared_transformer.feed_forward.down_proj.weight", - "shared_transformer.input_layernorm.weight", - "shared_transformer.pre_ff_layernorm.weight", - ] - self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]] + # We will incorporate the 'layers.{layer_id}.' prefix into the patterns. + prefix_pattern = rf"^layers\.{layer_id}\.shared_transformer\." + + # 1) Main shared keys (q/k/v/o_proj, gate_up_proj, down_proj, layernorms) + # combined into one pattern. You can separate these into multiple regex + # entries if you prefer finer granularity. + main_keys_pattern = re.compile( + prefix_pattern + + r"(?:" + + r"self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|" + + r"feed_forward\.(?:gate_up_proj|down_proj)\.weight|" + + r"(?:input_layernorm|pre_ff_layernorm)\.weight" + + r")$" + ) + self._tied_weights_keys.append(main_keys_pattern) + + # 2) If using shared MLP adapter layers, create regex patterns for those. if self.config.use_shared_mlp_adapter: - tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: + # Only add keys for the relevant adapter_id / block_id if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." + # gate_up_proj_adapter_list.X.[0|1].weight + # Instead of storing multiple strings, store a single combined regex + adapter_pattern = re.compile( + r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\." + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.feed_forward.gate_up_proj_adapter_list." - + str(adapter_id) - + ".1.weight" + + r"\.(?:0|1)\.weight$" ) + self._tied_weights_keys.append(adapter_pattern) adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] + + # 3) If using shared Attention adapter layers, create regex patterns for those. if self.config.use_shared_attention_adapter: - tied_keys_adapter = [] adapter_id = 0 for _layer_type in self.layers_block_type: if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id: - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." + # linear_q_adapter_list.X.[0|1].weight + # linear_k_adapter_list.X.[0|1].weight + # linear_v_adapter_list.X.[0|1].weight + # We'll combine them, but if you want separate patterns, split accordingly. + attn_adapter_pattern = re.compile( + r"^shared_transformer\.self_attn\." + + r"(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\." + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." - + str(adapter_id) - + ".0.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_q_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_k_adapter_list." - + str(adapter_id) - + ".1.weight" - ) - tied_keys_adapter.append( - "shared_transformer.self_attn.linear_v_adapter_list." - + str(adapter_id) - + ".1.weight" + + r"\.(?:0|1)\.weight$" ) + self._tied_weights_keys.append(attn_adapter_pattern) adapter_id += 1 - self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter] + + # Construct the actual layer layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers))) else: layers.append(next(mamba_layers)) + return layers def forward(