Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bug in modeling_tf_wav2vec2 #17233

Closed
2 of 4 tasks
ahmedlone127 opened this issue May 13, 2022 · 5 comments
Closed
2 of 4 tasks

bug in modeling_tf_wav2vec2 #17233

ahmedlone127 opened this issue May 13, 2022 · 5 comments
Assignees
Labels
bug WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@ahmedlone127
Copy link

ahmedlone127 commented May 13, 2022

System Info

- `transformers` version: 4.19.0
- Platform: Linux-5.4.188+-x86_64-with-Ubuntu-18.04-bionic
- Python version: 3.7.13
- Huggingface_hub version: 0.6.0
- PyTorch version (GPU? Yes): 1.11.0+cu113 (True)
- Tensorflow version (GPU? Yes): 2.8.0 (True)
- Flax version (GPU:Yes): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: I am running on colab therefore I think it's parallel

Who can help?

@patrickvonplaten
@Rocketknight1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction


import os
from transformers import Wav2Vec2Processor, TFWav2Vec2ForCTC
import tensorflow as tf
import numpy as np
import torch
import json
from datasets import load_dataset
import soundfile as sf
import torch

Wav2vec2Model = "facebook/wav2vec2-base-960h"
Wav2vec2_EXPORT_PATH = f"/content/export_wav2vec2-base-960h"

# load model and tokenizer
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
     
# load dummy dataset and read soundfiles
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
input_values = processor(ds[0]["audio"]["array"], return_tensors="tf", 
padding="longest",return_attention_mask=True).input_values  # Batch size 1

class MyWav2vec2(TFWav2Vec2ForCTC):
    @tf.function(
        input_signature=[
            {
                "input_ids": tf.TensorSpec((None, None), tf.float32, name="serving1_input_ids"),
            }
        ]
    )
    def serving1(self, inputs):
        outputs = self.call(input_values=inputs["input_ids"])
        return self.serving_output(outputs)
    

mywav2vec2 = MyWav2vec2.from_pretrained(Wav2vec2Model)
tf.saved_model.save(mywav2vec2, Wav2vec2_EXPORT_PATH, signatures={
    "serving1": mywav2vec2.serving1,
})

Error

TypeError                                 Traceback (most recent call last)
<ipython-input-13-06d8d6c67672> in <module>()
      1 jslwav2vec2 = JslWav2vec2.from_pretrained(Wav2vec2Model)
      2 tf.saved_model.save(jslwav2vec2, Wav2vec2_EXPORT_PATH, signatures={
----> 3     "serving1": jslwav2vec2.serving1,
      4     # "serving2": mygpt2.serving2
      5 })

43 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in save(obj, export_dir, signatures, options)
   1332   # pylint: enable=line-too-long
   1333   metrics.IncrementWriteApi(_SAVE_V2_LABEL)
-> 1334   save_and_return_nodes(obj, export_dir, signatures, options)
   1335   metrics.IncrementWrite(write_version="2")
   1336 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in save_and_return_nodes(obj, export_dir, signatures, options, experimental_skip_checkpoint)
   1367 
   1368   _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
-> 1369       _build_meta_graph(obj, signatures, options, meta_graph_def))
   1370   saved_model.saved_model_schema_version = (
   1371       constants.SAVED_MODEL_SCHEMA_VERSION)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def)
   1534 
   1535   with save_context.save_context(options):
-> 1536     return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
   1480   signatures, wrapped_functions = (
   1481       signature_serialization.canonicalize_signatures(signatures))
-> 1482   signature_serialization.validate_saveable_view(checkpoint_graph_view)
   1483   signature_map = signature_serialization.create_signature_map(signatures)
   1484   checkpoint_graph_view.set_signature(signature_map)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/signature_serialization.py in validate_saveable_view(saveable_view)
    299 def validate_saveable_view(saveable_view):
    300   """Performs signature-related sanity checks on `saveable_view`."""
--> 301   for name, dep in saveable_view.list_children(saveable_view.root):
    302     if name == SIGNATURE_ATTRIBUTE_NAME:
    303       if not isinstance(dep, _SignatureMap):

/usr/local/lib/python3.7/dist-packages/tensorflow/python/saved_model/save.py in list_children(self, obj)
    134               obj,
    135               save_type=base.SaveType.SAVEDMODEL,
--> 136               cache=self._serialization_cache))
    137     for name, child in self._children_cache[obj].items():
    138       yield base.TrackableReference(name, child)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/training/tracking/graph_view.py in list_children(self, obj, save_type, **kwargs)
    254     obj._maybe_initialize_trackable()
    255     children = [base.TrackableReference(name, ref) for name, ref
--> 256                 in obj._trackable_children(save_type, **kwargs).items()]
    257     # pylint: enable=protected-access
    258 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/training/tracking/base.py in _trackable_children(self, save_type, **kwargs)
   1477     elif save_type == SaveType.SAVEDMODEL:
   1478       cache = kwargs["cache"]
