Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Major] Update with string access and code refactory (#83) #93

Merged
merged 4 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
tutorials/* linguist-vendored
tutorials/basic_tutorials/* linguist-vendored
tutorials/advanced_tutorials/* linguist-vendored
pyvene_101.ipynb
tests/qa_runbook.ipynb linguist-vendored
Original file line number Diff line number Diff line change
Expand Up @@ -9,64 +9,17 @@
"""


from ..constants import CONST_INPUT_HOOK, CONST_OUTPUT_HOOK, CONST_QKV_INDICES
from ..constants import *


"""gpt2 base model"""
backpack_gpt2_lm_type_to_module_mapping = {
"block_input": ("backpack.gpt2_model.h[%s]", CONST_INPUT_HOOK),
"block_output": ("backpack.gpt2_model.h[%s]", CONST_OUTPUT_HOOK),
"mlp_activation": ("backpack.gpt2_model.h[%s].mlp.act", CONST_OUTPUT_HOOK),
"mlp_output": ("backpack.gpt2_model.h[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("backpack.gpt2_model.h[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("backpack.gpt2_model.h[%s].attn.c_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("backpack.gpt2_model.h[%s].attn.c_proj", CONST_INPUT_HOOK),
"attention_output": ("backpack.gpt2_model.h[%s].attn", CONST_OUTPUT_HOOK),
"attention_input": ("backpack.gpt2_model.h[%s].attn", CONST_INPUT_HOOK),
"query_output": ("backpack.gpt2_model.h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"key_output": ("backpack.gpt2_model.h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"value_output": ("backpack.gpt2_model.h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"head_query_output": ("backpack.gpt2_model.h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"head_key_output": ("backpack.gpt2_model.h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"head_value_output": ("backpack.gpt2_model.h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"sense_output": ("backpack.sense_network", CONST_OUTPUT_HOOK),
"sense_block_output": ("backpack.sense_network.block", CONST_OUTPUT_HOOK),
"sense_mlp_input": ("backpack.sense_network.final_mlp", CONST_INPUT_HOOK),
"sense_mlp_output": ("backpack.sense_network.final_mlp", CONST_OUTPUT_HOOK),
"sense_mlp_activation": ("backpack.sense_network.final_mlp.act", CONST_OUTPUT_HOOK),
"sense_weight_input": ("backpack.sense_weight_net", CONST_INPUT_HOOK),
"sense_weight_output": ("backpack.sense_weight_net", CONST_OUTPUT_HOOK),
"sense_network_output": ("backpack.sense_network", CONST_OUTPUT_HOOK),
}


backpack_gpt2_lm_type_to_dimension_mapping = {
"block_input": ("n_embd",),
"block_output": ("n_embd",),
"mlp_activation": (
"n_inner",
"n_embd*4",
),
"mlp_output": ("n_embd",),
"mlp_input": ("n_embd",),
"attention_value_output": ("n_embd",),
"head_attention_value_output": ("n_embd/n_head",),
"attention_output": ("n_embd",),
"attention_input": ("n_embd",),
"query_output": ("n_embd",),
"key_output": ("n_embd",),
"value_output": ("n_embd",),
"head_query_output": ("n_embd/n_head",),
"head_key_output": ("n_embd/n_head",),
"head_value_output": ("n_embd/n_head",),
"sense_output": ("n_embd",),
"sense_block_output": ("n_embd",),
"sense_mlp_input": ("n_embd",),
"sense_mlp_output": ("n_embd",),
"num_senses": ("num_senses",),
"sense_mlp_activation": (
"n_inner",
"n_embd*4",
),
"sense_network_output": ("n_embd",),
}

def create_backpack_gpt2(name="stanfordnlp/backpack-gpt2", cache_dir=None):
Expand Down
2 changes: 1 addition & 1 deletion pyvene/models/blip/modelings_intervenable_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""


from ..constants import CONST_INPUT_HOOK, CONST_OUTPUT_HOOK, CONST_QKV_INDICES
from ..constants import *

"""blip base model"""
blip_type_to_module_mapping = {
Expand Down
91 changes: 10 additions & 81 deletions pyvene/models/constants.py
Original file line number Diff line number Diff line change
@@ -1,88 +1,17 @@
CONST_VALID_INTERVENABLE_UNIT = ["pos", "h", "h.pos", "t"]

import torch

CONST_INPUT_HOOK = "register_forward_pre_hook"
CONST_OUTPUT_HOOK = "register_forward_hook"
CONST_GRAD_HOOK = "register_hook"


CONST_TRANSFORMER_TOPOLOGICAL_ORDER = [
"block_input",
"query_output",
"head_query_output",
"key_output",
"head_key_output",
"value_output",
"head_value_output",
"attention_input",
"attention_weight",
"head_attention_value_output",
"attention_value_output",
"attention_output",
"cross_attention_input",
"head_cross_attention_value_output",
"cross_attention_value_output",
"cross_attention_output",
"mlp_input",
"mlp_activation",
"mlp_output",
"block_output",
# special keys for backpack model
"sense_block_output",
"sense_mlp_input",
"sense_mlp_activation",
"sense_mlp_output",
"sense_output",
]


CONST_MLP_TOPOLOGICAL_ORDER = [
"block_input",
"mlp_activation",
"block_output",
]


CONST_GRU_TOPOLOGICAL_ORDER = [
"cell_input",
"x2h_output",
"h2h_output",
"reset_x2h_output",
"update_x2h_output",
"new_x2h_output",
"reset_h2h_output",
"update_h2h_output",
"new_h2h_output",
"reset_gate_input",
"update_gate_input",
"new_gate_input",
"reset_gate_output",
"update_gate_output",
"new_gate_output",
"cell_output",
]


CONST_QKV_INDICES = {
"query_output": 0,
"key_output": 1,
"value_output": 2,
"head_query_output": 0,
"head_key_output": 1,
"head_value_output": 2,
"reset_x2h_output": 0,
"update_x2h_output": 1,
"new_x2h_output": 2,
"reset_h2h_output": 0,
"update_h2h_output": 1,
"new_h2h_output": 2,
}
split_and_select = lambda x, num_slice, selct_index: torch.chunk(x, num_slice, dim=-1)[selct_index]
def split_heads(tensor, num_heads, attn_head_size):
"""Splits hidden_size dim into attn_head_size and num_heads."""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)

CONST_RUN_INDICES = {
"reset_x2h_output": 0,
"update_x2h_output": 1,
"new_x2h_output": 2,
"reset_h2h_output": 0,
"update_h2h_output": 1,
"new_h2h_output": 2,
}
split_half = lambda x, selct_index: torch.chunk(x, 2, dim=-1)[selct_index]
split_three = lambda x, selct_index: torch.chunk(x, 3, dim=-1)[selct_index]
split_head_and_permute = lambda x, num_head: split_heads(x, num_head, x.shape[-1]//num_head)
21 changes: 10 additions & 11 deletions pyvene/models/gpt2/modelings_intervenable_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""


from ..constants import CONST_INPUT_HOOK, CONST_OUTPUT_HOOK, CONST_QKV_INDICES
from ..constants import *


"""gpt2 base model"""
Expand All @@ -20,20 +20,21 @@
"mlp_output": ("h[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("h[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("h[%s].attn.c_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_weight": ("h[%s].attn.attn_dropout", CONST_INPUT_HOOK),
"attention_output": ("h[%s].attn", CONST_OUTPUT_HOOK),
"attention_input": ("h[%s].attn", CONST_INPUT_HOOK),
"query_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"key_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"value_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"head_query_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"head_key_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"head_value_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK),
"query_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 0)),
"key_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 1)),
"value_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 2)),
"head_query_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 0), (split_head_and_permute, "n_head")),
"head_key_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 1), (split_head_and_permute, "n_head")),
"head_value_output": ("h[%s].attn.c_attn", CONST_OUTPUT_HOOK, (split_three, 2), (split_head_and_permute, "n_head")),
}


gpt2_type_to_dimension_mapping = {
"n_head": ("n_head", ),
"block_input": ("n_embd",),
"block_output": ("n_embd",),
"mlp_activation": (
Expand All @@ -44,7 +45,6 @@
"mlp_input": ("n_embd",),
"attention_value_output": ("n_embd",),
"head_attention_value_output": ("n_embd/n_head",),
# attention weight dimension does not really matter
"attention_weight": ("max_position_embeddings", ),
"attention_output": ("n_embd",),
"attention_input": ("n_embd",),
Expand All @@ -60,8 +60,7 @@
"""gpt2 model with LM head"""
gpt2_lm_type_to_module_mapping = {}
for k, v in gpt2_type_to_module_mapping.items():
gpt2_lm_type_to_module_mapping[k] = (f"transformer.{v[0]}", v[1])

gpt2_lm_type_to_module_mapping[k] = (f"transformer.{v[0]}", ) + v[1:]

gpt2_lm_type_to_dimension_mapping = gpt2_type_to_dimension_mapping

Expand Down
11 changes: 6 additions & 5 deletions pyvene/models/gpt_neo/modelings_intervenable_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""


from ..constants import CONST_INPUT_HOOK, CONST_OUTPUT_HOOK, CONST_QKV_INDICES
from ..constants import *


"""gpt_neo base model"""
Expand All @@ -20,19 +20,20 @@
"mlp_output": ("h[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("h[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK),
"head_attention_value_output": ("h[%s].attn.out_proj", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("h[%s].attn", CONST_OUTPUT_HOOK),
"attention_input": ("h[%s].attn", CONST_INPUT_HOOK),
"query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK),
"key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK),
"value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK),
"head_key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK),
"head_value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK),
"head_query_output": ("h[%s].attn.q_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_key_output": ("h[%s].attn.k_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
"head_value_output": ("h[%s].attn.v_proj", CONST_OUTPUT_HOOK, (split_head_and_permute, "n_head")),
}


gpt_neo_type_to_dimension_mapping = {
"n_head": "num_heads",
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"mlp_activation": (
Expand Down
5 changes: 3 additions & 2 deletions pyvene/models/gpt_neox/modelings_intervenable_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""


from ..constants import CONST_INPUT_HOOK, CONST_OUTPUT_HOOK, CONST_QKV_INDICES
from ..constants import *


"""gpt_neox base model"""
Expand All @@ -20,7 +20,7 @@
"mlp_output": ("layers[%s].mlp", CONST_OUTPUT_HOOK),
"mlp_input": ("layers[%s].mlp", CONST_INPUT_HOOK),
"attention_value_output": ("layers[%s].attention.dense", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].attention.dense", CONST_INPUT_HOOK),
"head_attention_value_output": ("layers[%s].attention.dense", CONST_INPUT_HOOK, (split_head_and_permute, "n_head")),
"attention_output": ("layers[%s].attention", CONST_OUTPUT_HOOK),
"attention_input": ("layers[%s].attention", CONST_INPUT_HOOK),
# 'query_output': ("layers[%s].attention.query_key_value", CONST_OUTPUT_HOOK),
Expand All @@ -33,6 +33,7 @@


gpt_neox_type_to_dimension_mapping = {
"n_head": "num_attention_heads",
"block_input": ("hidden_size",),
"block_output": ("hidden_size",),
"mlp_activation": (
Expand Down
16 changes: 8 additions & 8 deletions pyvene/models/gru/modelings_intervenable_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""


from ..constants import CONST_INPUT_HOOK, CONST_OUTPUT_HOOK
from ..constants import *


"""gru base model"""
Expand All @@ -23,12 +23,12 @@
"new_gate_output": ("cells[%s].new_act", CONST_OUTPUT_HOOK),
"x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK),
"h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK),
"reset_x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK),
"update_x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK),
"new_x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK),
"reset_h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK),
"update_h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK),
"new_h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK),
"reset_x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK, (split_three, 0)),
"update_x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK, (split_three, 1)),
"new_x2h_output": ("cells[%s].x2h", CONST_OUTPUT_HOOK, (split_three, 2)),
"reset_h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK, (split_three, 0)),
"update_h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK, (split_three, 1)),
"new_h2h_output": ("cells[%s].h2h", CONST_OUTPUT_HOOK, (split_three, 2)),
"cell_output": ("cells[%s]", CONST_OUTPUT_HOOK),
}

Expand Down Expand Up @@ -56,7 +56,7 @@
"""mlp model with classification head"""
gru_classifier_type_to_module_mapping = {}
for k, v in gru_type_to_module_mapping.items():
gru_classifier_type_to_module_mapping[k] = (f"gru.{v[0]}", v[1])
gru_classifier_type_to_module_mapping[k] = (f"gru.{v[0]}", ) + v[1:]

gru_classifier_type_to_dimension_mapping = gru_type_to_dimension_mapping

Expand Down
Loading
Loading