Skip to content

Commit

Permalink
Merge pull request #32 from language-brainscore/feature-tokenizer-bas…
Browse files Browse the repository at this point in the history
…ed-span-extraction-1

tokenizer based span extraction 1,
fixes #19
  • Loading branch information
aalok-sathe authored May 20, 2022
2 parents 1f9c481 + b59693e commit 97bb72d
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 62 deletions.
4 changes: 3 additions & 1 deletion examples/test_mean_froi_pereira2018_firstsessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def main():
# Initialize brain and ANN encoders
brain_enc = lbs.encoder.BrainEncoder()
ann_enc = lbs.encoder.HuggingFaceEncoder(
model_id="distilgpt2", emb_preproc=tuple(), context_dimension="passage"
model_id="bert-base-uncased", emb_preproc=tuple(), context_dimension="passage",
bidirectional=True,
# model_id="distilgpt2", emb_preproc=tuple(), context_dimension="passage"
)

# Encode
Expand Down
132 changes: 90 additions & 42 deletions langbrainscore/encoder/ann.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
import typing
from enum import unique

import numpy as np
import torch
from tqdm import tqdm
import xarray as xr
from nltk import edit_distance

from langbrainscore.dataset import Dataset
from langbrainscore.interface import _ModelEncoder, EncoderRepresentations
from langbrainscore.interface import EncoderRepresentations, _ModelEncoder
from langbrainscore.utils.encoder import (
set_case,
aggregate_layers,
cos_sim_matrix,
count_zero_threshold_values,
flatten_activations_per_sample,
repackage_flattened_activations,
get_context_groups,
get_torch_device,
pick_matching_token_ixs,
preprocess_activations,
count_zero_threshold_values,
cos_sim_matrix,
get_index,
repackage_flattened_activations,
)
from langbrainscore.utils.logging import log
from langbrainscore.utils.xarray import copy_metadata
from tqdm import tqdm


class HuggingFaceEncoder(_ModelEncoder):
Expand All @@ -30,6 +32,7 @@ def __init__(
bidirectional: bool = False,
emb_aggregation: typing.Union[str, None, typing.Callable] = "last",
emb_preproc: typing.Tuple[str] = (),
include_special_tokens: bool = True,
) -> "HuggingFaceEncoder":

