Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TF s2s models #9478

Merged
merged 22 commits into from
Jan 21, 2021
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def input_processing(func, config, input_ids, **kwargs):
"""
signature = dict(inspect.signature(func).parameters)
signature.pop("kwargs", None)
signature.pop("self", None)
parameter_names = list(signature.keys())
output = {}
allowed_types = (tf.Tensor, bool, int, ModelOutput, tuple, list, dict, np.ndarray)
Expand All @@ -346,6 +347,8 @@ def input_processing(func, config, input_ids, **kwargs):
f"The following keyword arguments are not supported by this model: {list(kwargs['kwargs_call'].keys())}."
)

kwargs.pop("kwargs_call")

for k, v in kwargs.items():
if isinstance(v, allowed_types) or v is None:
output[k] = v
Expand All @@ -356,8 +359,8 @@ def input_processing(func, config, input_ids, **kwargs):
for i, input in enumerate(input_ids):
# EagerTensors don't allow to use the .name property so we check for a real Tensor
if type(input) == tf.Tensor:
# Tensor names have always the pattern name:device_id then we check only the
# name and not the device id
# Tensor names have always the pattern `name:id` then we check only the
# `name` part
tensor_name = input.name.split(":")[0]

if tensor_name in parameter_names:
Expand Down
166 changes: 114 additions & 52 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,29 +411,6 @@ def dummy_inputs(self):
}
return dummy_inputs

def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)

return base_model.shared

def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)

try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value

base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]

with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass

embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)

@tf.function(
input_signature=[
{
Expand Down Expand Up @@ -605,6 +582,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings
self.layers = [TFBartEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")

def get_embed_tokens(self):
return self.embed_tokens

def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens

Expand Down Expand Up @@ -744,6 +724,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings

self.dropout = tf.keras.layers.Dropout(config.dropout)

def get_embed_tokens(self):
return self.embed_tokens

def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens

Expand Down Expand Up @@ -871,13 +854,15 @@ def call(
hidden_states = self.dropout(hidden_states, training=inputs["training"])

# decoder layers
all_hidden_states = ()
all_self_attns = ()
present_key_values = ()
all_hidden_states = () if inputs["output_hidden_states"] else None
all_self_attns = () if inputs["output_attentions"] else None
present_key_values = () if inputs["use_cache"] else None

for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)

dropout_probability = random.uniform(0, 1)

if inputs["training"] and (dropout_probability < self.layerdrop):
Expand All @@ -901,12 +886,12 @@ def call(

if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)
else:
all_hidden_states = None

all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
if inputs["output_attentions"]:
all_self_attns = list(all_self_attns)

present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)

if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns
Expand All @@ -919,16 +904,14 @@ def call(
)


@add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
)
@keras_serializable
class TFBartModel(TFBartPretrainedModel):
base_model_prefix = "model"
class TFBartMainLayer(tf.keras.layers.Layer):
config_class = BartConfig

def __init__(self, config: BartConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
def __init__(self, config: BartConfig, **kwargs):
super().__init__(**kwargs)

self.config = config
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")

with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
Expand All @@ -942,19 +925,20 @@ def __init__(self, config: BartConfig, *inputs, **kwargs):
self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")

def get_encoder(self):
return self.encoder
def get_input_embeddings(self):
return self.shared

def get_decoder(self):
return self.decoder
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
self.encoder.set_embed_tokens(embed_tokens)
self.decoder.set_embed_tokens(embed_tokens)

@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/bart-large",
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
Expand Down Expand Up @@ -1053,8 +1037,86 @@ def call(
encoder_attentions=inputs["encoder_outputs"].attentions,
)


@add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
)
class TFBartModel(TFBartPretrainedModel):
def __init__(self, config: BartConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

self.model = TFBartMainLayer(config, name="model")

def get_encoder(self):
return self.model.encoder

def get_decoder(self):
return self.model.decoder

@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/bart-large",
output_type=TFSeq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
def call(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs: Optional[Union[Tuple, TFBaseModelOutput]] = None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs
):
inputs = input_processing(
func=self.call,
config=self.config,
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
kwargs_call=kwargs,
)

outputs = self.model(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
decoder_input_ids=inputs["decoder_input_ids"],
decoder_attention_mask=inputs["decoder_attention_mask"],
encoder_outputs=inputs["encoder_outputs"],
past_key_values=inputs["past_key_values"],
inputs_embeds=inputs["inputs_embeds"],
decoder_inputs_embeds=inputs["decoder_inputs_embeds"],
use_cache=inputs["use_cache"],
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
training=inputs["training"],
)

return outputs

def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
Comment on lines -1057 to +1119
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this a breaking change with version v4.2.x?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe that's more of a bugfix? Should we release a patch before the previous version gets too much usage?

Copy link
Contributor Author

@jplu jplu Jan 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug that all the Seq2Seq model have I don't know how this has been introduced, but was not part of my PR about Serving. Basically, this prevent the model to have a possible saved model, because all the values must be a tensor, not something else (here a tuple)

dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
Expand Down Expand Up @@ -1083,7 +1145,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel):

def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TFBartModel(config, name="model")
self.model = TFBartMainLayer(config, name="model")
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
self.final_logits_bias = self.add_weight(
Expand Down Expand Up @@ -1199,7 +1261,7 @@ def call(
)

def serving_output(self, output):
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
Expand Down
Loading