Skip to content

Commit

Permalink
use regex for tied keys
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Jan 24, 2025
1 parent 929ee67 commit 9007a52
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 104 deletions.
95 changes: 43 additions & 52 deletions src/transformers/models/zamba2/modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import re
from itertools import cycle
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -1558,81 +1559,71 @@ def get_layers(self, blocks, linear_layers, mamba_layers):
layers = []
self._tied_weights_keys = []
self.first_transformer_layer_id = 0

for layer_id, layer_type in enumerate(self.layers_block_type):
if layer_type == "hybrid":
if self.first_transformer_layer_id == 0:
self.first_transformer_layer_id = layer_id

block = next(blocks)

# If there are multiple shared blocks, we gather tied keys
if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1:
prefix_name = f"layers.{layer_id}."
tied_keys = [
"shared_transformer.self_attn.q_proj.weight",
"shared_transformer.self_attn.k_proj.weight",
"shared_transformer.self_attn.v_proj.weight",
"shared_transformer.self_attn.o_proj.weight",
"shared_transformer.feed_forward.gate_up_proj.weight",
"shared_transformer.feed_forward.down_proj.weight",
"shared_transformer.input_layernorm.weight",
"shared_transformer.pre_ff_layernorm.weight",
]
self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
# We will incorporate the 'layers.{layer_id}.' prefix into the patterns.
prefix_pattern = rf"^layers\.{layer_id}\.shared_transformer\."

# 1) Main shared keys (q/k/v/o_proj, gate_up_proj, down_proj, layernorms)
# combined into one pattern. You can separate these into multiple regex
# entries if you prefer finer granularity.
main_keys_pattern = re.compile(
prefix_pattern
+ r"(?:"
+ r"self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|"
+ r"feed_forward\.(?:gate_up_proj|down_proj)\.weight|"
+ r"(?:input_layernorm|pre_ff_layernorm)\.weight"
+ r")$"
)
self._tied_weights_keys.append(main_keys_pattern)

# 2) If using shared MLP adapter layers, create regex patterns for those.
if self.config.use_shared_mlp_adapter:
tied_keys_adapter = []
adapter_id = 0
for _layer_type in self.layers_block_type:
# Only add keys for the relevant adapter_id / block_id
if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id:
tied_keys_adapter.append(
"shared_transformer.feed_forward.gate_up_proj_adapter_list."
# gate_up_proj_adapter_list.X.[0|1].weight
# Instead of storing multiple strings, store a single combined regex
adapter_pattern = re.compile(
r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\."
+ str(adapter_id)
+ ".0.weight"
)
tied_keys_adapter.append(
"shared_transformer.feed_forward.gate_up_proj_adapter_list."
+ str(adapter_id)
+ ".1.weight"
+ r"\.(?:0|1)\.weight$"
)
self._tied_weights_keys.append(adapter_pattern)
adapter_id += 1
self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter]

# 3) If using shared Attention adapter layers, create regex patterns for those.
if self.config.use_shared_attention_adapter:
tied_keys_adapter = []
adapter_id = 0
for _layer_type in self.layers_block_type:
if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id:
tied_keys_adapter.append(
"shared_transformer.self_attn.linear_q_adapter_list."
+ str(adapter_id)
+ ".0.weight"
)
tied_keys_adapter.append(
"shared_transformer.self_attn.linear_k_adapter_list."
# linear_q_adapter_list.X.[0|1].weight
# linear_k_adapter_list.X.[0|1].weight
# linear_v_adapter_list.X.[0|1].weight
# We'll combine them, but if you want separate patterns, split accordingly.
attn_adapter_pattern = re.compile(
r"^shared_transformer\.self_attn\."
+ r"(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\."
+ str(adapter_id)
+ ".0.weight"
)
tied_keys_adapter.append(
"shared_transformer.self_attn.linear_v_adapter_list."
+ str(adapter_id)
+ ".0.weight"
)
tied_keys_adapter.append(
"shared_transformer.self_attn.linear_q_adapter_list."
+ str(adapter_id)
+ ".1.weight"
)
tied_keys_adapter.append(
"shared_transformer.self_attn.linear_k_adapter_list."
+ str(adapter_id)
+ ".1.weight"
)
tied_keys_adapter.append(
"shared_transformer.self_attn.linear_v_adapter_list."
+ str(adapter_id)
+ ".1.weight"
+ r"\.(?:0|1)\.weight$"
)
self._tied_weights_keys.append(attn_adapter_pattern)
adapter_id += 1
self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter]

# Construct the actual layer
layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers)))
else:
layers.append(next(mamba_layers))

return layers


