-
Notifications
You must be signed in to change notification settings - Fork 357
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generic instrumented Keras LM wrapper for LM salience
PiperOrigin-RevId: 606810304
- Loading branch information
Showing
1 changed file
with
375 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,375 @@ | ||
"""LIT model wrappers for generic instrumented Keras LMs.""" | ||
|
||
import functools | ||
import inspect | ||
import types | ||
from typing import Sequence | ||
|
||
from lit_nlp.api import model as lit_model | ||
from lit_nlp.api import types as lit_types | ||
from lit_nlp.lib import utils as lit_utils | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
|
||
_DEFAULT_MAX_LENGTH = 1024 | ||
|
||
|
||
class FieldNames(types.SimpleNamespace): | ||
PROMPT = "prompt" | ||
RESPONSE = "response" | ||
PROMPT_EMBEDDINGS = "prompt_embeddings" | ||
RESPONSE_EMBEDDINGS = "response_embeddings" | ||
TARGET = "target" | ||
TOKENS = "tokens" | ||
TARGET_MASK = "target_mask" | ||
GRAD_DOT_INPUT = "grad_dot_input" | ||
GRAD_NORM = "grad_l2" | ||
TOKEN_LOSS = "token_loss" | ||
|
||
|
||
class _KerasBaseModel(lit_model.BatchedModel): | ||
"""Base LIT model wrapper class for Keras on TensorFlow.""" | ||
|
||
# TODO(lit-dev): pytype annotations for model= ? | ||
# Should be keras_nlp.models.generative_task.GenerativeTask | ||
def __init__( | ||
self, | ||
model, | ||
max_length: int = _DEFAULT_MAX_LENGTH, | ||
batch_size: int = 16, | ||
): | ||
"""Base wrapper for a Keras/TF2 LM supporting the layer_intercept_fn API. | ||
Model should support the following methods: | ||
- .generate() | ||
- .score()* | ||
- .preprocessor.generate_preprocess() | ||
. .preprocessor.tokenizer.id_to_token() | ||
. .backbone.token_embedding() | ||
* The score function should accept layer_intercept_fn= as a way to intercept | ||
and manipulate activations between layers. We use this for salience, below. | ||
Args: | ||
model: pre-loaded Keras LM using the TF backend | ||
max_length: max sequence length | ||
batch_size: batch size | ||
""" | ||
super().__init__() | ||
|
||
self.model = model | ||
self.batch_size = batch_size | ||
self.max_length = max_length | ||
|
||
self.encode_inputs = self.model.preprocessor.generate_preprocess | ||
|
||
self.ids_to_tokens = np.vectorize( | ||
self.model.preprocessor.tokenizer.id_to_token | ||
) | ||
|
||
# map ids: <tf.int64>[batch_size, num_tokens] | ||
# to embs: <tf.float32>[batch_size, num_tokens, emb_dim] | ||
self.embedder = self.model.backbone.token_embedding | ||
|
||
@classmethod | ||
def from_loaded(cls, existing: "_KerasBaseModel", *args, **kw): | ||
"""Share weights and underlying Keras model with another instance.""" | ||
return cls(model=existing.model, *args, **kw) | ||
|
||
def max_minibatch_size(self) -> int: | ||
return self.batch_size | ||
|
||
@classmethod | ||
def init_spec(cls): | ||
# Cannot initialize from spec, because we need a Keras model object. | ||
return None | ||
|
||
def input_spec(self): | ||
return { | ||
FieldNames.PROMPT: lit_types.TextSegment(), | ||
FieldNames.TARGET: lit_types.TextSegment(required=False), | ||
} | ||
|
||
|
||
class KerasGenerationModel(_KerasBaseModel): | ||
"""LIT model wrapper for generating text with Keras on TensorFlow. | ||
This class accepts a loaded model and provides the LIT-required functions plus | ||
additional helper functions for generation tasks. | ||
This class supports generation and pass-through modes. If a dataset provides a | ||
pre-populated 'response' column then this model will return that text instead | ||
of generating new text from the 'prompt'. This allows the same model wrapper | ||
to be efficiently used to examine saved results from bulk-inference pipelines | ||
and new generations from, e.g., counterfactually generated examples, or novel | ||
evaluation datasets. | ||
""" | ||
|
||
def __init__(self, *args, output_embeddings=True, **kw): | ||
super().__init__(*args, **kw) | ||
self.output_embeddings = output_embeddings | ||
|
||
def embed_texts(self, texts: Sequence[str]): | ||
processed_inputs = self.encode_inputs( | ||
texts, sequence_length=self.max_length | ||
) | ||
# <tf.float32>[batch_size, num_tokens, emb_dim] | ||
embs = self.embedder(processed_inputs["token_ids"]) | ||
# <tf.bool>[batch_size, num_tokens] | ||
mask = processed_inputs["padding_mask"] | ||
return embs, mask | ||
|
||
def embed_and_mean_pool(self, texts: Sequence[str]): | ||
"""Return a single vector for each text.""" | ||
embs, mask = self.embed_texts(texts) | ||
# <tf.float32>[batch_size, num_tokens, 1] | ||
mask = tf.expand_dims(tf.cast(mask, dtype=tf.float32), axis=2) | ||
# <tf.float32>[batch_size, 1, emb_dim] | ||
pooled_embs = tf.reduce_sum( | ||
mask * embs, axis=1, keepdims=True | ||
) / tf.reduce_sum(mask, axis=1, keepdims=True) | ||
# <tf.float32>[batch_size, emb_dim] | ||
return tf.squeeze(pooled_embs, axis=1) | ||
|
||
def predict_minibatch( | ||
self, | ||
inputs: list[lit_types.JsonDict], | ||
) -> list[lit_types.JsonDict]: | ||
prompts: Sequence[str] = [ex[FieldNames.PROMPT] for ex in inputs] | ||
|
||
# TODO(lit-dev): suppport loading cached responses here, since running | ||
# generation can be expensive. | ||
full_responses: Sequence[str] = list( | ||
self.model.generate(prompts, max_length=self.max_length) | ||
) | ||
# Model outputs include the prompt, so trim that off and just return the | ||
# generated portion. | ||
responses: Sequence[str] = [ | ||
response[len(prompt) :] | ||
for response, prompt in zip(full_responses, prompts) | ||
] | ||
|
||
outputs = [{FieldNames.RESPONSE: response} for response in responses] | ||
|
||
if self.output_embeddings: | ||
prompt_embeddings = self.embed_and_mean_pool(prompts) | ||
# TODO(lit-dev): embed prompt + response and trim embedding instead? | ||
# Or just embed full_response. | ||
response_embeddings = self.embed_and_mean_pool(responses) | ||
|
||
for i in range(len(inputs)): | ||
outputs[i][FieldNames.PROMPT_EMBEDDINGS] = prompt_embeddings[i].numpy() | ||
outputs[i][FieldNames.RESPONSE_EMBEDDINGS] = response_embeddings[ | ||
i | ||
].numpy() | ||
|
||
return outputs | ||
|
||
def output_spec(self) -> lit_types.Spec: | ||
ret = { | ||
FieldNames.RESPONSE: lit_types.GeneratedText(parent=FieldNames.TARGET) | ||
} | ||
if self.output_embeddings: | ||
return ret | { | ||
FieldNames.PROMPT_EMBEDDINGS: lit_types.Embeddings(), | ||
FieldNames.RESPONSE_EMBEDDINGS: lit_types.Embeddings(), | ||
} | ||
return ret | ||
|
||
|
||
class KerasSalienceModel(_KerasBaseModel): | ||
"""LIT model wrapper for computing salience with Keras on TensorFlow. | ||
This class accepts a loaded model and provides the LIT-required functions plus | ||
additional helper functions to convert and clean tokens and to compute | ||
sequence salience. | ||
This class does not support generation; use the KerasGenerationModel class to | ||
generate the text for which this class will compute salience. | ||
""" | ||
|
||
def __init__(self, *args, **kw): | ||
super().__init__(*args, **kw) | ||
|
||
score_fn = getattr(self.model, "score", None) | ||
|
||
if score_fn is None or not inspect.ismethod(score_fn): | ||
raise TypeError( | ||
"Salience is computed via a .score() API, which is not supported by " | ||
"all GenerativeTask models in KerasNLP. Please provide a model that " | ||
"supports this API." | ||
) | ||
|
||
def _pred(self, input_ids, padding_mask, target_masks): | ||
"""Predict a batch of tokenized text.""" | ||
# <tf.float32>[batch_size, num_tokens]; ignore the last one in each row. | ||
target_ids = tf.roll(input_ids, shift=-1, axis=1) | ||
|
||
## | ||
# Process target masks | ||
|
||
# It doesn't make sense to interpret the first token, since it is not ever | ||
# predicted. But we need to ensure that the mask[0] is zero, so it doesn't | ||
# cause problems when 'rolled' to the last position below. | ||
modified_masks = [[0] + list(mask[1:]) for mask in target_masks] | ||
seq_len = target_ids.shape[1] | ||
pad_fn = functools.partial( | ||
lit_utils.pad1d, | ||
min_len=seq_len, | ||
max_len=seq_len, | ||
pad_val=0, | ||
pad_left=False, | ||
) | ||
padded_target_masks = np.stack( | ||
[pad_fn(mask) for mask in modified_masks], | ||
axis=0, | ||
) | ||
|
||
padded_target_masks = tf.constant(padded_target_masks, dtype=tf.float32) | ||
# Shift masks back so they align with target_ids. | ||
loss_mask = tf.roll(padded_target_masks, shift=-1, axis=1) | ||
|
||
embeddings = None | ||
|
||
with tf.GradientTape(watch_accessed_variables=True) as tape: | ||
|
||
def layer_intercept_fn(x, i): | ||
if i == -1: | ||
nonlocal embeddings, tape | ||
embeddings = x | ||
tape.watch(embeddings) | ||
return x | ||
|
||
# <tf.float32>[batch_size, num_tokens] | ||
per_token_loss = self.model.score( | ||
token_ids=input_ids, | ||
padding_mask=padding_mask, | ||
scoring_mode="loss", | ||
layer_intercept_fn=layer_intercept_fn, | ||
target_ids=target_ids, | ||
) | ||
masked_loss = per_token_loss * loss_mask | ||
|
||
# <tf.float32>[batch_size, num_tokens, hdim] | ||
grads = tape.gradient(masked_loss, embeddings) | ||
# <tf.float32>[batch_size, num_tokens] | ||
grad_l2 = tf.norm(grads, axis=2) | ||
# <tf.float32>[batch_size, num_tokens] | ||
grad_dot_input = tf.reduce_sum(grads * embeddings, axis=2) | ||
|
||
batched_outputs = { | ||
"input_ids": input_ids, | ||
"padding_mask": padding_mask, | ||
# Gradients are already aligned to input tokens. | ||
FieldNames.GRAD_NORM: grad_l2, | ||
FieldNames.GRAD_DOT_INPUT: grad_dot_input, | ||
# Shift token loss to align with (input) tokens. | ||
FieldNames.TOKEN_LOSS: tf.roll(per_token_loss, shift=1, axis=1), | ||
} | ||
|
||
return batched_outputs | ||
|
||
def _postprocess(self, preds): | ||
"""Post-process single-example preds. Operates on numpy arrays.""" | ||
mask = preds.pop("padding_mask").astype(bool) | ||
ids = preds.pop("input_ids")[mask] | ||
preds[FieldNames.TOKENS] = self.ids_to_tokens(ids) | ||
for key in lit_utils.find_spec_keys( | ||
self.output_spec(), lit_types.TokenScores | ||
): | ||
preds[key] = preds[key][mask] | ||
# First token (<bos>) is not actually predicted, so return 0 for loss. | ||
preds[FieldNames.TOKEN_LOSS][0] = 0 | ||
|
||
return preds | ||
|
||
def predict_minibatch(self, inputs): | ||
"""Predict on a single minibatch of examples.""" | ||
texts: Sequence[str] = [ | ||
ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs | ||
] | ||
preprocessed_texts = self.encode_inputs( | ||
texts, sequence_length=self.max_length | ||
) | ||
sequence_ids = preprocessed_texts["token_ids"] | ||
padding_mask = preprocessed_texts["padding_mask"] | ||
|
||
target_masks = [ex.get(FieldNames.TARGET_MASK, []) for ex in inputs] | ||
|
||
# Get the predictions. | ||
batched_outputs = self._pred(sequence_ids, padding_mask, target_masks) | ||
# Convert to numpy for post-processing. | ||
detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} | ||
# Split up batched outputs, then post-process each example. | ||
unbatched_outputs = lit_utils.unbatch_preds(detached_outputs) | ||
return map(self._postprocess, unbatched_outputs) | ||
|
||
def input_spec(self): | ||
return super().input_spec() | { | ||
FieldNames.TARGET_MASK: lit_types.TokenScores(align="", required=False), | ||
} | ||
|
||
def output_spec(self) -> lit_types.Spec: | ||
return { | ||
FieldNames.TOKENS: lit_types.Tokens(parent=""), # All tokens. | ||
FieldNames.GRAD_DOT_INPUT: lit_types.TokenScores( | ||
align=FieldNames.TOKENS | ||
), | ||
FieldNames.GRAD_NORM: lit_types.TokenScores(align=FieldNames.TOKENS), | ||
FieldNames.TOKEN_LOSS: lit_types.TokenScores(align=FieldNames.TOKENS), | ||
} | ||
|
||
|
||
class KerasTokenizerModel(_KerasBaseModel): | ||
"""LIT model wrapper for tokenizing text with Keras on TensorFlow. | ||
This class accepts a loaded model and provides the LIT-required functions plus | ||
additional helper functions to convert and clean tokens. | ||
""" | ||
|
||
def _postprocess(self, preds): | ||
"""Post-process single-example preds. Operates on numpy arrays.""" | ||
# Be sure to cast to bool, otherwise this will select intger positions 0, 1 | ||
# rather than acting as a boolean mask. | ||
mask = preds.pop("padding_mask").astype(bool) | ||
ids = preds.pop("token_ids")[mask] | ||
preds[FieldNames.TOKENS] = self.ids_to_tokens(ids) | ||
return preds | ||
|
||
def predict_minibatch(self, inputs): | ||
"""Tokenize a single minibatch of examples.""" | ||
texts: Sequence[str] = [ | ||
ex[FieldNames.PROMPT] + ex.get(FieldNames.TARGET, "") for ex in inputs | ||
] | ||
preprocessed_texts = self.encode_inputs( | ||
texts, sequence_length=self.max_length | ||
) | ||
batched_outputs = { | ||
"token_ids": preprocessed_texts["token_ids"], | ||
"padding_mask": preprocessed_texts["padding_mask"], | ||
} | ||
# Convert to numpy for post-processing. | ||
detached_outputs = {k: v.numpy() for k, v in batched_outputs.items()} | ||
# Split up batched outputs, then post-process each example. | ||
unbatched_outputs = lit_utils.unbatch_preds(detached_outputs) | ||
return map(self._postprocess, unbatched_outputs) | ||
|
||
def output_spec(self) -> lit_types.Spec: | ||
return { | ||
FieldNames.TOKENS: lit_types.Tokens(parent=""), # All tokens. | ||
} | ||
|
||
|
||
def initialize_model_group_for_salience( | ||
name, *args, **kw | ||
) -> dict[str, lit_model.Model]: | ||
"""Creates '{name}' and '_{name}_salience' and '_{name}_tokenizer'.""" | ||
generation_model = KerasGenerationModel(*args, **kw) | ||
salience_model = KerasSalienceModel(*args, **kw) | ||
tokenizer_model = KerasTokenizerModel(*args, **kw) | ||
return { | ||
name: generation_model, | ||
f"_{name}_salience": salience_model, | ||
f"_{name}_tokenizer": tokenizer_model, | ||
} |