Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TF s2s models #9478

Merged
merged 22 commits into from
Jan 21, 2021
Merged

Fix TF s2s models #9478

merged 22 commits into from
Jan 21, 2021

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Jan 8, 2021

What does this PR do?

This PR aims to fix the Seq2Seq models in order to make them able to be served through TF Serving. The problem is stated by @patrickvonplaten in #9313. The reason why it failed was because we use a model as a layer in the TFXXXForConditionalGeneration models. The tracing mechanism of TensorFlow when building a graph calls one by one all the layers for building the graph. In order to know what are the inputs needed by each layer, the tracing mechanism check if a layer has a custom input signature, if not, it takes as default a signature where only the first argument is mandatory. Here stand the problem, the Seq2Seq models needs have two mandatory arguments (input_ids and decoder_input_ids or inputs_embeds and decoder_inputs_embeds) and then the tracing fails.

The fix to this problem is to manually set the expected input signature of the base model when instantiating it in __init__. To be harmonized with the required serving, the same signature is used.

@jplu jplu changed the title Fix tf s2s Fix TF s2s models Jan 8, 2021
Comment on lines 1674 to 1685
def log():
logger.info(
"Input ids are automatically padded from {} to {} to be a multiple of `config.attention_window`: {}".format(
seq_len, seq_len + padding_len, attention_window
)
)
return 0

paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
def not_log():
return 1

if input_ids is not None:
input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
_ = tf.cond(tf.math.greater(padding_len, 0), log, not_log)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really fan of this... The problem here was that a Tensor was used as a Python bool, which is not allowed in graph mode. So I did this ugly hack in order to fake a condition to display the logging message.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, can we improve it? I'd then prefer to just always display a logging that can also just say

"Input ids are a multiple of config.attention_window. No padding required"

for the other case

Copy link
Contributor Author

@jplu jplu Jan 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't change the problem. We are not allowed to do:

if padding_len > 0:
   ...
else:
  ...