Expand Down
95 changes: 43 additions & 52 deletions src/transformers/models/zamba2/modular_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import re
from itertools import cycle
from typing import Callable, Optional, Tuple, Union

Expand Down Expand Up @@ -976,81 +977,71 @@ def get_layers(self, blocks, linear_layers, mamba_layers):
layers = []
self._tied_weights_keys = []
self.first_transformer_layer_id = 0

for layer_id, layer_type in enumerate(self.layers_block_type):
if layer_type == "hybrid":
if self.first_transformer_layer_id == 0:
self.first_transformer_layer_id = layer_id

block = next(blocks)

# If there are multiple shared blocks, we gather tied keys
if self.config.num_mem_blocks * len(self.config.hybrid_layer_ids) > 1:
prefix_name = f"layers.{layer_id}."
tied_keys = [
"shared_transformer.self_attn.q_proj.weight",
"shared_transformer.self_attn.k_proj.weight",
"shared_transformer.self_attn.v_proj.weight",
"shared_transformer.self_attn.o_proj.weight",
"shared_transformer.feed_forward.gate_up_proj.weight",
"shared_transformer.feed_forward.down_proj.weight",
"shared_transformer.input_layernorm.weight",
"shared_transformer.pre_ff_layernorm.weight",
]
self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
# We will incorporate the 'layers.{layer_id}.' prefix into the patterns.
prefix_pattern = rf"^layers\.{layer_id}\.shared_transformer\."

# 1) Main shared keys (q/k/v/o_proj, gate_up_proj, down_proj, layernorms)
# combined into one pattern. You can separate these into multiple regex
# entries if you prefer finer granularity.
main_keys_pattern = re.compile(
prefix_pattern
+ r"(?:"
+ r"self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|"
+ r"feed_forward\.(?:gate_up_proj|down_proj)\.weight|"
+ r"(?:input_layernorm|pre_ff_layernorm)\.weight"
+ r")$"
)
self._tied_weights_keys.append(main_keys_pattern)

# 2) If using shared MLP adapter layers, create regex patterns for those.
if self.config.use_shared_mlp_adapter:
tied_keys_adapter = []
adapter_id = 0
for _layer_type in self.layers_block_type:
# Only add keys for the relevant adapter_id / block_id
if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id:
tied_keys_adapter.append(
"shared_transformer.feed_forward.gate_up_proj_adapter_list."
# gate_up_proj_adapter_list.X.[0|1].weight
# Instead of storing multiple strings, store a single combined regex
adapter_pattern = re.compile(
r"^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\."
+ str(adapter_id)
+ ".0.weight"
)
tied_keys_adapter.append(
"shared_transformer.feed_forward.gate_up_proj_adapter_list."
+ str(adapter_id)
+ ".1.weight"
+ r"\.(?:0|1)\.weight$"
)
self._tied_weights_keys.append(adapter_pattern)
adapter_id += 1
self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter]

# 3) If using shared Attention adapter layers, create regex patterns for those.
if self.config.use_shared_attention_adapter:
tied_keys_adapter = []
adapter_id = 0
for _layer_type in self.layers_block_type:
if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id:
tied_keys_adapter.append(
"shared_transformer.self_attn.linear_q_adapter_list."
+ str(adapter_id)
+ ".0.weight"
)
tied_keys_adapter.append(
"shared_transformer.self_attn.linear_k_adapter_list."
# linear_q_adapter_list.X.[0|1].weight
# linear_k_adapter_list.X.[0|1].weight
# linear_v_adapter_list.X.[0|1].weight
# We'll combine them, but if you want separate patterns, split accordingly.
attn_adapter_pattern = re.compile(
r"^shared_transformer\.self_attn\."
+ r"(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\."
+ str(adapter_id)
+ ".0.weight"
)
tied_keys_adapter.append(
"shared_transformer.self_attn.linear_v_adapter_list."
+ str(adapter_id)
+ ".0.weight"
)
tied_keys_adapter.append(
"shared_transformer.self_attn.linear_q_adapter_list."
+ str(adapter_id)
+ ".1.weight"
)
tied_keys_adapter.append(
"shared_transformer.self_attn.linear_k_adapter_list."
+ str(adapter_id)
+ ".1.weight"
)
tied_keys_adapter.append(
"shared_transformer.self_attn.linear_v_adapter_list."
+ str(adapter_id)
+ ".1.weight"
+ r"\.(?:0|1)\.weight$"
)
self._tied_weights_keys.append(attn_adapter_pattern)
adapter_id += 1
self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter]

# Construct the actual layer
layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers)))
else:
layers.append(next(mamba_layers))

return layers

def forward(
Expand Down

0 comments on commit 9007a52

Please sign in to comment.