super().__init__(
Expand All @@ -38,6 +41,7 @@ def __init__(
_bidirectional=bidirectional,
_emb_aggregation=emb_aggregation,
_emb_preproc=emb_preproc,
_include_special_tokens=include_special_tokens,
)

from transformers import AutoConfig, AutoModel, AutoTokenizer
Expand Down Expand Up @@ -90,15 +94,25 @@ def encode(
to_check_in_cache = EncoderRepresentations(
dataset=dataset,
representations=None, # we don't have these yet
model_id=self._model_id,
context_dimension=self._context_dimension,
bidirectional=self._bidirectional,
emb_aggregation=self._emb_aggregation,
emb_preproc=self._emb_preproc,
)
try:
to_check_in_cache.load_cache()
return to_check_in_cache
except FileNotFoundError:
log(
f'unable to load cached representations for "{to_check_in_cache.identifier_string}"',
cmap="WARN",
type="WARN",
)

self.model.eval()
stimuli = dataset.stimuli.values

# Initialize the context group coordinate (obtain embeddings with context)
context_groups = get_context_groups(dataset, self._context_dimension)

Expand All @@ -113,9 +127,9 @@ def encode(
_, unique_ixs = np.unique(context_groups, return_index=True)
# Make sure context group order is preserved
for group in tqdm(context_groups[np.sort(unique_ixs)]):
# Mask based on the context group
mask_context = context_groups == group
stimuli_in_context = stimuli[mask_context]
# Mask based on the context group

# store model states for each stimulus in this context group
states_sentences_across_stimuli = []
Expand All @@ -131,8 +145,11 @@ def encode(
else:
stimuli_directional = stimuli_in_context

stimuli_directional = " ".join( # join the stimuli together within a context group
map(lambda x: x.strip(), stimuli_directional) # Strip out odd spaces between stimuli (but *not* within the stimuli).
# join the stimuli together within a context group
stimuli_directional = " ".join(
map(
lambda x: x.strip(), stimuli_directional
) # Strip out odd spaces between stimuli (but *not* within the stimuli).
)

tokenized_directional_context = self.tokenizer(
Expand All @@ -153,32 +170,58 @@ def encode(
hidden_states = result_model["hidden_states"]

layer_wise_activations = dict()
# Cut the "irrelevant" context from the hidden states

start_of_interest = stimuli_directional.find(stimulus.strip())
char_span_of_interest = slice(
start_of_interest, start_of_interest + len(stimulus.strip())
)
token_span_of_interest = pick_matching_token_ixs(
tokenized_directional_context, char_span_of_interest
)

all_special_ids = set(self.tokenizer.all_special_ids)
insert_first_upto = 0
insert_last_from = tokenized_directional_context.input_ids.shape[-1]
for i, tid in enumerate(tokenized_directional_context.input_ids[0, :]):
if tid.item() in all_special_ids:
insert_first_upto = i + 1
else:
break
for i in range(
1, tokenized_directional_context.input_ids.shape[-1] + 1
):
tid = tokenized_directional_context.input_ids[0, -i]
if tid.item() in all_special_ids:
insert_last_from -= 1
else:
break

for idx_layer, layer in enumerate(hidden_states): # Iterate over layers
layer_wise_activations[idx_layer] = layer[
# batch (singleton)
:,
# n_tokens
slice(
get_index(
self.tokenizer,
tokenized_directional_context.input_ids,
stimulus,
mode="start",
),
get_index(
self.tokenizer,
tokenized_directional_context.input_ids,
stimulus,
mode="stop",
),
this_extracted = layer[
:, # batch (singleton)
token_span_of_interest, # if self._context_dimension is not None else slice(None), # n_tokens
:, # emb_dim (e.g., 768, 1024, etc)
].squeeze(
0
) # collapse batch dim to obtain shape (n_tokens, emb_dim)

if self._include_special_tokens:
this_extracted = torch.cat(
[
layer[:, :insert_first_upto, :].squeeze(0),
this_extracted,
],
axis=0,
)
this_extracted = torch.cat(
[
this_extracted,
layer[:, insert_last_from:, :].squeeze(0),
],
axis=0,
)
if self._context_dimension is not None
else slice(None),
# emb_dim (e.g., 768)
:,
].squeeze() # collapse batch dim to obtain shape (n_tokens, emb_dim)
# ^ do we have to .detach() tensors here?

layer_wise_activations[idx_layer] = this_extracted.detach()

# Aggregate hidden states within a sample
# states_sentences_agg is a dict with key = layer, value = array of emb dimension
Expand Down Expand Up @@ -235,7 +278,7 @@ def encode(
"sampleid",
)

return EncoderRepresentations(
to_return = EncoderRepresentations(
dataset=dataset,
representations=encoded_dataset,
context_dimension=self._context_dimension,
Expand All @@ -244,6 +287,11 @@ def encode(
emb_preproc=self._emb_preproc,
)

# if write_cache:
# to_return.to_cache(overwrite=True)

return to_return

def get_special_token_offset(self) -> int:
"""
the offset (no. of tokens in tokenized text) from the start to exclude
Expand Down Expand Up @@ -275,7 +323,7 @@ def get_modelcard(self):
"vocab_size",
]

config_specs = {k: d_config[k] for k in config_specs_of_interest}
config_specs = {k: d_config.get(k, None) for k in config_specs_of_interest}

# Evaluate each layer

Expand All @@ -292,7 +340,7 @@ def get_explainable_variance(
TODO: move to `langbrainscore.analysis.?` or make @classmethod
"""
n_embd = self.config.n_embd
# n_embd = self.config.n_embd

# Get the PCA explained variance per layer
layer_ids = ann_encoded_dataset.layer.values
Expand All @@ -306,7 +354,7 @@ def get_explainable_variance(
.drop("timeid")
.squeeze()
)
assert layer_dataset.shape[1] == n_embd
# assert layer_dataset.shape[1] == n_embd

# Figure out how many PCs we attempt to fit
n_comp = np.min([layer_dataset.shape[1], layer_dataset.shape[0]])
Expand Down Expand Up @@ -350,7 +398,7 @@ def get_layer_sparsity(
TODO: move to `langbrainscore.analysis.?` or make @classmethod
"""
n_embd = self.config.n_embd
# n_embd = self.config.n_embd

# Get the PCA explained variance per layer
layer_ids = ann_encoded_dataset.layer.values
Expand All @@ -364,7 +412,7 @@ def get_layer_sparsity(
.drop("timeid")
.squeeze()
)
assert layer_dataset.shape[1] == n_embd
# assert layer_dataset.shape[1] == n_embd

# Get sparsity
zero_values = count_zero_threshold_values(
Expand Down
3 changes: 2 additions & 1 deletion langbrainscore/interface/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ class EncoderRepresentations(_Cacheable):
dataset: Dataset # pointer to the dataset these are the EncodedRepresentations of
representations: xr.DataArray # the xarray holding representations

model_id: str = None
context_dimension: str = None
bidirectional: bool = False
emb_case: typing.Union[str, None] = "lower"
# emb_case: typing.Union[str, None] = "lower"
emb_aggregation: typing.Union[str, None, typing.Callable] = "last"
emb_preproc: typing.Tuple[str] = ()

Expand Down
83 changes: 65 additions & 18 deletions langbrainscore/utils/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import numpy as np
import torch
import xarray as xr
from nltk import edit_distance
# from nltk import edit_distance

from langbrainscore.utils.resources import preprocessor_classes
from langbrainscore.utils.logging import log, get_verbosity


def count_zero_threshold_values(
Expand Down Expand Up @@ -60,6 +61,8 @@ def aggregate_layers(hidden_states: dict, **kwargs):
for i in hidden_states.keys():
if emb_aggregation == "last":
state = hidden_states[i][-1, :] # get last token
elif emb_aggregation == "first":
state = hidden_states[i][0, :] # get first token
elif emb_aggregation == "mean":
state = torch.mean(hidden_states[i], dim=0) # mean over tokens
elif emb_aggregation == "median":
Expand Down Expand Up @@ -158,23 +161,67 @@ def repackage_flattened_activations(

def cos_sim_matrix(A, B):
"""Compute the cosine similarity matrix between two matrices A and B.
1 means the two vectors are identical. 0 means they are orthogonal. -1 means they are opposite."""
1 means the two vectors are identical. 0 means they are orthogonal.
-1 means they are opposite."""
return (A * B).sum(axis=1) / (A * A).sum(axis=1) ** 0.5 / (B * B).sum(axis=1) ** 0.5


def get_index(tokenizer, supstr_tokens, substr, mode):
supstr_tokens = list(supstr_tokens.squeeze())
assert mode in ["start", "stop"]
edit_distances = []
for idx in range(len(supstr_tokens) + 1):
if mode == "start":
candidate_tokens = supstr_tokens[idx:]
else:
candidate_tokens = supstr_tokens[:idx]
candidate = tokenizer.decode(candidate_tokens)
if mode == "start":
comp = candidate[: len(substr)]
else:
comp = candidate[-len(substr) :]
edit_distances.append(edit_distance(comp, substr))
return np.argmin(edit_distances)

def pick_matching_token_ixs(batchencoding: 'transformers.tokenization_utils_base.BatchEncoding',
char_span_of_interest: slice) -> slice:
"""Picks token indices in a tokenized encoded sequence that best correspond to
a substring of interest in the original sequence, given by a char span (slice)
Args:
batchencoding (transformers.tokenization_utils_base.BatchEncoding): the output of a
`tokenizer(text)` call on a single text instance (not a batch, i.e. `tokenizer([text])`).
char_span_of_interest (slice): a `slice` object denoting the character indices in the
original `text` string we want to extract the corresponding tokens for
Returns:
slice: the start and stop indices within an encoded sequence that
best match the `char_span_of_interest`
"""
from transformers import tokenization_utils_base

start_token = 0
end_token = batchencoding.input_ids.shape[-1]
for i, _ in enumerate(batchencoding.input_ids.reshape(-1)):
span = batchencoding[0].token_to_chars(i)

if span is None: # for [CLS], no span is returned
if get_verbosity():
log(f'No span returned for token at {i}: "{batchencoding.tokens()[i]}"',
type='WARN', cmap='WARN')
continue
else:
span = tokenization_utils_base.CharSpan(*span)

if span.start <= char_span_of_interest.start:
start_token = i
if span.end >= char_span_of_interest.stop:
end_token = i+1
break

assert end_token-start_token <= batchencoding.input_ids.shape[-1], f'Extracted span is larger than original span'

return slice(start_token, end_token)



# def get_index(tokenizer, supstr_tokens, substr, mode):
# supstr_tokens = list(supstr_tokens.squeeze())
# assert mode in ["start", "stop"]
# edit_distances = []
# for idx in range(len(supstr_tokens) + 1):
# if mode == "start":
# candidate_tokens = supstr_tokens[idx:]
# else:
# candidate_tokens = supstr_tokens[:idx]
# candidate = tokenizer.decode(candidate_tokens)
# if mode == "start":
# comp = candidate[: len(substr)]
# else:
# comp = candidate[-len(substr) :]
# edit_distances.append(edit_distance(comp, substr))
# return np.argmin(edit_distances)
6 changes: 6 additions & 0 deletions langbrainscore/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,9 @@ class T:
)
tqdm.write("\n".join(lines), file=stderr)
# print(*lines, sep='\n', file=stderr)


def get_verbosity():
'''returns True if env variable "VERBOSE" is set to 1'''
import os
return os.environ.get('VERBOSE', None) == '1'

0 comments on commit 97bb72d

Please sign in to comment.