Because "padding_len>0" becomes a tensor, and a tensor cannot be used as a Python boolean :(

@@ -2494,6 +2487,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model = TF{{cookiecutter.camelcase_modelname}}Model(config, name="model")
self.model._set_save_spec(inputs=self.serving.input_signature)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the single line fix :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's totally fine with me

@jplu
Copy link
Contributor Author

jplu commented Jan 8, 2021

@patrickvonplaten Should I remove the following hack in BART?

if inputs["decoder_input_ids"] is None and inputs["input_ids"] is not None:
            inputs["decoder_input_ids"] = shift_tokens_right(
                inputs["input_ids"], self.config.pad_token_id, self.config.eos_token_id
            )

tensor_name = input.name.split(":")[0]
# Tensor names have always the pattern `tensor/name:id` then we check only the
# `name` part
tensor_name = input.name.split("/")[-1].split(":")[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this gives a very different output from what it did before no?
What changed, was it incorrect previously?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops!, Good catch! This was for testing something else, I will remove that!

combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
else:
combined_attention_mask = _expand_mask(
combined_attention_mask = tf.cond(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's solve this exactly how we solved it for TFBart. TFLed should use the exact same code as TFBart, here:

-> think the Bart way is much cleaner

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reason, input_shape[-1] > 1 is a tensor and we are not allowed to use a tensor as a Python boolean :(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it doesn't work in TFBart either?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works in BART because we don't ask Autograph to go through all the model with @tf.function nevertheless having self.led._set_save_spec(inputs=self.serving.input_signature) as the side effect to be understood by the TF tracer like if there was tf.function(TFLEDModel.call). Then, for this case, Autograph is running and then fails with:

tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

axis=-1,
if inputs["attention_mask"] is None and inputs["input_ids"] is not None:

def attn_mask_from_inp_ids():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, we don't need these complicated functions IMO

def not_log():
return 1

_ = tf.cond(tf.math.greater(padding_len, 0), log, not_log)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really don't want to have this as well. I don't want to have "fake functions" in TF - I'm sure there are better ways. This worked before no? Why do we have to change it now exactly?

],
axis=-1,
)
if inputs["attention_mask"] is None and inputs["input_ids"] is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please implement this also as it's implemented for Bart. This is not the cleanest way IMO

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm very happy with the "one-line" fix that solves the problem stated in the issue - this looks clean!

However, there are some other things here where I'm not sure if they are part of the fix or why they are done:

  1. The attention logic is changed in TFLongformer. Why is that done? This has nothing to do with s2s, so should not be in this PR, no? Also, if we have to change it -> I think we should change it in a cleaner way.
  2. In general I really don't like the tf.cond(condition, do_fn_one, do_fn_two) design. I think I understand that it is sometimes necessary, but I really like to keep the usage of this function to a minimum in general. The functional approach is very different to our general library design and make the code much much harder to read. It always creates an abstraction by having to wrap parts of the code into a function with no args, like def attn_mask_from_inp_ids() which is not easy to follow and to me always looks like a hack. In Bart we manage to do this part of the code without the usage of tf.cond and Bart has the same exact logic as LED has there -> so we can make it easier I think.
  3. I prefer having a logging function to be displayed for both cases over creating a "fake" logging function that doesn't do anything. I think it's fine to log "No padding required" in LED over creating a "hacky" "fake" logging function. The more we add those things to the library, the more the community will use such "hacky" functions as well which is not the right way IMO

@jplu
Copy link
Contributor Author

jplu commented Jan 8, 2021

In general I really don't like the tf.cond(condition, do_fn_one, do_fn_two) design. I think I understand that it is sometimes necessary, but I really like to keep the usage of this function to a minimum in general. The functional approach is very different to our general library design and make the code much much harder to read. It always creates an abstraction by having to wrap parts of the code into a function with no args, like def attn_mask_from_inp_ids() which is not easy to follow and to me always looks like a hack. In Bart we manage to do this part of the code without the usage of tf.cond and Bart has the same exact logic as LED has there -> so we can make it easier I think.

I understand your point and agree with you and I share you opinion on this, and unfortunately if you come to control flow (conditions and loops) there are some strict rules that one cannot overcome. tf.cond is somehow mandatory for autograph.

A solution I think that should work would be to force the layer.call function to be run in graph mode with @tf.function which takes care of making itself the translation of all these conditions and loops. This work in some cases, let's see if it works... Does it sounds a proper solution for you?

@patrickvonplaten
Copy link
Contributor

@patrickvonplaten Should I remove the following hack in BART?

if inputs["decoder_input_ids"] is None and inputs["input_ids"] is not None:
            inputs["decoder_input_ids"] = shift_tokens_right(
                inputs["input_ids"], self.config.pad_token_id, self.config.eos_token_id
            )

Please don't - it's needed for some use-cases in Bart and for backward comp

@patrickvonplaten
Copy link
Contributor

Actually, one thing I'd like to know more in general about our models in TF is the following:

"Can we use normal if-else statements in the forward pass"?

I always thought that the answer is:

"Yes we can as long as the output type and shape of each case is the same"

So for me statements like:

if shape_list(input_ids) > n:
    attention_mask = torch.zeros(shape_list(input_ids))
else:
   attention_mask = torch.ones(shape_list(input_ids))

(this code snippet doesn't exist -> it's just an example)

are totally fine. Is the assumption correct @jplu ?

Or can we in general never use normal if-else statements in TF's forward pass and have to rely on tf.cond(....)? This would really surprise me as we're having tons of if statements everywhere in the TF code...

@jplu
Copy link
Contributor Author

jplu commented Jan 8, 2021

The general answer is yes, but it has some conditions. If you run this condition in eager mode, it will works by default (you can basically do almost anything in eager mode)

If you run this condition in graph mode you have two solution to make it works:

  1. Either use tf.cond
  2. Or to wrap your condition into a function decorated with tf.function. This will have to effect to apply the Autograph library over the content of your decorated function. Autograph will automatically converts if-then clauses, loops, break, return, continue, and more. You can have more information here https://www.tensorflow.org/guide/function#autograph_transformations

@patrickvonplaten patrickvonplaten mentioned this pull request Jan 11, 2021
1 task
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

self.led = TFLEDMainLayer(config, name="led")
Copy link
Contributor

@patrickvonplaten patrickvonplaten Jan 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this backward compatible? Previously the weights of TFLEDModel had the structure:

model = TFLEDModel.from_pretrained(...)
model.encoder....

now it's

model = TFLEDModel.from_pretrained(...)
model.led.encoder....

-> can we load pre-trained weights like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you can test if you want to be sure:

>>> from transformers import TFLEDModel                   
>>> model = TFLEDModel.from_pretrained("allenai/led-base-16384")
2021-01-11 11:15:04.590557: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-01-11 11:15:04.636222: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library nvcuda.dll
2021-01-11 11:15:04.730405: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: GeForce RTX 2080 Super with Max-Q Design computeCapability: 7.5
coreClock: 1.08GHz coreCount: 48 deviceMemorySize: 8.00GiB deviceMemoryBandwidth: 327.88GiB/s
2021-01-11 11:15:04.731306: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cudart64_110.dll
2021-01-11 11:15:05.262717: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cublas64_11.dll
2021-01-11 11:15:05.263125: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cublasLt64_11.dll
2021-01-11 11:15:05.263510: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cufft64_10.dll
2021-01-11 11:15:05.263776: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library curand64_10.dll
2021-01-11 11:15:05.264038: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cusolver64_10.dll
2021-01-11 11:15:05.536359: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cusparse64_11.dll
2021-01-11 11:15:05.582260: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cudnn64_8.dll
2021-01-11 11:15:05.582943: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1862] Adding visible gpu devices: 0
2021-01-11 11:15:05.584880: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-01-11 11:15:05.608868: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1720] Found device 0 with properties: 
pciBusID: 0000:01:00.0 name: GeForce RTX 2080 Super with Max-Q Design computeCapability: 7.5
coreClock: 1.08GHz coreCount: 48 deviceMemorySize: 8.00GiB deviceMemoryBandwidth: 327.88GiB/s
2021-01-11 11:15:05.684182: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cudart64_110.dll
2021-01-11 11:15:05.688493: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cublas64_11.dll
2021-01-11 11:15:05.711249: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cublasLt64_11.dll
2021-01-11 11:15:05.758762: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cufft64_10.dll
2021-01-11 11:15:05.790718: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library curand64_10.dll
2021-01-11 11:15:05.791164: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cusolver64_10.dll
2021-01-11 11:15:05.791567: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cusparse64_11.dll
2021-01-11 11:15:05.816123: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cudnn64_8.dll
2021-01-11 11:15:05.889765: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1862] Adding visible gpu devices: 0
2021-01-11 11:15:06.936419: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1261] Device interconnect StreamExecutor with strength 1 edge matrix:
2021-01-11 11:15:06.937057: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1267]      0 
2021-01-11 11:15:06.937747: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1280] 0:   N
2021-01-11 11:15:06.939859: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1406] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 6596 MB memory) -> physical GPU (device: 0, name: GeForce RTX 2080 Super with Max-Q Design, pci bus id: 0000:01:00.0, compute capability: 7.5)
2021-01-11 11:15:06.987662: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
2021-01-11 11:15:08.114566: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2021-01-11 11:15:08.207163: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cublas64_11.dll
2021-01-11 11:15:09.145619: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library cublasLt64_11.dll
All model checkpoint layers were used when initializing TFLEDModel.