-> 1479       return self._get_legacy_saved_model_children(cache)
   1480     else:
   1481       raise ValueError("Unexpected format passed to `_trackable_children`. "

/usr/local/lib/python3.7/dist-packages/tensorflow/python/training/tracking/base.py in _get_legacy_saved_model_children(self, serialization_cache)
   1488 
   1489     # Retrieve functions attached to the object.
-> 1490     functions = self._list_functions_for_serialization(serialization_cache)
   1491 
   1492     # Trace concrete functions to force side-effects:

/usr/local/lib/python3.7/dist-packages/keras/engine/training.py in _list_functions_for_serialization(self, serialization_cache)
   3080     self.train_tf_function = None
   3081     functions = super(
-> 3082         Model, self)._list_functions_for_serialization(serialization_cache)
   3083     self.train_function = train_function
   3084     self.test_function = test_function

/usr/local/lib/python3.7/dist-packages/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
   3167   def _list_functions_for_serialization(self, serialization_cache):
   3168     return (self._trackable_saved_model_saver
-> 3169             .list_functions_for_serialization(serialization_cache))
   3170 
   3171   @property

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
     91       return {}
     92 
---> 93     fns = self.functions_to_serialize(serialization_cache)
     94 
     95     # The parent AutoTrackable class saves all user-defined tf.functions, and

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
     71   def functions_to_serialize(self, serialization_cache):
     72     return (self._get_serialized_attributes(
---> 73         serialization_cache).functions_to_serialize)
     74 
     75   def _get_serialized_attributes(self, serialization_cache):

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
     87 
     88     object_dict, function_dict = self._get_serialized_attributes_internal(
---> 89         serialization_cache)
     90 
     91     serialized_attr.set_and_validate_objects(object_dict)

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
     55     objects, functions = (
     56         super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
---> 57             serialization_cache))
     58     functions['_default_save_signature'] = default_signature
     59     return objects, functions

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
     96     """Returns dictionary of serialized attributes."""
     97     objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
---> 98     functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
     99     # Attribute validator requires that the default save signature is added to
    100     # function dict, even if the value is None.

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/save_impl.py in wrap_layer_functions(layer, serialization_cache)
    195       for fn in fns.values():
    196         if fn is not None and not isinstance(fn, LayerCall):
--> 197           fn.get_concrete_function()
    198 
    199   # Restore overwritten functions and losses

/usr/lib/python3.7/contextlib.py in __exit__(self, type, value, traceback)
    117         if type is None:
    118             try:
--> 119                 next(self.gen)
    120             except StopIteration:
    121                 return False

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/save_impl.py in tracing_scope()
    357       if training is not None:
    358         with backend.deprecated_internal_learning_phase_scope(training):
--> 359           fn.get_concrete_function(*args, **kwargs)
    360       else:
    361         fn.get_concrete_function(*args, **kwargs)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
   1262   def get_concrete_function(self, *args, **kwargs):
   1263     # Implements GenericFunction.get_concrete_function.
-> 1264     concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
   1265     concrete._garbage_collector.release()  # pylint: disable=protected-access
   1266     return concrete

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   1254       # run the first trace but we should fail if variables are created.
   1255       concrete = self._stateful_fn._get_concrete_function_garbage_collected(  # pylint: disable=protected-access
-> 1256           *args, **kwargs)
   1257       if self._created_variables:
   1258         raise ValueError("Creating variables on a non-first call to a function"

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   3034       args, kwargs = None, None
   3035     with self._lock:
-> 3036       graph_function, _ = self._maybe_define_function(args, kwargs)
   3037       seen_names = set()
   3038       captured = object_identity.ObjectIdentitySet(

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3290 
   3291           self._function_cache.add_call_context(cache_key.call_context)
-> 3292           graph_function = self._create_graph_function(args, kwargs)
   3293           self._function_cache.add(cache_key, cache_key_deletion_observer,
   3294                                    graph_function)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3138             arg_names=arg_names,
   3139             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3140             capture_by_value=self._capture_by_value),
   3141         self._function_attributes,
   3142         function_spec=self.function_spec,

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1159         _, original_func = tf_decorator.unwrap(python_func)
   1160 
-> 1161       func_outputs = python_func(*func_args, **func_kwargs)
   1162 
   1163       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    675         # the function a weak reference to itself to avoid a reference cycle.
    676         with OptionalXlaContext(compile_with_xla):
--> 677           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    678         return out
    679 

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
    570       with autocast_variable.enable_auto_cast_variables(
    571           layer._compute_dtype_object):  # pylint: disable=protected-access
--> 572         ret = method(*args, **kwargs)
    573     _restore_layer_losses(original_losses)
    574     return ret

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
    168     return control_flow_util.smart_cond(
    169         training, lambda: replace_training_and_call(True),
--> 170         lambda: replace_training_and_call(False))
    171 
    172   # Create arg spec for decorated function. If 'training' is not defined in the

/usr/local/lib/python3.7/dist-packages/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
    104         pred, true_fn=true_fn, false_fn=false_fn, name=name)
    105   return tf.__internal__.smart_cond.smart_cond(
--> 106       pred, true_fn=true_fn, false_fn=false_fn, name=name)
    107 
    108 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
     51   if pred_value is not None:
     52     if pred_value:
---> 53       return true_fn()
     54     else:
     55       return false_fn()

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/utils.py in <lambda>()
    167 
    168     return control_flow_util.smart_cond(
--> 169         training, lambda: replace_training_and_call(True),
    170         lambda: replace_training_and_call(False))
    171 

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/utils.py in replace_training_and_call(training)
    164     def replace_training_and_call(training):
    165       set_training_arg(training, training_arg_index, args, kwargs)
--> 166       return wrapped_call(*args, **kwargs)
    167 
    168     return control_flow_util.smart_cond(

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/save_impl.py in call(inputs, *args, **kwargs)
    650     return layer.keras_api.__call__  # pylint: disable=protected-access
    651   def call(inputs, *args, **kwargs):
--> 652     return call_and_return_conditional_losses(inputs, *args, **kwargs)[0]
    653   return _create_call_fn_decorator(layer, call)
    654 

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/save_impl.py in __call__(self, *args, **kwargs)
    608   def __call__(self, *args, **kwargs):
    609     self._maybe_trace(args, kwargs)
--> 610     return self.wrapped_call(*args, **kwargs)
    611 
    612   def get_concrete_function(self, *args, **kwargs):

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153       raise e.with_traceback(filtered_tb) from None
    154     finally:
    155       del filtered_tb

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
    570       with autocast_variable.enable_auto_cast_variables(
    571           layer._compute_dtype_object):  # pylint: disable=protected-access
--> 572         ret = method(*args, **kwargs)
    573     _restore_layer_losses(original_losses)
    574     return ret

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
    168     return control_flow_util.smart_cond(
    169         training, lambda: replace_training_and_call(True),
--> 170         lambda: replace_training_and_call(False))
    171 
    172   # Create arg spec for decorated function. If 'training' is not defined in the

/usr/local/lib/python3.7/dist-packages/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
    104         pred, true_fn=true_fn, false_fn=false_fn, name=name)
    105   return tf.__internal__.smart_cond.smart_cond(
--> 106       pred, true_fn=true_fn, false_fn=false_fn, name=name)
    107 
    108 

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/utils.py in <lambda>()
    167 
    168     return control_flow_util.smart_cond(
--> 169         training, lambda: replace_training_and_call(True),
    170         lambda: replace_training_and_call(False))
    171 

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/utils.py in replace_training_and_call(training)
    164     def replace_training_and_call(training):
    165       set_training_arg(training, training_arg_index, args, kwargs)
--> 166       return wrapped_call(*args, **kwargs)
    167 
    168     return control_flow_util.smart_cond(

/usr/local/lib/python3.7/dist-packages/keras/saving/saved_model/save_impl.py in call_and_return_conditional_losses(*args, **kwargs)
    632   def call_and_return_conditional_losses(*args, **kwargs):
    633     """Returns layer (call_output, conditional losses) tuple."""
--> 634     call_output = layer_call(*args, **kwargs)
    635     if version_utils.is_v1_layer_or_model(layer):
    636       conditional_losses = layer.get_losses_for(

/usr/local/lib/python3.7/dist-packages/transformers/models/wav2vec2/modeling_tf_wav2vec2.py in call(self, input_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict, training, **kwargs)
   1278         mask_time_indices = kwargs.get("mask_time_indices", None)
   1279         if inputs["training"]:
-> 1280             hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
   1281 
   1282         encoder_outputs = self.encoder(

/usr/local/lib/python3.7/dist-packages/transformers/models/wav2vec2/modeling_tf_wav2vec2.py in _mask_hidden_states(self, hidden_states, mask_time_indices)
   1212                 mask_prob=self.config.mask_time_prob,
   1213                 mask_length=self.config.mask_time_length,
-> 1214                 min_masks=2,
   1215             )
   1216             hidden_states = tf.where(

/usr/local/lib/python3.7/dist-packages/transformers/models/wav2vec2/modeling_tf_wav2vec2.py in _compute_mask_indices(shape, mask_prob, mask_length, min_masks)
    264     print(tf.random.uniform((1,)))
    265     print((mask_prob * sequence_length / mask_length + tf.random.uniform((1,)) )[0] )
--> 266     num_masked_spans = int(mask_prob * sequence_length / mask_length + tf.random.uniform((1,)))
    267     num_masked_spans = max(num_masked_spans, min_masks)
    268 

TypeError: int() argument must be a string, a bytes-like object or a number, not 'Tensor'


Expected behavior

I want to be able to export it to use it in tensorflow-serving
@gante
Copy link
Member

gante commented May 16, 2022

Hi @ahmedlone127 👋 The error appears because that line does not run without Eager Execution (see below), which is the case for your script. This is a problem on our side, and we will be fixing it 👍

image

@gante gante self-assigned this May 16, 2022
@gante
Copy link
Member

gante commented May 16, 2022

#17285

@ahmedlone127
Copy link
Author

Thanks a lot, hope this gets a fix soon :)

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@gante gante added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Jun 14, 2022
@amyeroberts
Copy link
Collaborator

Following merging of #18153 the reproduction snippet runs on main without error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests

3 participants