All the layers of TFLEDModel were initialized from the model checkpoint at allenai/led-base-16384.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFLEDModel for predictions without further training.

"The bare LED Model outputting raw hidden-states without any specific head on top.",
LED_START_DOCSTRING,
)
class TFLEDModel(TFLEDPreTrainedModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also TFLEDModel never had a TFLEDPreTrainedModel as a layer inside -> so why do we have to change TFLEDModel in the first place? I though we only need to change TFLEDForConditionalGeneration

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get what you mean by Also TFLEDModel never had a TFLEDPreTrainedModel as a layer inside ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AH ok! I just got what you meant! indeed TFLEDModel had no led attribute that was a model. Do you prefer to keep TFLEDModel as it was and use the brand new TFLEDMainLayer only in TFLEDForConditionalGeneration?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the solution is fully backwards compatible including things, like

TFLedModel.from_pretrained("allenai/led-base-16384", from_pt=True)

it's totally fine for me!


embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
hidden_states = inputs["inputs_embeds"] + embed_pos
Copy link
Contributor

@patrickvonplaten patrickvonplaten Jan 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we also have to add the TFMainLayer logic in this file

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is fine for me! However, I think that we need the same change you've added for TFLed, for:

  • TFBart,
  • TFMBart,
  • TFMarian,
  • TFPegasus,
  • TFBlenderbot
  • TFBlenderbotSmall

Even though TFBart works with a single input, we should also have the same structure as you've implemented for TFLED IMO - the default case for TFBart is also two input two tensors: input_ids and decoder_input_ids. Passing just a single input input_ids to TFBart is very much an edge case.

@jplu
Copy link
Contributor Author

jplu commented Jan 19, 2021

Now that we all agree on a solution, I will apply it for all the models 👍

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with the fix, ping me again when it's implemented for the other models :-)

@jplu
Copy link
Contributor Author

jplu commented Jan 20, 2021

Ok, LGTM!! Feel free to merge whenever you feel it^^

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for deploying the fix to all the models.

@@ -1929,17 +1909,41 @@ def call(
attentions=all_self_attns,
)

@tf.function
def compute_combined_attns_mask(self, inputs, input_shape, past_key_values_length):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should be the same for all Bart-like models I think -> can we use the # Copied from ... statements here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok apparently it's not needed for Bart, which I don't get. BartDecoder and LEDDecoder compute the combined_attns_mask in exactly the same way. Why does Bart get away with:

        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        if input_shape[-1] > 1:
            combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
        else:
            combined_attention_mask = _expand_mask(
                tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
            )
        if inputs["attention_mask"] is not None and input_shape[-1] > 1:
            combined_attention_mask = combined_attention_mask + _expand_mask(
                inputs["attention_mask"], tgt_len=input_shape[-1]
            )
        if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            inputs["encoder_attention_mask"] = _expand_mask(inputs["encoder_attention_mask"], tgt_len=input_shape[-1])

and LED needs a tf.function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I does work in BART because the condition is not the same. In LED there is:

if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:

While in BART it is:

if inputs["attention_mask"] is None and and input_shape[-1] > 1:

And the contents of this condition are different in both model.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually in master for TFLED it's also:
if inputs["attention_mask"] is None and and input_shape[-1] > 1: -> see:

if inputs["attention_mask"] is not None and input_shape[-1] > 1:
-> there might have been some merge conflicts...
Do you think we don't need the more difficult approach in LED?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code of TFLed actually looks exactly the same as for TFBart, I might have done a fix after your started your PR which probably led to some merge conflicts for TFLED... -> can we maybe try to align this functionality in TFLedDecoder and TFBartDecoder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's try with a cleaner alignment then :)

Copy link
Contributor Author

@jplu jplu Jan 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works!! Just pushed the update :) Thanks for having catched this ^^

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adapting it!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking care of all the models!

Comment on lines -1057 to +1119
pkv = (tf.tuple(output.past_key_values)[1] if self.config.use_cache else None,)
pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this a breaking change with version v4.2.x?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe that's more of a bugfix? Should we release a patch before the previous version gets too much usage?

Copy link
Contributor Author

@jplu jplu Jan 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug that all the Seq2Seq model have I don't know how this has been introduced, but was not part of my PR about Serving. Basically, this prevent the model to have a possible saved model, because all the values must be a tensor, not something else (here a tuple)


def test_saved_model_creation_extended(self):
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
# This test is too long (>30sec) and makes fail the CI
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this bound to change? In the meantime, maybe we can put it as a slow test instead of skipping it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is kind of useless because test_saved_model_creation is included in test_saved_model_creation_extended so it will be like doing twice the same test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok!

@jplu jplu merged commit a7dabfb into huggingface:master Jan 21, 2021
@jplu jplu deleted the fix-tf-s2s branch January 21, 2021 17:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants